├── .gitignore ├── LICENSE ├── MADE.py ├── README.md ├── baseline_model.py ├── common.py ├── config.py ├── data ├── indep_bernoulli │ ├── .gitignore │ └── load_data.py └── ptb │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── discreteflow_model.py ├── flows.py ├── lstm_flow.py ├── main.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | *~ 4 | *.png 5 | *.pdf 6 | output 7 | 8 | slurm_scripts 9 | *.out 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Harvard NLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MADE.py: -------------------------------------------------------------------------------- 1 | import math 2 | import collections 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn.parameter import Parameter 9 | from torch.autograd import Variable 10 | 11 | class MaskedHiddenLayer(nn.Module): 12 | def __init__(self, d_in, d_out, data_dim, nonlinearity, previous_m_k, output_order, bias=True, alt_init=False): 13 | super().__init__() 14 | if nonlinearity == 'relu': 15 | self.nonlin = nn.ReLU(inplace=True) 16 | elif nonlinearity == 'tanh': 17 | self.nonlin = nn.Tanh() 18 | elif nonlinearity == 'elu': 19 | self.nonlin = nn.ELU(inplace=True) 20 | elif nonlinearity == None: 21 | self.nonlin = lambda x : x 22 | else: 23 | raise NotImplementedError('only relu, tanh, and elu nonlinearities have been implemented') 24 | 25 | self.weight = Parameter(torch.Tensor(d_out, d_in)) 26 | if bias: 27 | self.bias = Parameter(torch.Tensor(d_out)) 28 | else: 29 | self.register_parameter('bias', None) 30 | 31 | self.alt_init = alt_init 32 | self.reset_parameters() 33 | 34 | if isinstance(output_order, str): 35 | if output_order == 'random': 36 | self.m_k = torch.empty(d_out, dtype=torch.long).random_(1, data_dim) 37 | elif output_order == 'sequential': 38 | self.m_k = torch.arange(0, data_dim) 39 | self.m_k = self.m_k.repeat(d_out//data_dim+1)[:d_out] 40 | else: 41 | # Allow for the network to produce multiple outputs conditioned on the same degree 42 | self.m_k = output_order.repeat(d_out//data_dim) 43 | 44 | mask = (self.m_k[:, None] >= previous_m_k[None, :]).float() 45 | self.register_buffer('mask', mask) 46 | 47 | def reset_parameters(self): 48 | if self.alt_init: 49 | stdv = 1. / math.sqrt(self.weight.size(0) + 1) 50 | self.weight.data.uniform_(-0.001, 0.001) 51 | else: 52 | stdv = 1. / math.sqrt(self.weight.size(1)) 53 | self.weight.data.uniform_(-stdv, stdv) 54 | 55 | if self.bias is not None: 56 | if self.alt_init: 57 | self.bias.data.zero_() 58 | else: 59 | self.bias.data.uniform_(-stdv, stdv) 60 | 61 | def forward(self, x): 62 | x = F.linear(x, Variable(self.mask)*self.weight, self.bias) 63 | x = self.nonlin(x) 64 | 65 | return x 66 | 67 | 68 | class MADE(nn.Module): 69 | def __init__(self, data_dim, n_hidden_layers, n_hidden_units, nonlinearity, hidden_order, bias=True, out_dim_per_inp_dim=1, input_order=None, 70 | conditional_inp_dim=None, dropout=[0, 0], nonar=False, alt_init=True): 71 | super().__init__() 72 | 73 | if not isinstance(dropout, collections.Iterable) or len(dropout) != 2: 74 | raise ValueError('dropout argument should be an iterable with [input drop fraction, hidden drop fraction') 75 | 76 | layers = [] 77 | if input_order is None: 78 | previous_m_k = torch.arange(data_dim)+1 79 | end_order = torch.arange(data_dim) 80 | else: 81 | if not nonar and not np.all(np.sort(input_order) == np.arange(data_dim)+1): 82 | raise ValueError('input_order must contain 1 through data_dim, inclusive, in any order') 83 | previous_m_k = input_order 84 | end_order = input_order-1 85 | 86 | if conditional_inp_dim is not None: 87 | previous_m_k = torch.cat([previous_m_k, torch.zeros(conditional_inp_dim, dtype=previous_m_k.dtype)]) 88 | 89 | effective_data_dim = torch.max(previous_m_k) # This is only used to set the m_k values for each hidden layer 90 | 91 | for i in range(n_hidden_layers): 92 | if i == 0: 93 | d_in = data_dim 94 | if conditional_inp_dim is not None: 95 | d_in += conditional_inp_dim 96 | drop_val = dropout[0] 97 | else: 98 | d_in = n_hidden_units 99 | drop_val = dropout[1] 100 | 101 | if drop_val > 0: 102 | layers.append(nn.Dropout(drop_val)) 103 | 104 | new_layer = MaskedHiddenLayer(d_in, n_hidden_units, effective_data_dim, nonlinearity, previous_m_k, hidden_order, bias=bias, alt_init=alt_init) 105 | previous_m_k = new_layer.m_k 106 | layers.append(new_layer) 107 | 108 | layers.append(MaskedHiddenLayer(n_hidden_units, data_dim*out_dim_per_inp_dim, data_dim, None, previous_m_k, end_order, bias=bias)) 109 | 110 | self.network = nn.Sequential(*layers) 111 | self.data_dim = data_dim 112 | self.out_dim_per_inp_dim = out_dim_per_inp_dim 113 | self.end_order = end_order 114 | self.conditional_inp_dim = conditional_inp_dim 115 | 116 | def forward(self, inputs): 117 | if self.conditional_inp_dim is not None: 118 | x, cond_inp = inputs 119 | x = torch.cat([x, cond_inp], -1) 120 | else: 121 | x = inputs 122 | 123 | x = self.network(x) 124 | 125 | if self.out_dim_per_inp_dim == 1: 126 | return x 127 | 128 | # If the network produces multiple outputs conditioned on the same degree, return as [B, data_dim, out_dim_per_inp_dim] 129 | #x = torch.transpose(x.view(-1, self.out_dim_per_inp_dim, self.data_dim), -1, -2) 130 | #if x_1d: 131 | # x = x.squeeze() 132 | x = x.view(*x.shape[:-1], self.out_dim_per_inp_dim, self.data_dim) 133 | x = torch.transpose(x, -1, -2) 134 | 135 | return x 136 | 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent Normalizing Flows for Discrete Sequences 2 | 3 | This code provides an implementation for and reproduces results from [Latent Normalizing Flows for Discrete Sequences](https://arxiv.org/abs/1901.10548) 4 | Zachary Ziegler, Alexander Rush 5 | ICML 2019 6 | 7 | ## Dependencies 8 | The code was tested with `python 3.6`, `pytorch 0.4.1`, `torchtext 0.2.3`, and `CUDA 9.2`. 9 | 10 | ## Character-level language modeling: 11 | 12 | PTB data with [Mikolov preprocessing](http://www.fit.vutbr.cz/~imikolov/rnnlm/char.pdf) is checked in. 13 | 14 | Baseline and proposed models can be trained with: 15 | ``` 16 | python main.py --dataset ptb --run_name charptb_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 17 | python main.py --dataset ptb --run_name charptb_discreteflow_af-af 18 | python main.py --dataset ptb --run_name charptb_discreteflow_af-scf --hiddenflow_scf_layers 19 | python main.py --dataset ptb --run_name charptb_discreteflow_iaf-scf --dropout_p 0 --hiddenflow_flow_layers 3 --hiddenflow_scf_layers --prior_type IAF 20 | ``` 21 | 22 | Evaluation on the test set is run with e.g. 23 | ``` 24 | python main.py --dataset ptb --run_name charptb_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 --load_dir output/charptb_baselinelstm/saves/ --evaluate 25 | ``` 26 | 27 | ## Polyphonic music: 28 | 29 | Commands to train the baseline and proposed models are 30 | 31 | Nottingham: 32 | ``` 33 | python main.py --indep_bernoulli --dataset nottingham --run_name nottingham_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 --patience 4 34 | python main.py --indep_bernoulli --dataset nottingham --run_name nottingham_discreteflow_af-af --patience 2 --ELBO_samples 1 --B_train 5 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 35 | python main.py --indep_bernoulli --dataset nottingham --run_name nottingham_discreteflow_af-scf --patience 2 --ELBO_samples 1 --B_train 5 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --hiddenflow_scf_layers 36 | python main.py --indep_bernoulli --dataset nottingham --run_name nottingham_discreteflow_iaf-scf --patience 2 --ELBO_samples 1 --B_train 5 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --hiddenflow_scf_layers --prior_type IAF 37 | ``` 38 | 39 | Piano_midi 40 | ``` 41 | python main.py --indep_bernoulli --dataset piano_midi --run_name piano_midi_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 --patience 4 42 | python main.py --indep_bernoulli --dataset piano_midi --run_name piano_midi_discreteflow_af-af --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 43 | python main.py --indep_bernoulli --dataset piano_midi --run_name piano_midi_discreteflow_af-scf --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 --hiddenflow_scf_layers 44 | python main.py --indep_bernoulli --dataset piano_midi --run_name piano_midi_discreteflow_iaf-scf --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 --hiddenflow_scf_layers --prior_type IAF 45 | ``` 46 | 47 | Musedata 48 | ``` 49 | python main.py --indep_bernoulli --dataset muse_data --run_name muse_data_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 --patience 4 50 | python main.py --indep_bernoulli --dataset muse_data --run_name muse_data_discreteflow_af-af --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 51 | python main.py --indep_bernoulli --dataset muse_data --run_name muse_data_discreteflow_af-scf --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 --hiddenflow_scf_layers 52 | python main.py --indep_bernoulli --dataset muse_data --run_name muse_data_discreteflow_iaf-scf --patience 2 --ELBO_samples 1 --B_train 1 --B_val 1 --initial_kl_zero 20 --kl_rampup_time 15 --nll_samples 4 --grad_accum 8 --hiddenflow_scf_layers --prior_type IAF 53 | ``` 54 | 55 | JSB_chorales 56 | ``` 57 | python main.py --indep_bernoulli --dataset jsb_chorales --run_name jsb_chorales_baselinelstm --model_type baseline --dropout_p 0.1 --optim sgd --lr 20 --patience 4 58 | python main.py --indep_bernoulli --dataset jsb_chorales --run_name jsb_chorales_discreteflow_af-af --patience 2 --ELBO_samples 1 --B_train 5 --B_val 64 --initial_kl_zero 20 --kl_rampup_time 15 59 | python main.py --indep_bernoulli --dataset jsb_chorales --run_name jsb_chorales_discreteflow_af-scf --patience 2 --ELBO_samples 1 --B_train 5 --B_val 64 --initial_kl_zero 20 --kl_rampup_time 15 --hiddenflow_scf_layers 60 | python main.py --indep_bernoulli --dataset jsb_chorales --run_name jsb_chorales_discreteflow_iaf-scf --patience 2 --ELBO_samples 1 --B_train 5 --B_val 64 --initial_kl_zero 20 --kl_rampup_time 15 --hiddenflow_scf_layers --prior_type IAF 61 | ``` 62 | -------------------------------------------------------------------------------- /baseline_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import nn 5 | from torch.distributions import Categorical, Bernoulli 6 | 7 | from common import FeedForwardNet 8 | from utils import make_pos_cond 9 | 10 | class LSTMModel(nn.Module): 11 | 12 | def __init__(self, vocab_size, loss_weights, n_inp_units, n_hidden_units, n_layers, dropout_p, T_condition=False, max_T=-1, tie_weights=False, indep_bernoulli=False): 13 | super().__init__() 14 | 15 | self.dropout = nn.Dropout(dropout_p) 16 | self.input_embedding = nn.Embedding(vocab_size, n_inp_units) 17 | 18 | rnn_inp_size = n_inp_units 19 | if T_condition: 20 | rnn_inp_size += max_T*2 21 | 22 | self.rnn = nn.LSTM(rnn_inp_size, n_hidden_units, n_layers, dropout=dropout_p) 23 | 24 | self.output_embedding = FeedForwardNet(n_hidden_units, n_inp_units, vocab_size, 1, 'none') 25 | 26 | if tie_weights: 27 | self.output_embedding.network[-1].weight = self.input_embedding.weight 28 | 29 | self.indep_bernoulli = indep_bernoulli 30 | self.vocab_size = vocab_size 31 | self.n_layers = n_layers 32 | self.n_hidden_units = n_hidden_units 33 | self.T_condition = T_condition 34 | self.max_T = max_T 35 | 36 | if self.indep_bernoulli: 37 | self.criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 38 | else: 39 | self.criterion = torch.nn.CrossEntropyLoss(loss_weights, reduction='none') 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | nn.init.xavier_uniform_(self.input_embedding.weight) 45 | 46 | def forward(self, x, lengths): 47 | # Input is [T, B] with index of word 48 | T, B = x.shape[0], x.shape[1] 49 | 50 | hidden = self.init_hidden(B) 51 | 52 | if self.T_condition: 53 | cond_inp = make_pos_cond(T, B, lengths.cpu(), self.max_T).to(x.device) 54 | 55 | if self.indep_bernoulli: 56 | embeddings = torch.matmul(x, self.input_embedding.weight) 57 | else: 58 | embeddings = self.input_embedding(x) 59 | embeddings = self.dropout(embeddings) # [T, B, n_inp_units] 60 | 61 | if self.T_condition: 62 | cond_inp_shifted = torch.cat((cond_inp[1:], torch.zeros((1, B, self.max_T*2), device=cond_inp.device)), 0) 63 | embeddings = torch.cat((embeddings, cond_inp_shifted), -1) 64 | 65 | embeddings = nn.utils.rnn.pack_padded_sequence(embeddings, lengths) 66 | rnn_outp, _ = self.rnn(embeddings, hidden) # [T, B, n_hidden_units], [num_layers, B, n_hidden_units]x2 67 | rnn_outp = nn.utils.rnn.pad_packed_sequence(rnn_outp)[0] 68 | 69 | rnn_outp = torch.cat((hidden[0][-1:], rnn_outp), 0)[:-1] 70 | rnn_outp = self.dropout(rnn_outp) 71 | 72 | scores = self.output_embedding(rnn_outp) # [T, B, V] 73 | 74 | if self.indep_bernoulli: 75 | loss = self.criterion(scores.view(-1, scores.shape[-1]), x.view(-1, x.shape[-1])).view(scores.shape).sum(-1) 76 | # This doesn't 0 out loss values from padding, but later on the main loop will do that 77 | else: 78 | loss = self.criterion(scores.view(-1, scores.shape[-1]), x.view(-1)).view(scores.shape[:-1]) # [T, B] 79 | 80 | return loss 81 | 82 | def generate(self, T, B): 83 | if not self.T_condition: 84 | raise NotImplementedError("Only the version conditioned on T has been implemented.") 85 | 86 | hidden = self.init_hidden(B) 87 | lengths = torch.tensor([T]*B) 88 | device = hidden[0].device 89 | 90 | cond_inp = make_pos_cond(T, B, lengths, self.max_T).to(device) 91 | 92 | if self.indep_bernoulli: 93 | generation = torch.zeros(T, B, self.vocab_size, dtype=torch.long, device=device) 94 | else: 95 | generation = torch.zeros(T, B, dtype=torch.long, device=device) 96 | 97 | last_rnn_outp = hidden[0][-1] 98 | for t in range(T): 99 | scores = self.output_embedding(last_rnn_outp) # [B, V] 100 | if self.indep_bernoulli: 101 | word_dist = Bernoulli(logits=scores) 102 | else: 103 | word_dist = Categorical(logits=scores) 104 | 105 | selected_index = word_dist.sample() 106 | generation[t] = selected_index 107 | 108 | if t < T-1: 109 | if self.indep_bernoulli: 110 | inp_embeddings = torch.matmul(generation[t].float(), self.input_embedding.weight) 111 | else: 112 | inp_embeddings = self.input_embedding(generation[t]) # [B, E] 113 | inp_embeddings = torch.cat((inp_embeddings, cond_inp[t+1]), -1) 114 | 115 | last_rnn_outp, hidden = self.rnn(inp_embeddings[None, :, :], hidden) 116 | last_rnn_outp = last_rnn_outp[0] 117 | 118 | return generation 119 | 120 | def gen_one_noTcond(self, eos_index, max_T): 121 | hidden = self.init_hidden(1) 122 | device = hidden[0].device 123 | 124 | last_rnn_outp = hidden[0][-1] # [1, C] 125 | generation = [] 126 | 127 | for t in range(max_T): 128 | scores = self.output_embedding(last_rnn_outp) # [1, V] 129 | word_dist = Categorical(logits=scores) 130 | selected_index = word_dist.sample() # [1] 131 | 132 | if selected_index == eos_index: 133 | break 134 | 135 | generation.append(selected_index) 136 | inp_embeddings = self.input_embedding(selected_index) # [1, inp_E] 137 | last_rnn_outp, hidden = self.rnn(inp_embeddings[None, :, :], hidden) 138 | last_rnn_outp = last_rnn_outp[0] 139 | 140 | return torch.tensor(generation, dtype=torch.long, device=device) 141 | 142 | def init_hidden(self, batch_size): 143 | weight = next(self.parameters()) 144 | h = weight.new_zeros(self.n_layers, batch_size, self.n_hidden_units) 145 | c = weight.new_zeros(self.n_layers, batch_size, self.n_hidden_units) 146 | return (h, c) 147 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class FeedForwardNet(nn.Module): 4 | def __init__(self, inp_dim, hidden_dim, outp_dim, n_layers, nonlinearity, dropout=0): 5 | super().__init__() 6 | 7 | layers = [] 8 | d_in = inp_dim 9 | for i in range(n_layers): 10 | module = nn.Linear(d_in, hidden_dim) 11 | self.reset_parameters(module) 12 | layers.append(module) 13 | 14 | if dropout > 0: 15 | layers.append(nn.Dropout(dropout)) 16 | 17 | if nonlinearity == 'relu': 18 | nonlin = nn.ReLU(inplace=True) 19 | elif nonlinearity == 'tanh': 20 | nonlin = nn.Tanh() 21 | elif nonlinearity == 'elu': 22 | nonlin = nn.ELU(inplace=True) 23 | elif nonlinearity != 'none': 24 | raise NotImplementedError('only relu, tanh, and elu nonlinearities have been implemented') 25 | 26 | if nonlinearity != 'none': 27 | layers.append(nonlin) 28 | 29 | d_in = hidden_dim 30 | 31 | module = nn.Linear(d_in, outp_dim) 32 | self.reset_parameters(module) 33 | layers.append(module) 34 | 35 | self.network = nn.Sequential(*layers) 36 | 37 | def reset_parameters(self, module): 38 | init_range = 0.07 39 | module.weight.data.uniform_(-init_range, init_range) 40 | module.bias.data.zero_() 41 | 42 | def forward(self, x): 43 | return self.network(x) 44 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # Data parameters 9 | parser.add_argument('--run_name', type=str, default='charptb_AF-AF') 10 | parser.add_argument('--output_dir', type=str, default='output') 11 | parser.add_argument('--load_dir', type=str) 12 | parser.add_argument('--evaluate_only', action='store_true') 13 | parser.add_argument('--dataset', type=str, default='ptb') 14 | parser.add_argument('--nll_every', type=int, default=5) 15 | parser.add_argument('--indep_bernoulli', action='store_true') 16 | parser.add_argument('--noT_condition', action='store_true') 17 | 18 | # Optimization parameters 19 | parser.add_argument('--num_epochs', type=int, default=100) 20 | parser.add_argument('--B_train', type=int, default=15) 21 | parser.add_argument('--B_val', type=int, default=15) 22 | parser.add_argument('--grad_accum', type=int, default=1) 23 | parser.add_argument('--optim', type=str, default='adam') 24 | parser.add_argument('--lr', type=float, default=1e-3) 25 | parser.add_argument('--dropout_p', type=float, default=0.2) 26 | parser.add_argument('--grad_clip', type=float, default=0.25) 27 | parser.add_argument('--seed', type=int) 28 | 29 | # KL/LR schedule parameters 30 | parser.add_argument('--initial_kl_zero', type=int, default=4) 31 | parser.add_argument('--kl_rampup_time', type=int, default=10) 32 | parser.add_argument('--patience', type=int, default=1) 33 | 34 | # Sample parameters 35 | parser.add_argument('--ELBO_samples', type=int, default=10) 36 | parser.add_argument('--nll_samples', type=int, default=30) 37 | 38 | # General model parameters 39 | parser.add_argument('--model_type', type=str, default='discrete_flow') 40 | parser.add_argument('--inp_embedding_size', type=int, default=500) 41 | parser.add_argument('--zsize', type=int, default=50) 42 | parser.add_argument('--hidden_size', type=int, default=500) 43 | parser.add_argument('--dlocs', nargs='*', default=['prior_rnn']) 44 | parser.add_argument('--notie_weights', action='store_true') 45 | 46 | # Inference network parameters 47 | parser.add_argument('--q_rnn_layers', type=int, default=2) 48 | 49 | # Generative network parameters 50 | ## Prior 51 | parser.add_argument('--prior_type', type=str, default='AF') 52 | parser.add_argument('--p_ff_layers', type=int, default=0) 53 | parser.add_argument('--p_rnn_layers', type=int, default=2) 54 | parser.add_argument('--p_rnn_units', type=int, default=500) 55 | parser.add_argument('--p_num_flow_layers', type=int, default=1) 56 | parser.add_argument('--transform_function', type=str, default='nlsq') 57 | 58 | ### Prior MADE Flow 59 | parser.add_argument('--nohiddenflow', action='store_true') 60 | parser.add_argument('--hiddenflow_layers', type=int, default=2) 61 | parser.add_argument('--hiddenflow_units', type=int, default=100) 62 | parser.add_argument('--hiddenflow_flow_layers', type=int, default=5) 63 | parser.add_argument('--hiddenflow_scf_layers', action='store_true') 64 | 65 | ## Likelihood parameters 66 | parser.add_argument('--gen_bilstm_layers', type=int, default=2) 67 | 68 | args = parser.parse_args() 69 | 70 | if args.dlocs is None: 71 | setattr(args, 'dlocs', []) 72 | 73 | setattr(args, 'savedir', args.output_dir+'/'+args.run_name+'/saves/') 74 | setattr(args, 'logdir', args.output_dir+'/'+args.run_name+'/logs/') 75 | 76 | os.makedirs(args.savedir, exist_ok=True) 77 | os.makedirs(args.logdir, exist_ok=True) 78 | 79 | if args.seed is None: 80 | setattr(args, 'seed', random.randint(0, 1000000)) 81 | 82 | return args 83 | -------------------------------------------------------------------------------- /data/indep_bernoulli/.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | -------------------------------------------------------------------------------- /data/indep_bernoulli/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from collections import namedtuple 5 | from six.moves.urllib.request import urlopen 6 | 7 | import six.moves.cPickle as pickle 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | dset = namedtuple("dset", ["name", "url", "filename"]) 13 | 14 | JSB_CHORALES = dset("jsb_chorales", 15 | "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle", 16 | "jsb_chorales.pkl") 17 | 18 | PIANO_MIDI = dset("piano_midi", 19 | "http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle", 20 | "piano_midi.pkl") 21 | 22 | MUSE_DATA = dset("muse_data", 23 | "http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle", 24 | "muse_data.pkl") 25 | 26 | NOTTINGHAM = dset("nottingham", 27 | "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle", 28 | "nottingham.pkl") 29 | str2obj = {'jsb_chorales': JSB_CHORALES, 'piano_midi': PIANO_MIDI, 'muse_data': MUSE_DATA, 'nottingham': NOTTINGHAM} 30 | 31 | 32 | # this function processes the raw data; in particular it unsparsifies it 33 | def process_data(base_path, dataset, min_note=21, note_range=88): 34 | output = os.path.join(base_path, dataset.filename) 35 | if os.path.exists(output): 36 | try: 37 | with open(output, "rb") as f: 38 | return pickle.load(f) 39 | except (ValueError, UnicodeDecodeError): 40 | # Assume python env has changed. 41 | # Recreate pickle file in this env's format. 42 | os.remove(output) 43 | 44 | print("processing raw data - {} ...".format(dataset.name)) 45 | data = pickle.load(urlopen(dataset.url)) 46 | processed_dataset = {} 47 | for split, data_split in data.items(): 48 | processed_dataset[split] = {} 49 | n_seqs = len(data_split) 50 | processed_dataset[split]['sequence_lengths'] = torch.zeros(n_seqs, dtype=torch.long) 51 | processed_dataset[split]['sequences'] = [] 52 | for seq in range(n_seqs): 53 | seq_length = len(data_split[seq]) 54 | processed_dataset[split]['sequence_lengths'][seq] = seq_length 55 | processed_sequence = torch.zeros((seq_length, note_range)) 56 | for t in range(seq_length): 57 | note_slice = torch.tensor(list(data_split[seq][t]), dtype=torch.int64) - min_note 58 | slice_length = len(note_slice) 59 | if slice_length > 0: 60 | processed_sequence[t, note_slice] = torch.ones(slice_length) 61 | processed_dataset[split]['sequences'].append(processed_sequence) 62 | print(split) 63 | print(n_seqs) 64 | print(processed_dataset[split]['sequence_lengths']) 65 | print(processed_dataset[split]['sequence_lengths'].max()) 66 | print(processed_dataset[split]['sequences'][0][0], processed_dataset[split]['sequences'][0].shape) 67 | pickle.dump(processed_dataset, open(output, "wb"), pickle.HIGHEST_PROTOCOL) 68 | print("dumped processed data to %s" % output) 69 | 70 | 71 | # this logic will be initiated upon import 72 | base_path = os.path.dirname(os.path.realpath(__file__)) 73 | if not os.path.exists(base_path): 74 | os.mkdir(base_path) 75 | 76 | # ingest training/validation/test data from disk 77 | def load_data(dataset): 78 | # download and process dataset if it does not exist 79 | dataset = str2obj[dataset] 80 | process_data(base_path, dataset) 81 | file_loc = os.path.join(base_path, dataset.filename) 82 | with open(file_loc, "rb") as f: 83 | dset = pickle.load(f) 84 | #for k, v in dset.items(): 85 | # sequences = v["sequences"] 86 | # dset[k]["sequences"] = pad_sequence(sequences, batch_first=True).type(torch.Tensor) 87 | # dset[k]["sequence_lengths"] = v["sequence_lengths"] 88 | return dset 89 | 90 | if __name__ == '__main__': 91 | for k, v in str2obj.items(): 92 | load_data(k) 93 | -------------------------------------------------------------------------------- /discreteflow_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.distributions import Categorical 4 | import torch.nn.functional as F 5 | 6 | import math 7 | import sys 8 | 9 | from lstm_flow import AFPrior 10 | from common import FeedForwardNet 11 | from utils import make_pos_cond 12 | 13 | class InferenceBlock(nn.Module): 14 | def __init__(self, inf_inp_dim, hidden_size, zsize, dropout_p, q_rnn_layers, dropout_locations, max_T): 15 | super().__init__() 16 | self.dropout = nn.Dropout(dropout_p) 17 | 18 | rnn_q_inp_size = inf_inp_dim + 2*max_T 19 | self.rnn_q = torch.nn.LSTM(rnn_q_inp_size, hidden_size, q_rnn_layers, dropout=dropout_p if 'rnn_x' in dropout_locations else 0, bidirectional=True) 20 | 21 | self.q_base_ff = nn.Linear(hidden_size*2, zsize*2) 22 | 23 | self.hidden_size = hidden_size 24 | self.zsize = zsize 25 | self.q_rnn_layers = q_rnn_layers 26 | self.dropout_locations = dropout_locations 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | init_range = 0.07 31 | self.q_base_ff.weight.data.uniform_(-init_range, init_range) 32 | self.q_base_ff.bias.data.zero_() 33 | 34 | def sample_q_z(self, inf_inp, lengths, cond_inp, ELBO_samples): 35 | """ 36 | output is z [T, B, s, E] 37 | """ 38 | 39 | ## Run RNN over input 40 | T, B = inf_inp.shape[:2] 41 | hidden_rnn = self.init_hidden_rnn(B) 42 | 43 | inf_inp_packed = torch.cat((inf_inp, cond_inp), -1) 44 | 45 | total_length = inf_inp_packed.shape[0] 46 | inf_inp_packed = torch.nn.utils.rnn.pack_padded_sequence(inf_inp_packed, lengths) 47 | rnn_outp, _ = self.rnn_q(inf_inp_packed, hidden_rnn) # [T, B, hidden_size], [num_layers, B, hidden_size]x2 48 | rnn_outp = torch.nn.utils.rnn.pad_packed_sequence(rnn_outp, total_length=total_length)[0] 49 | 50 | if 'rnn_x' in self.dropout_locations: 51 | rnn_outp = self.dropout(rnn_outp) 52 | 53 | ## Sample ELBO_sample z's from RNN output 54 | rnn_outp = rnn_outp[:, :, None, :].repeat(1, 1, ELBO_samples, 1) 55 | q_z_base = self.q_base_ff(rnn_outp) 56 | 57 | q_z_base = q_z_base.view(*rnn_outp.shape[:-1], self.zsize, 2) 58 | z_base_mean = q_z_base[..., 0] 59 | z_base_logvar = q_z_base[..., 1] 60 | z_base_std = torch.exp(0.5*z_base_logvar) 61 | 62 | eps_initial = torch.randn(T, B, ELBO_samples, self.zsize, device=z_base_mean.device) 63 | z = z_base_mean + z_base_std*eps_initial # [T, B, s, E] 64 | 65 | log_q_z = -1/2*(math.log(2*math.pi) + z_base_logvar + (z - z_base_mean).pow(2)/z_base_std.pow(2)).sum(-1) # [T, B, s] 66 | 67 | # Reshape z into B and s 68 | z = z.view(T, B, ELBO_samples, self.zsize) # [T, B, s, E] 69 | 70 | return z, log_q_z 71 | 72 | def init_hidden_rnn(self, batch_size): 73 | weight = next(self.parameters()) 74 | h = weight.new_zeros(self.q_rnn_layers*2, batch_size, self.hidden_size) 75 | c = weight.new_zeros(self.q_rnn_layers*2, batch_size, self.hidden_size) 76 | return (h, c) 77 | 78 | 79 | class GenerativeBlock(nn.Module): 80 | def __init__(self, hidden_size, zsize, prior_type, dropout_p, dropout_locations, outp_rnn_layers, max_T, **kwargs): 81 | super().__init__() 82 | self.dropout = nn.Dropout(dropout_p) 83 | 84 | # Prior 85 | if prior_type not in ['AF', 'IAF', 'hiddenflow_only']: 86 | raise ValueError('Error, prior_type %s unknown' % prior_type) 87 | 88 | p_rnn_layers = kwargs['p_rnn_layers'] 89 | p_rnn_units = kwargs['p_rnn_units'] 90 | if p_rnn_units < 0: 91 | p_rnn_units = hidden_size 92 | 93 | p_num_flow_layers = kwargs['p_num_flow_layers'] 94 | transform_function = kwargs['transform_function'] 95 | hiddenflow_params = {k: v for k, v in kwargs.items() if 'hiddenflow' in k} 96 | 97 | self.prior = AFPrior(p_rnn_units, zsize, dropout_p, dropout_locations, prior_type, p_num_flow_layers, p_rnn_layers, 98 | max_T=max_T, transform_function=transform_function, hiddenflow_params=hiddenflow_params) 99 | 100 | # BiLSTM 101 | self.rnn_outp = nn.LSTM(zsize + 2*max_T, hidden_size, outp_rnn_layers, dropout=dropout_p if 'rnn_outp' in dropout_locations else 0, bidirectional=True) 102 | self.outp_dim = 2*hidden_size + zsize 103 | 104 | self.outp_rnn_layers = outp_rnn_layers 105 | self.hidden_size = hidden_size 106 | self.zsize = zsize 107 | self.dropout_locations = dropout_locations 108 | 109 | def apply_bilstm(self, z, lengths_s, cond_inp_s): 110 | """ 111 | z is [T, B, s, E] 112 | """ 113 | T, B, ELBO_samples = z.shape[:3] 114 | 115 | hidden_outp = self.init_hidden(B) 116 | hidden_outp = tuple(h[:, :, None, :].repeat(1, 1, ELBO_samples, 1).view(-1, B*ELBO_samples, self.hidden_size) for h in hidden_outp) 117 | 118 | z = z.view(T, B*ELBO_samples, self.zsize) 119 | z_packed = z 120 | if 'z_before_outp' in self.dropout_locations: 121 | z_packed = self.dropout(z_packed) 122 | 123 | z_packed = torch.cat((z_packed, cond_inp_s), -1) 124 | 125 | total_length = z_packed.shape[0] 126 | z_packed = nn.utils.rnn.pack_padded_sequence(z_packed, lengths_s) 127 | rnn_outp_outp, _ = self.rnn_outp(z_packed, hidden_outp) 128 | rnn_outp_outp = nn.utils.rnn.pad_packed_sequence(rnn_outp_outp, total_length=total_length)[0] 129 | 130 | if 'rnn_outp' in self.dropout_locations: 131 | rnn_outp_outp = self.dropout(rnn_outp_outp) 132 | 133 | z_cat = z.view(T, B, ELBO_samples, self.zsize) 134 | rnn_outp_outp = rnn_outp_outp.view(T, B, ELBO_samples, 2, self.hidden_size) 135 | 136 | # Reorganize rnn output 137 | hidden_outp_sep = hidden_outp[0].view(self.outp_rnn_layers, 2, B, ELBO_samples, self.hidden_size) 138 | 139 | rnn_outp_outp_shifted_forward = torch.cat((hidden_outp_sep[-1:, 0], rnn_outp_outp[:, :, :, 0]), 0)[:-1] # [T, B, s, hidden] 140 | rnn_outp_outp_shifted_backward = torch.cat((rnn_outp_outp[:, :, :, 1], hidden_outp_sep[-1:, 1]), 0)[1:] 141 | 142 | z_with_hist = torch.cat((z_cat, rnn_outp_outp_shifted_forward, rnn_outp_outp_shifted_backward), -1) 143 | 144 | return z_with_hist 145 | 146 | def init_hidden(self, batch_size): 147 | weight = next(self.parameters()) 148 | h = weight.new_zeros(self.outp_rnn_layers*2, batch_size, self.hidden_size) 149 | c = weight.new_zeros(self.outp_rnn_layers*2, batch_size, self.hidden_size) 150 | return (h, c) 151 | 152 | 153 | class DFModel(nn.Module): 154 | def __init__(self, vocab_size, loss_weights, n_inp_embedding, hidden_size, zsize, dropout_p, dropout_locations, # general parameters 155 | prior_type, gen_bilstm_layers, prior_kwargs, # gen block parameters 156 | q_rnn_layers, tie_weights, max_T, indep_bernoulli=False): # misc parameters 157 | super().__init__() 158 | 159 | for loc in dropout_locations: 160 | if loc not in ['embedding', 'rnn_x', 'z_before_prior', 'prior_rnn_inp', 'prior_rnn', 'prior_ff', 'z_before_outp', 'rnn_outp', 'outp_ff']: 161 | raise ValueError('dropout location %s not a valid location' % loc) 162 | 163 | self.dropout = torch.nn.Dropout(dropout_p) 164 | 165 | ## Initial embedding 166 | self.input_embedding = torch.nn.Embedding(vocab_size, n_inp_embedding) 167 | 168 | ## Latent models 169 | self.generative_model = GenerativeBlock(hidden_size, zsize, prior_type, dropout_p, dropout_locations, gen_bilstm_layers, max_T, **prior_kwargs) 170 | 171 | self.inference_model = InferenceBlock(n_inp_embedding, hidden_size, zsize, dropout_p, q_rnn_layers, dropout_locations, max_T) 172 | 173 | ## Generative output to x 174 | self.outp_ff = FeedForwardNet(self.generative_model.outp_dim, hidden_size, vocab_size, 1, 'none', dropout=dropout_p if 'outp_ff' in dropout_locations else 0) 175 | 176 | if tie_weights: 177 | self.outp_ff.network[-1].weight = self.input_embedding.weight 178 | 179 | if indep_bernoulli: 180 | self.cross_entropy = torch.nn.BCEWithLogitsLoss(reduction='none') 181 | else: 182 | self.cross_entropy = torch.nn.CrossEntropyLoss(loss_weights, reduction='none') 183 | 184 | self.indep_bernoulli = indep_bernoulli 185 | self.vocab_size = vocab_size 186 | self.dropout_locations = dropout_locations 187 | self.max_T = max_T 188 | 189 | self.reset_parameters() 190 | 191 | def reset_parameters(self): 192 | init_range = 0.07 193 | self.input_embedding.weight.data.uniform_(-init_range, init_range) 194 | 195 | def generate(self, lengths, temp=1.0, argmax_x=True): 196 | """ 197 | lengths is [B] with lengths of each sentence in the batch 198 | all inputs should be on the same compute device 199 | """ 200 | 201 | T = torch.max(lengths) 202 | B = lengths.shape[0] 203 | 204 | ## Calculate position conditioning 205 | pos_cond = make_pos_cond(T, B, lengths, self.max_T) 206 | 207 | ## Generate z's from prior 208 | z, _ = self.generative_model.prior.generate(lengths, cond_inp=pos_cond, temp=temp) 209 | z = z[:, :, None, :] 210 | 211 | ## Apply BiLSTM part of likelihood 212 | gen_outp = self.generative_model.apply_bilstm(z, lengths, cond_inp_s=pos_cond) 213 | gen_outp = gen_outp.squeeze(2) 214 | 215 | ## Final output 216 | scores = self.outp_ff(gen_outp) # [T, B, V] 217 | 218 | if argmax_x: 219 | if self.indep_bernoulli: 220 | probs = torch.sigmoid(scores) 221 | generation = (probs > 0.5).long() 222 | else: 223 | generation = torch.argmax(scores, -1) 224 | else: 225 | if self.indep_bernoulli: 226 | word_dist = Bernoulli(logits=scores) 227 | else: 228 | word_dist = Categorical(logits=scores) 229 | generation = word_dist.sample() 230 | 231 | return generation 232 | 233 | def evaluate_x(self, x, lengths, ELBO_samples=1): 234 | """ 235 | x is [T, B] with indices of tokens 236 | lengths is [B] with lengths of each sentence in the batch 237 | all inputs should be on the same compute device 238 | """ 239 | 240 | T, B = x.shape[:2] 241 | 242 | ## Create ELBO_sample versions of inputs copied across a new dimension 243 | lengths_s = lengths[:, None].repeat(1, ELBO_samples).view(-1) 244 | 245 | pos_cond = make_pos_cond(T, B, lengths, self.max_T) 246 | pos_cond_s = pos_cond[:, :, None, :].repeat(1, 1, ELBO_samples, 1).view(T, B*ELBO_samples, self.max_T*2) 247 | 248 | ## Get the initial x embeddings 249 | if self.indep_bernoulli: 250 | embeddings = torch.matmul(x, self.input_embedding.weight) 251 | else: 252 | embeddings = self.input_embedding(x) # [T, B, n_inp_embedding] 253 | 254 | if 'embedding' in self.dropout_locations: 255 | embeddings = self.dropout(embeddings) 256 | 257 | z, log_q_z = self.inference_model.sample_q_z(embeddings, lengths, pos_cond, ELBO_samples) # [T, B, s, E] 258 | 259 | log_p_z = self.generative_model.prior.evaluate(z, lengths_s, cond_inp_s=pos_cond_s) 260 | gen_outp = self.generative_model.apply_bilstm(z, lengths_s, cond_inp_s=pos_cond_s) 261 | gen_outp = gen_outp.view(gen_outp.shape[0], B*ELBO_samples, gen_outp.shape[-1]) # [T, B*s, 2*hidden] 262 | 263 | ## Final output 264 | scores = self.outp_ff(gen_outp) # [T, B, s, V] 265 | 266 | if self.indep_bernoulli: 267 | targets = x[:, :, None, :].repeat(1, 1, ELBO_samples, 1) 268 | reconst_loss = self.cross_entropy(scores.view(-1, self.vocab_size), targets.view(-1, self.vocab_size)).view(T, B, ELBO_samples, self.vocab_size).sum(-1) # [T, B, s] 269 | else: 270 | targets = x[:, :, None].repeat(1, 1, ELBO_samples) 271 | reconst_loss = self.cross_entropy(scores.view(-1, self.vocab_size), targets.view(-1)).view(T, B, ELBO_samples) 272 | 273 | kl_loss = (log_q_z - log_p_z) # [T, B, s] 274 | 275 | return reconst_loss, kl_loss 276 | -------------------------------------------------------------------------------- /flows.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable 9 | 10 | from MADE import MADE 11 | from common import FeedForwardNet 12 | 13 | # Transformation functions 14 | 15 | class Affine(): 16 | num_params = 2 17 | 18 | @staticmethod 19 | def get_pseudo_params(nn_outp): 20 | a = nn_outp[..., 0] # [B, D] 21 | var_outp = nn_outp[..., 1] 22 | 23 | b = torch.exp(0.5*var_outp) 24 | logbsq = var_outp 25 | 26 | return a, logbsq, b 27 | 28 | @staticmethod 29 | def standard(x, nn_outp): 30 | a, logbsq, b = Affine.get_pseudo_params(nn_outp) 31 | 32 | y = a + b*x 33 | logdet = 0.5*logbsq.sum(-1) 34 | 35 | return y, logdet 36 | 37 | @staticmethod 38 | def reverse(y, nn_outp): 39 | a, logbsq, b = Affine.get_pseudo_params(nn_outp) 40 | 41 | x = (y - a)/b 42 | logdet = 0.5*logbsq.sum(-1) 43 | 44 | return x, logdet 45 | 46 | def arccosh(x): 47 | return torch.log(x + torch.sqrt(x.pow(2)-1)) 48 | 49 | def arcsinh(x): 50 | return torch.log(x + torch.sqrt(x.pow(2)+1)) 51 | 52 | class NLSq(): 53 | num_params = 5 54 | logA = math.log(8*math.sqrt(3)/9-0.05) # 0.05 is a small number to prevent exactly 0 slope 55 | 56 | @staticmethod 57 | def get_pseudo_params(nn_outp): 58 | a = nn_outp[..., 0] # [B, D] 59 | logb = nn_outp[..., 1]*0.4 60 | B = nn_outp[..., 2]*0.3 61 | logd = nn_outp[..., 3]*0.4 62 | f = nn_outp[..., 4] 63 | 64 | b = torch.exp(logb) 65 | d = torch.exp(logd) 66 | c = torch.tanh(B)*torch.exp(NLSq.logA + logb - logd) 67 | 68 | return a, b, c, d, f 69 | 70 | @staticmethod 71 | def standard(x, nn_outp): 72 | a, b, c, d, f = NLSq.get_pseudo_params(nn_outp) 73 | 74 | # double needed for stability. No effect on overall speed 75 | a = a.double() 76 | b = b.double() 77 | c = c.double() 78 | d = d.double() 79 | f = f.double() 80 | x = x.double() 81 | 82 | aa = -b*d.pow(2) 83 | bb = (x-a)*d.pow(2) - 2*b*d*f 84 | cc = (x-a)*2*d*f - b*(1+f.pow(2)) 85 | dd = (x-a)*(1+f.pow(2)) - c 86 | 87 | p = (3*aa*cc - bb.pow(2))/(3*aa.pow(2)) 88 | q = (2*bb.pow(3) - 9*aa*bb*cc + 27*aa.pow(2)*dd)/(27*aa.pow(3)) 89 | 90 | t = -2*torch.abs(q)/q*torch.sqrt(torch.abs(p)/3) 91 | inter_term1 = -3*torch.abs(q)/(2*p)*torch.sqrt(3/torch.abs(p)) 92 | inter_term2 = 1/3*arccosh(torch.abs(inter_term1-1)+1) 93 | t = t*torch.cosh(inter_term2) 94 | 95 | tpos = -2*torch.sqrt(torch.abs(p)/3) 96 | inter_term1 = 3*q/(2*p)*torch.sqrt(3/torch.abs(p)) 97 | inter_term2 = 1/3*arcsinh(inter_term1) 98 | tpos = tpos*torch.sinh(inter_term2) 99 | 100 | t[p > 0] = tpos[p > 0] 101 | y = t - bb/(3*aa) 102 | 103 | arg = d*y + f 104 | denom = 1 + arg.pow(2) 105 | 106 | x_new = a + b*y + c/denom 107 | 108 | logdet = -torch.log(b - 2*c*d*arg/denom.pow(2)).sum(-1) 109 | 110 | y = y.float() 111 | logdet = logdet.float() 112 | 113 | return y, logdet 114 | 115 | 116 | @staticmethod 117 | def reverse(y, nn_outp): 118 | a, b, c, d, f = NLSq.get_pseudo_params(nn_outp) 119 | 120 | arg = d*y + f 121 | denom = 1 + arg.pow(2) 122 | x = a + b*y + c/denom 123 | 124 | logdet = -torch.log(b - 2*c*d*arg/denom.pow(2)).sum(-1) 125 | 126 | return x, logdet 127 | 128 | class SCFLayer(nn.Module): 129 | def __init__(self, data_dim, n_hidden_layers, n_hidden_units, nonlinearity, transform_function, hidden_order=None, swap_trngen_dirs=False, 130 | input_order=None, conditional_inp_dim=None, dropout=[0, 0]): 131 | super().__init__() 132 | 133 | self.net = FeedForwardNet(data_dim//2 + conditional_inp_dim, n_hidden_units, (data_dim-(data_dim//2))*transform_function.num_params, n_hidden_layers, nonlinearity, dropout=dropout[1]) 134 | 135 | self.train_func = transform_function.standard if swap_trngen_dirs else transform_function.reverse 136 | self.gen_func = transform_function.reverse if swap_trngen_dirs else transform_function.standard 137 | self.input_order = input_order 138 | 139 | self.use_cond_inp = conditional_inp_dim is not None 140 | 141 | def forward(self, inputs): 142 | """ 143 | Defines the reverse pass which is used during training 144 | logdet means log det del_y/del_x 145 | """ 146 | 147 | data_dim = len(self.input_order) 148 | assert data_dim == inputs[0].shape[-1] 149 | 150 | first_indices = torch.arange(len(self.input_order))[self.input_order <= data_dim//2] # This is <= because input_order goes from 1 to data_dim+1 151 | second_indices = torch.arange(len(self.input_order))[self.input_order > data_dim//2] 152 | 153 | if self.use_cond_inp: 154 | y, logdet, cond_inp = inputs 155 | net_inp = torch.cat([y[..., first_indices], cond_inp], -1) 156 | else: 157 | y, logdet = inputs 158 | net_inp = y[..., first_indices] 159 | 160 | nn_outp = self.net(net_inp).view(*net_inp.shape[:-1], data_dim-(data_dim//2), -1) # [..., ~data_dim/2, num_params] 161 | 162 | x = torch.tensor(y) 163 | x[..., second_indices], change_logdet = self.train_func(y[..., second_indices], nn_outp) 164 | 165 | return x, logdet + change_logdet, cond_inp 166 | 167 | def generate(self, inputs): 168 | """ 169 | Defines the forward pass which is used during testing 170 | logdet means log det del_y/del_x 171 | """ 172 | 173 | data_dim = len(self.input_order) 174 | assert data_dim == inputs[0].shape[-1] 175 | 176 | first_indices = torch.arange(len(self.input_order))[self.input_order <= data_dim//2] # This is <= because input_order goes from 1 to data_dim+1 177 | second_indices = torch.arange(len(self.input_order))[self.input_order > data_dim//2] 178 | 179 | if self.use_cond_inp: 180 | x, logdet, cond_inp = inputs 181 | net_inp = torch.cat([x[..., first_indices], cond_inp], -1) 182 | else: 183 | x, logdet = inputs 184 | net_inp = x[..., first_indices] 185 | 186 | nn_outp = self.net(net_inp).view(*net_inp.shape[:-1], data_dim-(data_dim//2), -1) # [..., ~data_dim/2, num_params] 187 | 188 | y = torch.tensor(x) 189 | y[..., second_indices], change_logdet = self.gen_func(x[..., second_indices], nn_outp) 190 | 191 | return y, logdet + change_logdet, cond_inp 192 | 193 | 194 | class AFLayer(nn.Module): 195 | def __init__(self, data_dim, n_hidden_layers, n_hidden_units, nonlinearity, transform_function, hidden_order='sequential', swap_trngen_dirs=False, 196 | input_order=None, conditional_inp_dim=None, dropout=[0, 0], coupling_level=0): 197 | super().__init__() 198 | 199 | self.made = MADE(data_dim, n_hidden_layers, n_hidden_units, nonlinearity, hidden_order, 200 | out_dim_per_inp_dim=transform_function.num_params, input_order=input_order, conditional_inp_dim=conditional_inp_dim, 201 | dropout=dropout) 202 | 203 | self.train_func = transform_function.standard if swap_trngen_dirs else transform_function.reverse 204 | self.gen_func = transform_function.reverse if swap_trngen_dirs else transform_function.standard 205 | self.output_order = self.made.end_order 206 | self.data_dim = data_dim 207 | 208 | self.use_cond_inp = conditional_inp_dim is not None 209 | 210 | def forward(self, inputs): 211 | """ 212 | Defines the reverse pass which is used during training 213 | logdet means log det del_y/del_x 214 | """ 215 | 216 | if self.use_cond_inp: 217 | y, logdet, cond_inp = inputs 218 | nn_outp = self.made([y, cond_inp]) # [B, D, 2] 219 | else: 220 | y, logdet = inputs 221 | nn_outp = self.made(y) # [B, D, 2] 222 | 223 | x, change_logdet = self.train_func(y, nn_outp) 224 | 225 | return x, logdet + change_logdet, cond_inp 226 | 227 | def generate(self, inputs): 228 | """ 229 | Defines the forward pass which is used during testing 230 | logdet means log det del_y/del_x 231 | """ 232 | if self.use_cond_inp: 233 | x, logdet, cond_inp = inputs 234 | else: 235 | x, logdet = inputs 236 | 237 | y = torch.tensor(x) 238 | for idx in range(self.data_dim): 239 | t = (self.output_order==idx).nonzero()[0][0] 240 | 241 | if self.use_cond_inp: 242 | nn_outp = self.made([y, cond_inp]) 243 | else: 244 | nn_outp = self.made(y) 245 | 246 | y[..., t:t+1], new_partial_logdet = self.gen_func(x[..., t:t+1], nn_outp[..., t:t+1, :]) 247 | logdet += new_partial_logdet 248 | 249 | return y, logdet, cond_inp 250 | 251 | # Full flow combining multiple layers 252 | 253 | class Flow(nn.Module): 254 | def __init__(self, data_dim, n_hidden_layers, n_hidden_units, nonlinearity, num_flow_layers, transform_function, 255 | iaf_like=False, hidden_order='sequential', 256 | swap_trngen_dirs=False, conditional_inp_dim=None, dropout=[0, 0], reverse_between_layers=True, 257 | scf_layers=False, reverse_first_layer=False): 258 | super().__init__() 259 | 260 | if transform_function == 'affine': 261 | transform_function = Affine 262 | elif transform_function == 'nlsq': 263 | transform_function = NLSq 264 | elif transform_function != Affine and transform_function != NLSq: # Can pass string or actual class 265 | raise NotImplementedError('Only the affine transformation function has been implemented') 266 | 267 | if scf_layers: 268 | AutoregressiveLayer = SCFLayer 269 | else: 270 | AutoregressiveLayer = AFLayer 271 | 272 | # Note: This ordering is the ordering as applied to go from data -> base 273 | flow_layers = [] 274 | 275 | input_order = torch.arange(data_dim)+1 276 | 277 | if reverse_first_layer: 278 | input_order = reversed(input_order) 279 | 280 | for i in range(num_flow_layers): 281 | flow_layers.append(AutoregressiveLayer(data_dim, n_hidden_layers, n_hidden_units, nonlinearity, transform_function, 282 | hidden_order=hidden_order, swap_trngen_dirs=swap_trngen_dirs, input_order=input_order, 283 | conditional_inp_dim=conditional_inp_dim, dropout=dropout)) 284 | if reverse_between_layers: 285 | input_order = reversed(input_order) 286 | 287 | self.flow = nn.Sequential(*flow_layers) 288 | self.use_cond_inp = conditional_inp_dim is not None 289 | 290 | def forward(self, inputs): 291 | """ 292 | Defines the reverse pass which is used during training 293 | logdet means log det del_y/del_x 294 | """ 295 | if self.use_cond_inp: 296 | y, cond_inp = inputs 297 | else: 298 | y = inputs 299 | 300 | logdet = torch.zeros(y.shape[:-1], device=y.device) 301 | 302 | if self.use_cond_inp: 303 | x, logdet, _ = self.flow([y, logdet, cond_inp]) 304 | else: 305 | x, logdet = self.flow([y, logdet]) 306 | 307 | return x, logdet 308 | 309 | def generate(self, inputs): 310 | """ 311 | Defines the forward pass which is used during testing 312 | logdet means log det del_y/del_x 313 | """ 314 | 315 | if self.use_cond_inp: 316 | x, cond_inp = inputs 317 | else: 318 | x = inputs 319 | 320 | logdet = torch.zeros(x.shape[:-1], device=x.device) 321 | y = x 322 | for flow_layer in reversed(self.flow): 323 | if self.use_cond_inp: 324 | y, logdet, _ = flow_layer.generate([y, logdet, cond_inp]) 325 | else: 326 | y, logdet = flow_layer.generate([y, logdet]) 327 | 328 | return y, logdet 329 | -------------------------------------------------------------------------------- /lstm_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import math 4 | import time 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from flows import Affine, NLSq 10 | from flows import Flow as MADE_flow 11 | from common import FeedForwardNet 12 | from utils import reverse_padded_sequence 13 | 14 | 15 | class LSTM_AFLayer(nn.Module): 16 | def __init__(self, layer_num, inp_dim, n_hidden_layers, n_hidden_units, dropout_p, transform_function, rnn_cond_dim=None, 17 | swap_trngen_dirs=False, reverse_inps=False, hiddenflow_params={}, dlocs=['rnn', 'rnn_outp'], notimecontext=False): 18 | super().__init__() 19 | 20 | self.rnn_inp_drop = nn.Dropout(dropout_p) 21 | self.rnn_outp_drop = nn.Dropout(dropout_p) 22 | 23 | lstm_inp_dim = inp_dim 24 | self.use_rnn_cond_inp = rnn_cond_dim is not None 25 | if self.use_rnn_cond_inp: 26 | lstm_inp_dim += rnn_cond_dim 27 | if not notimecontext: 28 | self.initial_hidden_cond_ff = FeedForwardNet(rnn_cond_dim, n_hidden_units, 2*n_hidden_units*n_hidden_layers, 1, 'relu') 29 | 30 | if not notimecontext: 31 | self.lstm = nn.LSTM(lstm_inp_dim, n_hidden_units, n_hidden_layers, dropout=dropout_p if 'rnn' in dlocs else 0) 32 | after_rnn_inp_units = n_hidden_units 33 | else: 34 | after_rnn_inp_units = rnn_cond_dim 35 | 36 | # Whether or not a MADE autoregressive flow (hiddenflow) should be used to model p(z_t), or if the elements of z_t should be independent 37 | self.use_hiddenflow = not hiddenflow_params['nohiddenflow'] 38 | if self.use_hiddenflow: 39 | if reverse_inps: 40 | raise NotImplementedError('hiddenflow with reversing the inputs in time has not been implemented. Will have to take into account the fact that '+ 41 | 'IAF and AF flows both use forward for the training pass, which is diffent than the convention here.') 42 | 43 | hiddenflow_layers = hiddenflow_params['hiddenflow_layers'] 44 | hiddenflow_units = hiddenflow_params['hiddenflow_units'] 45 | hiddenflow_flow_layers = hiddenflow_params['hiddenflow_flow_layers'] # if > 1, automatically reverses the order 46 | hiddenflow_scf_layers = hiddenflow_params['hiddenflow_scf_layers'] 47 | hiddenflow_reverse_first = layer_num % 2 == 1 48 | 49 | if hiddenflow_units <= inp_dim: 50 | raise ValueError('Error, hiddenflow_units must be greater than the inp_dim so all inp variables have connections to the output') 51 | 52 | MADE_dropout = [dropout_p, dropout_p] if 'ff' in dlocs else [0, 0] 53 | self.outp_net = MADE_flow(inp_dim, hiddenflow_layers, hiddenflow_units, 'relu', hiddenflow_flow_layers, transform_function, iaf_like=False, 54 | swap_trngen_dirs=swap_trngen_dirs, conditional_inp_dim=after_rnn_inp_units, dropout=MADE_dropout, 55 | reverse_between_layers=True, scf_layers=hiddenflow_scf_layers, reverse_first_layer=hiddenflow_reverse_first) 56 | else: 57 | if notimecontext: 58 | raise ValueError('notimecontext does not make sense without MADE layers') 59 | 60 | self.outp_net = nn.Linear(after_rnn_inp_units, transform_function.num_params*inp_dim) 61 | 62 | self.num_params = transform_function.num_params 63 | self.train_func = transform_function.standard if swap_trngen_dirs else transform_function.reverse 64 | self.gen_func = transform_function.reverse if swap_trngen_dirs else transform_function.standard 65 | 66 | self.layer_num = layer_num # Needed to keep track of hidden states 67 | self.n_hidden_layers = n_hidden_layers # Needed for init_hidden 68 | self.n_hidden_units = n_hidden_units # Needed for init_hidden 69 | self.inp_dim = inp_dim # Needed for init_last_nn_outp 70 | self.reverse_inps = reverse_inps 71 | self.dlocs = dlocs # Options are [rnn_inp, rnn, rnn_outp, made] 72 | self.notimecontext = notimecontext 73 | 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | init_range = 0.07 78 | if not self.use_hiddenflow: 79 | self.outp_net.weight.data.uniform_(-init_range, init_range) 80 | self.outp_net.bias.data.zero_() 81 | 82 | def forward(self, inputs): 83 | """ 84 | Defines the reverse pass which is used during training 85 | logdet means log det del_y/del_x 86 | """ 87 | 88 | y, logdet, hiddens, rnn_cond_inp, lengths = inputs # y is [T, B, inp_dim] 89 | 90 | y_packed = y 91 | cur_rnn_cond_inp = rnn_cond_inp 92 | B = y.shape[1] 93 | 94 | lengths_inp = lengths 95 | if not self.notimecontext: 96 | if self.reverse_inps: 97 | y_packed = reverse_padded_sequence(y_packed, lengths_inp) 98 | cur_rnn_cond_inp = reverse_padded_sequence(rnn_cond_inp, lengths_inp) 99 | 100 | if self.use_rnn_cond_inp: 101 | actual_hidden = self.initial_hidden_cond_ff(cur_rnn_cond_inp[0]).view(B, self.n_hidden_layers, self.n_hidden_units, 2) # [B, layers, hidden, 2] 102 | actual_hidden = actual_hidden.transpose(0, 1) # [layers, B, hidden, 2] 103 | actual_hidden = tuple([actual_hidden[..., 0].contiguous(), actual_hidden[..., 1].contiguous()]) 104 | hiddens[self.layer_num] = actual_hidden 105 | 106 | cur_rnn_cond_inp_shifted = torch.cat((cur_rnn_cond_inp[1:], cur_rnn_cond_inp.new_zeros((1, *cur_rnn_cond_inp.shape[1:]))), 0) 107 | y_packed = torch.cat((y_packed, cur_rnn_cond_inp_shifted), -1) 108 | 109 | if 'rnn_inp' in self.dlocs: 110 | y_packed = self.rnn_inp_drop(y_packed) 111 | 112 | if not self.notimecontext: 113 | total_length = y_packed.shape[0] 114 | y_packed = nn.utils.rnn.pack_padded_sequence(y_packed, lengths_inp) 115 | rnn_outp, final_hidden = self.lstm(y_packed, hiddens[self.layer_num]) 116 | rnn_outp = nn.utils.rnn.pad_packed_sequence(rnn_outp, total_length=total_length)[0] 117 | rnn_outp = torch.cat((hiddens[self.layer_num][0][-1:], rnn_outp), 0)[:-1] # This will correctly shift the outputs so they are actually autoregressive 118 | 119 | hiddens[self.layer_num] = final_hidden 120 | 121 | if self.reverse_inps: # Undo the reverse ordering so the outputs have the correct ordering 122 | rnn_outp = reverse_padded_sequence(rnn_outp, lengths_inp) 123 | 124 | if 'rnn_outp' in self.dlocs: 125 | rnn_outp = self.rnn_outp_drop(rnn_outp) 126 | 127 | 128 | if self.use_hiddenflow: 129 | hiddenflow_conditional = cur_rnn_cond_inp if self.notimecontext else rnn_outp 130 | x_new, change_logdet = self.outp_net([y, hiddenflow_conditional]) 131 | else: 132 | nn_outp = self.outp_net(rnn_outp) 133 | nn_outp = nn_outp.view(*nn_outp.shape[:-1], self.inp_dim, self.num_params) 134 | 135 | x_new, change_logdet = self.train_func(y, nn_outp) # x is [T, B, inp_dim], change_logdet is [T, B] 136 | 137 | x = x_new 138 | logdet += change_logdet 139 | 140 | return x, logdet, hiddens, rnn_cond_inp, lengths 141 | 142 | def generate(self, inputs): 143 | """ 144 | Defines the forward pass which is used during testing 145 | logdet means log det del_y/del_x 146 | """ 147 | 148 | x, logdet, hiddens, rnn_cond_inp, lengths = inputs 149 | 150 | rnn_cond_inp_touse = rnn_cond_inp 151 | if self.reverse_inps: 152 | x = reverse_padded_sequence(x, lengths) 153 | rnn_cond_inp_touse = reverse_padded_sequence(rnn_cond_inp, lengths) 154 | 155 | rnn_cond_inp_touse = torch.cat((rnn_cond_inp_touse, rnn_cond_inp_touse.new_zeros((1, *rnn_cond_inp_touse.shape[1:]))), 0) 156 | 157 | y = torch.tensor(x) # [T, B, inp_dim] 158 | change_logdet = torch.zeros_like(logdet) # [T, B] 159 | 160 | if self.use_rnn_cond_inp: 161 | B = x.shape[1] 162 | actual_hidden = self.initial_hidden_cond_ff(rnn_cond_inp_touse[0]).view(B, self.n_hidden_layers, self.n_hidden_units, 2) # [B, layers, hidden, 2] 163 | actual_hidden = actual_hidden.transpose(0, 1) # [layers, B, hidden, 2] 164 | actual_hidden = tuple([actual_hidden[..., 0].contiguous(), actual_hidden[..., 1].contiguous()]) 165 | hiddens[self.layer_num] = actual_hidden 166 | 167 | last_rnn_outp = hiddens[self.layer_num][0][-1:] # [1, B, hidden] 168 | last_hiddens = hiddens[self.layer_num] 169 | for t in range(x.shape[0]): 170 | if 'rnn_outp' in self.dlocs: 171 | last_rnn_outp = self.rnn_outp_drop(last_rnn_outp) 172 | 173 | if self.use_hiddenflow: 174 | y[t:t+1], new_partial_logdet = self.outp_net.generate([x[t], last_rnn_outp[0]]) 175 | else: 176 | nn_outp = self.outp_net(last_rnn_outp) 177 | nn_outp = nn_outp.view(1, last_rnn_outp.shape[1], self.inp_dim, self.num_params) 178 | 179 | y[t:t+1], new_partial_logdet = self.gen_func(x[t:t+1], nn_outp) 180 | 181 | change_logdet[t] = new_partial_logdet 182 | 183 | if self.use_rnn_cond_inp: 184 | rnn_cond_inp_t = rnn_cond_inp_touse[t+1:t+2] 185 | lstm_inp = torch.cat((y[t:t+1], rnn_cond_inp_t), -1) 186 | else: 187 | lstm_inp = y[t:t+1].clone() 188 | 189 | if 'rnn_inp' in self.dlocs: 190 | lstm_inp = self.rnn_inp_drop(lstm_inp) 191 | 192 | last_rnn_outp, last_hiddens = self.lstm(lstm_inp, last_hiddens) 193 | 194 | for h in last_hiddens: 195 | h[:, :, :] = -9999999999 # If lengths is provided, then the hidden output provided by this function is wrong. If they're ever used for anything, this should make it clear there's an error 196 | 197 | hiddens[self.layer_num] = last_hiddens 198 | 199 | if self.reverse_inps: 200 | y = reverse_padded_sequence(y, lengths_inp) 201 | change_logdet = reverse_padded_sequence(change_logdet, lengths_inp) 202 | 203 | return y, logdet + change_logdet, hiddens, rnn_cond_inp, lengths 204 | 205 | def init_hidden(self, batch_size): 206 | weight = next(self.parameters()) 207 | h = weight.new_zeros(self.n_hidden_layers, batch_size, self.n_hidden_units) 208 | c = weight.new_zeros(self.n_hidden_layers, batch_size, self.n_hidden_units) 209 | return (h, c) 210 | 211 | # Full flow combining multiple layers 212 | 213 | class LSTMFlow(nn.Module): 214 | def __init__(self, inp_dim, n_hidden_layers, n_hidden_units, dropout_p, num_flow_layers, transform_function, 215 | rnn_cond_dim=None, swap_trngen_dirs=False, 216 | sequential_training=False, reverse_ordering=False, hiddenflow_params={}, 217 | dlocs=[], notimecontext=False): 218 | super().__init__() 219 | 220 | if transform_function == 'affine': 221 | transform_function = Affine 222 | elif transform_function == 'nlsq': 223 | transform_function = NLSq 224 | else: 225 | raise NotImplementedError('Only the affine and nlsq transformation functions have been implemented') 226 | 227 | # Note: This ordering is the ordering as applied during training 228 | flow_layers = [] 229 | reverse_inps = False 230 | 231 | # This is neccessary so that q(z) and p(z) are based on the same ordering if there are an even number of layers and IAF posterior is used 232 | if swap_trngen_dirs and num_flow_layers % 2 == 0: 233 | reverse_inps = True 234 | 235 | # This is needed after the previous line, because if using sequential training for p (i.e. IAF prior) you don't want to start with reversed inputs if you have an even number of flow layers 236 | if sequential_training: 237 | swap_trngen_dirs = not swap_trngen_dirs 238 | 239 | for i in range(num_flow_layers): 240 | flow_layers.append(LSTM_AFLayer(i, inp_dim, n_hidden_layers, n_hidden_units, dropout_p, transform_function, 241 | rnn_cond_dim=rnn_cond_dim, swap_trngen_dirs=swap_trngen_dirs, reverse_inps=reverse_inps, 242 | hiddenflow_params=hiddenflow_params, dlocs=dlocs, notimecontext=notimecontext)) 243 | if reverse_ordering: 244 | reverse_inps = not reverse_inps 245 | 246 | self.flow = nn.Sequential(*flow_layers) 247 | self.use_rnn_cond_inp = rnn_cond_dim is not None 248 | self.sequential_training = sequential_training 249 | 250 | def forward(self, y, hiddens, lengths, rnn_cond_inp=None): 251 | """ 252 | Defines the reverse pass which is used during training 253 | logdet means log det del_y/del_x 254 | """ 255 | #if self.use_cond_inp: 256 | # y, hiddens, cond_inp = inputs 257 | #else: 258 | # y, hiddens = inputs 259 | 260 | if self.use_rnn_cond_inp and rnn_cond_inp is None: 261 | raise ValueError("use_rnn_cond_inp is set but rnn_cond_inp is None in forward") 262 | 263 | logdet = torch.zeros(y.shape[:-1], device=y.device) 264 | 265 | if self.sequential_training: 266 | x = y 267 | for flow_layer in reversed(self.flow): 268 | x, logdet, hiddens, _, _ = flow_layer.generate([x, logdet, hiddens, rnn_cond_inp, lengths]) 269 | else: 270 | x, logdet, hiddens, _, _ = self.flow([y, logdet, hiddens, rnn_cond_inp, lengths]) 271 | 272 | return x, logdet, hiddens 273 | 274 | def generate(self, x, hiddens, lengths, rnn_cond_inp=None): 275 | """ 276 | Defines the forward pass which is used during testing 277 | logdet means log det del_y/del_x 278 | """ 279 | 280 | if self.use_rnn_cond_inp and rnn_cond_inp is None: 281 | raise ValueError("use_rnn_cond_inp is set but rnn_cond_inp is None in generate") 282 | 283 | logdet = torch.zeros(x.shape[:-1], device=x.device) 284 | 285 | if self.sequential_training: 286 | y, logdet, hiddens, _, _ = self.flow([x, logdet, hiddens, rnn_cond_inp, lengths]) 287 | else: 288 | y = x 289 | for flow_layer in reversed(self.flow): 290 | y, logdet, hiddens, _, _ = flow_layer.generate([y, logdet, hiddens, rnn_cond_inp, lengths]) 291 | 292 | return y, logdet, hiddens 293 | 294 | def init_hidden(self, batch_size): 295 | return [fl.init_hidden(batch_size) for fl in self.flow] 296 | 297 | # Prior using the LSTMFlow 298 | 299 | class AFPrior(nn.Module): 300 | def __init__(self, hidden_size, zsize, dropout_p, dropout_locations, prior_type, num_flow_layers, rnn_layers, max_T=-1, 301 | transform_function='affine', hiddenflow_params={}): 302 | super().__init__() 303 | 304 | sequential_training = prior_type == 'IAF' 305 | notimecontext = prior_type == 'hiddenflow_only' 306 | 307 | dlocs = [] 308 | if 'prior_rnn' in dropout_locations: 309 | dlocs.append('rnn') 310 | dlocs.append('rnn_outp') 311 | if 'prior_rnn_inp' in dropout_locations: 312 | dlocs.append('rnn_inp') 313 | if 'prior_ff' in dropout_locations: 314 | dlocs.append('ff') 315 | 316 | self.flow = LSTMFlow(zsize, rnn_layers, hidden_size, dropout_p, num_flow_layers, 317 | transform_function, rnn_cond_dim=2*max_T, 318 | sequential_training=sequential_training, hiddenflow_params=hiddenflow_params, dlocs=dlocs, 319 | notimecontext=notimecontext) 320 | 321 | self.dropout = nn.Dropout(dropout_p) 322 | 323 | self.hidden_size = hidden_size 324 | self.zsize = zsize 325 | self.dropout_locations=dropout_locations 326 | 327 | def evaluate(self, z, lengths_s, cond_inp_s=None): 328 | """ 329 | z is [T, B, s, E] 330 | output is log_p_z [T, B, s] 331 | """ 332 | T, B, ELBO_samples = z.shape[:3] 333 | 334 | hidden = self.flow.init_hidden(B) 335 | hidden = [tuple(h[:, :, None, :].repeat(1, 1, ELBO_samples, 1).view(-1, ELBO_samples*B, self.hidden_size) for h in hidden_pl) for hidden_pl in hidden] 336 | 337 | if 'z_before_prior' in self.dropout_locations: 338 | z = self.dropout(z) 339 | 340 | z = z.view(T, B*ELBO_samples, z.shape[-1]) 341 | eps, logdet, _ = self.flow(z, hidden, lengths_s, rnn_cond_inp=cond_inp_s) 342 | eps = eps.view(T, B, ELBO_samples, self.zsize) 343 | logdet = logdet.view(T, B, ELBO_samples) 344 | 345 | log_p_eps = -1/2*(math.log(2*math.pi) + eps.pow(2)).sum(-1) # [T, B, s] 346 | log_p_z = log_p_eps - logdet 347 | 348 | return log_p_z 349 | 350 | def generate(self, lengths, cond_inp=None, temp=1.0): 351 | T = torch.max(lengths) 352 | B = lengths.shape[0] 353 | 354 | hidden = self.flow.init_hidden(B) 355 | 356 | eps = torch.randn((T, B, self.zsize), device=hidden[0][0].device)*temp 357 | z, logdet, _ = self.flow.generate(eps, hidden, lengths, rnn_cond_inp=cond_inp) 358 | 359 | log_p_eps = -1/2*(math.log(2*math.pi) + eps.pow(2)).sum(-1) # [T, B] 360 | log_p_zs = log_p_eps - logdet 361 | 362 | return z, log_p_zs 363 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import random 4 | import math 5 | import json 6 | import numpy as np 7 | 8 | import torchtext 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | from baseline_model import LSTMModel 13 | from discreteflow_model import DFModel 14 | from config import parse_args 15 | from utils import load_indep_bernoulli, load_categorical, get_optimizer, log, save, load, build_log_p_T, get_kl_weight 16 | 17 | 18 | def run_epoch(train, start_kl_weight, delta_kl_weight, NLL_samples, ds, steps=-1): 19 | if train: 20 | model.train() 21 | else: 22 | model.eval() 23 | 24 | total_loss = 0 25 | avg_kl = 0 26 | total_log_likelihood = 0 27 | total_tokens = 0 28 | start_time = time.time() 29 | 30 | accum_counter = 0 31 | for i, batch in enumerate(iter(ds)): 32 | if steps > 0 and i >= steps: 33 | break 34 | 35 | kl_weight = start_kl_weight + delta_kl_weight*i/len(ds) 36 | 37 | batch_data = Variable(batch.text[0].to(device)) 38 | lengths = Variable(batch.text[1].to(device)) 39 | 40 | if train and accum_counter == 0: 41 | model.zero_grad() 42 | 43 | if args.model_type == 'baseline': 44 | loss = model(batch_data, lengths)[:, :, None] # [T, B, s] 45 | kl_loss = torch.zeros_like(loss) 46 | elif args.model_type == 'discrete_flow': 47 | reconst_loss, kl_loss = model.evaluate_x(batch_data, lengths, ELBO_samples=args.ELBO_samples) # Inputs should be [T, B], outputs should be [T, B, s] 48 | loss = reconst_loss + kl_weight*kl_loss 49 | 50 | # Exact loss is -(ELBO(x_i)+log_p_T(T_i))/T_i for each x_i 51 | # NLL bound is 1/sum(T_i)*sum(-(ELBO(x_i)+log_p_T(T_i))) 52 | 53 | indices = torch.arange(batch_data.shape[0]).view(-1, 1).to(device) 54 | loss_mask = indices >= lengths.view(1, -1) 55 | loss_mask = loss_mask[:, :, None].repeat(1, 1, loss.shape[-1]) 56 | 57 | loss[loss_mask] = 0 58 | kl_loss[loss_mask] = 0 59 | 60 | if not args.noT_condition: 61 | denom = (lengths+1).float() # if T conditioning, should normalizing by lengths+1 to be the same as normal -including models 62 | else: 63 | denom = (lengths).float() 64 | 65 | loss = loss.mean(-1).sum(0) # mean over ELBO samples and time, [B] 66 | if not args.noT_condition: 67 | loss -= log_p_T[lengths] # Take into account log_p_T for each batch (negative b/c this is NLL) 68 | 69 | obj = (loss/denom).mean() # Mean over batches 70 | 71 | if train: 72 | obj_per_accum = obj.clone()/args.grad_accum 73 | obj_per_accum.backward() 74 | 75 | accum_counter += 1 76 | if accum_counter == args.grad_accum: 77 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 78 | optimizer.step() 79 | accum_counter = 0 80 | 81 | # Estimate NLL with importance sampling 82 | if NLL_samples > 0 and args.model_type == 'discrete_flow': 83 | with torch.no_grad(): 84 | reconst_loss_val, kl_loss_val = model.evaluate_x(batch_data, lengths, ELBO_samples=NLL_samples) # [T, B, s] 85 | 86 | inside_terms = (-reconst_loss_val - kl_loss_val) # [T, B, s] 87 | loss_mask = indices >= lengths.view(1, -1) 88 | loss_mask = loss_mask[:, :, None].repeat(1, 1, NLL_samples) 89 | inside_terms[loss_mask] = 0 90 | 91 | inside_terms_sumT = inside_terms.sum(0) # [B, s] 92 | log_likelihood = torch.logsumexp(inside_terms_sumT, -1) - math.log(NLL_samples) # [B] 93 | 94 | if not args.noT_condition: 95 | log_likelihood += log_p_T[lengths] 96 | 97 | total_log_likelihood += log_likelihood.sum().item() 98 | 99 | kl_loss = kl_loss.mean(-1).sum(0) 100 | total_loss += loss.sum().item() 101 | avg_kl += kl_loss.sum().item() 102 | total_tokens += denom.sum().item() 103 | 104 | avg_kl /= total_tokens 105 | total_loss /= total_tokens 106 | total_log_likelihood /= total_tokens 107 | 108 | return total_loss, avg_kl, total_log_likelihood, time.time()-start_time 109 | 110 | # Setup 111 | ## ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 112 | 113 | # Get parameters 114 | device = torch.device('cuda') 115 | args = parse_args() 116 | np.random.seed(args.seed) 117 | torch.manual_seed(args.seed) 118 | random.seed(args.seed) 119 | 120 | # Load data 121 | if args.indep_bernoulli: 122 | (train, val, test), pad_val, vocab_size = load_indep_bernoulli(args.dataset) 123 | else: 124 | (train, val, test), pad_val, vocab_size = load_categorical(args.dataset, args.noT_condition) 125 | log_p_T, max_T = build_log_p_T(args, train, val) 126 | log_p_T = log_p_T.to(device) 127 | train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits((train, val, test), batch_sizes=[args.B_train, args.B_val, args.B_val], device=-1, repeat=False, sort_key=lambda x: len(x.text), sort_within_batch=True) 128 | 129 | # Build model 130 | loss_weights = torch.ones(vocab_size) 131 | loss_weights[pad_val] = 0 132 | 133 | if args.model_type == 'discrete_flow': 134 | prior_kwargs = {'p_rnn_layers': args.p_rnn_layers, 'p_rnn_units': args.p_rnn_units, 'p_num_flow_layers': args.p_num_flow_layers, 135 | 'nohiddenflow': args.nohiddenflow, 'hiddenflow_layers': args.hiddenflow_layers, 'hiddenflow_units': args.hiddenflow_units, 136 | 'hiddenflow_flow_layers': args.hiddenflow_flow_layers, 'hiddenflow_scf_layers': args.hiddenflow_scf_layers, 137 | 'transform_function': args.transform_function} 138 | model = DFModel(vocab_size, loss_weights, args.inp_embedding_size, args.hidden_size, args.zsize, args.dropout_p, args.dlocs, 139 | args.prior_type, args.gen_bilstm_layers, prior_kwargs, 140 | args.q_rnn_layers, not args.notie_weights, max_T, indep_bernoulli=args.indep_bernoulli).to(device) 141 | elif args.model_type == 'baseline': 142 | model = LSTMModel(vocab_size, loss_weights, args.inp_embedding_size, args.hidden_size, args.p_rnn_layers, args.dropout_p, T_condition=not args.noT_condition, 143 | max_T=max_T, tie_weights=not args.notie_weights, indep_bernoulli=args.indep_bernoulli).to(device) 144 | setattr(args, 'ELBO_samples', 1) 145 | setattr(args, 'nll_samples', 0) 146 | setattr(args, 'kl_rampup_time', 0) 147 | setattr(args, 'initial_kl_zero', 0) 148 | else: 149 | raise ValueError('model_type must be one of discrete_flow, baseline') 150 | 151 | # Build optimizer 152 | optimizer = get_optimizer(args.optim, model.parameters(), args.lr) 153 | 154 | # Load parameters if needed 155 | if args.load_dir: 156 | starting_epoch, best_val_loss, lr = load(model, optimizer, args) 157 | auto_lr = True 158 | cur_impatience = 0 159 | optimizer = get_optimizer(args.optim, model.parameters(), lr) 160 | else: 161 | starting_epoch = 0 162 | best_val_loss = 999999999 163 | lr = args.lr 164 | auto_lr = False 165 | 166 | # If evaluate_only, only do that and don't train 167 | 168 | if args.evaluate_only: 169 | torch.set_printoptions(threshold=10000) 170 | train_loss, train_kl, train_LL, train_time = run_epoch(False, 1.0, 0.0, args.nll_samples, train_iter, steps=200) 171 | print('train loss: %.5f, train NLL (%d): %.5f, train kl: %.5f, train time: %.2fs' % (train_loss, args.nll_samples, -train_LL, train_kl, train_time)) 172 | 173 | val_loss, val_kl, val_LL, val_time = run_epoch(False, 1.0, 0.0, args.nll_samples, val_iter) 174 | print('val loss: %.5f, val NLL (%d): %.5f, val kl: %.5f, val time: %.2fs' % (val_loss, args.nll_samples, -val_LL, val_kl, val_time)) 175 | 176 | test_loss, test_kl, test_LL, test_time = run_epoch(False, 1.0, 0.0, args.nll_samples, test_iter) 177 | print('test loss: %.5f, test NLL (%d): %.5f, test kl: %.5f, test time: %.2fs' % (test_loss, args.nll_samples, -test_LL, test_kl, test_time)) 178 | 179 | sys.exit() 180 | 181 | # Train 182 | ## ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 183 | 184 | log(args, '--------------------- NEW START ----------------------') 185 | 186 | # Save parameters 187 | with open(args.output_dir+'/'+args.run_name+'/params.txt', 'w') as f: 188 | f.write(json.dumps(args.__dict__, indent=4, sort_keys=True)) 189 | 190 | 191 | for i in range(starting_epoch, args.num_epochs): 192 | decrease_lr = False 193 | save_model = False 194 | 195 | last_kl_weight, _ = get_kl_weight(args, i-1) 196 | cur_kl_weight, done = get_kl_weight(args, i) 197 | 198 | if done: 199 | auto_lr = True 200 | 201 | train_loss, train_kl, _, train_time = run_epoch(True, last_kl_weight, cur_kl_weight-last_kl_weight, 0, train_iter) 202 | 203 | val_NLL_samples = args.nll_samples if (i+1)%args.nll_every == 0 else 0 204 | val_loss, val_kl, val_log_likelihood, val_time = run_epoch(False, cur_kl_weight, 0.0, val_NLL_samples, val_iter) 205 | 206 | log_str = 'Epoch %d | train loss: %.3f, val loss: %.3f, val NLL (%d): %.3f | train kl: %.3f, val kl: %.3f | kl_weight: %.3f, time: %.2fs/%.2fs' % \ 207 | (i, train_loss, val_loss, val_NLL_samples, -val_log_likelihood, train_kl, val_kl, cur_kl_weight, train_time, val_time) 208 | print(log_str) 209 | log(args, log_str) 210 | 211 | if auto_lr: 212 | if val_loss < best_val_loss: 213 | best_val_loss = val_loss 214 | cur_impatience = 0 215 | save_model = True 216 | else: 217 | cur_impatience += 1 218 | if cur_impatience == args.patience: 219 | decrease_lr = True 220 | 221 | if decrease_lr: 222 | lr /= 4 223 | optimizer = get_optimizer(args.optim, model.parameters(), lr) 224 | 225 | print('* Learning rate dropping by a factor of 4') 226 | log(args, '* Learning rate dropping by a factor of 4') 227 | cur_impatience = 0 228 | 229 | if save_model: 230 | save(model, optimizer, args, 'after_epoch_%d' % i, i+1, best_val_loss, lr) 231 | 232 | save(model, optimizer, args, 'end', args.num_epochs, best_val_loss, lr) 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.autograd import Function 5 | 6 | import torchtext 7 | from torchtext import data 8 | import io 9 | import os 10 | 11 | import numpy as np 12 | 13 | from data.indep_bernoulli.load_data import load_data 14 | 15 | # Data processing 16 | # ------------------------------------------------------------------------------------------------------------------------------ 17 | 18 | class SentenceLanguageModelingDataset(data.Dataset): 19 | def __init__(self, path, text_field, encoding='utf-8', include_eos=True, **kwargs): 20 | fields = [('text', text_field)] 21 | examples = [] 22 | with io.open(path, encoding=encoding) as f: 23 | for line in f: 24 | text = text_field.preprocess(line) 25 | if include_eos: 26 | text += [u''] 27 | examples.append(data.Example.fromlist([text], fields)) 28 | 29 | super().__init__(examples, fields, **kwargs) 30 | 31 | def load_indep_bernoulli(dataset): 32 | dset = load_data(dataset) 33 | 34 | class MultipleOutputExample: 35 | def __init__(self, tensor): 36 | self.text = tensor 37 | 38 | class MultipleOutputField(data.Field): 39 | def __init__(self, pad_index): 40 | super().__init__(include_lengths=True, use_vocab=False) 41 | self.pad_index = pad_index 42 | 43 | def process(self, batch, device, train): 44 | lengths = [len(batch_i) for batch_i in batch] 45 | max_length = max(lengths) 46 | 47 | D = batch[0].shape[1] 48 | 49 | new_list = [] 50 | for seq in batch: 51 | if len(seq) < max_length: 52 | padding = torch.zeros(1, D) 53 | padding[0, self.pad_index] = 1. 54 | padding = padding.repeat(max_length-len(seq), 1) 55 | seq = torch.cat((seq, padding), 0) 56 | new_list.append(seq) 57 | 58 | tensor = torch.stack(new_list) 59 | tensor = torch.transpose(tensor, 0, 1) 60 | 61 | lengths = torch.tensor(lengths) 62 | 63 | return tensor, lengths 64 | 65 | pad_val = 0 66 | text = MultipleOutputField(pad_val) 67 | 68 | datasets = {} 69 | for split, split_data in dset.items(): 70 | examples = [] 71 | for seq in split_data['sequences']: 72 | new_seq = torch.cat((torch.zeros(seq.shape[0], 1), seq), dim=1) 73 | examples.append(MultipleOutputExample(new_seq)) 74 | datasets[split] = data.Dataset(examples, [('text', text)]) 75 | 76 | train = datasets['train'] 77 | val = datasets['valid'] 78 | test = datasets['test'] 79 | 80 | vocab_size = 89 81 | 82 | return (train, val, test), pad_val, vocab_size 83 | 84 | def load_categorical(dataset, noT_condition_prior): 85 | unk_token = '' 86 | text = torchtext.data.Field(include_lengths=True, unk_token=unk_token, tokenize=(lambda s: list(s.strip()))) 87 | 88 | 89 | MAX_LEN = 288 90 | MIN_LEN = 1 91 | 92 | train, val, test = SentenceLanguageModelingDataset.splits(path='./data/%s/'%dataset, train='train.txt', validation='valid.txt', test='test.txt', text_field=text, 93 | include_eos=noT_condition_prior, filter_pred=lambda x: len(vars(x)['text']) <= MAX_LEN and len(vars(x)['text']) >= MIN_LEN) 94 | 95 | 96 | text.build_vocab(train) 97 | pad_val = text.vocab.stoi[''] 98 | 99 | vocab_size = len(text.vocab) 100 | 101 | return (train, val, test), pad_val, vocab_size 102 | 103 | # Utility functions 104 | # ------------------------------------------------------------------------------------------------------------------------------ 105 | 106 | def get_optimizer(name, parameters, lr): 107 | if name == 'adadelta': 108 | optimizer = torch.optim.Adadelta(parameters, lr=lr) 109 | elif name == 'adam': 110 | optimizer = torch.optim.Adam(parameters, lr=lr) 111 | elif name == 'sgd': 112 | optimizer = torch.optim.SGD(parameters, lr=lr) 113 | else: 114 | raise NotImplementedError('Only adadelta, adam, and sgd, have been implemented') 115 | 116 | return optimizer 117 | 118 | def log(args, log_str): 119 | with open(args.logdir+'summary.txt', 'a+') as f: 120 | f.write(log_str+'\n') 121 | 122 | def save(model, optimizer, args, name, current_epoch, best_val, lr): 123 | savedir = args.savedir+name 124 | os.makedirs(savedir, exist_ok=True) 125 | 126 | torch.save(model.state_dict(), savedir+'/model.pt') 127 | torch.save(optimizer.state_dict(), savedir+'/optimizer.pt') 128 | np.savez(savedir+'/misc.npz', current_epoch=current_epoch, best_val=best_val, current_lr=lr) 129 | 130 | def load(model, optimizer, args): 131 | model.load_state_dict(torch.load(args.load_dir+'/model.pt')) 132 | 133 | try: 134 | optimizer.load_state_dict(torch.load(args.load_dir+'/optimizer.pt')) 135 | misc_data = np.load(args.load_dir+'/misc.npz') 136 | current_epoch = misc_data['current_epoch'] 137 | best_val = misc_data['best_val'] 138 | current_lr = misc_data['current_lr'] 139 | except: 140 | print('Error loading optimizer state. Will continue anyway starting from beginning') 141 | current_epoch, best_val, current_lr = 0, 999999999, args.lr 142 | 143 | return current_epoch, best_val, current_lr 144 | 145 | def build_log_p_T(args, train, val): 146 | T_hist = torch.zeros(100000) 147 | max_T = 0 148 | for ex in train.examples+val.examples: 149 | ex_len = len(ex.text) 150 | T_hist[ex_len] += 1 151 | if ex_len > max_T: 152 | max_T = ex_len 153 | 154 | if args.indep_bernoulli: 155 | max_T = int(max_T*1.25) 156 | T_hist += 1 157 | 158 | T_hist = T_hist[:max_T+1] 159 | log_p_T = torch.log(T_hist/T_hist.sum()) 160 | 161 | return log_p_T, max_T 162 | 163 | def get_kl_weight(args, i): 164 | if args.initial_kl_zero == 0 and args.kl_rampup_time == 0: 165 | return 1.0, True 166 | 167 | x_start = args.initial_kl_zero 168 | x_end = args.initial_kl_zero + args.kl_rampup_time 169 | y_start = 0.00001 170 | y_end = 1.0 171 | done = False 172 | if i < x_start: 173 | cur_kl_weight = y_start 174 | elif i > x_end: 175 | cur_kl_weight = y_end 176 | done = True 177 | else: 178 | cur_kl_weight = (i-x_start)/(x_end-x_start)*(y_end-y_start) + y_start 179 | 180 | return cur_kl_weight, done 181 | 182 | # Model utility functions 183 | # ------------------------------------------------------------------------------------------------------------------------------ 184 | 185 | def make_pos_cond(T, B, lengths, max_T): 186 | device = lengths.device 187 | 188 | p_plus_int = torch.arange(T, device=device)[:, None].repeat(1, B)[:, :, None] 189 | p_plus_oh = torch.empty(T, B, max_T, device=device).zero_() 190 | p_plus_oh.scatter_(2, p_plus_int, 1) 191 | 192 | p_minus_int = lengths[None, :] - 1 - torch.arange(T, device=device)[:, None] 193 | p_minus_int[p_minus_int < 0] = max_T-1 194 | p_minus_oh = torch.empty(T, B, max_T, device=device).zero_() 195 | p_minus_oh.scatter_(2, p_minus_int[:, :, None], 1) 196 | 197 | pos_cond = torch.cat((p_plus_oh, p_minus_oh), -1) # [T, B, max_T*2] 198 | 199 | return pos_cond 200 | 201 | def reverse_padded_sequence(inputs, lengths, batch_first=False): 202 | if batch_first: 203 | inputs = inputs.transpose(0, 1) 204 | 205 | if inputs.size(1) != len(lengths): 206 | raise ValueError('inputs incompatible with lengths.') 207 | 208 | reversed_inputs = inputs.data.clone() 209 | for i, length in enumerate(lengths): 210 | time_ind = torch.LongTensor(list(reversed(range(length)))) 211 | reversed_inputs[:length, i] = inputs[:, i][time_ind] 212 | 213 | if batch_first: 214 | reversed_inputs = reversed_inputs.transpose(0, 1) 215 | 216 | return reversed_inputs 217 | 218 | --------------------------------------------------------------------------------