├── data ├── __init__.py └── text_data.py ├── modules ├── lm │ ├── __init__.py │ └── lm_lstm.py ├── decoders │ ├── __init__.py │ ├── decoder_help.py │ ├── decoder_helper.py │ ├── decoder.py │ └── dec_lstm.py ├── encoders │ ├── __init__.py │ ├── encoder.py │ ├── enc_lstm.py │ └── gaussian_encoder.py ├── discriminators │ ├── __init__.py │ └── discriminator_linear.py ├── __init__.py ├── utils.py └── vae.py ├── config ├── config_yahoo.py ├── config_ptb.py ├── config_snli.py ├── config_yahoo_label.py └── config_short_yelp.py ├── exp_utils.py ├── scripts ├── sampling_training_labels.py ├── unsupervised_cluster.py └── multi-bleu.perl ├── README.md ├── prepare_data.py ├── lm.py ├── text_get_mean.py ├── utils.py ├── text_ss_ft.py ├── text_beta.py └── text_anneal_fb.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_data import * -------------------------------------------------------------------------------- /modules/lm/__init__.py: -------------------------------------------------------------------------------- 1 | from .lm_lstm import * -------------------------------------------------------------------------------- /modules/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .dec_lstm import * 2 | -------------------------------------------------------------------------------- /modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .enc_lstm import * 2 | -------------------------------------------------------------------------------- /modules/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator_linear import * -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders import * 2 | from .decoders import * 3 | from .vae import * 4 | from .lm import * 5 | # from .plotter import * 6 | from .utils import * 7 | from .discriminators import * -------------------------------------------------------------------------------- /config/config_yahoo.py: -------------------------------------------------------------------------------- 1 | 2 | params={ 3 | 'enc_type': 'lstm', 4 | 'dec_type': 'lstm', 5 | 'nz': 32, 6 | 'ni': 512, 7 | 'enc_nh': 1024, 8 | 'dec_nh': 1024, 9 | 'dec_dropout_in': 0.5, 10 | 'dec_dropout_out': 0.5, 11 | 'batch_size': 32, 12 | 'epochs': 100, 13 | 'test_nepoch': 5, 14 | 'train_data': 'datasets/yahoo_data/yahoo.train.txt', 15 | 'val_data': 'datasets/yahoo_data/yahoo.valid.txt', 16 | 'test_data': 'datasets/yahoo_data/yahoo.test.txt' 17 | } 18 | -------------------------------------------------------------------------------- /config/config_ptb.py: -------------------------------------------------------------------------------- 1 | params={ 2 | 'enc_type': 'lstm', 3 | 'dec_type': 'lstm', 4 | 'nz': 32, 5 | 'ni': 256, 6 | 'enc_nh': 256, 7 | 'dec_nh': 256, 8 | 'log_niter': 50, 9 | 'dec_dropout_in': 0.5, 10 | 'dec_dropout_out': 0.5, 11 | 'batch_size': 32, 12 | 'epochs': 100, 13 | 'test_nepoch': 5, 14 | 'train_data': 'datasets/ptb_data/ptb.train.txt', 15 | 'val_data': 'datasets/ptb_data/ptb.valid.txt', 16 | 'test_data': 'datasets/ptb_data/ptb.test.txt', 17 | 'ais_prior': 'normal', 18 | 'ais_T': 500, 19 | 'ais_K': 2223 20 | } 21 | -------------------------------------------------------------------------------- /config/config_snli.py: -------------------------------------------------------------------------------- 1 | params={ 2 | 'enc_type': 'lstm', 3 | 'dec_type': 'lstm', 4 | 'nz': 32, 5 | 'ni': 128, 6 | 'enc_nh': 512, 7 | 'dec_nh': 512, 8 | 'log_niter': 50, 9 | 'dec_dropout_in': 0.5, 10 | 'dec_dropout_out': 0.5, 11 | 'batch_size': 32, 12 | 'epochs': 100, 13 | 'test_nepoch': 5, 14 | 'train_data': 'datasets/snli_data/snli.train.txt', 15 | 'val_data': 'datasets/snli_data/snli.valid.txt', 16 | 'test_data': 'datasets/snli_data/snli.test.txt', 17 | 'ais_prior': 'normal', 18 | 'ais_T': 500, 19 | 'ais_K': 3, 20 | "label": False 21 | } 22 | -------------------------------------------------------------------------------- /modules/decoders/decoder_help.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BeamSearchNode(object): 4 | def __init__(self, hiddenstate, previousNode, wordId, logProb, length): 5 | ''' 6 | :param hiddenstate: 7 | :param previousNode: 8 | :param wordId: 9 | :param logProb: 10 | :param length: 11 | ''' 12 | self.h = hiddenstate 13 | self.prevNode = previousNode 14 | self.wordid = wordId 15 | self.logp = logProb 16 | self.leng = length 17 | 18 | def eval(self, alpha=1.0): 19 | reward = 0 20 | # Add here a function for shaping a reward 21 | 22 | return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward 23 | -------------------------------------------------------------------------------- /modules/decoders/decoder_helper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BeamSearchNode(object): 4 | def __init__(self, hiddenstate, previousNode, wordId, logProb, length): 5 | ''' 6 | :param hiddenstate: 7 | :param previousNode: 8 | :param wordId: 9 | :param logProb: 10 | :param length: 11 | ''' 12 | self.h = hiddenstate 13 | self.prevNode = previousNode 14 | self.wordid = wordId 15 | self.logp = logProb 16 | self.leng = length 17 | 18 | def eval(self, alpha=1.0): 19 | reward = 0 20 | # Add here a function for shaping a reward 21 | 22 | return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward 23 | 24 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def safe_log(z): 4 | return torch.log(z + 1e-7) 5 | 6 | def log_sum_exp(value, dim=None, keepdim=False): 7 | """Numerically stable implementation of the operation 8 | value.exp().sum(dim, keepdim).log() 9 | """ 10 | if dim is not None: 11 | m, _ = torch.max(value, dim=dim, keepdim=True) 12 | value0 = value - m 13 | if keepdim is False: 14 | m = m.squeeze(dim) 15 | return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) 16 | else: 17 | m = torch.max(value) 18 | sum_exp = torch.sum(torch.exp(value - m)) 19 | return m + torch.log(sum_exp) 20 | 21 | 22 | def generate_grid(zmin, zmax, dz, device, ndim=2): 23 | """generate a 1- or 2-dimensional grid 24 | Returns: Tensor, int 25 | Tensor: The grid tensor with shape (k^2, 2), 26 | where k=(zmax - zmin)/dz 27 | int: k 28 | """ 29 | 30 | if ndim == 2: 31 | x = torch.arange(zmin, zmax, dz) 32 | k = x.size(0) 33 | 34 | x1 = x.unsqueeze(1).repeat(1, k).view(-1) 35 | x2 = x.repeat(k) 36 | 37 | return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k 38 | 39 | elif ndim == 1: 40 | return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device) -------------------------------------------------------------------------------- /exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | print(s) 11 | # if print_: 12 | # print(s) 13 | if log_: 14 | with open(log_path, 'a+') as f_log: 15 | f_log.write(s + '\n') 16 | 17 | def get_logger(log_path, **kwargs): 18 | return functools.partial(logging, log_path=log_path, **kwargs) 19 | 20 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 21 | if debug: 22 | print('Debug Mode : no experiment dir created') 23 | return functools.partial(logging, log_path=None, log_=False) 24 | 25 | if os.path.exists(dir_path): 26 | print("Path {} exists. Remove and remake.".format(dir_path)) 27 | shutil.rmtree(dir_path) 28 | 29 | os.makedirs(dir_path) 30 | 31 | print('Experiment dir : {}'.format(dir_path)) 32 | if scripts_to_save is not None: 33 | script_path = os.path.join(dir_path, 'scripts') 34 | if not os.path.exists(script_path): 35 | os.makedirs(script_path) 36 | for script in scripts_to_save: 37 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 38 | shutil.copyfile(script, dst_file) 39 | 40 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 41 | 42 | def save_checkpoint(model, optimizer, path, epoch): 43 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) -------------------------------------------------------------------------------- /scripts/sampling_training_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script samples a certain number of training samples for fine-tuning 3 | """ 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser(description='sampling') 10 | parser.add_argument('--num_label', type=int) 11 | parser.add_argument('--seed', type=int, default="783435") 12 | parser.add_argument('--split_even', action="store_true", default=False) 13 | parser.add_argument('--dataset', type=str) 14 | 15 | args = parser.parse_args() 16 | 17 | np.random.seed(args.seed) 18 | 19 | data_dir = os.path.join("datasets/{}_data".format(args.dataset)) 20 | fout = open(os.path.join(data_dir, "{}.train.{}.txt".format(args.dataset, args.num_label)), "w") 21 | 22 | if args.split_even: 23 | label2text = {} 24 | with open(os.path.join(data_dir, "{}.train.txt".format(args.dataset))) as fin: 25 | for line in fin: 26 | label, text = line.strip().split("\t") 27 | label = int(label) 28 | 29 | if label in label2text: 30 | label2text[label] += [text] 31 | else: 32 | label2text[label] = [text] 33 | 34 | 35 | nlabel = len(label2text) 36 | sample_per_label = int(args.num_label / nlabel) 37 | for i in range(nlabel): 38 | index = np.random.choice(len(label2text[i]), sample_per_label, replace=False) 39 | for j in index: 40 | fout.write("{}\t{}\n".format(i, label2text[i][j])) 41 | 42 | else: 43 | with open(os.path.join(data_dir, "{}.train.txt".format(args.dataset))) as fin: 44 | text = fin.readlines() 45 | index = np.random.choice(len(text), args.num_label, replace=False) 46 | for i in index: 47 | fout.write(text[i]) 48 | 49 | fout.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text 2 | 3 | This is PyTorch implementation of the following [paper](https://arxiv.org/abs/1909.00868): 4 | 5 | ``` 6 | A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text 7 | Bohan Li*, Junxian He*, Graham Neubig, Taylor Berg-Kirkpatrick, Yiming Yang 8 | EMNLP 2019 9 | ``` 10 | 11 | Please contact bohanl1@cs.cmu.edu if you have any questions. 12 | 13 | ## Requirements 14 | 15 | * Python >= 3.6 16 | * PyTorch >= 1.0 17 | * pip install editdistance 18 | 19 | ## Data 20 | 21 | Datasets used in this paper can be downloaded with: 22 | 23 | ``` 24 | python prepare_data.py 25 | ``` 26 | 27 | ## Usage 28 | 29 | Train a AE first 30 | ``` 31 | python text_beta.py \ 32 | --dataset yahoo \ 33 | --beta 0 \ 34 | --lr 0.5 35 | ``` 36 | 37 | Train VAE with our method 38 | ``` 39 | ae_exp_dir=exp_yahoo_beta/yahoos_lr0.5_beta0.0_drop0.5 40 | python text_anneal_fb.py \ 41 | --dataset yahoo \ 42 | --load_path ${ae_exp_dir}/model.pt \ 43 | --reset_dec \ 44 | --kl_start 0 \ 45 | --warm_up 10 \ 46 | --target_kl 8 \ 47 | --fb 2 \ 48 | --lr 0.5 49 | ``` 50 | 51 | Logs, models and samples would be saved into folder `exp`. 52 | 53 | 54 | ## Reference 55 | 56 | ``` 57 | @inproceedings{li2019emnlp, 58 | title = {A Surprisingly Effective Fix for Deep Latent Variable Modeling of Text}, 59 | author = {Bohan Li and Junxian He and Graham Neubig and Taylor Berg-Kirkpatrick and Yiming Yang}, 60 | booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 61 | address = {Hong Kong}, 62 | month = {November}, 63 | year = {2019} 64 | } 65 | 66 | ``` 67 | 68 | ## Acknowledgements 69 | 70 | A large portion of this repo is borrowed from https://github.com/jxhe/vae-lagging-encoder 71 | 72 | -------------------------------------------------------------------------------- /modules/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..utils import log_sum_exp 6 | 7 | class EncoderBase(nn.Module): 8 | """docstring for EncoderBase""" 9 | def __init__(self): 10 | super(EncoderBase, self).__init__() 11 | 12 | def forward(self, x): 13 | """ 14 | Args: 15 | x: (batch_size, *) 16 | 17 | Returns: the tensors required to parameterize a distribution. 18 | E.g. for Gaussian encoder it returns the mean and variance tensors 19 | 20 | """ 21 | 22 | raise NotImplementedError 23 | 24 | def sample(self, input, nsamples): 25 | """sampling from the encoder 26 | Returns: Tensor1 27 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 28 | """ 29 | 30 | raise NotImplementedError 31 | 32 | def encode(self, input, nsamples): 33 | """perform the encoding and compute the KL term 34 | 35 | Returns: Tensor1, Tensor2 36 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 37 | Tensor2: the tenor of KL for each x with shape [batch] 38 | 39 | """ 40 | 41 | raise NotImplementedError 42 | 43 | 44 | def eval_inference_dist(self, x, z, param=None): 45 | """this function computes log q(z | x) 46 | Args: 47 | z: tensor 48 | different z points that will be evaluated, with 49 | shape [batch, nsamples, nz] 50 | Returns: Tensor1 51 | Tensor1: log q(z|x) with shape [batch, nsamples] 52 | """ 53 | 54 | raise NotImplementedError 55 | 56 | def calc_mi(self, x): 57 | """Approximate the mutual information between x and z 58 | I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) 59 | 60 | Returns: Float 61 | 62 | """ 63 | 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | import tarfile 4 | import os 5 | 6 | def download_file_from_google_drive(id, destination): 7 | URL = "https://docs.google.com/uc?export=download" 8 | 9 | session = requests.Session() 10 | 11 | response = session.get(URL, params = { 'id' : id }, stream = True) 12 | token = get_confirm_token(response) 13 | 14 | if token: 15 | params = { 'id' : id, 'confirm' : token } 16 | response = session.get(URL, params = params, stream = True) 17 | 18 | save_response_content(response, destination) 19 | 20 | def get_confirm_token(response): 21 | for key, value in response.cookies.items(): 22 | if key.startswith('download_warning'): 23 | return value 24 | 25 | return None 26 | 27 | def save_response_content(response, destination): 28 | CHUNK_SIZE = 32768 29 | 30 | with open(destination, "wb") as f: 31 | for chunk in response.iter_content(CHUNK_SIZE): 32 | if chunk: # filter out keep-alive new chunks 33 | f.write(chunk) 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser(description="data downloading") 37 | parser.add_argument('--dataset', choices=["yahoo", "snli", "short_yelp", "all"], 38 | default="all", help='dataset to use') 39 | 40 | args = parser.parse_args() 41 | 42 | if not os.path.exists("datasets"): 43 | os.makedirs("datasets") 44 | 45 | os.chdir("datasets") 46 | 47 | yahoo_id = "13azGlTuGdzWLCmgDmQPmvb_jcexVWX7i" 48 | snli_id = "11NHEPxV7OrqmODozxQGezSU8093iUlUx" 49 | short_yelp_id = "18h8UYr801qr-USCYRzZySi_DSbUkHW71" 50 | 51 | if args.dataset == "yahoo": 52 | file_id = [yahoo_id] 53 | elif args.dataset == "snli": 54 | file_id = [snli_id] 55 | elif args.dataset == "short_yelp": 56 | file_id = [short_yelp_id] 57 | else: 58 | file_id = [yahoo_id, snli_id, short_yelp_id] 59 | 60 | destination = "datasets.tar.gz" 61 | 62 | for file_id_e in file_id: 63 | download_file_from_google_drive(file_id_e, destination) 64 | tar = tarfile.open(destination, "r:gz") 65 | tar.extractall() 66 | tar.close() 67 | os.remove(destination) 68 | 69 | os.chdir("../") 70 | 71 | -------------------------------------------------------------------------------- /modules/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DecoderBase(nn.Module): 6 | """docstring for Decoder""" 7 | def __init__(self): 8 | super(DecoderBase, self).__init__() 9 | 10 | 11 | def freeze(self): 12 | for param in self.parameters(): 13 | param.requires_grad = False 14 | 15 | def decode(self, x, z): 16 | """ 17 | Args: 18 | x: (batch_size, seq_len) 19 | z: (batch_size, n_sample, nz) 20 | 21 | Returns: Tensor1 22 | Tensor1: the output logits with size (batch_size * n_sample, seq_len, vocab_size) 23 | 24 | """ 25 | 26 | raise NotImplementedError 27 | 28 | def reconstruct_error(self, x, z): 29 | """reconstruction loss 30 | Args: 31 | x: (batch_size, *) 32 | z: (batch_size, n_sample, nz) 33 | Returns: 34 | loss: (batch_size, n_sample). Loss 35 | across different sentence and z 36 | """ 37 | 38 | raise NotImplementedError 39 | 40 | def beam_search_decode(self, z, K): 41 | """beam search decoding 42 | Args: 43 | z: (batch_size, nz) 44 | K: the beam size 45 | 46 | Returns: List1 47 | List1: the decoded word sentence list 48 | """ 49 | 50 | raise NotImplementedError 51 | 52 | def sample_decode(self, z): 53 | """sampling from z 54 | Args: 55 | z: (batch_size, nz) 56 | 57 | Returns: List1 58 | List1: the decoded word sentence list 59 | """ 60 | 61 | raise NotImplementedError 62 | 63 | def greedy_decode(self, z): 64 | """greedy decoding from z 65 | Args: 66 | z: (batch_size, nz) 67 | 68 | Returns: List1 69 | List1: the decoded word sentence list 70 | """ 71 | 72 | raise NotImplementedError 73 | 74 | def log_probability(self, x, z): 75 | """ 76 | Args: 77 | x: (batch_size, *) 78 | z: (batch_size, n_sample, nz) 79 | Returns: 80 | log_p: (batch_size, n_sample). 81 | log_p(x|z) across different x and z 82 | """ 83 | 84 | raise NotImplementedError 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /modules/discriminators/discriminator_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 6 | 7 | class LinearDiscriminator(nn.Module): 8 | """docstring for LinearDiscriminator""" 9 | def __init__(self, args, encoder): 10 | super(LinearDiscriminator, self).__init__() 11 | self.args = args 12 | 13 | self.encoder = encoder 14 | if args.freeze_enc: 15 | for param in self.encoder.parameters(): 16 | param.requires_grad = False 17 | 18 | self.linear = nn.Linear(args.nz, args.ncluster) 19 | self.dropout = nn.Dropout(0.5) 20 | self.loss = nn.CrossEntropyLoss(reduction="none") 21 | 22 | def get_performance(self, batch_data, batch_labels): 23 | mu, _ = self.encoder(batch_data) 24 | if not self.args.freeze_enc: 25 | mu = self.dropout(mu) 26 | logits = self.linear(mu) 27 | loss = self.loss(logits, batch_labels) 28 | 29 | _, pred = torch.max(logits, dim=1) 30 | correct = torch.eq(pred, batch_labels).float().sum().item() 31 | 32 | return loss, correct 33 | 34 | 35 | class MLPDiscriminator(nn.Module): 36 | """docstring for LinearDiscriminator""" 37 | def __init__(self, args, encoder): 38 | super(MLPDiscriminator, self).__init__() 39 | self.args = args 40 | 41 | self.encoder = encoder 42 | if args.freeze_enc: 43 | for param in self.encoder.parameters(): 44 | param.requires_grad = False 45 | 46 | self.feats = nn.Sequential( 47 | nn.Linear(args.nz, args.nz), 48 | nn.ReLU(), 49 | nn.Dropout(0.3), 50 | nn.Linear(args.nz, args.nz), 51 | nn.ReLU(), 52 | nn.Dropout(0.3), 53 | nn.Linear(args.nz, args.ncluster), 54 | ) 55 | self.loss = nn.CrossEntropyLoss(reduction="none") 56 | 57 | def get_performance(self, batch_data, batch_labels): 58 | mu, _ = self.encoder(batch_data) 59 | logits = self.feats(mu) 60 | loss = self.loss(logits, batch_labels) 61 | 62 | _, pred = torch.max(logits, dim=1) 63 | correct = torch.eq(pred, batch_labels).float().sum().item() 64 | 65 | return loss, correct 66 | -------------------------------------------------------------------------------- /scripts/unsupervised_cluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from sklearn.mixture import GaussianMixture 5 | from sklearn.decomposition import PCA 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | parser = argparse.ArgumentParser(description='GMM unsupervised clustering') 9 | parser.add_argument('--exp_dir', type=str) 10 | parser.add_argument('--num', type=int, default=10) 11 | parser.add_argument('--pca_num', type=int, default=0) 12 | parser.add_argument('--one2one', action="store_true", default=False) 13 | 14 | args = parser.parse_args() 15 | 16 | gmm = GaussianMixture(n_components=args.num, tol=1e-3, max_iter=200, n_init=1, verbose=1) 17 | 18 | if args.pca_num > 0: 19 | pca = PCA(n_components=args.pca_num) 20 | 21 | 22 | train_x = np.loadtxt(os.path.join(args.exp_dir, "train.vec"), delimiter="\t") 23 | valid_x = np.loadtxt(os.path.join(args.exp_dir, "val.vec"), delimiter="\t") 24 | test_x = np.loadtxt(os.path.join(args.exp_dir, "test.vec"), delimiter="\t") 25 | 26 | if args.pca_num > 0: 27 | pca.fit(train_x) 28 | 29 | train_x = pca.transform(train_x) 30 | valid_x = pca.transform(valid_x) 31 | test_x = pca.transform(test_x) 32 | 33 | print(train_x.shape) 34 | 35 | print("start fitting gmm on training data") 36 | gmm.fit(train_x) 37 | 38 | valid_pred_y = gmm.predict(valid_x) 39 | valid_true_y = np.loadtxt(os.path.join(args.exp_dir, "val.label"), dtype=np.int) 40 | 41 | if args.one2one: 42 | print("linear assignment") 43 | cost_matrix = np.zeros((args.num, args.num)) 44 | 45 | for i, j in zip(valid_pred_y, valid_true_y): 46 | cost_matrix[i,j] -= 1 47 | 48 | row_ind, col_ind = linear_sum_assignment(cost_matrix) 49 | else: 50 | # (nsamples, ncomponents) 51 | valid_score = gmm.predict_proba(valid_x) 52 | valid_max_index = np.argmax(valid_score, axis=0) 53 | col_ind = {} 54 | for i in range(args.num): 55 | col_ind[i] = valid_true_y[valid_max_index[i]] 56 | 57 | print(col_ind) 58 | correct = 0. 59 | for i, j in zip(valid_pred_y, valid_true_y): 60 | if col_ind[i] == j: 61 | correct += 1 62 | print("validation acc {}".format(correct / len(valid_pred_y))) 63 | 64 | test_pred_y = gmm.predict(test_x) 65 | test_true_y = np.loadtxt(os.path.join(args.exp_dir, "test.label"), dtype=np.int) 66 | correct = 0. 67 | for i, j in zip(test_pred_y, test_true_y): 68 | if col_ind[i] == j: 69 | correct += 1 70 | print("test acc {}".format(correct / len(test_pred_y))) 71 | 72 | train_pred_y = gmm.predict(train_x) 73 | train_true_y = np.loadtxt(os.path.join(args.exp_dir, "train.label"), dtype=np.int) 74 | correct = 0. 75 | for i, j in zip(train_pred_y, train_true_y): 76 | if col_ind[i] == j: 77 | correct += 1 78 | print("train acc {}".format(correct / len(train_pred_y))) 79 | -------------------------------------------------------------------------------- /config/config_yahoo_label.py: -------------------------------------------------------------------------------- 1 | 2 | params={ 3 | 'enc_type': 'lstm', 4 | 'dec_type': 'lstm', 5 | 'nz': 32, 6 | 'ni': 512, 7 | 'enc_nh': 1024, 8 | 'dec_nh': 1024, 9 | 'dec_dropout_in': 0.5, 10 | 'dec_dropout_out': 0.5, 11 | 'batch_size': 32, 12 | 'epochs': 100, 13 | 'test_nepoch': 5, 14 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.txt', 15 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 16 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 17 | 'label': True 18 | } 19 | 20 | 21 | params_ss_100={ 22 | 'enc_type': 'lstm', 23 | 'dec_type': 'lstm', 24 | 'nz': 32, 25 | 'ni': 512, 26 | 'enc_nh': 1024, 27 | 'dec_nh': 1024, 28 | 'dec_dropout_in': 0.5, 29 | 'dec_dropout_out': 0.5, 30 | 'epochs': 1000, 31 | 'test_nepoch': 5, 32 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.100.txt', 33 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 34 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 35 | 'vocab_file': 'datasets/yahoo_label_data/yahoo.vocab', 36 | 'ncluster': 10, 37 | 'label': True 38 | } 39 | 40 | params_ss_500={ 41 | 'enc_type': 'lstm', 42 | 'dec_type': 'lstm', 43 | 'nz': 32, 44 | 'ni': 512, 45 | 'enc_nh': 1024, 46 | 'dec_nh': 1024, 47 | 'dec_dropout_in': 0.5, 48 | 'dec_dropout_out': 0.5, 49 | 'epochs': 1000, 50 | 'test_nepoch': 5, 51 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.500.txt', 52 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 53 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 54 | 'vocab_file': 'datasets/yahoo_label_data/yahoo.vocab', 55 | 'ncluster': 10, 56 | 'label': True 57 | } 58 | 59 | params_ss_1000={ 60 | 'enc_type': 'lstm', 61 | 'dec_type': 'lstm', 62 | 'nz': 32, 63 | 'ni': 512, 64 | 'enc_nh': 1024, 65 | 'dec_nh': 1024, 66 | 'dec_dropout_in': 0.5, 67 | 'dec_dropout_out': 0.5, 68 | 'epochs': 1000, 69 | 'test_nepoch': 5, 70 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.1000.txt', 71 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 72 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 73 | 'vocab_file': 'datasets/yahoo_label_data/yahoo.vocab', 74 | 'ncluster': 10, 75 | 'label': True 76 | } 77 | 78 | params_ss_2000={ 79 | 'enc_type': 'lstm', 80 | 'dec_type': 'lstm', 81 | 'nz': 32, 82 | 'ni': 512, 83 | 'enc_nh': 1024, 84 | 'dec_nh': 1024, 85 | 'dec_dropout_in': 0.5, 86 | 'dec_dropout_out': 0.5, 87 | 'epochs': 1000, 88 | 'test_nepoch': 5, 89 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.2000.txt', 90 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 91 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 92 | 'vocab_file': 'datasets/yahoo_label_data/yahoo.vocab', 93 | 'ncluster': 10, 94 | 'label': True 95 | } 96 | 97 | params_ss_10000={ 98 | 'enc_type': 'lstm', 99 | 'dec_type': 'lstm', 100 | 'nz': 32, 101 | 'ni': 512, 102 | 'enc_nh': 1024, 103 | 'dec_nh': 1024, 104 | 'dec_dropout_in': 0.5, 105 | 'dec_dropout_out': 0.5, 106 | 'epochs': 1000, 107 | 'test_nepoch': 5, 108 | 'train_data': 'datasets/yahoo_label_data/yahoo.train.10000.txt', 109 | 'val_data': 'datasets/yahoo_label_data/yahoo.valid.txt', 110 | 'test_data': 'datasets/yahoo_label_data/yahoo.test.txt', 111 | 'vocab_file': 'datasets/yahoo_label_data/yahoo.vocab', 112 | 'ncluster': 10, 113 | 'label': True 114 | } -------------------------------------------------------------------------------- /config/config_short_yelp.py: -------------------------------------------------------------------------------- 1 | params={ 2 | 'enc_type': 'lstm', 3 | 'dec_type': 'lstm', 4 | 'nz': 32, 5 | 'ni': 128, 6 | 'enc_nh': 512, 7 | 'dec_nh': 512, 8 | 'log_niter': 50, 9 | 'dec_dropout_in': 0.5, 10 | 'dec_dropout_out': 0.5, 11 | 'batch_size': 32, 12 | 'epochs': 100, 13 | 'test_nepoch': 5, 14 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.txt', 15 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 16 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 17 | "label": True 18 | } 19 | 20 | 21 | params_ss_100={ 22 | 'enc_type': 'lstm', 23 | 'dec_type': 'lstm', 24 | 'nz': 32, 25 | 'ni': 128, 26 | 'enc_nh': 512, 27 | 'dec_nh': 512, 28 | 'log_niter': 50, 29 | 'dec_dropout_in': 0.5, 30 | 'dec_dropout_out': 0.5, 31 | # 'batch_size': 32, 32 | 'epochs': 100, 33 | 'test_nepoch': 5, 34 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.100.txt', 35 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 36 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 37 | 'vocab_file': 'datasets/short_yelp_data/vocab.txt', 38 | 'ncluster': 10, 39 | "label": True 40 | } 41 | 42 | params_ss_500={ 43 | 'enc_type': 'lstm', 44 | 'dec_type': 'lstm', 45 | 'nz': 32, 46 | 'ni': 128, 47 | 'enc_nh': 512, 48 | 'dec_nh': 512, 49 | 'log_niter': 50, 50 | 'dec_dropout_in': 0.5, 51 | 'dec_dropout_out': 0.5, 52 | # 'batch_size': 32, 53 | 'epochs': 100, 54 | 'test_nepoch': 5, 55 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.500.txt', 56 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 57 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 58 | 'vocab_file': 'datasets/short_yelp_data/vocab.txt', 59 | 'ncluster': 10, 60 | "label": True 61 | } 62 | 63 | params_ss_1000={ 64 | 'enc_type': 'lstm', 65 | 'dec_type': 'lstm', 66 | 'nz': 32, 67 | 'ni': 128, 68 | 'enc_nh': 512, 69 | 'dec_nh': 512, 70 | 'log_niter': 50, 71 | 'dec_dropout_in': 0.5, 72 | 'dec_dropout_out': 0.5, 73 | # 'batch_size': 32, 74 | 'epochs': 100, 75 | 'test_nepoch': 5, 76 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.1000.txt', 77 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 78 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 79 | 'vocab_file': 'datasets/short_yelp_data/vocab.txt', 80 | 'ncluster': 10, 81 | "label": True 82 | } 83 | 84 | 85 | params_ss_2000={ 86 | 'enc_type': 'lstm', 87 | 'dec_type': 'lstm', 88 | 'nz': 32, 89 | 'ni': 128, 90 | 'enc_nh': 512, 91 | 'dec_nh': 512, 92 | 'log_niter': 50, 93 | 'dec_dropout_in': 0.5, 94 | 'dec_dropout_out': 0.5, 95 | # 'batch_size': 32, 96 | 'epochs': 100, 97 | 'test_nepoch': 5, 98 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.2000.txt', 99 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 100 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 101 | 'vocab_file': 'datasets/short_yelp_data/vocab.txt', 102 | 'ncluster': 10, 103 | "label": True 104 | } 105 | 106 | 107 | params_ss_10000={ 108 | 'enc_type': 'lstm', 109 | 'dec_type': 'lstm', 110 | 'nz': 32, 111 | 'ni': 128, 112 | 'enc_nh': 512, 113 | 'dec_nh': 512, 114 | 'log_niter': 50, 115 | 'dec_dropout_in': 0.5, 116 | 'dec_dropout_out': 0.5, 117 | # 'batch_size': 32, 118 | 'epochs': 100, 119 | 'test_nepoch': 5, 120 | 'train_data': 'datasets/short_yelp_data/short_yelp.train.10000.txt', 121 | 'val_data': 'datasets/short_yelp_data/short_yelp.valid.txt', 122 | 'test_data': 'datasets/short_yelp_data/short_yelp.test.txt', 123 | 'vocab_file': 'datasets/short_yelp_data/vocab.txt', 124 | 'ncluster': 10, 125 | "label": True 126 | } 127 | -------------------------------------------------------------------------------- /modules/encoders/enc_lstm.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 7 | from .gaussian_encoder import GaussianEncoderBase 8 | from ..utils import log_sum_exp 9 | 10 | class GaussianLSTMEncoder(GaussianEncoderBase): 11 | """Gaussian LSTM Encoder with constant-length input""" 12 | def __init__(self, args, vocab_size, model_init, emb_init): 13 | super(GaussianLSTMEncoder, self).__init__() 14 | self.ni = args.ni 15 | self.nh = args.enc_nh 16 | self.nz = args.nz 17 | self.args = args 18 | 19 | self.embed = nn.Embedding(vocab_size, args.ni) 20 | 21 | self.lstm = nn.LSTM(input_size=args.ni, 22 | hidden_size=args.enc_nh, 23 | num_layers=1, 24 | batch_first=True, 25 | dropout=0) 26 | 27 | self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) 28 | 29 | self.reset_parameters(model_init, emb_init) 30 | 31 | def reset_parameters(self, model_init, emb_init): 32 | # for name, param in self.lstm.named_parameters(): 33 | # # self.initializer(param) 34 | # if 'bias' in name: 35 | # nn.init.constant_(param, 0.0) 36 | # # model_init(param) 37 | # elif 'weight' in name: 38 | # model_init(param) 39 | 40 | # model_init(self.linear.weight) 41 | # emb_init(self.embed.weight) 42 | for param in self.parameters(): 43 | model_init(param) 44 | emb_init(self.embed.weight) 45 | 46 | 47 | def forward(self, input): 48 | """ 49 | Args: 50 | x: (batch_size, seq_len) 51 | 52 | Returns: Tensor1, Tensor2 53 | Tensor1: the mean tensor, shape (batch, nz) 54 | Tensor2: the logvar tensor, shape (batch, nz) 55 | """ 56 | 57 | # (batch_size, seq_len-1, args.ni) 58 | word_embed = self.embed(input) 59 | 60 | _, (last_state, last_cell) = self.lstm(word_embed) 61 | 62 | mean, logvar = self.linear(last_state).chunk(2, -1) 63 | 64 | # fix variance as a pre-defined value 65 | if self.args.fix_var > 0: 66 | logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean) 67 | 68 | return mean.squeeze(0), logvar.squeeze(0) 69 | 70 | # def eval_inference_mode(self, x): 71 | # """compute the mode points in the inference distribution 72 | # (in Gaussian case) 73 | # Returns: Tensor 74 | # Tensor: the posterior mode points with shape (*, nz) 75 | # """ 76 | 77 | # # (batch_size, nz) 78 | # mu, logvar = self.forward(x) 79 | 80 | 81 | class VarLSTMEncoder(GaussianLSTMEncoder): 82 | """Gaussian LSTM Encoder with variable-length input""" 83 | def __init__(self, args, vocab_size, model_init, emb_init): 84 | super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init) 85 | 86 | 87 | def forward(self, input): 88 | """ 89 | Args: 90 | input: tuple which contains x and sents_len 91 | x: (batch_size, seq_len) 92 | sents_len: long tensor of sentence lengths 93 | 94 | Returns: Tensor1, Tensor2 95 | Tensor1: the mean tensor, shape (batch, nz) 96 | Tensor2: the logvar tensor, shape (batch, nz) 97 | """ 98 | 99 | input, sents_len = input 100 | # (batch_size, seq_len, args.ni) 101 | word_embed = self.embed(input) 102 | 103 | packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) 104 | 105 | _, (last_state, last_cell) = self.lstm(packed_embed) 106 | 107 | mean, logvar = self.linear(last_state).chunk(2, -1) 108 | 109 | return mean.squeeze(0), logvar.squeeze(0) 110 | 111 | def encode(self, input, nsamples): 112 | """perform the encoding and compute the KL term 113 | Args: 114 | input: tuple which contains x and sents_len 115 | 116 | Returns: Tensor1, Tensor2 117 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 118 | Tensor2: the tenor of KL for each x with shape [batch] 119 | 120 | """ 121 | 122 | # (batch_size, nz) 123 | mu, logvar = self.forward(input) 124 | 125 | # (batch, nsamples, nz) 126 | z = self.reparameterize(mu, logvar, nsamples) 127 | 128 | KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 129 | 130 | return z, KL 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /modules/encoders/gaussian_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .encoder import EncoderBase 6 | from ..utils import log_sum_exp 7 | 8 | class GaussianEncoderBase(EncoderBase): 9 | """docstring for EncoderBase""" 10 | def __init__(self): 11 | super(GaussianEncoderBase, self).__init__() 12 | 13 | def freeze(self): 14 | for param in self.parameters(): 15 | param.requires_grad = False 16 | 17 | def forward(self, x): 18 | """ 19 | Args: 20 | x: (batch_size, *) 21 | 22 | Returns: Tensor1, Tensor2 23 | Tensor1: the mean tensor, shape (batch, nz) 24 | Tensor2: the logvar tensor, shape (batch, nz) 25 | """ 26 | 27 | raise NotImplementedError 28 | 29 | def encode_stats(self, x): 30 | 31 | return self.forward(x) 32 | 33 | def sample(self, input, nsamples): 34 | """sampling from the encoder 35 | Returns: Tensor1 36 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 37 | """ 38 | 39 | # (batch_size, nz) 40 | mu, logvar = self.forward(input) 41 | 42 | # (batch, nsamples, nz) 43 | z = self.reparameterize(mu, logvar, nsamples) 44 | 45 | return z, (mu, logvar) 46 | 47 | def encode(self, input, nsamples): 48 | """perform the encoding and compute the KL term 49 | 50 | Returns: Tensor1, Tensor2 51 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 52 | Tensor2: the tenor of KL for each x with shape [batch] 53 | 54 | """ 55 | 56 | # (batch_size, nz) 57 | mu, logvar = self.forward(input) 58 | 59 | # (batch, nsamples, nz) 60 | z = self.reparameterize(mu, logvar, nsamples) 61 | 62 | KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 63 | 64 | return z, KL 65 | 66 | def reparameterize(self, mu, logvar, nsamples=1): 67 | """sample from posterior Gaussian family 68 | Args: 69 | mu: Tensor 70 | Mean of gaussian distribution with shape (batch, nz) 71 | 72 | logvar: Tensor 73 | logvar of gaussian distibution with shape (batch, nz) 74 | 75 | Returns: Tensor 76 | Sampled z with shape (batch, nsamples, nz) 77 | """ 78 | batch_size, nz = mu.size() 79 | std = logvar.mul(0.5).exp() 80 | 81 | mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) 82 | std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) 83 | 84 | eps = torch.zeros_like(std_expd).normal_() 85 | 86 | return mu_expd + torch.mul(eps, std_expd) 87 | 88 | def eval_inference_dist(self, x, z, param=None): 89 | """this function computes log q(z | x) 90 | Args: 91 | z: tensor 92 | different z points that will be evaluated, with 93 | shape [batch, nsamples, nz] 94 | Returns: Tensor1 95 | Tensor1: log q(z|x) with shape [batch, nsamples] 96 | """ 97 | 98 | nz = z.size(2) 99 | 100 | if not param: 101 | mu, logvar = self.forward(x) 102 | else: 103 | mu, logvar = param 104 | 105 | # (batch_size, 1, nz) 106 | mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) 107 | var = logvar.exp() 108 | 109 | # (batch_size, nsamples, nz) 110 | dev = z - mu 111 | 112 | # (batch_size, nsamples) 113 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 114 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 115 | 116 | return log_density 117 | 118 | def calc_mi(self, x): 119 | """Approximate the mutual information between x and z 120 | I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) 121 | 122 | Returns: Float 123 | 124 | """ 125 | 126 | # [x_batch, nz] 127 | mu, logvar = self.forward(x) 128 | 129 | x_batch, nz = mu.size() 130 | 131 | # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) 132 | neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean() 133 | 134 | # [z_batch, 1, nz] 135 | z_samples = self.reparameterize(mu, logvar, 1) 136 | 137 | # [1, x_batch, nz] 138 | mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) 139 | var = logvar.exp() 140 | 141 | # (z_batch, x_batch, nz) 142 | dev = z_samples - mu 143 | 144 | # (z_batch, x_batch) 145 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 146 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 147 | 148 | # log q(z): aggregate posterior 149 | # [z_batch] 150 | log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) 151 | 152 | return (neg_entropy - log_qz.mean(-1)).item() -------------------------------------------------------------------------------- /scripts/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | #no warnings; 8 | use warnings; 9 | use strict; 10 | 11 | my $lowercase = 0; 12 | if ($ARGV[0] eq "-lc") { 13 | $lowercase = 1; 14 | shift; 15 | } 16 | 17 | my $stem = $ARGV[0]; 18 | if (!defined $stem) { 19 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 20 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 21 | exit(1); 22 | } 23 | 24 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 25 | 26 | my @REF; 27 | my $ref=0; 28 | while(-e "$stem$ref") { 29 | &add_to_ref("$stem$ref",\@REF); 30 | $ref++; 31 | } 32 | &add_to_ref($stem,\@REF) if -e $stem; 33 | die("ERROR: could not find reference file $stem") unless scalar @REF; 34 | 35 | # add additional references explicitly specified on the command line 36 | shift; 37 | foreach my $stem (@ARGV) { 38 | &add_to_ref($stem,\@REF) if -e $stem; 39 | } 40 | 41 | 42 | 43 | sub add_to_ref { 44 | my ($file,$REF) = @_; 45 | my $s=0; 46 | if ($file =~ /.gz$/) { 47 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 48 | } else { 49 | open(REF,$file) or die "Can't read $file"; 50 | } 51 | while() { 52 | chop; 53 | push @{$$REF[$s++]}, $_; 54 | } 55 | close(REF); 56 | } 57 | 58 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 59 | my $s=0; 60 | while() { 61 | chop; 62 | $_ = lc if $lowercase; 63 | my @WORD = split; 64 | my %REF_NGRAM = (); 65 | my $length_translation_this_sentence = scalar(@WORD); 66 | my ($closest_diff,$closest_length) = (9999,9999); 67 | foreach my $reference (@{$REF[$s]}) { 68 | # print "$s $_ <=> $reference\n"; 69 | $reference = lc($reference) if $lowercase; 70 | my @WORD = split(' ',$reference); 71 | my $length = scalar(@WORD); 72 | my $diff = abs($length_translation_this_sentence-$length); 73 | if ($diff < $closest_diff) { 74 | $closest_diff = $diff; 75 | $closest_length = $length; 76 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 77 | } elsif ($diff == $closest_diff) { 78 | $closest_length = $length if $length < $closest_length; 79 | # from two references with the same closeness to me 80 | # take the *shorter* into account, not the "first" one. 81 | } 82 | for(my $n=1;$n<=4;$n++) { 83 | my %REF_NGRAM_N = (); 84 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 85 | my $ngram = "$n"; 86 | for(my $w=0;$w<$n;$w++) { 87 | $ngram .= " ".$WORD[$start+$w]; 88 | } 89 | $REF_NGRAM_N{$ngram}++; 90 | } 91 | foreach my $ngram (keys %REF_NGRAM_N) { 92 | if (!defined($REF_NGRAM{$ngram}) || 93 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 94 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 95 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 96 | } 97 | } 98 | } 99 | } 100 | $length_translation += $length_translation_this_sentence; 101 | $length_reference += $closest_length; 102 | for(my $n=1;$n<=4;$n++) { 103 | my %T_NGRAM = (); 104 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 105 | my $ngram = "$n"; 106 | for(my $w=0;$w<$n;$w++) { 107 | $ngram .= " ".$WORD[$start+$w]; 108 | } 109 | $T_NGRAM{$ngram}++; 110 | } 111 | foreach my $ngram (keys %T_NGRAM) { 112 | $ngram =~ /^(\d+) /; 113 | my $n = $1; 114 | # my $corr = 0; 115 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 116 | $TOTAL[$n] += $T_NGRAM{$ngram}; 117 | if (defined($REF_NGRAM{$ngram})) { 118 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 119 | $CORRECT[$n] += $T_NGRAM{$ngram}; 120 | # $corr = $T_NGRAM{$ngram}; 121 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 122 | } 123 | else { 124 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 125 | # $corr = $REF_NGRAM{$ngram}; 126 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 127 | } 128 | } 129 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 130 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 131 | } 132 | } 133 | $s++; 134 | } 135 | my $brevity_penalty = 1; 136 | my $bleu = 0; 137 | 138 | my @bleu=(); 139 | 140 | for(my $n=1;$n<=4;$n++) { 141 | if (defined ($TOTAL[$n])){ 142 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 143 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 144 | }else{ 145 | $bleu[$n]=0; 146 | } 147 | } 148 | 149 | if ($length_reference==0){ 150 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 151 | exit(1); 152 | } 153 | 154 | if ($length_translation<$length_reference) { 155 | $brevity_penalty = exp(1-$length_reference/$length_translation); 156 | } 157 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 158 | my_log( $bleu[2] ) + 159 | my_log( $bleu[3] ) + 160 | my_log( $bleu[4] ) ) / 4) ; 161 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 162 | 100*$bleu, 163 | 100*$bleu[1], 164 | 100*$bleu[2], 165 | 100*$bleu[3], 166 | 100*$bleu[4], 167 | $brevity_penalty, 168 | $length_translation / $length_reference, 169 | $length_translation, 170 | $length_reference; 171 | 172 | 173 | print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; 174 | 175 | sub my_log { 176 | return -9999999999 unless $_[0]; 177 | return log($_[0]); 178 | } 179 | -------------------------------------------------------------------------------- /modules/lm/lm_lstm.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | 3 | import time 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 11 | 12 | import numpy as np 13 | 14 | class LSTM_LM(nn.Module): 15 | """LSTM decoder with constant-length data""" 16 | def __init__(self, args, vocab, model_init, emb_init): 17 | super(LSTM_LM, self).__init__() 18 | self.ni = args.ni 19 | self.nh = args.dec_nh 20 | 21 | # no padding when setting padding_idx to -1 22 | self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) 23 | 24 | self.dropout_in = nn.Dropout(args.dec_dropout_in) 25 | self.dropout_out = nn.Dropout(args.dec_dropout_out) 26 | 27 | # concatenate z with input 28 | self.lstm = nn.LSTM(input_size=args.ni, 29 | hidden_size=args.dec_nh, 30 | num_layers=1, 31 | batch_first=True) 32 | 33 | # prediction layer 34 | self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) 35 | 36 | vocab_mask = torch.ones(len(vocab)) 37 | # vocab_mask[vocab['']] = 0 38 | self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) 39 | 40 | self.reset_parameters(model_init, emb_init) 41 | 42 | def reset_parameters(self, model_init, emb_init): 43 | # for name, param in self.lstm.named_parameters(): 44 | # # self.initializer(param) 45 | # if 'bias' in name: 46 | # nn.init.constant_(param, 0.0) 47 | # # model_init(param) 48 | # elif 'weight' in name: 49 | # model_init(param) 50 | 51 | # model_init(self.trans_linear.weight) 52 | # model_init(self.pred_linear.weight) 53 | for param in self.parameters(): 54 | model_init(param) 55 | emb_init(self.embed.weight) 56 | 57 | 58 | def decode(self, input): 59 | """ 60 | Args: 61 | input: (batch_size, seq_len) 62 | """ 63 | 64 | # not predicting start symbol 65 | # sents_len -= 1 66 | 67 | batch_size, seq_len = input.size() 68 | 69 | # (batch_size, seq_len, ni) 70 | word_embed = self.embed(input) 71 | word_embed = self.dropout_in(word_embed) 72 | 73 | c_init = word_embed.new_zeros((1, batch_size, self.nh)) 74 | h_init = word_embed.new_zeros((1, batch_size, self.nh)) 75 | output, _ = self.lstm(word_embed, (h_init, c_init)) 76 | 77 | output = self.dropout_out(output) 78 | 79 | # (batch_size, seq_len, vocab_size) 80 | output_logits = self.pred_linear(output) 81 | 82 | return output_logits 83 | 84 | def reconstruct_error(self, x): 85 | """Cross Entropy in the language case 86 | Args: 87 | x: (batch_size, seq_len) 88 | z: (batch_size, n_sample, nz) 89 | Returns: 90 | loss: (batch_size). Loss across different sentences 91 | """ 92 | 93 | #remove end symbol 94 | src = x[:, :-1] 95 | 96 | # remove start symbol 97 | tgt = x[:, 1:] 98 | 99 | batch_size, seq_len = src.size() 100 | 101 | # (batch_size * n_sample, seq_len, vocab_size) 102 | output_logits = self.decode(src) 103 | 104 | tgt = tgt.contiguous().view(-1) 105 | 106 | # (batch_size * seq_len) 107 | loss = self.loss(output_logits.view(-1, output_logits.size(2)), 108 | tgt) 109 | 110 | 111 | # (batch_size) 112 | return loss.view(batch_size, -1).sum(-1) 113 | 114 | def log_probability(self, x): 115 | """Cross Entropy in the language case 116 | Args: 117 | x: (batch_size, seq_len) 118 | Returns: 119 | log_p: (batch_size). 120 | """ 121 | 122 | return -self.reconstruct_error(x) 123 | 124 | 125 | # class VarLSTMDecoder(LSTMDecoder): 126 | # """LSTM decoder with constant-length data""" 127 | # def __init__(self, args, vocab, model_init, emb_init): 128 | # super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) 129 | 130 | # self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['']) 131 | # vocab_mask = torch.ones(len(vocab)) 132 | # vocab_mask[vocab['']] = 0 133 | # self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) 134 | 135 | # self.reset_parameters(model_init, emb_init) 136 | 137 | # def decode(self, input, z): 138 | # """ 139 | # Args: 140 | # input: tuple which contains x and sents_len 141 | # x: (batch_size, seq_len) 142 | # sents_len: long tensor of sentence lengths 143 | # z: (batch_size, n_sample, nz) 144 | # """ 145 | 146 | # input, sents_len = input 147 | 148 | # # not predicting start symbol 149 | # sents_len = sents_len - 1 150 | 151 | # batch_size, n_sample, _ = z.size() 152 | # seq_len = input.size(1) 153 | 154 | # # (batch_size, seq_len, ni) 155 | # word_embed = self.embed(input) 156 | # word_embed = self.dropout_in(word_embed) 157 | 158 | # if n_sample == 1: 159 | # z_ = z.expand(batch_size, seq_len, self.nz) 160 | 161 | # else: 162 | # word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ 163 | # .contiguous() 164 | 165 | # # (batch_size * n_sample, seq_len, ni) 166 | # word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) 167 | 168 | # z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() 169 | # z_ = z_.view(batch_size * n_sample, seq_len, self.nz) 170 | 171 | # # (batch_size * n_sample, seq_len, ni + nz) 172 | # word_embed = torch.cat((word_embed, z_), -1) 173 | 174 | # sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) 175 | # packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) 176 | 177 | # z = z.view(batch_size * n_sample, self.nz) 178 | # # h_init = self.trans_linear(z).unsqueeze(0) 179 | # # c_init = h_init.new_zeros(h_init.size()) 180 | # c_init = self.trans_linear(z).unsqueeze(0) 181 | # h_init = torch.tanh(c_init) 182 | # output, _ = self.lstm(packed_embed, (h_init, c_init)) 183 | # output, _ = pad_packed_sequence(output, batch_first=True) 184 | 185 | # output = self.dropout_out(output) 186 | 187 | # # (batch_size * n_sample, seq_len, vocab_size) 188 | # output_logits = self.pred_linear(output) 189 | 190 | # return output_logits 191 | 192 | # def reconstruct_error(self, x, z): 193 | # """Cross Entropy in the language case 194 | # Args: 195 | # x: tuple which contains x_ and sents_len 196 | # x_: (batch_size, seq_len) 197 | # sents_len: long tensor of sentence lengths 198 | # z: (batch_size, n_sample, nz) 199 | # Returns: 200 | # loss: (batch_size, n_sample). Loss 201 | # across different sentence and z 202 | # """ 203 | 204 | # x, sents_len = x 205 | 206 | # #remove end symbol 207 | # src = x[:, :-1] 208 | 209 | # # remove start symbol 210 | # tgt = x[:, 1:] 211 | 212 | # batch_size, seq_len = src.size() 213 | # n_sample = z.size(1) 214 | 215 | # # (batch_size * n_sample, seq_len, vocab_size) 216 | # output_logits = self.decode((src, sents_len), z) 217 | 218 | # if n_sample == 1: 219 | # tgt = tgt.contiguous().view(-1) 220 | # else: 221 | # # (batch_size * n_sample * seq_len) 222 | # tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ 223 | # .contiguous().view(-1) 224 | 225 | # # (batch_size * n_sample * seq_len) 226 | # loss = self.loss(output_logits.view(-1, output_logits.size(2)), 227 | # tgt) 228 | 229 | 230 | # # (batch_size, n_sample) 231 | # return loss.view(batch_size, n_sample, -1).sum(-1) 232 | 233 | # def log_probability(self, x, z): 234 | # """Cross Entropy in the language case 235 | # Args: 236 | # x: tuple which contains x_ and sents_len 237 | # x_: (batch_size, seq_len) 238 | # sents_len: long tensor of sentence lengths 239 | # z: (batch_size, n_sample, nz) 240 | # Returns: 241 | # log_p(x|z): (batch_size, n_sample). 242 | # """ 243 | 244 | # return -self.reconstruct_error(x, z) 245 | 246 | -------------------------------------------------------------------------------- /lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import importlib 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | 12 | from data import MonoTextData 13 | 14 | from modules import LSTM_LM 15 | 16 | from exp_utils import create_exp_dir 17 | from utils import uniform_initializer, xavier_normal_initializer 18 | 19 | # clip_grad = 5.0 20 | # decay_epoch = 2 21 | # lr_decay = 0.5 22 | max_decay = 5 23 | 24 | 25 | def init_config(): 26 | parser = argparse.ArgumentParser(description='VAE mode collapse study') 27 | 28 | parser.add_argument('--dataset', type=str, required=True, help='dataset to use') 29 | parser.add_argument('--exp_dir', default=None, type=str, 30 | help='experiment directory.') 31 | 32 | # select mode 33 | parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') 34 | parser.add_argument('--load_path', type=str, default='') 35 | 36 | # optimization parameters 37 | parser.add_argument('--lr', type=float, default=1.0, help='Learning rate') 38 | parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') 39 | parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') 40 | 41 | parser.add_argument('--lr_decay', type=float, default=0.5) 42 | parser.add_argument('--decay_epoch', type=int, default=2) 43 | parser.add_argument('--clip_grad', type=float, default=5.0, help='') 44 | 45 | # others 46 | parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed') 47 | parser.add_argument('--cuda', action='store_true', default=False, help='use gpu') 48 | 49 | args = parser.parse_args() 50 | 51 | # set args.cuda 52 | args.cuda = torch.cuda.is_available() 53 | 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | if args.cuda: 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | # load config file into args 60 | config_file = "config.config_%s" % args.dataset 61 | params = importlib.import_module(config_file).params 62 | args = argparse.Namespace(**vars(args), **params) 63 | 64 | 65 | # set load and save paths 66 | if args.exp_dir == None: 67 | args.exp_dir = "exp_{}_lm/{}_{}_{}".format(args.dataset, 68 | args.dataset, args.opt, args.lr) 69 | 70 | if len(args.load_path) <= 0 and args.eval: 71 | args.load_path = os.path.join(args.exp_dir, 'model.pt') 72 | 73 | args.save_path = os.path.join(args.exp_dir, 'model.pt') 74 | 75 | return args 76 | 77 | def test(model, test_data_batch, args): 78 | global logging 79 | 80 | report_loss = 0 81 | report_num_words = report_num_sents = 0 82 | for i in np.random.permutation(len(test_data_batch)): 83 | batch_data = test_data_batch[i] 84 | batch_size, sent_len = batch_data.size() 85 | 86 | # not predict start symbol 87 | report_num_words += (sent_len - 1) * batch_size 88 | 89 | report_num_sents += batch_size 90 | 91 | 92 | loss = model.reconstruct_error(batch_data) 93 | 94 | 95 | loss = loss.sum() 96 | 97 | report_loss += loss.item() 98 | 99 | nll = (report_loss) / report_num_sents 100 | ppl = np.exp(nll * report_num_sents / report_num_words) 101 | 102 | logging('avg_loss: %.4f, nll: %.4f, ppl: %.4f' % \ 103 | (nll, nll, ppl)) 104 | sys.stdout.flush() 105 | 106 | return nll, ppl 107 | 108 | def main(args): 109 | global logging 110 | logging = create_exp_dir(args.exp_dir, scripts_to_save=["text_cyc_anneal.py"]) 111 | 112 | if args.cuda: 113 | logging('using cuda') 114 | 115 | logging('model saving path: %s' % args.save_path) 116 | 117 | logging(str(args)) 118 | 119 | opt_dict = {"not_improved": 0, "lr": args.lr, "best_loss": 1e4} 120 | 121 | train_data = MonoTextData(args.train_data) 122 | 123 | vocab = train_data.vocab 124 | vocab_size = len(vocab) 125 | 126 | val_data = MonoTextData(args.val_data, vocab=vocab) 127 | test_data = MonoTextData(args.test_data, vocab=vocab) 128 | 129 | logging('Train data: %d samples' % len(train_data)) 130 | logging('finish reading datasets, vocab size is %d' % len(vocab)) 131 | logging('dropped sentences: %d' % train_data.dropped) 132 | sys.stdout.flush() 133 | 134 | model_init = uniform_initializer(0.01) 135 | emb_init = uniform_initializer(0.1) 136 | 137 | device = torch.device("cuda" if args.cuda else "cpu") 138 | args.device = device 139 | lm = LSTM_LM(args, vocab, model_init, emb_init).to(device) 140 | 141 | if args.load_path: 142 | loaded_state_dict = torch.load(args.load_path) 143 | lm.load_state_dict(loaded_state_dict) 144 | logging("%s loaded" % args.load_path) 145 | 146 | if args.opt == "sgd": 147 | optimizer = optim.SGD(lm.parameters(), lr=args.lr, momentum=args.momentum) 148 | opt_dict['lr'] = args.lr 149 | elif args.opt == "adam": 150 | optimizer = optim.Adam(lm.parameters(), lr=args.lr) 151 | opt_dict['lr'] = args.lr 152 | else: 153 | raise ValueError("optimizer not supported") 154 | 155 | iter_ = decay_cnt = 0 156 | best_loss = 1e4 157 | best_nll = best_ppl = 0 158 | lm.train() 159 | start = time.time() 160 | 161 | train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, 162 | device=device, 163 | batch_first=True) 164 | val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, 165 | device=device, 166 | batch_first=True) 167 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 168 | device=device, 169 | batch_first=True) 170 | for epoch in range(args.epochs): 171 | report_loss = 0 172 | report_num_words = report_num_sents = 0 173 | for i in np.random.permutation(len(train_data_batch)): 174 | batch_data = train_data_batch[i] 175 | batch_size, sent_len = batch_data.size() 176 | 177 | # not predict start symbol 178 | report_num_words += (sent_len - 1) * batch_size 179 | 180 | report_num_sents += batch_size 181 | 182 | optimizer.zero_grad() 183 | 184 | loss = lm.reconstruct_error(batch_data) 185 | 186 | report_loss += loss.sum().item() 187 | loss = loss.mean(dim=-1) 188 | loss.backward() 189 | torch.nn.utils.clip_grad_norm_(lm.parameters(), args.clip_grad) 190 | 191 | optimizer.step() 192 | 193 | if iter_ % args.log_niter == 0: 194 | train_loss = report_loss / report_num_sents 195 | 196 | logging('epoch: %d, iter: %d, avg_loss: %.4f, time elapsed %.2fs' % 197 | (epoch, iter_, train_loss, time.time() - start)) 198 | sys.stdout.flush() 199 | 200 | iter_ += 1 201 | 202 | if epoch % args.test_nepoch == 0: 203 | #logging('epoch: %d, testing' % epoch) 204 | lm.eval() 205 | 206 | with torch.no_grad(): 207 | nll, ppl = test(lm, test_data_batch, args) 208 | logging('test | epoch: %d, nll: %.4f, ppl: %.4f' % (epoch, nll, ppl)) 209 | lm.train() 210 | 211 | 212 | lm.eval() 213 | with torch.no_grad(): 214 | nll, ppl = test(lm, val_data_batch, args) 215 | logging('valid | epoch: %d, nll: %.4f, ppl: %.4f' % (epoch, nll, ppl)) 216 | 217 | if nll < best_loss: 218 | logging('update best loss') 219 | best_loss = nll 220 | best_nll = nll 221 | best_ppl = ppl 222 | torch.save(lm.state_dict(), args.save_path) 223 | 224 | if nll > opt_dict["best_loss"]: 225 | opt_dict["not_improved"] += 1 226 | if opt_dict["not_improved"] >= args.decay_epoch: 227 | opt_dict["best_loss"] = loss 228 | opt_dict["not_improved"] = 0 229 | opt_dict["lr"] = opt_dict["lr"] * args.lr_decay 230 | lm.load_state_dict(torch.load(args.save_path)) 231 | logging('new lr: %f' % opt_dict["lr"]) 232 | decay_cnt += 1 233 | if args.opt == "sgd": 234 | optimizer = optim.SGD(lm.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 235 | elif args.opt == "adam": 236 | optimizer = optim.Adam(lm.parameters(), lr=opt_dict["lr"]) 237 | else: 238 | raise ValueError("optimizer not supported") 239 | else: 240 | opt_dict["not_improved"] = 0 241 | opt_dict["best_loss"] = nll 242 | 243 | if decay_cnt == max_decay: 244 | break 245 | 246 | lm.train() 247 | 248 | logging('valid | best_loss: %.4f, nll: %.4f, ppl: %.4f' \ 249 | % (best_loss, best_nll, best_ppl)) 250 | 251 | # reload best lm model 252 | lm.load_state_dict(torch.load(args.save_path)) 253 | 254 | with torch.no_grad(): 255 | nll, ppl = test(lm, test_data_batch, args) 256 | logging('test | nll: %.4f, ppl: %.4f' % (nll, ppl)) 257 | 258 | 259 | 260 | if __name__ == '__main__': 261 | args = init_config() 262 | main(args) 263 | -------------------------------------------------------------------------------- /text_get_mean.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import time 3 | import importlib 4 | import argparse 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch import nn, optim 10 | 11 | from data import MonoTextData 12 | from modules import VAE 13 | from modules import GaussianLSTMEncoder, LSTMDecoder 14 | 15 | from exp_utils import create_exp_dir 16 | from utils import uniform_initializer, xavier_normal_initializer, calc_iwnll, calc_mi, calc_au, sample_sentences, visualize_latent, reconstruct 17 | 18 | # old parameters 19 | clip_grad = 5.0 20 | decay_epoch = 2 21 | lr_decay = 0.5 22 | max_decay = 5 23 | 24 | # Junxian's new parameters 25 | # clip_grad = 1.0 26 | # decay_epoch = 5 27 | # lr_decay = 0.5 28 | # max_decay = 5 29 | 30 | def init_config(): 31 | parser = argparse.ArgumentParser(description='VAE mode collapse study') 32 | 33 | # model hyperparameters 34 | parser.add_argument('--dataset', type=str, required=True, help='dataset to use') 35 | # optimization parameters 36 | parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') 37 | parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') 38 | 39 | parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') 40 | parser.add_argument('--iw_nsamples', type=int, default=500, 41 | help='number of samples to compute importance weighted estimate') 42 | 43 | # select mode 44 | parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') 45 | parser.add_argument('--load_dir', type=str, default='') 46 | 47 | # decoding 48 | parser.add_argument('--reconstruct_from', type=str, default='', help="the model checkpoint path") 49 | parser.add_argument('--reconstruct_to', type=str, default="decoding.txt", help="save file") 50 | parser.add_argument('--decoding_strategy', type=str, choices=["greedy", "beam", "sample"], default="greedy") 51 | 52 | # annealing paramters 53 | parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs. warm_up=0 means not anneal") 54 | parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") 55 | 56 | 57 | # inference parameters 58 | parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed') 59 | 60 | # output directory 61 | parser.add_argument('--exp_dir', default=None, type=str, 62 | help='experiment directory.') 63 | parser.add_argument("--save_ckpt", type=int, default=0, 64 | help="save checkpoint every epoch before this number") 65 | parser.add_argument("--save_latent", type=int, default=0) 66 | 67 | # new 68 | parser.add_argument("--fix_var", type=float, default=-1) 69 | parser.add_argument("--reset_dec", action="store_true", default=False) 70 | parser.add_argument("--load_best_epoch", type=int, default=15) 71 | parser.add_argument("--lr", type=float, default=1.) 72 | 73 | parser.add_argument("--fb", type=int, default=0, 74 | help="0: no fb; 1: fb; 2: max(target_kl, kl) for each dimension") 75 | parser.add_argument("--target_kl", type=float, default=-1, 76 | help="target kl of the free bits trick") 77 | 78 | 79 | args = parser.parse_args() 80 | 81 | # set args.cuda 82 | args.cuda = torch.cuda.is_available() 83 | 84 | # set seeds 85 | # seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] 86 | # args.seed = seed_set[args.taskid] 87 | np.random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | if args.cuda: 90 | torch.cuda.manual_seed(args.seed) 91 | torch.backends.cudnn.deterministic = True 92 | 93 | # load config file into args 94 | config_file = "config.config_%s" % args.dataset 95 | params = importlib.import_module(config_file).params 96 | args = argparse.Namespace(**vars(args), **params) 97 | 98 | args.save_dir = args.load_dir 99 | args.load_path = os.path.join(args.load_dir, "model.pt") 100 | 101 | # set args.label 102 | if 'label' in params: 103 | args.label = params['label'] 104 | else: 105 | args.label = False 106 | 107 | return args 108 | 109 | 110 | def test(model, test_data_batch, mode, args, verbose=True): 111 | report_kl_loss = report_rec_loss = report_loss = 0 112 | report_num_words = report_num_sents = 0 113 | for i in np.random.permutation(len(test_data_batch)): 114 | batch_data = test_data_batch[i] 115 | batch_size, sent_len = batch_data.size() 116 | 117 | # not predict start symbol 118 | report_num_words += (sent_len - 1) * batch_size 119 | report_num_sents += batch_size 120 | loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, nsamples=args.nsamples) 121 | assert(not loss_rc.requires_grad) 122 | 123 | loss_rc = loss_rc.sum() 124 | loss_kl = loss_kl.sum() 125 | loss = loss.sum() 126 | 127 | report_rec_loss += loss_rc.item() 128 | report_kl_loss += loss_kl.item() 129 | if args.warm_up == 0 and args.kl_start < 1e-6: 130 | report_loss += loss_rc.item() 131 | else: 132 | report_loss += loss.item() 133 | 134 | mutual_info = calc_mi(model, test_data_batch) 135 | 136 | test_loss = report_loss / report_num_sents 137 | 138 | nll = (report_kl_loss + report_rec_loss) / report_num_sents 139 | kl = report_kl_loss / report_num_sents 140 | ppl = np.exp(nll * report_num_sents / report_num_words) 141 | if verbose: 142 | print('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \ 143 | (mode, test_loss, report_kl_loss / report_num_sents, mutual_info, 144 | report_rec_loss / report_num_sents, nll, ppl)) 145 | #sys.stdout.flush() 146 | 147 | return test_loss, nll, kl, ppl, mutual_info 148 | 149 | 150 | def save_latents(args, vae, test_data_batch, test_label_batch, str_): 151 | fout_label = open(os.path.join(args.save_dir, f'{str_}.label'),'w') 152 | with open(os.path.join(args.save_dir, f'{str_}.vec'),'w') as f: 153 | for i in range(len(test_data_batch)): 154 | batch_data = test_data_batch[i] 155 | batch_label = test_label_batch[i] 156 | batch_size, sent_len = batch_data.size() 157 | means, _ = vae.encoder.forward(batch_data) 158 | for j in range(batch_size): 159 | fout_label.write(batch_label[j] + "\n") 160 | mean = means[j,:].cpu().detach().numpy().tolist() 161 | f.write('\t'.join([str(val) for val in mean]) + '\n') 162 | 163 | 164 | def main(args): 165 | train_data = MonoTextData(args.train_data, label=args.label) 166 | vocab = train_data.vocab 167 | vocab_size = len(vocab) 168 | 169 | vocab_path = os.path.join("/".join(args.train_data.split("/")[:-1]), "vocab.txt") 170 | with open(vocab_path, "w") as fout: 171 | for i in range(vocab_size): 172 | fout.write("{}\n".format(vocab.id2word(i))) 173 | #return 174 | 175 | val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) 176 | test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) 177 | 178 | print('Train data: %d samples' % len(train_data)) 179 | print('finish reading datasets, vocab size is %d' % len(vocab)) 180 | print('dropped sentences: %d' % train_data.dropped) 181 | sys.stdout.flush() 182 | 183 | log_niter = (len(train_data)//args.batch_size)//10 184 | 185 | model_init = uniform_initializer(0.01) 186 | emb_init = uniform_initializer(0.1) 187 | 188 | #device = torch.device("cuda" if args.cuda else "cpu") 189 | device = "cuda" if args.cuda else "cpu" 190 | args.device = device 191 | 192 | if args.enc_type == 'lstm': 193 | encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) 194 | args.enc_nh = args.dec_nh 195 | else: 196 | raise ValueError("the specified encoder type is not supported") 197 | 198 | decoder = LSTMDecoder(args, vocab, model_init, emb_init) 199 | vae = VAE(encoder, decoder, args).to(device) 200 | 201 | print('begin evaluation') 202 | vae.load_state_dict(torch.load(args.load_path)) 203 | vae.eval() 204 | with torch.no_grad(): 205 | test_data_batch, test_batch_labels = test_data.create_data_batch_labels(batch_size=args.batch_size, 206 | device=device, 207 | batch_first=True) 208 | 209 | # test(vae, test_data_batch, "TEST", args) 210 | # au, au_var = calc_au(vae, test_data_batch) 211 | # print("%d active units" % au) 212 | 213 | train_data_batch, train_batch_labels = train_data.create_data_batch_labels(batch_size=args.batch_size, 214 | device=device, 215 | batch_first=True) 216 | 217 | val_data_batch, val_batch_labels = val_data.create_data_batch_labels(batch_size=args.batch_size, 218 | device=device, 219 | batch_first=True) 220 | 221 | print("getting vectors for training") 222 | save_latents(args, vae, train_data_batch, train_batch_labels, "train") 223 | print("getting vectors for validating") 224 | save_latents(args, vae, val_data_batch, val_batch_labels, "val") 225 | print("getting vectors for testing") 226 | save_latents(args, vae, test_data_batch, test_batch_labels, "test") 227 | 228 | 229 | if __name__ == '__main__': 230 | args = init_config() 231 | main(args) 232 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | import torch 4 | from torch import nn, optim 5 | import subprocess 6 | 7 | class uniform_initializer(object): 8 | def __init__(self, stdv): 9 | self.stdv = stdv 10 | def __call__(self, tensor): 11 | nn.init.uniform_(tensor, -self.stdv, self.stdv) 12 | 13 | 14 | class xavier_normal_initializer(object): 15 | def __call__(self, tensor): 16 | nn.init.xavier_normal_(tensor) 17 | 18 | def reconstruct(model, test_data_batch, vocab, strategy, fname): 19 | hyps = [] 20 | refs = [] 21 | with open(fname, "w") as fout: 22 | #for i in range(10): 23 | # batch_data = test_data_batch[i] 24 | 25 | for batch_data in test_data_batch: 26 | decoded_batch = model.reconstruct(batch_data, strategy) 27 | 28 | source = [[vocab.id2word(id_.item()) for id_ in sent] for sent in batch_data] 29 | for j in range(len(batch_data)): 30 | ref = " ".join(source[j]) 31 | hyp = " ".join(decoded_batch[j]) 32 | fout.write("SOURCE: {}\n".format(ref)) 33 | fout.write("RECON: {}\n\n".format(hyp)) 34 | 35 | refs += [ref[len(""): -len("")]] 36 | if strategy == "beam": 37 | hyps += [hyp[len(""): -len("")]] 38 | else: 39 | hyps += [hyp[: -len("")]] 40 | 41 | fname_ref = fname + ".ref" 42 | fname_hyp = fname + ".hyp" 43 | with open(fname_ref, "w") as f: 44 | f.write("\n".join(refs)) 45 | with open(fname_hyp, "w") as f: 46 | f.write("\n".join(hyps)) 47 | call_multi_bleu_perl("scripts/multi-bleu.perl", fname_hyp, fname_ref, verbose=True) 48 | 49 | 50 | def calc_iwnll(model, test_data_batch, args, ns=100): 51 | 52 | report_nll_loss = 0 53 | report_num_words = report_num_sents = 0 54 | print("iw nll computing ", end="") 55 | for id_, i in enumerate(np.random.permutation(len(test_data_batch))): 56 | batch_data = test_data_batch[i] 57 | batch_size, sent_len = batch_data.size() 58 | 59 | # not predict start symbol 60 | report_num_words += (sent_len - 1) * batch_size 61 | 62 | report_num_sents += batch_size 63 | if id_ % (round(len(test_data_batch) / 20)) == 0: 64 | print('%d%% ' % (id_/(round(len(test_data_batch) / 20)) * 5), end="") 65 | sys.stdout.flush() 66 | 67 | loss = model.nll_iw(batch_data, nsamples=args.iw_nsamples, ns=ns) 68 | 69 | report_nll_loss += loss.sum().item() 70 | 71 | print() 72 | sys.stdout.flush() 73 | 74 | nll = report_nll_loss / report_num_sents 75 | ppl = np.exp(nll * report_num_sents / report_num_words) 76 | 77 | return nll, ppl 78 | 79 | # def calc_mi(model, test_data_batch): 80 | # mi = 0 81 | # num_examples = 0 82 | # for batch_data in test_data_batch: 83 | # batch_size = batch_data.size(0) 84 | # num_examples += batch_size 85 | # mutual_info = model.calc_mi_q(batch_data) 86 | # mi += mutual_info * batch_size 87 | 88 | # return mi / num_examples 89 | 90 | def calc_mi(model, test_data_batch): 91 | # calc_mi_v3 92 | import math 93 | from modules.utils import log_sum_exp 94 | 95 | mi = 0 96 | num_examples = 0 97 | 98 | mu_batch_list, logvar_batch_list = [], [] 99 | neg_entropy = 0. 100 | for batch_data in test_data_batch: 101 | mu, logvar = model.encoder.forward(batch_data) 102 | x_batch, nz = mu.size() 103 | ##print(x_batch, end=' ') 104 | num_examples += x_batch 105 | 106 | # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) 107 | neg_entropy += (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).sum().item() 108 | mu_batch_list += [mu.cpu()] 109 | logvar_batch_list += [logvar.cpu()] 110 | 111 | neg_entropy = neg_entropy / num_examples 112 | ##print() 113 | 114 | num_examples = 0 115 | log_qz = 0. 116 | for i in range(len(mu_batch_list)): 117 | ############### 118 | # get z_samples 119 | ############### 120 | mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() 121 | 122 | # [z_batch, 1, nz] 123 | if hasattr(model.encoder, 'reparameterize'): 124 | z_samples = model.encoder.reparameterize(mu, logvar, 1) 125 | else: 126 | z_samples = model.encoder.gaussian_enc.reparameterize(mu, logvar, 1) 127 | z_samples = z_samples.view(-1, 1, nz) 128 | num_examples += z_samples.size(0) 129 | 130 | ############### 131 | # compute density 132 | ############### 133 | # [1, x_batch, nz] 134 | #mu, logvar = mu_batch_list[i].cuda(), logvar_batch_list[i].cuda() 135 | #indices = list(np.random.choice(np.arange(len(mu_batch_list)), 10)) + [i] 136 | indices = np.arange(len(mu_batch_list)) 137 | mu = torch.cat([mu_batch_list[_] for _ in indices], dim=0).cuda() 138 | logvar = torch.cat([logvar_batch_list[_] for _ in indices], dim=0).cuda() 139 | x_batch, nz = mu.size() 140 | 141 | mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) 142 | var = logvar.exp() 143 | 144 | # (z_batch, x_batch, nz) 145 | dev = z_samples - mu 146 | 147 | # (z_batch, x_batch) 148 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 149 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 150 | 151 | # log q(z): aggregate posterior 152 | # [z_batch] 153 | log_qz += (log_sum_exp(log_density, dim=1) - math.log(x_batch)).sum(-1) 154 | 155 | log_qz /= num_examples 156 | mi = neg_entropy - log_qz 157 | 158 | return mi 159 | 160 | 161 | def calc_au(model, test_data_batch, delta=0.01): 162 | """compute the number of active units 163 | """ 164 | cnt = 0 165 | for batch_data in test_data_batch: 166 | mean, _ = model.encode_stats(batch_data) 167 | if cnt == 0: 168 | means_sum = mean.sum(dim=0, keepdim=True) 169 | else: 170 | means_sum = means_sum + mean.sum(dim=0, keepdim=True) 171 | cnt += mean.size(0) 172 | 173 | # (1, nz) 174 | mean_mean = means_sum / cnt 175 | 176 | cnt = 0 177 | for batch_data in test_data_batch: 178 | mean, _ = model.encode_stats(batch_data) 179 | if cnt == 0: 180 | var_sum = ((mean - mean_mean) ** 2).sum(dim=0) 181 | else: 182 | var_sum = var_sum + ((mean - mean_mean) ** 2).sum(dim=0) 183 | cnt += mean.size(0) 184 | 185 | # (nz) 186 | au_var = var_sum / (cnt - 1) 187 | 188 | return (au_var >= delta).sum().item(), au_var 189 | 190 | 191 | def sample_sentences(vae, vocab, device, num_sentences): 192 | global logging 193 | 194 | vae.eval() 195 | sampled_sents = [] 196 | for i in range(num_sentences): 197 | z = vae.sample_from_prior(1) 198 | z = z.view(1,1,-1) 199 | start = vocab.word2id[''] 200 | # START = torch.tensor([[[start]]]) 201 | START = torch.tensor([[start]]) 202 | end = vocab.word2id[''] 203 | START = START.to(device) 204 | z = z.to(device) 205 | vae.eval() 206 | sentence = vae.decoder.sample_text(START, z, end, device) 207 | decoded_sentence = vocab.decode_sentence(sentence) 208 | sampled_sents.append(decoded_sentence) 209 | for i, sent in enumerate(sampled_sents): 210 | logging(i,":",' '.join(sent)) 211 | 212 | # def visualize_latent(args, vae, device, test_data): 213 | # f = open('yelp_embeddings_z','w') 214 | # g = open('yelp_embeddings_labels','w') 215 | 216 | # test_data_batch, test_label_batch = test_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) 217 | # for i in range(len(test_data_batch)): 218 | # batch_data = test_data_batch[i] 219 | # batch_label = test_label_batch[i] 220 | # batch_size, sent_len = batch_data.size() 221 | # means, _ = vae.encoder.forward(batch_data) 222 | # for i in range(batch_size): 223 | # mean = means[i,:].cpu().detach().numpy().tolist() 224 | # for val in mean: 225 | # f.write(str(val)+'\t') 226 | # f.write('\n') 227 | # for label in batch_label: 228 | # g.write(label+'\n') 229 | # fo 230 | # print(mean.size()) 231 | # print(logvar.size()) 232 | # fooo 233 | 234 | def visualize_latent(args, epoch, vae, device, test_data): 235 | nsamples = 1 236 | 237 | with open(os.path.join(args.exp_dir, f'synthetic_latent_{epoch}.txt'),'w') as f: 238 | test_data_batch, test_label_batch = test_data.create_data_batch_labels(batch_size=args.batch_size, device=device, batch_first=True) 239 | for i in range(len(test_data_batch)): 240 | batch_data = test_data_batch[i] 241 | batch_label = test_label_batch[i] 242 | batch_size, sent_len = batch_data.size() 243 | samples, _ = vae.encoder.encode(batch_data, nsamples) 244 | for i in range(batch_size): 245 | for j in range(nsamples): 246 | sample = samples[i,j,:].cpu().detach().numpy().tolist() 247 | f.write(batch_label[i] + '\t' + ' '.join([str(val) for val in sample]) + '\n') 248 | 249 | 250 | def call_multi_bleu_perl(fname_bleu_script, fname_hyp, fname_ref, verbose=True): 251 | cmd = "perl %s %s < %s" % (fname_bleu_script, fname_ref, fname_hyp) 252 | popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, \ 253 | stderr=subprocess.PIPE, shell=True) 254 | popen.wait() 255 | try: 256 | bleu_result = popen.stdout.readline().strip().decode("utf-8") 257 | if verbose: 258 | print(bleu_result) 259 | bleu = float(bleu_result[7:bleu_result.index(',')]) 260 | stderrs = popen.stderr.readlines() 261 | if len(stderrs) > 1: 262 | for line in stderrs: 263 | print(line.strip()) 264 | except Exception as e: 265 | print(e) 266 | bleu = 0. 267 | return bleu 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | -------------------------------------------------------------------------------- /data/text_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | from collections import defaultdict 6 | 7 | 8 | class VocabEntry(object): 9 | """docstring for Vocab""" 10 | def __init__(self, word2id=None): 11 | super(VocabEntry, self).__init__() 12 | 13 | if word2id: 14 | self.word2id = word2id 15 | self.unk_id = word2id[''] 16 | else: 17 | self.word2id = dict() 18 | self.unk_id = 3 19 | self.word2id[''] = 0 20 | self.word2id[''] = 1 21 | self.word2id[''] = 2 22 | self.word2id[''] = self.unk_id 23 | 24 | self.id2word_ = {v: k for k, v in self.word2id.items()} 25 | 26 | def __getitem__(self, word): 27 | return self.word2id.get(word, self.unk_id) 28 | 29 | def __contains__(self, word): 30 | return word in self.word2id 31 | 32 | def __len__(self): 33 | return len(self.word2id) 34 | 35 | def add(self, word): 36 | if word not in self: 37 | wid = self.word2id[word] = len(self) 38 | self.id2word[wid] = word 39 | return wid 40 | 41 | else: 42 | return self[word] 43 | 44 | def id2word(self, wid): 45 | return self.id2word_[wid] 46 | 47 | def decode_sentence(self, sentence): 48 | decoded_sentence = [] 49 | for wid_t in sentence: 50 | wid = wid_t.item() 51 | decoded_sentence.append(self.id2word_[wid]) 52 | return decoded_sentence 53 | 54 | 55 | @staticmethod 56 | def from_corpus(fname): 57 | vocab = VocabEntry() 58 | with open(fname) as fin: 59 | for line in fin: 60 | _ = [vocab.add(word) for word in line.split()] 61 | 62 | return vocab 63 | 64 | 65 | class MonoTextData(object): 66 | """docstring for MonoTextData""" 67 | def __init__(self, fname, label=False, max_length=None, vocab=None): 68 | super(MonoTextData, self).__init__() 69 | 70 | self.data, self.vocab, self.dropped, self.labels = self._read_corpus(fname, label, max_length, vocab) 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | 75 | def _read_corpus(self, fname, label, max_length, vocab): 76 | data = [] 77 | labels = [] if label else None 78 | dropped = 0 79 | if not vocab: 80 | vocab = defaultdict(lambda: len(vocab)) 81 | vocab[''] = 0 82 | vocab[''] = 1 83 | vocab[''] = 2 84 | vocab[''] = 3 85 | 86 | with open(fname) as fin: 87 | for line in fin: 88 | if label: 89 | split_line = line.split('\t') 90 | lb = split_line[0] 91 | split_line = split_line[1].split() 92 | else: 93 | split_line = line.split() 94 | if len(split_line) < 1: 95 | dropped += 1 96 | continue 97 | 98 | if max_length: 99 | if len(split_line) > max_length: 100 | dropped += 1 101 | continue 102 | 103 | if label: 104 | labels.append(lb) 105 | data.append([vocab[word] for word in split_line]) 106 | 107 | if isinstance(vocab, VocabEntry): 108 | return data, vocab, dropped, labels 109 | 110 | return data, VocabEntry(vocab), dropped, labels 111 | 112 | def _to_tensor(self, batch_data, batch_first, device): 113 | """pad a list of sequences, and transform them to tensors 114 | Args: 115 | batch_data: a batch of sentences (list) that are composed of 116 | word ids. 117 | batch_first: If true, the returned tensor shape is 118 | (batch, seq_len), otherwise (seq_len, batch) 119 | device: torch.device 120 | Returns: Tensor, Int list 121 | Tensor: Tensor of the batch data after padding 122 | Int list: a list of integers representing the length 123 | of each sentence (including start and stop symbols) 124 | """ 125 | 126 | 127 | # pad stop symbol 128 | batch_data = [sent + [self.vocab['']] for sent in batch_data] 129 | 130 | sents_len = [len(sent) for sent in batch_data] 131 | 132 | max_len = max(sents_len) 133 | 134 | batch_size = len(sents_len) 135 | sents_new = [] 136 | 137 | # pad start symbol 138 | sents_new.append([self.vocab['']] * batch_size) 139 | for i in range(max_len): 140 | sents_new.append([sent[i] if len(sent) > i else self.vocab[''] \ 141 | for sent in batch_data]) 142 | 143 | 144 | sents_ts = torch.tensor(sents_new, dtype=torch.long, 145 | requires_grad=False, device=device) 146 | 147 | if batch_first: 148 | sents_ts = sents_ts.permute(1, 0).contiguous() 149 | 150 | return sents_ts, [length + 1 for length in sents_len] 151 | 152 | 153 | def data_iter(self, batch_size, device, batch_first=False, shuffle=True): 154 | """pad data with start and stop symbol, and pad to the same length 155 | Returns: 156 | batch_data: LongTensor with shape (seq_len, batch_size) 157 | sents_len: list of data length, this is the data length 158 | after counting start and stop symbols 159 | """ 160 | index_arr = np.arange(len(self.data)) 161 | 162 | if shuffle: 163 | np.random.shuffle(index_arr) 164 | 165 | batch_num = int(np.ceil(len(index_arr)) / float(batch_size)) 166 | for i in range(batch_num): 167 | batch_ids = index_arr[i * batch_size : (i+1) * batch_size] 168 | batch_data = [self.data[index] for index in batch_ids] 169 | 170 | # uncomment this line if the dataset has variable length 171 | batch_data.sort(key=lambda e: -len(e)) 172 | 173 | batch_data, sents_len = self._to_tensor(batch_data, batch_first, device) 174 | 175 | yield batch_data, sents_len 176 | 177 | def create_data_batch_labels(self, batch_size, device, batch_first=False): 178 | """pad data with start and stop symbol, batching is performerd w.r.t. 179 | the sentence length, so that each returned batch has the same length, 180 | no further pack sequence function (e.g. pad_packed_sequence) is required 181 | Returns: List 182 | List: a list of batched data, each element is a tensor with shape 183 | (seq_len, batch_size) 184 | """ 185 | sents_len = np.array([len(sent) for sent in self.data]) 186 | sort_idx = np.argsort(sents_len) 187 | sort_len = sents_len[sort_idx] 188 | 189 | # record the locations where length changes 190 | change_loc = [] 191 | for i in range(1, len(sort_len)): 192 | if sort_len[i] != sort_len[i-1]: 193 | change_loc.append(i) 194 | change_loc.append(len(sort_len)) 195 | 196 | batch_data_list = [] 197 | batch_label_list = [] 198 | total = 0 199 | curr = 0 200 | for idx in change_loc: 201 | while curr < idx: 202 | batch_data = [] 203 | batch_label = [] 204 | next = min(curr + batch_size, idx) 205 | for id_ in range(curr, next): 206 | batch_data.append(self.data[sort_idx[id_]]) 207 | batch_label.append(self.labels[sort_idx[id_]]) 208 | curr = next 209 | batch_data, sents_len = self._to_tensor(batch_data, batch_first, device) 210 | batch_data_list.append(batch_data) 211 | batch_label_list.append(batch_label) 212 | 213 | total += batch_data.size(0) 214 | assert(sents_len == ([sents_len[0]] * len(sents_len))) 215 | 216 | assert(total == len(self.data)) 217 | return batch_data_list, batch_label_list 218 | 219 | def create_data_batch(self, batch_size, device, batch_first=False): 220 | """pad data with start and stop symbol, batching is performerd w.r.t. 221 | the sentence length, so that each returned batch has the same length, 222 | no further pack sequence function (e.g. pad_packed_sequence) is required 223 | Returns: List 224 | List: a list of batched data, each element is a tensor with shape 225 | (seq_len, batch_size) 226 | """ 227 | sents_len = np.array([len(sent) for sent in self.data]) 228 | sort_idx = np.argsort(sents_len) 229 | sort_len = sents_len[sort_idx] 230 | 231 | # record the locations where length changes 232 | change_loc = [] 233 | for i in range(1, len(sort_len)): 234 | if sort_len[i] != sort_len[i-1]: 235 | change_loc.append(i) 236 | change_loc.append(len(sort_len)) 237 | 238 | batch_data_list = [] 239 | total = 0 240 | curr = 0 241 | for idx in change_loc: 242 | while curr < idx: 243 | batch_data = [] 244 | next = min(curr + batch_size, idx) 245 | for id_ in range(curr, next): 246 | batch_data.append(self.data[sort_idx[id_]]) 247 | curr = next 248 | batch_data, sents_len = self._to_tensor(batch_data, batch_first, device) 249 | batch_data_list.append(batch_data) 250 | 251 | total += batch_data.size(0) 252 | assert(sents_len == ([sents_len[0]] * len(sents_len))) 253 | 254 | assert(total == len(self.data)) 255 | return batch_data_list 256 | 257 | 258 | def data_sample(self, nsample, device, batch_first=False, shuffle=True): 259 | """sample a subset of data (like data_iter) 260 | Returns: 261 | batch_data: LongTensor with shape (seq_len, batch_size) 262 | sents_len: list of data length, this is the data length 263 | after counting start and stop symbols 264 | """ 265 | 266 | index_arr = np.arange(len(self.data)) 267 | 268 | if shuffle: 269 | np.random.shuffle(index_arr) 270 | 271 | batch_ids = index_arr[: nsample] 272 | batch_data = [self.data[index] for index in batch_ids] 273 | 274 | # uncomment this line if the dataset has variable length 275 | batch_data.sort(key=lambda e: -len(e)) 276 | 277 | batch_data, sents_len = self._to_tensor(batch_data, batch_first, device) 278 | 279 | return batch_data, sents_len 280 | -------------------------------------------------------------------------------- /modules/vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .utils import log_sum_exp 6 | from .lm import LSTM_LM 7 | 8 | 9 | class VAE(nn.Module): 10 | """VAE with normal prior""" 11 | def __init__(self, encoder, decoder, args): 12 | super(VAE, self).__init__() 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | 16 | self.args = args 17 | 18 | self.nz = args.nz 19 | 20 | loc = torch.zeros(self.nz, device=args.device) 21 | scale = torch.ones(self.nz, device=args.device) 22 | 23 | self.prior = torch.distributions.normal.Normal(loc, scale) 24 | 25 | def encode(self, x, nsamples=1): 26 | """ 27 | Returns: Tensor1, Tensor2 28 | Tensor1: the tensor latent z with shape [batch, nsamples, nz] 29 | Tensor2: the tenor of KL for each x with shape [batch] 30 | """ 31 | return self.encoder.encode(x, nsamples) 32 | 33 | def encode_stats(self, x): 34 | """ 35 | Returns: Tensor1, Tensor2 36 | Tensor1: the mean of latent z with shape [batch, nz] 37 | Tensor2: the logvar of latent z with shape [batch, nz] 38 | """ 39 | 40 | return self.encoder.encode_stats(x) 41 | 42 | def decode(self, z, strategy, K=10): 43 | """generate samples from z given strategy 44 | 45 | Args: 46 | z: [batch, nsamples, nz] 47 | strategy: "beam" or "greedy" or "sample" 48 | K: the beam width parameter 49 | 50 | Returns: List1 51 | List1: a list of decoded word sequence 52 | """ 53 | 54 | if strategy == "beam": 55 | return self.decoder.beam_search_decode(z, K) 56 | elif strategy == "greedy": 57 | return self.decoder.greedy_decode(z) 58 | elif strategy == "sample": 59 | return self.decoder.sample_decode(z) 60 | else: 61 | raise ValueError("the decoding strategy is not supported") 62 | 63 | 64 | def reconstruct(self, x, decoding_strategy="greedy", K=5): 65 | """reconstruct from input x 66 | 67 | Args: 68 | x: (batch, *) 69 | decoding_strategy: "beam" or "greedy" or "sample" 70 | K: the beam width parameter 71 | 72 | Returns: List1 73 | List1: a list of decoded word sequence 74 | """ 75 | z = self.sample_from_inference(x).squeeze(1) 76 | 77 | return self.decode(z, decoding_strategy, K) 78 | 79 | 80 | def loss(self, x, kl_weight, nsamples=1): 81 | """ 82 | Args: 83 | x: if the data is constant-length, x is the data tensor with 84 | shape (batch, *). Otherwise x is a tuple that contains 85 | the data tensor and length list 86 | 87 | Returns: Tensor1, Tensor2, Tensor3 88 | Tensor1: total loss [batch] 89 | Tensor2: reconstruction loss shape [batch] 90 | Tensor3: KL loss shape [batch] 91 | """ 92 | 93 | z, KL = self.encode(x, nsamples) 94 | 95 | # (batch) 96 | reconstruct_err = self.decoder.reconstruct_error(x, z).mean(dim=1) 97 | 98 | 99 | return reconstruct_err + kl_weight * KL, reconstruct_err, KL 100 | 101 | 102 | def loss_iw(self, x, kl_weight, nsamples=50, ns=10): 103 | """ 104 | Args: 105 | x: if the data is constant-length, x is the data tensor with 106 | shape (batch, *). Otherwise x is a tuple that contains 107 | the data tensor and length list 108 | 109 | Returns: Tensor1, Tensor2, Tensor3 110 | Tensor1: total loss [batch] 111 | Tensor2: reconstruction loss shape [batch] 112 | Tensor3: KL loss shape [batch] 113 | """ 114 | 115 | mu, logvar = self.encoder.forward(x) 116 | 117 | ################## 118 | # compute KL 119 | ################## 120 | KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 121 | 122 | tmp = [] 123 | reconstruct_err_sum = 0 124 | 125 | #import pdb 126 | 127 | for _ in range(int(nsamples / ns)): 128 | 129 | # (batch, nsamples, nz) 130 | z = self.encoder.reparameterize(mu, logvar, ns) 131 | 132 | ################## 133 | # compute qzx 134 | ################## 135 | nz = z.size(2) 136 | 137 | # (batch_size, 1, nz) 138 | _mu, _logvar = mu.unsqueeze(1), logvar.unsqueeze(1) 139 | var = _logvar.exp() 140 | 141 | # (batch_size, nsamples, nz) 142 | dev = z - _mu 143 | 144 | # (batch_size, nsamples) 145 | log_qzx = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 146 | 0.5 * (nz * math.log(2 * math.pi) + _logvar.sum(-1)) 147 | 148 | ################## 149 | # compute qzx 150 | ################## 151 | log_pz = (-0.5 * math.log(2*math.pi) - z**2 / 2).sum(dim=-1) 152 | 153 | 154 | ################## 155 | # compute reconstruction loss 156 | ################## 157 | # (batch) 158 | reconstruct_err = self.decoder.reconstruct_error(x, z) 159 | reconstruct_err_sum += reconstruct_err.cpu().detach().sum(dim=1) 160 | 161 | #pdb.set_trace() 162 | 163 | tmp.append(reconstruct_err + kl_weight * (log_qzx - log_pz)) 164 | 165 | nll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) 166 | 167 | return nll_iw, reconstruct_err_sum / nsamples, KL 168 | 169 | 170 | def nll_iw(self, x, nsamples, ns=100): 171 | """compute the importance weighting estimate of the log-likelihood 172 | Args: 173 | x: if the data is constant-length, x is the data tensor with 174 | shape (batch, *). Otherwise x is a tuple that contains 175 | the data tensor and length list 176 | nsamples: Int 177 | the number of samples required to estimate marginal data likelihood 178 | Returns: Tensor1 179 | Tensor1: the estimate of log p(x), shape [batch] 180 | """ 181 | 182 | # compute iw every ns samples to address the memory issue 183 | # nsamples = 500, ns = 100 184 | # nsamples = 500, ns = 10 185 | 186 | # TODO: note that x is forwarded twice in self.encoder.sample(x, ns) and self.eval_inference_dist(x, z, param) 187 | #. this problem is to be solved in order to speed up 188 | 189 | tmp = [] 190 | for _ in range(int(nsamples / ns)): 191 | # [batch, ns, nz] 192 | # param is the parameters required to evaluate q(z|x) 193 | z, param = self.encoder.sample(x, ns) 194 | 195 | # [batch, ns] 196 | log_comp_ll = self.eval_complete_ll(x, z) 197 | log_infer_ll = self.eval_inference_dist(x, z, param) 198 | 199 | tmp.append(log_comp_ll - log_infer_ll) 200 | 201 | ll_iw = log_sum_exp(torch.cat(tmp, dim=-1), dim=-1) - math.log(nsamples) 202 | 203 | return -ll_iw 204 | 205 | def KL(self, x): 206 | _, KL = self.encode(x, 1) 207 | 208 | return KL 209 | 210 | def eval_prior_dist(self, zrange): 211 | """perform grid search to calculate the true posterior 212 | Args: 213 | zrange: tensor 214 | different z points that will be evaluated, with 215 | shape (k^2, nz), where k=(zmax - zmin)/space 216 | """ 217 | 218 | # (k^2) 219 | return self.prior.log_prob(zrange).sum(dim=-1) 220 | 221 | def eval_complete_ll(self, x, z): 222 | """compute log p(z,x) 223 | Args: 224 | x: Tensor 225 | input with shape [batch, seq_len] 226 | z: Tensor 227 | evaluation points with shape [batch, nsamples, nz] 228 | Returns: Tensor1 229 | Tensor1: log p(z,x) Tensor with shape [batch, nsamples] 230 | """ 231 | 232 | # [batch, nsamples] 233 | log_prior = self.eval_prior_dist(z) 234 | log_gen = self.eval_cond_ll(x, z) 235 | 236 | return log_prior + log_gen 237 | 238 | def eval_cond_ll(self, x, z): 239 | """compute log p(x|z) 240 | """ 241 | 242 | return self.decoder.log_probability(x, z) 243 | 244 | def eval_log_model_posterior(self, x, grid_z): 245 | """perform grid search to calculate the true posterior 246 | this function computes p(z|x) 247 | Args: 248 | grid_z: tensor 249 | different z points that will be evaluated, with 250 | shape (k^2, nz), where k=(zmax - zmin)/pace 251 | 252 | Returns: Tensor 253 | Tensor: the log posterior distribution log p(z|x) with 254 | shape [batch_size, K^2] 255 | """ 256 | try: 257 | batch_size = x.size(0) 258 | except: 259 | batch_size = x[0].size(0) 260 | 261 | # (batch_size, k^2, nz) 262 | grid_z = grid_z.unsqueeze(0).expand(batch_size, *grid_z.size()).contiguous() 263 | 264 | # (batch_size, k^2) 265 | log_comp = self.eval_complete_ll(x, grid_z) 266 | 267 | # normalize to posterior 268 | log_posterior = log_comp - log_sum_exp(log_comp, dim=1, keepdim=True) 269 | 270 | return log_posterior 271 | 272 | def sample_from_inference(self, x, nsamples=1): 273 | """perform sampling from inference net 274 | Returns: Tensor 275 | Tensor: samples from infernece nets with 276 | shape (batch_size, nsamples, nz) 277 | """ 278 | z, _ = self.encoder.sample(x, nsamples) 279 | 280 | return z 281 | 282 | 283 | def sample_from_posterior(self, x, nsamples): 284 | """perform MH sampling from model posterior 285 | Returns: Tensor 286 | Tensor: samples from model posterior with 287 | shape (batch_size, nsamples, nz) 288 | """ 289 | 290 | # use the samples from inference net as initial points 291 | # for MCMC sampling. [batch_size, nsamples, nz] 292 | cur = self.encoder.sample_from_inference(x, 1) 293 | cur_ll = self.eval_complete_ll(x, cur) 294 | total_iter = self.args.mh_burn_in + nsamples * self.args.mh_thin 295 | samples = [] 296 | for iter_ in range(total_iter): 297 | next = torch.normal(mean=cur, 298 | std=cur.new_full(size=cur.size(), fill_value=self.args.mh_std)) 299 | # [batch_size, 1] 300 | next_ll = self.eval_complete_ll(x, next) 301 | ratio = next_ll - cur_ll 302 | 303 | accept_prob = torch.min(ratio.exp(), ratio.new_ones(ratio.size())) 304 | 305 | uniform_t = accept_prob.new_empty(accept_prob.size()).uniform_() 306 | 307 | # [batch_size, 1] 308 | mask = (uniform_t < accept_prob).float() 309 | 310 | mask_ = mask.unsqueeze(2) 311 | 312 | cur = mask_ * next + (1 - mask_) * cur 313 | cur_ll = mask * next_ll + (1 - mask) * cur_ll 314 | 315 | if iter_ >= self.args.mh_burn_in and (iter_ - self.args.mh_burn_in) % self.args.mh_thin == 0: 316 | samples.append(cur.unsqueeze(1)) 317 | 318 | 319 | return torch.cat(samples, dim=1) 320 | 321 | def calc_model_posterior_mean(self, x, grid_z): 322 | """compute the mean value of model posterior, i.e. E_{z ~ p(z|x)}[z] 323 | Args: 324 | grid_z: different z points that will be evaluated, with 325 | shape (k^2, nz), where k=(zmax - zmin)/pace 326 | x: [batch, *] 327 | 328 | Returns: Tensor1 329 | Tensor1: the mean value tensor with shape [batch, nz] 330 | 331 | """ 332 | 333 | # [batch, K^2] 334 | log_posterior = self.eval_log_model_posterior(x, grid_z) 335 | posterior = log_posterior.exp() 336 | 337 | # [batch, nz] 338 | return torch.mul(posterior.unsqueeze(2), grid_z.unsqueeze(0)).sum(1) 339 | 340 | def calc_infer_mean(self, x): 341 | """ 342 | Returns: Tensor1 343 | Tensor1: the mean of inference distribution, with shape [batch, nz] 344 | """ 345 | 346 | mean, logvar = self.encoder.forward(x) 347 | 348 | return mean 349 | 350 | 351 | 352 | def eval_inference_dist(self, x, z, param=None): 353 | """ 354 | Returns: Tensor 355 | Tensor: the posterior density tensor with 356 | shape (batch_size, nsamples) 357 | """ 358 | return self.encoder.eval_inference_dist(x, z, param) 359 | 360 | def calc_mi_q(self, x): 361 | """Approximate the mutual information between x and z 362 | under distribution q(z|x) 363 | 364 | Args: 365 | x: [batch_size, *]. The sampled data to estimate mutual info 366 | """ 367 | 368 | return self.encoder.calc_mi(x) 369 | -------------------------------------------------------------------------------- /text_ss_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import importlib 4 | import argparse 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch import nn, optim 10 | 11 | from collections import defaultdict 12 | 13 | from data import MonoTextData, VocabEntry 14 | from modules import VAE, LinearDiscriminator, MLPDiscriminator 15 | from modules import GaussianLSTMEncoder, LSTMDecoder 16 | 17 | from exp_utils import create_exp_dir 18 | from utils import uniform_initializer, xavier_normal_initializer, calc_iwnll, calc_mi, calc_au, sample_sentences, visualize_latent, reconstruct 19 | 20 | # old parameters 21 | clip_grad = 5.0 22 | decay_epoch = 2 23 | lr_decay = 0.5 24 | max_decay = 5 25 | 26 | # Junxian's new parameters 27 | # clip_grad = 1.0 28 | # decay_epoch = 5 29 | # lr_decay = 0.5 30 | # max_decay = 5 31 | 32 | logging = None 33 | 34 | def init_config(): 35 | parser = argparse.ArgumentParser(description='VAE mode collapse study') 36 | 37 | # model hyperparameters 38 | parser.add_argument('--dataset', type=str, required=True, help='dataset to use') 39 | # optimization parameters 40 | parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') 41 | parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') 42 | 43 | parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') 44 | parser.add_argument('--iw_nsamples', type=int, default=500, 45 | help='number of samples to compute importance weighted estimate') 46 | 47 | # select mode 48 | parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') 49 | parser.add_argument('--load_path', type=str, default='') 50 | 51 | # decoding 52 | parser.add_argument('--reconstruct_from', type=str, default='', help="the model checkpoint path") 53 | parser.add_argument('--reconstruct_to', type=str, default="decoding.txt", help="save file") 54 | parser.add_argument('--decoding_strategy', type=str, choices=["greedy", "beam", "sample"], default="greedy") 55 | 56 | # annealing paramters 57 | parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs. warm_up=0 means not anneal") 58 | parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") 59 | 60 | 61 | # inference parameters 62 | parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed') 63 | 64 | # output directory 65 | parser.add_argument('--exp_dir', default=None, type=str, 66 | help='experiment directory.') 67 | parser.add_argument("--save_ckpt", type=int, default=0, 68 | help="save checkpoint every epoch before this number") 69 | parser.add_argument("--save_latent", type=int, default=0) 70 | 71 | # new 72 | parser.add_argument("--fix_var", type=float, default=-1) 73 | parser.add_argument("--reset_dec", action="store_true", default=False) 74 | parser.add_argument("--load_best_epoch", type=int, default=0) 75 | parser.add_argument("--lr", type=float, default=1.) 76 | 77 | parser.add_argument("--fb", type=int, default=0, 78 | help="0: no fb; 1: fb; 2: max(target_kl, kl) for each dimension") 79 | parser.add_argument("--target_kl", type=float, default=-1, 80 | help="target kl of the free bits trick") 81 | 82 | parser.add_argument("--batch_size", type=int, default=16, 83 | help="target kl of the free bits trick") 84 | parser.add_argument("--update_every", type=int, default=1, 85 | help="target kl of the free bits trick") 86 | parser.add_argument("--num_label", type=int, default=100, 87 | help="target kl of the free bits trick") 88 | parser.add_argument("--freeze_enc", action="store_true", default=False) 89 | parser.add_argument("--discriminator", type=str, default="linear") 90 | 91 | args = parser.parse_args() 92 | 93 | # set args.cuda 94 | args.cuda = torch.cuda.is_available() 95 | 96 | # set seeds 97 | # seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] 98 | # args.seed = seed_set[args.taskid] 99 | np.random.seed(args.seed) 100 | torch.manual_seed(args.seed) 101 | if args.cuda: 102 | torch.cuda.manual_seed(args.seed) 103 | torch.backends.cudnn.deterministic = True 104 | 105 | # load config file into args 106 | config_file = "config.config_%s" % args.dataset 107 | if args.num_label == 100: 108 | params = importlib.import_module(config_file).params_ss_100 109 | elif args.num_label == 500: 110 | params = importlib.import_module(config_file).params_ss_500 111 | elif args.num_label == 1000: 112 | params = importlib.import_module(config_file).params_ss_1000 113 | elif args.num_label == 2000: 114 | params = importlib.import_module(config_file).params_ss_2000 115 | elif args.num_label == 10000: 116 | params = importlib.import_module(config_file).params_ss_10000 117 | 118 | args = argparse.Namespace(**vars(args), **params) 119 | 120 | load_str = "_load" if args.load_path != "" else "" 121 | if args.fb == 0: 122 | fb_str = "" 123 | elif args.fb == 1: 124 | fb_str = "_fb" 125 | elif args.fb == 2: 126 | fb_str = "_fbdim" 127 | 128 | opt_str = "_adam" if args.opt == "adam" else "_sgd" 129 | nlabel_str = "_nlabel{}".format(args.num_label) 130 | freeze_str = "_freeze" if args.freeze_enc else "" 131 | 132 | if len(args.load_path.split("/")) > 2: 133 | load_path_str = args.load_path.split("/")[1] 134 | else: 135 | load_path_str = args.load_path.split("/")[0] 136 | 137 | model_str = "_{}".format(args.discriminator) 138 | # set load and save paths 139 | if args.exp_dir == None: 140 | args.exp_dir = "exp_{}{}_ss_ft/{}{}{}{}{}".format(args.dataset, 141 | load_str, load_path_str, model_str, opt_str, nlabel_str, freeze_str) 142 | 143 | 144 | if len(args.load_path) <= 0 and args.eval: 145 | args.load_path = os.path.join(args.exp_dir, 'model.pt') 146 | 147 | args.save_path = os.path.join(args.exp_dir, 'model.pt') 148 | 149 | # set args.label 150 | if 'label' in params: 151 | args.label = params['label'] 152 | else: 153 | args.label = False 154 | 155 | return args 156 | 157 | 158 | def test(model, test_data_batch, test_labels_batch, mode, args, verbose=True): 159 | global logging 160 | 161 | report_correct = report_loss = 0 162 | report_num_words = report_num_sents = 0 163 | for i in np.random.permutation(len(test_data_batch)): 164 | batch_data = test_data_batch[i] 165 | batch_labels = test_labels_batch[i] 166 | batch_labels = [int(x) for x in batch_labels] 167 | 168 | batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=args.device) 169 | 170 | batch_size, sent_len = batch_data.size() 171 | 172 | # not predict start symbol 173 | report_num_words += (sent_len - 1) * batch_size 174 | report_num_sents += batch_size 175 | loss, correct = model.get_performance(batch_data, batch_labels) 176 | 177 | loss = loss.sum() 178 | 179 | report_loss += loss.item() 180 | report_correct += correct 181 | 182 | test_loss = report_loss / report_num_sents 183 | acc = report_correct / report_num_sents 184 | 185 | if verbose: 186 | logging('%s --- avg_loss: %.4f, acc: %.4f' % \ 187 | (mode, test_loss, acc)) 188 | #sys.stdout.flush() 189 | 190 | return test_loss, acc 191 | 192 | 193 | def main(args): 194 | global logging 195 | logging = create_exp_dir(args.exp_dir, scripts_to_save=[]) 196 | 197 | if args.cuda: 198 | logging('using cuda') 199 | logging(str(args)) 200 | 201 | opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} 202 | 203 | vocab = {} 204 | with open(args.vocab_file) as fvocab: 205 | for i, line in enumerate(fvocab): 206 | vocab[line.strip()] = i 207 | 208 | vocab = VocabEntry(vocab) 209 | 210 | train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab) 211 | 212 | vocab_size = len(vocab) 213 | 214 | val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) 215 | test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) 216 | 217 | logging('Train data: %d samples' % len(train_data)) 218 | logging('finish reading datasets, vocab size is %d' % len(vocab)) 219 | logging('dropped sentences: %d' % train_data.dropped) 220 | #sys.stdout.flush() 221 | 222 | log_niter = max(1, (len(train_data)//(args.batch_size * args.update_every))//10) 223 | 224 | model_init = uniform_initializer(0.01) 225 | emb_init = uniform_initializer(0.1) 226 | 227 | #device = torch.device("cuda" if args.cuda else "cpu") 228 | device = "cuda" if args.cuda else "cpu" 229 | args.device = device 230 | 231 | if args.enc_type == 'lstm': 232 | encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) 233 | args.enc_nh = args.dec_nh 234 | else: 235 | raise ValueError("the specified encoder type is not supported") 236 | 237 | decoder = LSTMDecoder(args, vocab, model_init, emb_init) 238 | vae = VAE(encoder, decoder, args).to(device) 239 | 240 | if args.load_path: 241 | loaded_state_dict = torch.load(args.load_path) 242 | #curr_state_dict = vae.state_dict() 243 | #curr_state_dict.update(loaded_state_dict) 244 | vae.load_state_dict(loaded_state_dict) 245 | logging("%s loaded" % args.load_path) 246 | 247 | if args.eval: 248 | logging('begin evaluation') 249 | vae.load_state_dict(torch.load(args.load_path)) 250 | vae.eval() 251 | with torch.no_grad(): 252 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 253 | device=device, 254 | batch_first=True) 255 | 256 | test(vae, test_data_batch, test_labels_batch, "TEST", args) 257 | au, au_var = calc_au(vae, test_data_batch) 258 | logging("%d active units" % au) 259 | # print(au_var) 260 | 261 | test_data_batch = test_data.create_data_batch(batch_size=1, 262 | device=device, 263 | batch_first=True) 264 | calc_iwnll(vae, test_data_batch, args) 265 | 266 | return 267 | 268 | if args.discriminator == "linear": 269 | discriminator = LinearDiscriminator(args, vae.encoder).to(device) 270 | elif args.discriminator == "mlp": 271 | discriminator = MLPDiscriminator(args, vae.encoder).to(device) 272 | 273 | if args.opt == "sgd": 274 | optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum) 275 | opt_dict['lr'] = args.lr 276 | elif args.opt == "adam": 277 | optimizer = optim.Adam(discriminator.parameters(), lr=0.001) 278 | opt_dict['lr'] = 0.001 279 | else: 280 | raise ValueError("optimizer not supported") 281 | 282 | iter_ = decay_cnt = 0 283 | best_loss = 1e4 284 | best_kl = best_nll = best_ppl = 0 285 | pre_mi = 0 286 | discriminator.train() 287 | start = time.time() 288 | 289 | kl_weight = args.kl_start 290 | if args.warm_up > 0: 291 | anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) 292 | else: 293 | anneal_rate = 0 294 | 295 | dim_target_kl = args.target_kl / float(args.nz) 296 | 297 | train_data_batch, train_labels_batch = train_data.create_data_batch_labels(batch_size=args.batch_size, 298 | device=device, 299 | batch_first=True) 300 | 301 | val_data_batch, val_labels_batch = val_data.create_data_batch_labels(batch_size=128, 302 | device=device, 303 | batch_first=True) 304 | 305 | test_data_batch, test_labels_batch = test_data.create_data_batch_labels(batch_size=128, 306 | device=device, 307 | batch_first=True) 308 | 309 | acc_cnt = 1 310 | acc_loss = 0. 311 | for epoch in range(args.epochs): 312 | report_loss = 0 313 | report_correct = report_num_words = report_num_sents = 0 314 | acc_batch_size = 0 315 | optimizer.zero_grad() 316 | for i in np.random.permutation(len(train_data_batch)): 317 | 318 | batch_data = train_data_batch[i] 319 | batch_labels = train_labels_batch[i] 320 | batch_labels = [int(x) for x in batch_labels] 321 | 322 | batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device) 323 | 324 | batch_size, sent_len = batch_data.size() 325 | 326 | # not predict start symbol 327 | report_num_words += (sent_len - 1) * batch_size 328 | report_num_sents += batch_size 329 | acc_batch_size += batch_size 330 | 331 | # (batch_size) 332 | loss, correct = discriminator.get_performance(batch_data, batch_labels) 333 | 334 | acc_loss = acc_loss + loss.sum() 335 | 336 | if acc_cnt % args.update_every == 0: 337 | acc_loss = acc_loss / acc_batch_size 338 | acc_loss.backward() 339 | 340 | torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad) 341 | 342 | optimizer.step() 343 | optimizer.zero_grad() 344 | 345 | acc_cnt = 0 346 | acc_loss = 0 347 | acc_batch_size = 0 348 | 349 | acc_cnt += 1 350 | report_loss += loss.sum().item() 351 | report_correct += correct 352 | 353 | if iter_ % log_niter == 0: 354 | #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents 355 | train_loss = report_loss / report_num_sents 356 | 357 | 358 | logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \ 359 | 'time %.2fs' % 360 | (epoch, iter_, train_loss, report_correct / report_num_sents, 361 | time.time() - start)) 362 | 363 | #sys.stdout.flush() 364 | 365 | iter_ += 1 366 | 367 | logging('lr {}'.format(opt_dict["lr"])) 368 | 369 | discriminator.eval() 370 | 371 | with torch.no_grad(): 372 | loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args) 373 | # print(au_var) 374 | 375 | if loss < best_loss: 376 | logging('update best loss') 377 | best_loss = loss 378 | best_acc = acc 379 | torch.save(discriminator.state_dict(), args.save_path) 380 | 381 | if loss > opt_dict["best_loss"]: 382 | opt_dict["not_improved"] += 1 383 | if opt_dict["not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: 384 | opt_dict["best_loss"] = loss 385 | opt_dict["not_improved"] = 0 386 | opt_dict["lr"] = opt_dict["lr"] * lr_decay 387 | discriminator.load_state_dict(torch.load(args.save_path)) 388 | logging('new lr: %f' % opt_dict["lr"]) 389 | decay_cnt += 1 390 | if args.opt == "sgd": 391 | optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 392 | opt_dict['lr'] = opt_dict["lr"] 393 | elif args.opt == "adam": 394 | optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"]) 395 | opt_dict['lr'] = opt_dict["lr"] 396 | else: 397 | raise ValueError("optimizer not supported") 398 | 399 | else: 400 | opt_dict["not_improved"] = 0 401 | opt_dict["best_loss"] = loss 402 | 403 | if decay_cnt == max_decay: 404 | break 405 | 406 | if epoch % args.test_nepoch == 0: 407 | with torch.no_grad(): 408 | loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) 409 | 410 | discriminator.train() 411 | 412 | 413 | # compute importance weighted estimate of log p(x) 414 | discriminator.load_state_dict(torch.load(args.save_path)) 415 | discriminator.eval() 416 | 417 | with torch.no_grad(): 418 | loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args) 419 | # print(au_var) 420 | 421 | if __name__ == '__main__': 422 | args = init_config() 423 | main(args) 424 | -------------------------------------------------------------------------------- /text_beta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import importlib 5 | import argparse 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | 12 | from data import MonoTextData 13 | from modules import VAE 14 | from modules import GaussianLSTMEncoder, LSTMDecoder 15 | 16 | from exp_utils import create_exp_dir 17 | from utils import uniform_initializer, xavier_normal_initializer, calc_iwnll, calc_mi, calc_au, sample_sentences, visualize_latent, reconstruct 18 | 19 | clip_grad = 5.0 20 | decay_epoch = 5 21 | lr_decay = 0.5 22 | max_decay = 5 23 | 24 | ns=2 25 | 26 | logging = None 27 | 28 | def init_config(): 29 | parser = argparse.ArgumentParser(description='VAE mode collapse study') 30 | 31 | # model hyperparameters 32 | parser.add_argument('--dataset', type=str, required=True, help='dataset to use') 33 | 34 | # optimization parameters 35 | parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') 36 | parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') 37 | parser.add_argument('--lr', type=float, default=1.0) 38 | parser.add_argument('--nsamples', type=int, default=1, help='number of iw samples for training') 39 | parser.add_argument('--iw_train_nsamples', type=int, default=-1) 40 | parser.add_argument('--iw_train_ns', type=int, default=1, help='number of iw samples for training in each batch') 41 | parser.add_argument('--iw_nsamples', type=int, default=500, 42 | help='number of samples to compute importance weighted estimate') 43 | 44 | # select mode 45 | parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') 46 | parser.add_argument('--load_path', type=str, default='') 47 | 48 | # decoding 49 | parser.add_argument('--reconstruct_from', type=str, default='', help="the model checkpoint path") 50 | parser.add_argument('--reconstruct_to', type=str, default="decoding.txt", help="save file") 51 | parser.add_argument('--decoding_strategy', type=str, choices=["greedy", "beam", "sample"], default="greedy") 52 | 53 | # annealing paramters 54 | parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs") 55 | parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") 56 | 57 | # inference parameters 58 | parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed') 59 | 60 | # output directory 61 | parser.add_argument('--exp_dir', default=None, type=str, 62 | help='experiment directory.') 63 | parser.add_argument("--save_ckpt", type=int, default=0, 64 | help="save checkpoint every epoch before this number") 65 | parser.add_argument("--save_latent", type=int, default=0) 66 | 67 | # new 68 | parser.add_argument("--fix_var", type=float, default=-1) 69 | parser.add_argument("--freeze_epoch", type=int, default=-1) 70 | parser.add_argument("--reset_dec", action="store_true", default=False) 71 | parser.add_argument("--beta", type=float, default=1.0) 72 | parser.add_argument("--load_best_epoch", type=int, default=15) 73 | 74 | args = parser.parse_args() 75 | 76 | # set args.cuda 77 | args.cuda = torch.cuda.is_available() 78 | 79 | # set seeds 80 | # seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] 81 | # args.seed = seed_set[args.taskid] 82 | np.random.seed(args.seed) 83 | torch.manual_seed(args.seed) 84 | if args.cuda: 85 | torch.cuda.manual_seed(args.seed) 86 | torch.backends.cudnn.deterministic = True 87 | 88 | # load config file into args 89 | config_file = "config.config_%s" % args.dataset 90 | params = importlib.import_module(config_file).params 91 | args = argparse.Namespace(**vars(args), **params) 92 | 93 | # set load and save paths 94 | load_str = "_load" if args.load_path != "" else "" 95 | iw_str = "_iw{}".format(args.iw_train_nsamples) if args.iw_train_nsamples > 0 else "" 96 | 97 | if args.exp_dir == None: 98 | args.exp_dir = "exp_{}_beta/{}_lr{}_beta{}_drop{}_{}".format( 99 | args.dataset, args.dataset, args.lr, args.beta, args.dec_dropout_in, iw_str) 100 | 101 | if len(args.load_path) <= 0 and args.eval: 102 | args.load_path = os.path.join(args.exp_dir, 'model.pt') 103 | 104 | args.save_path = os.path.join(args.exp_dir, 'model.pt') 105 | 106 | # set args.label 107 | if 'label' in params: 108 | args.label = params['label'] 109 | else: 110 | args.label = False 111 | 112 | return args 113 | 114 | 115 | def test(model, test_data_batch, mode, args, verbose=True): 116 | global logging 117 | 118 | report_kl_loss = report_rec_loss = report_loss = 0 119 | report_num_words = report_num_sents = 0 120 | for i in np.random.permutation(len(test_data_batch)): 121 | batch_data = test_data_batch[i] 122 | batch_size, sent_len = batch_data.size() 123 | 124 | # not predict start symbol 125 | report_num_words += (sent_len - 1) * batch_size 126 | report_num_sents += batch_size 127 | #loss, loss_rc, loss_kl = model.loss(batch_data, args.beta, nsamples=args.nsamples) 128 | 129 | if args.iw_train_nsamples < 0: 130 | loss, loss_rc, loss_kl = model.loss(batch_data, args.beta, nsamples=args.nsamples) 131 | else: 132 | loss, loss_rc, loss_kl = model.loss_iw(batch_data, args.beta, nsamples=args.iw_train_nsamples, ns=ns) 133 | 134 | assert(not loss_rc.requires_grad) 135 | 136 | loss_rc = loss_rc.sum() 137 | loss_kl = loss_kl.sum() 138 | loss = loss.sum() 139 | 140 | report_rec_loss += loss_rc.item() 141 | report_kl_loss += loss_kl.item() 142 | report_loss += loss.item() 143 | 144 | mutual_info = calc_mi(model, test_data_batch) 145 | 146 | test_loss = report_loss / report_num_sents 147 | 148 | nll = (report_kl_loss + report_rec_loss) / report_num_sents 149 | kl = report_kl_loss / report_num_sents 150 | ppl = np.exp(nll * report_num_sents / report_num_words) 151 | if verbose: 152 | logging('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \ 153 | (mode, test_loss, report_kl_loss / report_num_sents, mutual_info, 154 | report_rec_loss / report_num_sents, nll, ppl)) 155 | #sys.stdout.flush() 156 | 157 | return test_loss, nll, kl, ppl, mutual_info 158 | 159 | 160 | def main(args): 161 | global logging 162 | debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction 163 | logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) 164 | 165 | if args.cuda: 166 | logging('using cuda') 167 | logging(str(args)) 168 | 169 | opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} 170 | 171 | train_data = MonoTextData(args.train_data, label=args.label) 172 | 173 | vocab = train_data.vocab 174 | vocab_size = len(vocab) 175 | 176 | val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) 177 | test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) 178 | 179 | logging('Train data: %d samples' % len(train_data)) 180 | logging('finish reading datasets, vocab size is %d' % len(vocab)) 181 | logging('dropped sentences: %d' % train_data.dropped) 182 | #sys.stdout.flush() 183 | 184 | log_niter = (len(train_data)//args.batch_size)//10 185 | 186 | model_init = uniform_initializer(0.01) 187 | emb_init = uniform_initializer(0.1) 188 | 189 | #device = torch.device("cuda" if args.cuda else "cpu") 190 | device = "cuda" if args.cuda else "cpu" 191 | args.device = device 192 | 193 | if args.enc_type == 'lstm': 194 | encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) 195 | args.enc_nh = args.dec_nh 196 | else: 197 | raise ValueError("the specified encoder type is not supported") 198 | 199 | decoder = LSTMDecoder(args, vocab, model_init, emb_init) 200 | vae = VAE(encoder, decoder, args).to(device) 201 | 202 | if args.load_path: 203 | loaded_state_dict = torch.load(args.load_path) 204 | #curr_state_dict = vae.state_dict() 205 | #curr_state_dict.update(loaded_state_dict) 206 | vae.load_state_dict(loaded_state_dict) 207 | logging("%s loaded" % args.load_path) 208 | 209 | if args.reset_dec: 210 | vae.decoder.reset_parameters(model_init, emb_init) 211 | 212 | 213 | if args.eval: 214 | logging('begin evaluation') 215 | vae.load_state_dict(torch.load(args.load_path)) 216 | vae.eval() 217 | with torch.no_grad(): 218 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 219 | device=device, 220 | batch_first=True) 221 | 222 | test(vae, test_data_batch, "TEST", args) 223 | au, au_var = calc_au(vae, test_data_batch) 224 | logging("%d active units" % au) 225 | # print(au_var) 226 | 227 | test_data_batch = test_data.create_data_batch(batch_size=1, 228 | device=device, 229 | batch_first=True) 230 | 231 | nll, ppl = calc_iwnll(vae, test_data_batch, args) 232 | logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) 233 | 234 | return 235 | 236 | if args.reconstruct_from != "": 237 | print("begin decoding") 238 | sys.stdout.flush() 239 | 240 | vae.load_state_dict(torch.load(args.reconstruct_from)) 241 | vae.eval() 242 | with torch.no_grad(): 243 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 244 | device=device, 245 | batch_first=True) 246 | # test(vae, test_data_batch, "TEST", args) 247 | reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) 248 | 249 | return 250 | 251 | if args.opt == "sgd": 252 | enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) 253 | dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) 254 | opt_dict['lr'] = args.lr 255 | elif args.opt == "adam": 256 | enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) 257 | dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) 258 | opt_dict['lr'] = 0.001 259 | else: 260 | raise ValueError("optimizer not supported") 261 | 262 | iter_ = decay_cnt = 0 263 | best_loss = 1e4 264 | best_kl = best_nll = best_ppl = 0 265 | pre_mi = 0 266 | vae.train() 267 | start = time.time() 268 | 269 | train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, 270 | device=device, 271 | batch_first=True) 272 | 273 | val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, 274 | device=device, 275 | batch_first=True) 276 | 277 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 278 | device=device, 279 | batch_first=True) 280 | 281 | # At any point you can hit Ctrl + C to break out of training early. 282 | try: 283 | for epoch in range(args.epochs): 284 | report_kl_loss = report_rec_loss = report_loss = 0 285 | report_num_words = report_num_sents = 0 286 | 287 | for i in np.random.permutation(len(train_data_batch)): 288 | 289 | batch_data = train_data_batch[i] 290 | batch_size, sent_len = batch_data.size() 291 | 292 | # not predict start symbol 293 | report_num_words += (sent_len - 1) * batch_size 294 | report_num_sents += batch_size 295 | 296 | kl_weight = args.beta 297 | 298 | enc_optimizer.zero_grad() 299 | dec_optimizer.zero_grad() 300 | 301 | if args.iw_train_nsamples < 0: 302 | loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) 303 | else: 304 | loss, loss_rc, loss_kl = vae.loss_iw(batch_data, kl_weight, nsamples=args.iw_train_nsamples, ns=ns) 305 | loss = loss.mean(dim=-1) 306 | 307 | loss.backward() 308 | torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) 309 | 310 | loss_rc = loss_rc.sum() 311 | loss_kl = loss_kl.sum() 312 | 313 | enc_optimizer.step() 314 | dec_optimizer.step() 315 | 316 | report_rec_loss += loss_rc.item() 317 | report_kl_loss += loss_kl.item() 318 | report_loss += loss.item() * batch_size 319 | 320 | if iter_ % log_niter == 0: 321 | #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents 322 | train_loss = report_loss / report_num_sents 323 | logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 324 | 'time elapsed %.2fs, kl_weight %.4f' % 325 | (epoch, iter_, train_loss, report_kl_loss / report_num_sents, 326 | report_rec_loss / report_num_sents, time.time() - start, kl_weight)) 327 | 328 | #sys.stdout.flush() 329 | 330 | report_rec_loss = report_kl_loss = report_loss = 0 331 | report_num_words = report_num_sents = 0 332 | 333 | iter_ += 1 334 | 335 | logging('kl weight %.4f' % kl_weight) 336 | 337 | vae.eval() 338 | with torch.no_grad(): 339 | loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) 340 | au, au_var = calc_au(vae, val_data_batch) 341 | logging("%d active units" % au) 342 | # print(au_var) 343 | 344 | if args.save_ckpt > 0 and epoch <= args.save_ckpt: 345 | logging('save checkpoint') 346 | torch.save(vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) 347 | 348 | if loss < best_loss: 349 | logging('update best loss') 350 | best_loss = loss 351 | best_nll = nll 352 | best_kl = kl 353 | best_ppl = ppl 354 | torch.save(vae.state_dict(), args.save_path) 355 | 356 | if loss > opt_dict["best_loss"]: 357 | opt_dict["not_improved"] += 1 358 | if opt_dict["not_improved"] >= decay_epoch and epoch >=args.load_best_epoch: 359 | opt_dict["best_loss"] = loss 360 | opt_dict["not_improved"] = 0 361 | opt_dict["lr"] = opt_dict["lr"] * lr_decay 362 | vae.load_state_dict(torch.load(args.save_path)) 363 | logging('new lr: %f' % opt_dict["lr"]) 364 | decay_cnt += 1 365 | enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 366 | dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 367 | 368 | else: 369 | opt_dict["not_improved"] = 0 370 | opt_dict["best_loss"] = loss 371 | 372 | if decay_cnt == max_decay: 373 | break 374 | 375 | if epoch % args.test_nepoch == 0: 376 | with torch.no_grad(): 377 | loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) 378 | 379 | if args.save_latent > 0 and epoch <= args.save_latent: 380 | visualize_latent(args, epoch, vae, "cuda", test_data) 381 | 382 | vae.train() 383 | 384 | except KeyboardInterrupt: 385 | logging('-' * 100) 386 | logging('Exiting from training early') 387 | 388 | # compute importance weighted estimate of log p(x) 389 | vae.load_state_dict(torch.load(args.save_path)) 390 | 391 | vae.eval() 392 | with torch.no_grad(): 393 | loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) 394 | au, au_var = calc_au(vae, test_data_batch) 395 | logging("%d active units" % au) 396 | # print(au_var) 397 | 398 | test_data_batch = test_data.create_data_batch(batch_size=1, 399 | device=device, 400 | batch_first=True) 401 | with torch.no_grad(): 402 | nll, ppl = calc_iwnll(vae, test_data_batch, args) 403 | logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) 404 | 405 | if __name__ == '__main__': 406 | args = init_config() 407 | main(args) 408 | -------------------------------------------------------------------------------- /modules/decoders/dec_lstm.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | 3 | import time 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 11 | 12 | import numpy as np 13 | 14 | from .decoder import DecoderBase 15 | from .decoder_helper import BeamSearchNode 16 | 17 | class LSTMDecoder(DecoderBase): 18 | """LSTM decoder with constant-length data""" 19 | def __init__(self, args, vocab, model_init, emb_init): 20 | super(LSTMDecoder, self).__init__() 21 | self.ni = args.ni 22 | self.nh = args.dec_nh 23 | self.nz = args.nz 24 | self.vocab = vocab 25 | self.device = args.device 26 | 27 | # no padding when setting padding_idx to -1 28 | self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) 29 | 30 | self.dropout_in = nn.Dropout(args.dec_dropout_in) 31 | self.dropout_out = nn.Dropout(args.dec_dropout_out) 32 | 33 | # for initializing hidden state and cell 34 | self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False) 35 | 36 | # concatenate z with input 37 | self.lstm = nn.LSTM(input_size=args.ni + args.nz, 38 | hidden_size=args.dec_nh, 39 | num_layers=1, 40 | batch_first=True) 41 | 42 | # prediction layer 43 | self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) 44 | 45 | vocab_mask = torch.ones(len(vocab)) 46 | # vocab_mask[vocab['']] = 0 47 | self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) 48 | 49 | self.reset_parameters(model_init, emb_init) 50 | 51 | def reset_parameters(self, model_init, emb_init): 52 | # for name, param in self.lstm.named_parameters(): 53 | # # self.initializer(param) 54 | # if 'bias' in name: 55 | # nn.init.constant_(param, 0.0) 56 | # # model_init(param) 57 | # elif 'weight' in name: 58 | # model_init(param) 59 | 60 | # model_init(self.trans_linear.weight) 61 | # model_init(self.pred_linear.weight) 62 | for param in self.parameters(): 63 | model_init(param) 64 | emb_init(self.embed.weight) 65 | 66 | def sample_text(self, input, z, EOS, device): 67 | sentence = [input] 68 | max_index = 0 69 | 70 | input_word = input 71 | batch_size, n_sample, _ = z.size() 72 | seq_len = 1 73 | z_ = z.expand(batch_size, seq_len, self.nz) 74 | seq_len = input.size(1) 75 | softmax = torch.nn.Softmax(dim=0) 76 | while max_index != EOS and len(sentence) < 100: 77 | # (batch_size, seq_len, ni) 78 | word_embed = self.embed(input_word) 79 | word_embed = torch.cat((word_embed, z_), -1) 80 | c_init = self.trans_linear(z).unsqueeze(0) 81 | h_init = torch.tanh(c_init) 82 | if len(sentence) == 1: 83 | h_init = h_init.squeeze(dim=1) 84 | c_init = c_init.squeeze(dim=1) 85 | output, hidden = self.lstm.forward(word_embed, (h_init, c_init)) 86 | else: 87 | output, hidden = self.lstm.forward(word_embed, hidden) 88 | # (batch_size * n_sample, seq_len, vocab_size) 89 | output_logits = self.pred_linear(output) 90 | output_logits = output_logits.view(-1) 91 | probs = softmax(output_logits) 92 | # max_index = torch.argmax(output_logits) 93 | max_index = torch.multinomial(probs, num_samples=1) 94 | input_word = torch.tensor([[max_index]]).to(device) 95 | sentence.append(max_index) 96 | return sentence 97 | 98 | def decode(self, input, z): 99 | """ 100 | Args: 101 | input: (batch_size, seq_len) 102 | z: (batch_size, n_sample, nz) 103 | """ 104 | 105 | # not predicting start symbol 106 | # sents_len -= 1 107 | 108 | batch_size, n_sample, _ = z.size() 109 | seq_len = input.size(1) 110 | 111 | # (batch_size, seq_len, ni) 112 | word_embed = self.embed(input) 113 | word_embed = self.dropout_in(word_embed) 114 | 115 | if n_sample == 1: 116 | z_ = z.expand(batch_size, seq_len, self.nz) 117 | 118 | else: 119 | word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ 120 | .contiguous() 121 | 122 | # (batch_size * n_sample, seq_len, ni) 123 | word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) 124 | 125 | z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() 126 | z_ = z_.view(batch_size * n_sample, seq_len, self.nz) 127 | 128 | # (batch_size * n_sample, seq_len, ni + nz) 129 | word_embed = torch.cat((word_embed, z_), -1) 130 | 131 | z = z.view(batch_size * n_sample, self.nz) 132 | c_init = self.trans_linear(z).unsqueeze(0) 133 | h_init = torch.tanh(c_init) 134 | # h_init = self.trans_linear(z).unsqueeze(0) 135 | # c_init = h_init.new_zeros(h_init.size()) 136 | output, _ = self.lstm(word_embed, (h_init, c_init)) 137 | 138 | output = self.dropout_out(output) 139 | 140 | # (batch_size * n_sample, seq_len, vocab_size) 141 | output_logits = self.pred_linear(output) 142 | 143 | return output_logits 144 | 145 | def reconstruct_error(self, x, z): 146 | """Cross Entropy in the language case 147 | Args: 148 | x: (batch_size, seq_len) 149 | z: (batch_size, n_sample, nz) 150 | Returns: 151 | loss: (batch_size, n_sample). Loss 152 | across different sentence and z 153 | """ 154 | 155 | #remove end symbol 156 | src = x[:, :-1] 157 | 158 | # remove start symbol 159 | tgt = x[:, 1:] 160 | 161 | batch_size, seq_len = src.size() 162 | n_sample = z.size(1) 163 | 164 | # (batch_size * n_sample, seq_len, vocab_size) 165 | output_logits = self.decode(src, z) 166 | 167 | if n_sample == 1: 168 | tgt = tgt.contiguous().view(-1) 169 | else: 170 | # (batch_size * n_sample * seq_len) 171 | tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ 172 | .contiguous().view(-1) 173 | 174 | # (batch_size * n_sample * seq_len) 175 | loss = self.loss(output_logits.view(-1, output_logits.size(2)), 176 | tgt) 177 | 178 | 179 | # (batch_size, n_sample) 180 | return loss.view(batch_size, n_sample, -1).sum(-1) 181 | 182 | 183 | def log_probability(self, x, z): 184 | """Cross Entropy in the language case 185 | Args: 186 | x: (batch_size, seq_len) 187 | z: (batch_size, n_sample, nz) 188 | Returns: 189 | log_p: (batch_size, n_sample). 190 | log_p(x|z) across different x and z 191 | """ 192 | 193 | return -self.reconstruct_error(x, z) 194 | 195 | def beam_search_decode(self, z, K=5): 196 | """beam search decoding, code is based on 197 | https://github.com/pcyin/pytorch_basic_nmt/blob/master/nmt.py 198 | 199 | the current implementation decodes sentence one by one, further batching would improve the speed 200 | 201 | Args: 202 | z: (batch_size, nz) 203 | K: the beam width 204 | 205 | Returns: List1 206 | List1: the decoded word sentence list 207 | """ 208 | 209 | decoded_batch = [] 210 | batch_size, nz = z.size() 211 | 212 | # (1, batch_size, nz) 213 | c_init = self.trans_linear(z).unsqueeze(0) 214 | h_init = torch.tanh(c_init) 215 | 216 | # decoding goes sentence by sentence 217 | for idx in range(batch_size): 218 | # Start with the start of the sentence token 219 | decoder_input = torch.tensor([[self.vocab[""]]], dtype=torch.long, device=self.device) 220 | decoder_hidden = (h_init[:,idx,:].unsqueeze(1), c_init[:,idx,:].unsqueeze(1)) 221 | 222 | node = BeamSearchNode(decoder_hidden, None, decoder_input, 0., 1) 223 | live_hypotheses = [node] 224 | 225 | completed_hypotheses = [] 226 | 227 | t = 0 228 | while len(completed_hypotheses) < K and t < 100: 229 | t += 1 230 | 231 | # (len(live), 1) 232 | decoder_input = torch.cat([node.wordid for node in live_hypotheses], dim=0) 233 | 234 | # (1, len(live), nh) 235 | decoder_hidden_h = torch.cat([node.h[0] for node in live_hypotheses], dim=1) 236 | decoder_hidden_c = torch.cat([node.h[1] for node in live_hypotheses], dim=1) 237 | 238 | decoder_hidden = (decoder_hidden_h, decoder_hidden_c) 239 | 240 | 241 | # (len(live), 1, ni) --> (len(live), 1, ni+nz) 242 | word_embed = self.embed(decoder_input) 243 | word_embed = torch.cat((word_embed, z[idx].view(1, 1, -1).expand( 244 | len(live_hypotheses), 1, nz)), dim=-1) 245 | 246 | output, decoder_hidden = self.lstm(word_embed, decoder_hidden) 247 | 248 | # (len(live), 1, vocab_size) 249 | output_logits = self.pred_linear(output) 250 | decoder_output = F.log_softmax(output_logits, dim=-1) 251 | 252 | prev_logp = torch.tensor([node.logp for node in live_hypotheses], dtype=torch.float, device=self.device) 253 | decoder_output = decoder_output + prev_logp.view(len(live_hypotheses), 1, 1) 254 | 255 | # (len(live) * vocab_size) 256 | decoder_output = decoder_output.view(-1) 257 | 258 | # (K) 259 | log_prob, indexes = torch.topk(decoder_output, K-len(completed_hypotheses)) 260 | 261 | live_ids = indexes // len(self.vocab) 262 | word_ids = indexes % len(self.vocab) 263 | 264 | live_hypotheses_new = [] 265 | for live_id, word_id, log_prob_ in zip(live_ids, word_ids, log_prob): 266 | node = BeamSearchNode((decoder_hidden[0][:, live_id, :].unsqueeze(1), 267 | decoder_hidden[1][:, live_id, :].unsqueeze(1)), 268 | live_hypotheses[live_id], word_id.view(1, 1), log_prob_, t) 269 | 270 | if word_id.item() == self.vocab[""]: 271 | completed_hypotheses.append(node) 272 | else: 273 | live_hypotheses_new.append(node) 274 | 275 | live_hypotheses = live_hypotheses_new 276 | 277 | if len(completed_hypotheses) == K: 278 | break 279 | 280 | for live in live_hypotheses: 281 | completed_hypotheses.append(live) 282 | 283 | utterances = [] 284 | for n in sorted(completed_hypotheses, key=lambda node: node.logp, reverse=True): 285 | utterance = [] 286 | utterance.append(self.vocab.id2word(n.wordid.item())) 287 | # back trace 288 | while n.prevNode != None: 289 | n = n.prevNode 290 | utterance.append(self.vocab.id2word(n.wordid.item())) 291 | 292 | utterance = utterance[::-1] 293 | 294 | utterances.append(utterance) 295 | 296 | # only save the top 1 297 | break 298 | 299 | decoded_batch.append(utterances[0]) 300 | 301 | return decoded_batch 302 | 303 | 304 | def greedy_decode(self, z): 305 | return self.sample_decode(z, greedy=True) 306 | 307 | def sample_decode(self, z, greedy=False): 308 | """sample/greedy decoding from z 309 | Args: 310 | z: (batch_size, nz) 311 | 312 | Returns: List1 313 | List1: the decoded word sentence list 314 | """ 315 | 316 | batch_size = z.size(0) 317 | decoded_batch = [[] for _ in range(batch_size)] 318 | 319 | # (batch_size, 1, nz) 320 | c_init = self.trans_linear(z).unsqueeze(0) 321 | h_init = torch.tanh(c_init) 322 | 323 | decoder_hidden = (h_init, c_init) 324 | decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) 325 | end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device) 326 | 327 | mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) 328 | length_c = 1 329 | while mask.sum().item() != 0 and length_c < 100: 330 | 331 | # (batch_size, 1, ni) --> (batch_size, 1, ni+nz) 332 | word_embed = self.embed(decoder_input) 333 | word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) 334 | 335 | output, decoder_hidden = self.lstm(word_embed, decoder_hidden) 336 | 337 | # (batch_size, 1, vocab_size) --> (batch_size, vocab_size) 338 | decoder_output = self.pred_linear(output) 339 | output_logits = decoder_output.squeeze(1) 340 | 341 | # (batch_size) 342 | if greedy: 343 | max_index = torch.argmax(output_logits, dim=1) 344 | else: 345 | probs = F.softmax(output_logits, dim=1) 346 | max_index = torch.multinomial(probs, num_samples=1).squeeze(1) 347 | 348 | decoder_input = max_index.unsqueeze(1) 349 | length_c += 1 350 | 351 | for i in range(batch_size): 352 | word = self.vocab.id2word(max_index[i].item()) 353 | if mask[i].item(): 354 | decoded_batch[i].append(self.vocab.id2word(max_index[i].item())) 355 | 356 | mask = torch.mul((max_index != end_symbol), mask) 357 | 358 | return decoded_batch 359 | 360 | class VarLSTMDecoder(LSTMDecoder): 361 | """LSTM decoder with constant-length data""" 362 | def __init__(self, args, vocab, model_init, emb_init): 363 | super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) 364 | 365 | self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['']) 366 | vocab_mask = torch.ones(len(vocab)) 367 | vocab_mask[vocab['']] = 0 368 | self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) 369 | 370 | self.reset_parameters(model_init, emb_init) 371 | 372 | def decode(self, input, z): 373 | """ 374 | Args: 375 | input: tuple which contains x and sents_len 376 | x: (batch_size, seq_len) 377 | sents_len: long tensor of sentence lengths 378 | z: (batch_size, n_sample, nz) 379 | """ 380 | 381 | input, sents_len = input 382 | 383 | # not predicting start symbol 384 | sents_len = sents_len - 1 385 | 386 | batch_size, n_sample, _ = z.size() 387 | seq_len = input.size(1) 388 | 389 | # (batch_size, seq_len, ni) 390 | word_embed = self.embed(input) 391 | word_embed = self.dropout_in(word_embed) 392 | 393 | if n_sample == 1: 394 | z_ = z.expand(batch_size, seq_len, self.nz) 395 | 396 | else: 397 | word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ 398 | .contiguous() 399 | 400 | # (batch_size * n_sample, seq_len, ni) 401 | word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) 402 | 403 | z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() 404 | z_ = z_.view(batch_size * n_sample, seq_len, self.nz) 405 | 406 | # (batch_size * n_sample, seq_len, ni + nz) 407 | word_embed = torch.cat((word_embed, z_), -1) 408 | 409 | sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) 410 | packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) 411 | 412 | z = z.view(batch_size * n_sample, self.nz) 413 | # h_init = self.trans_linear(z).unsqueeze(0) 414 | # c_init = h_init.new_zeros(h_init.size()) 415 | c_init = self.trans_linear(z).unsqueeze(0) 416 | h_init = torch.tanh(c_init) 417 | output, _ = self.lstm(packed_embed, (h_init, c_init)) 418 | output, _ = pad_packed_sequence(output, batch_first=True) 419 | 420 | output = self.dropout_out(output) 421 | 422 | # (batch_size * n_sample, seq_len, vocab_size) 423 | output_logits = self.pred_linear(output) 424 | 425 | return output_logits 426 | 427 | def reconstruct_error(self, x, z): 428 | """Cross Entropy in the language case 429 | Args: 430 | x: tuple which contains x_ and sents_len 431 | x_: (batch_size, seq_len) 432 | sents_len: long tensor of sentence lengths 433 | z: (batch_size, n_sample, nz) 434 | Returns: 435 | loss: (batch_size, n_sample). Loss 436 | across different sentence and z 437 | """ 438 | 439 | x, sents_len = x 440 | 441 | #remove end symbol 442 | src = x[:, :-1] 443 | 444 | # remove start symbol 445 | tgt = x[:, 1:] 446 | 447 | batch_size, seq_len = src.size() 448 | n_sample = z.size(1) 449 | 450 | # (batch_size * n_sample, seq_len, vocab_size) 451 | output_logits = self.decode((src, sents_len), z) 452 | 453 | if n_sample == 1: 454 | tgt = tgt.contiguous().view(-1) 455 | else: 456 | # (batch_size * n_sample * seq_len) 457 | tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ 458 | .contiguous().view(-1) 459 | 460 | # (batch_size * n_sample * seq_len) 461 | loss = self.loss(output_logits.view(-1, output_logits.size(2)), 462 | tgt) 463 | 464 | 465 | # (batch_size, n_sample) 466 | return loss.view(batch_size, n_sample, -1).sum(-1) 467 | 468 | -------------------------------------------------------------------------------- /text_anneal_fb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import importlib 4 | import argparse 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from torch import nn, optim 10 | 11 | from data import MonoTextData 12 | from modules import VAE 13 | from modules import GaussianLSTMEncoder, LSTMDecoder 14 | 15 | from exp_utils import create_exp_dir 16 | from utils import uniform_initializer, xavier_normal_initializer, calc_iwnll, calc_mi, calc_au, sample_sentences, visualize_latent, reconstruct 17 | 18 | clip_grad = 5.0 19 | decay_epoch = 5 20 | lr_decay = 0.5 21 | max_decay = 5 22 | 23 | logging = None 24 | 25 | def init_config(): 26 | parser = argparse.ArgumentParser(description='VAE mode collapse study') 27 | 28 | # model hyperparameters 29 | parser.add_argument('--dataset', type=str, required=True, help='dataset to use') 30 | # optimization parameters 31 | parser.add_argument('--momentum', type=float, default=0, help='sgd momentum') 32 | parser.add_argument('--opt', type=str, choices=["sgd", "adam"], default="sgd", help='sgd momentum') 33 | 34 | parser.add_argument('--nsamples', type=int, default=1, help='number of samples for training') 35 | parser.add_argument('--iw_nsamples', type=int, default=500, 36 | help='number of samples to compute importance weighted estimate') 37 | 38 | # select mode 39 | parser.add_argument('--eval', action='store_true', default=False, help='compute iw nll') 40 | parser.add_argument('--load_path', type=str, default='') 41 | 42 | # decoding 43 | parser.add_argument('--reconstruct_from', type=str, default='', help="the model checkpoint path") 44 | parser.add_argument('--reconstruct_to', type=str, default="decoding.txt", help="save file") 45 | parser.add_argument('--decoding_strategy', type=str, choices=["greedy", "beam", "sample"], default="greedy") 46 | 47 | # annealing paramters 48 | parser.add_argument('--warm_up', type=int, default=10, help="number of annealing epochs. warm_up=0 means not anneal") 49 | parser.add_argument('--kl_start', type=float, default=1.0, help="starting KL weight") 50 | 51 | 52 | # inference parameters 53 | parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed') 54 | 55 | # output directory 56 | parser.add_argument('--exp_dir', default=None, type=str, 57 | help='experiment directory.') 58 | parser.add_argument("--save_ckpt", type=int, default=0, 59 | help="save checkpoint every epoch before this number") 60 | parser.add_argument("--save_latent", type=int, default=0) 61 | 62 | # new 63 | parser.add_argument("--fix_var", type=float, default=-1) 64 | parser.add_argument("--reset_dec", action="store_true", default=False) 65 | parser.add_argument("--load_best_epoch", type=int, default=15) 66 | parser.add_argument("--lr", type=float, default=1.) 67 | 68 | parser.add_argument("--fb", type=int, default=0, 69 | help="0: no fb; 1: fb; 2: max(target_kl, kl) for each dimension") 70 | parser.add_argument("--target_kl", type=float, default=-1, 71 | help="target kl of the free bits trick") 72 | 73 | args = parser.parse_args() 74 | 75 | # set args.cuda 76 | args.cuda = torch.cuda.is_available() 77 | 78 | # set seeds 79 | # seed_set = [783435, 101, 202, 303, 404, 505, 606, 707, 808, 909] 80 | # args.seed = seed_set[args.taskid] 81 | np.random.seed(args.seed) 82 | torch.manual_seed(args.seed) 83 | if args.cuda: 84 | torch.cuda.manual_seed(args.seed) 85 | torch.backends.cudnn.deterministic = True 86 | 87 | # load config file into args 88 | config_file = "config.config_%s" % args.dataset 89 | params = importlib.import_module(config_file).params 90 | args = argparse.Namespace(**vars(args), **params) 91 | 92 | load_str = "_load" if args.load_path != "" else "" 93 | if args.fb == 0: 94 | fb_str = "" 95 | elif args.fb == 1: 96 | fb_str = "_fb" 97 | elif args.fb == 2: 98 | fb_str = "_fbdim" 99 | elif args.fb == 3: 100 | fb_str = "_fb3" 101 | 102 | # set load and save paths 103 | if args.exp_dir == None: 104 | args.exp_dir = "exp_{}{}/{}_warm{}_kls{:.1f}{}_tr{}".format(args.dataset, 105 | load_str, args.dataset, args.warm_up, args.kl_start, fb_str, args.target_kl) 106 | 107 | 108 | if len(args.load_path) <= 0 and args.eval: 109 | args.load_path = os.path.join(args.exp_dir, 'model.pt') 110 | 111 | args.save_path = os.path.join(args.exp_dir, 'model.pt') 112 | 113 | # set args.label 114 | if 'label' in params: 115 | args.label = params['label'] 116 | else: 117 | args.label = False 118 | 119 | return args 120 | 121 | 122 | def test(model, test_data_batch, mode, args, verbose=True): 123 | global logging 124 | 125 | report_kl_loss = report_rec_loss = report_loss = 0 126 | report_num_words = report_num_sents = 0 127 | for i in np.random.permutation(len(test_data_batch)): 128 | batch_data = test_data_batch[i] 129 | batch_size, sent_len = batch_data.size() 130 | 131 | # not predict start symbol 132 | report_num_words += (sent_len - 1) * batch_size 133 | report_num_sents += batch_size 134 | loss, loss_rc, loss_kl = model.loss(batch_data, 1.0, nsamples=args.nsamples) 135 | assert(not loss_rc.requires_grad) 136 | 137 | loss_rc = loss_rc.sum() 138 | loss_kl = loss_kl.sum() 139 | loss = loss.sum() 140 | 141 | report_rec_loss += loss_rc.item() 142 | report_kl_loss += loss_kl.item() 143 | if args.warm_up == 0 and args.kl_start < 1e-6: 144 | report_loss += loss_rc.item() 145 | else: 146 | report_loss += loss.item() 147 | 148 | mutual_info = calc_mi(model, test_data_batch) 149 | 150 | test_loss = report_loss / report_num_sents 151 | 152 | nll = (report_kl_loss + report_rec_loss) / report_num_sents 153 | kl = report_kl_loss / report_num_sents 154 | ppl = np.exp(nll * report_num_sents / report_num_words) 155 | if verbose: 156 | logging('%s --- avg_loss: %.4f, kl: %.4f, mi: %.4f, recon: %.4f, nll: %.4f, ppl: %.4f' % \ 157 | (mode, test_loss, report_kl_loss / report_num_sents, mutual_info, 158 | report_rec_loss / report_num_sents, nll, ppl)) 159 | #sys.stdout.flush() 160 | 161 | return test_loss, nll, kl, ppl, mutual_info 162 | 163 | 164 | def main(args): 165 | global logging 166 | debug = (args.reconstruct_from != "" or args.eval == True) # don't make exp dir for reconstruction 167 | logging = create_exp_dir(args.exp_dir, scripts_to_save=None, debug=debug) 168 | 169 | if args.cuda: 170 | logging('using cuda') 171 | logging(str(args)) 172 | 173 | opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4} 174 | 175 | train_data = MonoTextData(args.train_data, label=args.label) 176 | 177 | vocab = train_data.vocab 178 | vocab_size = len(vocab) 179 | 180 | val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab) 181 | test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab) 182 | 183 | logging('Train data: %d samples' % len(train_data)) 184 | logging('finish reading datasets, vocab size is %d' % len(vocab)) 185 | logging('dropped sentences: %d' % train_data.dropped) 186 | #sys.stdout.flush() 187 | 188 | log_niter = (len(train_data)//args.batch_size)//10 189 | 190 | model_init = uniform_initializer(0.01) 191 | emb_init = uniform_initializer(0.1) 192 | 193 | #device = torch.device("cuda" if args.cuda else "cpu") 194 | device = "cuda" if args.cuda else "cpu" 195 | args.device = device 196 | 197 | if args.enc_type == 'lstm': 198 | encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init) 199 | args.enc_nh = args.dec_nh 200 | else: 201 | raise ValueError("the specified encoder type is not supported") 202 | 203 | decoder = LSTMDecoder(args, vocab, model_init, emb_init) 204 | vae = VAE(encoder, decoder, args).to(device) 205 | 206 | if args.load_path: 207 | loaded_state_dict = torch.load(args.load_path) 208 | #curr_state_dict = vae.state_dict() 209 | #curr_state_dict.update(loaded_state_dict) 210 | vae.load_state_dict(loaded_state_dict) 211 | logging("%s loaded" % args.load_path) 212 | 213 | if args.reset_dec: 214 | logging("\n-------reset decoder-------\n") 215 | vae.decoder.reset_parameters(model_init, emb_init) 216 | 217 | 218 | if args.eval: 219 | logging('begin evaluation') 220 | vae.load_state_dict(torch.load(args.load_path)) 221 | vae.eval() 222 | with torch.no_grad(): 223 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 224 | device=device, 225 | batch_first=True) 226 | 227 | test(vae, test_data_batch, "TEST", args) 228 | au, au_var = calc_au(vae, test_data_batch) 229 | logging("%d active units" % au) 230 | # print(au_var) 231 | 232 | test_data_batch = test_data.create_data_batch(batch_size=1, 233 | device=device, 234 | batch_first=True) 235 | nll, ppl = calc_iwnll(vae, test_data_batch, args) 236 | logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) 237 | 238 | return 239 | 240 | if args.reconstruct_from != "": 241 | print("begin decoding") 242 | vae.load_state_dict(torch.load(args.reconstruct_from)) 243 | vae.eval() 244 | with torch.no_grad(): 245 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 246 | device=device, 247 | batch_first=True) 248 | # test(vae, test_data_batch, "TEST", args) 249 | reconstruct(vae, test_data_batch, vocab, args.decoding_strategy, args.reconstruct_to) 250 | 251 | return 252 | 253 | if args.opt == "sgd": 254 | enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=args.lr, momentum=args.momentum) 255 | dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=args.lr, momentum=args.momentum) 256 | opt_dict['lr'] = args.lr 257 | elif args.opt == "adam": 258 | enc_optimizer = optim.Adam(vae.encoder.parameters(), lr=0.001) 259 | dec_optimizer = optim.Adam(vae.decoder.parameters(), lr=0.001) 260 | opt_dict['lr'] = 0.001 261 | else: 262 | raise ValueError("optimizer not supported") 263 | 264 | iter_ = decay_cnt = 0 265 | best_loss = 1e4 266 | best_kl = best_nll = best_ppl = 0 267 | pre_mi = 0 268 | vae.train() 269 | start = time.time() 270 | 271 | kl_weight = args.kl_start 272 | if args.warm_up > 0: 273 | anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size)) 274 | else: 275 | anneal_rate = 0 276 | 277 | dim_target_kl = args.target_kl / float(args.nz) 278 | 279 | train_data_batch = train_data.create_data_batch(batch_size=args.batch_size, 280 | device=device, 281 | batch_first=True) 282 | 283 | val_data_batch = val_data.create_data_batch(batch_size=args.batch_size, 284 | device=device, 285 | batch_first=True) 286 | 287 | test_data_batch = test_data.create_data_batch(batch_size=args.batch_size, 288 | device=device, 289 | batch_first=True) 290 | 291 | # At any point you can hit Ctrl + C to break out of training early. 292 | try: 293 | for epoch in range(args.epochs): 294 | report_kl_loss = report_rec_loss = report_loss = 0 295 | report_num_words = report_num_sents = 0 296 | 297 | for i in np.random.permutation(len(train_data_batch)): 298 | 299 | batch_data = train_data_batch[i] 300 | batch_size, sent_len = batch_data.size() 301 | 302 | # not predict start symbol 303 | report_num_words += (sent_len - 1) * batch_size 304 | report_num_sents += batch_size 305 | 306 | kl_weight = min(1.0, kl_weight + anneal_rate) 307 | 308 | enc_optimizer.zero_grad() 309 | dec_optimizer.zero_grad() 310 | 311 | 312 | if args.fb == 0: 313 | loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) 314 | elif args.fb == 1: 315 | loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) 316 | kl_mask = (loss_kl > args.target_kl).float() 317 | loss = loss_rc + kl_mask * kl_weight * loss_kl 318 | elif args.fb == 2: 319 | mu, logvar = vae.encoder(batch_data) 320 | z = vae.encoder.reparameterize(mu, logvar, args.nsamples) 321 | loss_kl = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) 322 | kl_mask = (loss_kl > dim_target_kl).float() 323 | fake_loss_kl = (kl_mask * loss_kl).sum(dim=1) 324 | loss_rc = vae.decoder.reconstruct_error(batch_data, z).mean(dim=1) 325 | loss = loss_rc + kl_weight * fake_loss_kl 326 | elif args.fb == 3: 327 | loss, loss_rc, loss_kl = vae.loss(batch_data, kl_weight, nsamples=args.nsamples) 328 | kl_mask = (loss_kl.mean() > args.target_kl).float() 329 | loss = loss_rc + kl_mask * kl_weight * loss_kl 330 | 331 | loss = loss.mean(dim=-1) 332 | 333 | loss.backward() 334 | torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_grad) 335 | 336 | loss_rc = loss_rc.sum() 337 | loss_kl = loss_kl.sum() 338 | 339 | enc_optimizer.step() 340 | dec_optimizer.step() 341 | 342 | report_rec_loss += loss_rc.item() 343 | report_kl_loss += loss_kl.item() 344 | report_loss += loss_rc.item() + loss_kl.item() 345 | 346 | if iter_ % log_niter == 0: 347 | #train_loss = (report_rec_loss + report_kl_loss) / report_num_sents 348 | train_loss = report_loss / report_num_sents 349 | 350 | logging('epoch: %d, iter: %d, avg_loss: %.4f, kl: %.4f, recon: %.4f,' \ 351 | 'time %.2fs, kl_weight %.4f' % 352 | (epoch, iter_, train_loss, report_kl_loss / report_num_sents, 353 | report_rec_loss / report_num_sents, time.time() - start, kl_weight)) 354 | 355 | #sys.stdout.flush() 356 | 357 | report_rec_loss = report_kl_loss = report_loss = 0 358 | report_num_words = report_num_sents = 0 359 | 360 | iter_ += 1 361 | 362 | logging('kl weight %.4f' % kl_weight) 363 | logging('lr {}'.format(opt_dict["lr"])) 364 | 365 | vae.eval() 366 | with torch.no_grad(): 367 | loss, nll, kl, ppl, mi = test(vae, val_data_batch, "VAL", args) 368 | au, au_var = calc_au(vae, val_data_batch) 369 | logging("%d active units" % au) 370 | # print(au_var) 371 | 372 | if args.save_ckpt > 0 and epoch <= args.save_ckpt: 373 | logging('save checkpoint') 374 | torch.save(vae.state_dict(), os.path.join(args.exp_dir, f'model_ckpt_{epoch}.pt')) 375 | 376 | if loss < best_loss: 377 | logging('update best loss') 378 | best_loss = loss 379 | best_nll = nll 380 | best_kl = kl 381 | best_ppl = ppl 382 | torch.save(vae.state_dict(), args.save_path) 383 | 384 | if loss > opt_dict["best_loss"]: 385 | opt_dict["not_improved"] += 1 386 | if opt_dict["not_improved"] >= decay_epoch and epoch >= args.load_best_epoch: 387 | opt_dict["best_loss"] = loss 388 | opt_dict["not_improved"] = 0 389 | opt_dict["lr"] = opt_dict["lr"] * lr_decay 390 | vae.load_state_dict(torch.load(args.save_path)) 391 | logging('new lr: %f' % opt_dict["lr"]) 392 | decay_cnt += 1 393 | enc_optimizer = optim.SGD(vae.encoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 394 | dec_optimizer = optim.SGD(vae.decoder.parameters(), lr=opt_dict["lr"], momentum=args.momentum) 395 | 396 | else: 397 | opt_dict["not_improved"] = 0 398 | opt_dict["best_loss"] = loss 399 | 400 | if decay_cnt == max_decay: 401 | break 402 | 403 | if epoch % args.test_nepoch == 0: 404 | with torch.no_grad(): 405 | loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) 406 | 407 | if args.save_latent > 0 and epoch <= args.save_latent: 408 | visualize_latent(args, epoch, vae, "cuda", test_data) 409 | 410 | vae.train() 411 | 412 | except KeyboardInterrupt: 413 | logging('-' * 100) 414 | logging('Exiting from training early') 415 | 416 | # compute importance weighted estimate of log p(x) 417 | vae.load_state_dict(torch.load(args.save_path)) 418 | 419 | vae.eval() 420 | with torch.no_grad(): 421 | loss, nll, kl, ppl, _ = test(vae, test_data_batch, "TEST", args) 422 | au, au_var = calc_au(vae, test_data_batch) 423 | logging("%d active units" % au) 424 | # print(au_var) 425 | 426 | test_data_batch = test_data.create_data_batch(batch_size=1, 427 | device=device, 428 | batch_first=True) 429 | with torch.no_grad(): 430 | nll, ppl = calc_iwnll(vae, test_data_batch, args) 431 | logging('iw nll: %.4f, iw ppl: %.4f' % (nll, ppl)) 432 | 433 | if __name__ == '__main__': 434 | args = init_config() 435 | main(args) 436 | --------------------------------------------------------------------------------