├── .gitignore ├── README.md ├── data └── MNIST │ └── .gitkeep ├── main.py ├── models ├── __init__.py ├── basetdvae.py └── tdvae │ ├── __init__.py │ └── tdvae.py ├── readers ├── __init__.py └── moving_mnist.py ├── results ├── 11_5.png └── 1_15.png ├── runners ├── __init__.py ├── basemnist.py └── tdvaerunner.py └── test_reader.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | processed/ 3 | raw/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TD-VAE 2 | 3 | TD-VAE implementation in PyTorch 1.0. 4 | 5 | This code implements the ideas presented in the paper [Temporal Difference Variational Auto-Encoder (Gregor et al)][2]. This implementation includes configurable number of stochastic layers as well as the specific multilayer RNN design proposed in the paper. 6 | 7 | **NOTE**: This implementation also makes use of [`pylego`][1], which is a minimal library to write easily extendable experimental machine learning code. 8 | 9 | ## Results 10 | 11 | Here are the results on Moving MNIST, where the context length is 11, and the last 5 images in each row are generated by jumpy prediction from the model. This is the same setting as the one presented in the paper. 12 | ![Figure with context 11, 5 predictions](./results/11_5.png) 13 | 14 | Here are the results on Moving MNIST, where the context length is 1, and the remaining 15 images in each row are generated by jumpy prediction from the model. This shows the model sampling a direction and jumpy transitioning through the states without having enough information in the context. 15 | ![Figure with context 1, 15 predictions](./results/1_15.png) 16 | 17 | [1]: https://github.com/ankitkv/pylego 18 | [2]: https://arxiv.org/abs/1806.03107 19 | -------------------------------------------------------------------------------- /data/MNIST/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/data/MNIST/.gitkeep -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from pylego.misc import add_argument as arg 5 | 6 | from runners.tdvaerunner import TDVAERunner 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | arg(parser, 'name', type=str, required=True, help='name of the experiment') 12 | arg(parser, 'model', type=str, default='tdvae.tdvae', help='model to use') 13 | arg(parser, 'cuda', type=bool, default=True, help='enable CUDA') 14 | arg(parser, 'load_file', type=str, default='', help='file to load model from') 15 | arg(parser, 'save_file', type=str, default='model.dat', help='model save file') 16 | arg(parser, 'save_every', type=int, default=500, help='save every these many global steps (-1 to disable saving)') 17 | arg(parser, 'data_path', type=str, default='data/MNIST') 18 | arg(parser, 'logs_path', type=str, default='logs') 19 | arg(parser, 'force_logs', type=bool, default=False) 20 | arg(parser, 'optimizer', type=str, default='adam', help='one of: adam') 21 | arg(parser, 'learning_rate', type=float, default=5e-4, help='-1 to use model default') 22 | arg(parser, 'grad_norm', type=float, default=5.0, help='gradient norm clipping (-1 to disable)') 23 | arg(parser, 'seq_len', type=int, default=20, help='sequence length') 24 | arg(parser, 'batch_size', type=int, default=64, help='batch size') 25 | arg(parser, 'samples_per_seq', type=int, default=1, help='(t1, t2) samples per input sequence') 26 | arg(parser, 'b_size', type=int, default=50, help='belief size') 27 | arg(parser, 'z_size', type=int, default=8, help='state size') 28 | arg(parser, 'layers', type=int, default=2, help='number of layers') 29 | arg(parser, 't_diff_min', type=int, default=1, help='minimum time difference t2-t1') 30 | arg(parser, 't_diff_max', type=int, default=4, help='maximum time difference t2-t1') 31 | arg(parser, 'epochs', type=int, default=50000, help='no. of training epochs') 32 | arg(parser, 'max_batches', type=int, default=-1, help='max batches per split (if not -1, for debugging)') 33 | arg(parser, 'print_every', type=int, default=100, help='print losses every these many steps') 34 | arg(parser, 'gpus', type=str, default='0') 35 | arg(parser, 'threads', type=int, default=-1, help='data processing threads (-1 to determine from CPUs)') 36 | arg(parser, 'debug', type=bool, default=False, help='run model in debug mode') 37 | arg(parser, 'visualize_every', type=int, default=-1, 38 | help='visualize during training every these many steps (-1 to disable)') 39 | arg(parser, 'visualize_only', type=bool, default=False, help='epoch visualize the loaded model and exit') 40 | arg(parser, 'visualize_split', type=str, default='test', help='split to visualize with visualize_only') 41 | flags = parser.parse_args() 42 | if flags.threads < 0: 43 | flags.threads = max(1, len(os.sched_getaffinity(0)) - 1) 44 | if flags.grad_norm < 0: 45 | flags.grad_norm = None 46 | 47 | iters = 0 48 | while True: 49 | if iters == 4: 50 | raise IOError("Too many retries, choose a different name.") 51 | flags.log_dir = '{}/{}'.format(flags.logs_path, flags.name) 52 | try: 53 | print('* Creating log dir', flags.log_dir) 54 | os.makedirs(flags.log_dir) 55 | break 56 | except IOError as e: 57 | if flags.force_logs: 58 | print('*', flags.log_dir, 'not recreated') 59 | break 60 | else: 61 | print('*', flags.log_dir, 'already exists') 62 | flags.name = flags.name + "_" 63 | iters += 1 64 | 65 | print('Arguments:', flags) 66 | if flags.visualize_only and not flags.load_file: 67 | print('! WARNING: visualize_only without load_file!') 68 | 69 | if flags.cuda: 70 | os.environ['CUDA_VISIBLE_DEVICES'] = flags.gpus 71 | 72 | flags.save_file = flags.log_dir + '/' + flags.save_file 73 | 74 | if flags.model.startswith('tdvae.'): 75 | runner = TDVAERunner 76 | runner(flags).run(visualize_only=flags.visualize_only, visualize_split=flags.visualize_split) 77 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/models/__init__.py -------------------------------------------------------------------------------- /models/basetdvae.py: -------------------------------------------------------------------------------- 1 | from pylego.model import Model 2 | 3 | 4 | class BaseTDVAE(Model): 5 | 6 | def __init__(self, model, flags, *args, **kwargs): 7 | self.flags = flags 8 | super().__init__(model=model, *args, **kwargs) 9 | -------------------------------------------------------------------------------- /models/tdvae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/models/tdvae/__init__.py -------------------------------------------------------------------------------- /models/tdvae/tdvae.py: -------------------------------------------------------------------------------- 1 | """Some parts adapted from the TD-VAE code by Xinqiang Ding .""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from pylego import ops 8 | 9 | from ..basetdvae import BaseTDVAE 10 | 11 | 12 | class DBlock(nn.Module): 13 | """ A basic building block for computing parameters of a normal distribution. 14 | Corresponds to D in the appendix.""" 15 | 16 | def __init__(self, input_size, hidden_size, output_size): 17 | super().__init__() 18 | self.fc1 = nn.Linear(input_size, hidden_size) 19 | self.fc2 = nn.Linear(input_size, hidden_size) 20 | self.fc_mu = nn.Linear(hidden_size, output_size) 21 | self.fc_logsigma = nn.Linear(hidden_size, output_size) 22 | 23 | def forward(self, input_): 24 | t = torch.tanh(self.fc1(input_)) 25 | t = t * torch.sigmoid(self.fc2(input_)) 26 | mu = self.fc_mu(t) 27 | logsigma = self.fc_logsigma(t) 28 | return mu, logsigma 29 | 30 | 31 | class PreProcess(nn.Module): 32 | """ The pre-process layer for MNIST image. 33 | """ 34 | 35 | def __init__(self, input_size, processed_x_size): 36 | super().__init__() 37 | self.fc1 = nn.Linear(input_size, processed_x_size) 38 | self.fc2 = nn.Linear(processed_x_size, processed_x_size) 39 | 40 | def forward(self, input_): 41 | t = torch.relu(self.fc1(input_)) 42 | t = torch.relu(self.fc2(t)) 43 | return t 44 | 45 | 46 | class Decoder(nn.Module): 47 | """ The decoder layer converting state to observation. 48 | """ 49 | 50 | def __init__(self, z_size, hidden_size, x_size): 51 | super().__init__() 52 | self.fc1 = nn.Linear(z_size, hidden_size) 53 | self.fc2 = nn.Linear(hidden_size, hidden_size) 54 | self.fc3 = nn.Linear(hidden_size, x_size) 55 | 56 | def forward(self, z): 57 | t = torch.tanh(self.fc1(z)) 58 | t = torch.tanh(self.fc2(t)) 59 | p = torch.sigmoid(self.fc3(t)) 60 | return p 61 | 62 | 63 | class TDVAE(nn.Module): 64 | """ The full TD-VAE model with jumpy prediction. 65 | """ 66 | 67 | def __init__(self, x_size, processed_x_size, b_size, z_size, layers, samples_per_seq, t_diff_min, t_diff_max): 68 | super().__init__() 69 | self.layers = layers 70 | self.samples_per_seq = samples_per_seq 71 | self.t_diff_min = t_diff_min 72 | self.t_diff_max = t_diff_max 73 | 74 | x_size = x_size 75 | processed_x_size = processed_x_size 76 | b_size = b_size 77 | z_size = z_size 78 | 79 | # Input pre-process layer 80 | self.process_x = PreProcess(x_size, processed_x_size) 81 | 82 | # Multilayer LSTM for aggregating belief states 83 | self.b_rnn = ops.MultilayerLSTM(input_size=processed_x_size, hidden_size=b_size, layers=layers, 84 | every_layer_input=True, use_previous_higher=True) 85 | 86 | # Multilayer state model is used. Sampling is done by sampling higher layers first. 87 | self.z_b = nn.ModuleList([DBlock(b_size + (z_size if layer < layers - 1 else 0), 50, z_size) 88 | for layer in range(layers)]) 89 | 90 | # Given belief and state at time t2, infer the state at time t1 91 | self.z1_z2_b1 = nn.ModuleList([DBlock(b_size + layers * z_size + (z_size if layer < layers - 1 else 0), 50, 92 | z_size) for layer in range(layers)]) 93 | 94 | # Given the state at time t1, model state at time t2 through state transition 95 | self.z2_z1 = nn.ModuleList([DBlock(layers * z_size + (z_size if layer < layers - 1 else 0), 50, z_size) 96 | for layer in range(layers)]) 97 | 98 | # state to observation 99 | self.x_z = Decoder(layers * z_size, 200, x_size) 100 | 101 | def forward(self, x): 102 | # sample t1 and t2 103 | t1 = torch.randint(0, x.size(1) - self.t_diff_max, (self.samples_per_seq, x.size(0)), device=x.device) 104 | t2 = t1 + torch.randint(self.t_diff_min, self.t_diff_max + 1, (self.samples_per_seq, x.size(0)), 105 | device=x.device) 106 | # x = x[:, :t2.max() + 1] # usually not required with big enough batch size 107 | 108 | # pre-process image x 109 | processed_x = self.process_x(x) # max x length is max(t2) + 1 110 | 111 | # aggregate the belief b 112 | b = self.b_rnn(processed_x) # size: bs, time, layers, dim 113 | 114 | # replicate b multiple times 115 | b = b[None, ...].expand(self.samples_per_seq, -1, -1, -1, -1) # size: copy, bs, time, layers, dim 116 | 117 | # Element-wise indexing. sizes: bs, layers, dim 118 | b1 = torch.gather(b, 2, t1[..., None, None, None].expand(-1, -1, -1, b.size(3), b.size(4))).view( 119 | -1, b.size(3), b.size(4)) 120 | b2 = torch.gather(b, 2, t2[..., None, None, None].expand(-1, -1, -1, b.size(3), b.size(4))).view( 121 | -1, b.size(3), b.size(4)) 122 | 123 | # q_B(z2 | b2) 124 | qb_z2_b2_mus, qb_z2_b2_logvars, qb_z2_b2s = [], [], [] 125 | for layer in range(self.layers - 1, -1, -1): 126 | if layer == self.layers - 1: 127 | qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](b2[:, layer]) 128 | else: 129 | qb_z2_b2_mu, qb_z2_b2_logvar = self.z_b[layer](torch.cat([b2[:, layer], qb_z2_b2], dim=1)) 130 | qb_z2_b2_mus.insert(0, qb_z2_b2_mu) 131 | qb_z2_b2_logvars.insert(0, qb_z2_b2_logvar) 132 | 133 | qb_z2_b2 = ops.reparameterize_gaussian(qb_z2_b2_mu, qb_z2_b2_logvar, self.training) 134 | qb_z2_b2s.insert(0, qb_z2_b2) 135 | 136 | qb_z2_b2_mu = torch.cat(qb_z2_b2_mus, dim=1) 137 | qb_z2_b2_logvar = torch.cat(qb_z2_b2_logvars, dim=1) 138 | qb_z2_b2 = torch.cat(qb_z2_b2s, dim=1) 139 | 140 | # q_S(z1 | z2, b1, b2) ~= q_S(z1 | z2, b1) 141 | qs_z1_z2_b1_mus, qs_z1_z2_b1_logvars, qs_z1_z2_b1s = [], [], [] 142 | for layer in range(self.layers - 1, -1, -1): # TODO optionally condition t2 - t1 143 | if layer == self.layers - 1: 144 | qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b1[layer](torch.cat([qb_z2_b2, b1[:, layer]], dim=1)) 145 | else: 146 | qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar = self.z1_z2_b1[layer](torch.cat([qb_z2_b2, b1[:, layer], 147 | qs_z1_z2_b1], dim=1)) 148 | qs_z1_z2_b1_mus.insert(0, qs_z1_z2_b1_mu) 149 | qs_z1_z2_b1_logvars.insert(0, qs_z1_z2_b1_logvar) 150 | 151 | qs_z1_z2_b1 = ops.reparameterize_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, self.training) 152 | qs_z1_z2_b1s.insert(0, qs_z1_z2_b1) 153 | 154 | qs_z1_z2_b1_mu = torch.cat(qs_z1_z2_b1_mus, dim=1) 155 | qs_z1_z2_b1_logvar = torch.cat(qs_z1_z2_b1_logvars, dim=1) 156 | qs_z1_z2_b1 = torch.cat(qs_z1_z2_b1s, dim=1) 157 | 158 | # p_T(z2 | z1), also conditions on q_B(z2) from higher layer 159 | pt_z2_z1_mus, pt_z2_z1_logvars = [], [] 160 | for layer in range(self.layers - 1, -1, -1): # TODO optionally condition t2 - t1 161 | if layer == self.layers - 1: 162 | pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](qs_z1_z2_b1) 163 | else: 164 | pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat([qs_z1_z2_b1, qb_z2_b2s[layer + 1]], dim=1)) 165 | pt_z2_z1_mus.insert(0, pt_z2_z1_mu) 166 | pt_z2_z1_logvars.insert(0, pt_z2_z1_logvar) 167 | 168 | pt_z2_z1_mu = torch.cat(pt_z2_z1_mus, dim=1) 169 | pt_z2_z1_logvar = torch.cat(pt_z2_z1_logvars, dim=1) 170 | 171 | # p_B(z1 | b1) 172 | pb_z1_b1_mus, pb_z1_b1_logvars = [], [] 173 | for layer in range(self.layers - 1, -1, -1): # TODO optionally condition t2 - t1 174 | if layer == self.layers - 1: 175 | pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](b1[:, layer]) 176 | else: 177 | pb_z1_b1_mu, pb_z1_b1_logvar = self.z_b[layer](torch.cat([b1[:, layer], qs_z1_z2_b1s[layer + 1]], 178 | dim=1)) 179 | pb_z1_b1_mus.insert(0, pb_z1_b1_mu) 180 | pb_z1_b1_logvars.insert(0, pb_z1_b1_logvar) 181 | 182 | pb_z1_b1_mu = torch.cat(pb_z1_b1_mus, dim=1) 183 | pb_z1_b1_logvar = torch.cat(pb_z1_b1_logvars, dim=1) 184 | 185 | # p_D(x2 | z2) 186 | pd_x2_z2 = self.x_z(qb_z2_b2) 187 | 188 | return (x, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, 189 | qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2) 190 | 191 | def visualize(self, x, t, n): 192 | # pre-process image x 193 | processed_x = self.process_x(x) # x length is t + 1 194 | 195 | # aggregate the belief b 196 | b = self.b_rnn(processed_x)[:, t] # size: bs, time, layers, dim 197 | 198 | # compute z from b 199 | p_z_bs = [] 200 | for layer in range(self.layers - 1, -1, -1): 201 | if layer == self.layers - 1: 202 | p_z_b_mu, p_z_b_logvar = self.z_b[layer](b[:, layer]) 203 | else: 204 | p_z_b_mu, p_z_b_logvar = self.z_b[layer](torch.cat([b[:, layer], p_z_b], dim=1)) 205 | p_z_b = ops.reparameterize_gaussian(p_z_b_mu, p_z_b_logvar, True) 206 | p_z_bs.insert(0, p_z_b) 207 | 208 | z = torch.cat(p_z_bs, dim=1) 209 | rollout_x = [] 210 | 211 | for _ in range(n): 212 | next_z = [] 213 | for layer in range(self.layers - 1, -1, -1): # TODO optionally condition n 214 | if layer == self.layers - 1: 215 | pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](z) 216 | else: 217 | pt_z2_z1_mu, pt_z2_z1_logvar = self.z2_z1[layer](torch.cat([z, pt_z2_z1], dim=1)) 218 | pt_z2_z1 = ops.reparameterize_gaussian(pt_z2_z1_mu, pt_z2_z1_logvar, True) 219 | next_z.insert(0, pt_z2_z1) 220 | 221 | z = torch.cat(next_z, dim=1) 222 | rollout_x.append(self.x_z(z)) 223 | 224 | return torch.stack(rollout_x, dim=1) 225 | 226 | 227 | class TDVAEModel(BaseTDVAE): 228 | 229 | def __init__(self, flags, *args, **kwargs): 230 | # XXX hardcoded for moving MNIST 231 | super().__init__(TDVAE(28 * 28, 28 * 28, flags.b_size, flags.z_size, flags.layers, flags.samples_per_seq, 232 | flags.t_diff_min, flags.t_diff_max), flags, *args, **kwargs) 233 | 234 | def loss_function(self, forward_ret, labels=None): 235 | (x, t2, qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar, qb_z2_b2_mu, qb_z2_b2_logvar, 236 | qb_z2_b2, pt_z2_z1_mu, pt_z2_z1_logvar, pd_x2_z2) = forward_ret 237 | 238 | # replicate x multiple times 239 | x = x[None, ...].expand(self.flags.samples_per_seq, -1, -1, -1) # size: copy, bs, time, dim 240 | x2 = torch.gather(x, 2, t2[..., None, None].expand(-1, -1, -1, x.size(3))).view(-1, x.size(3)) 241 | batch_size = x2.size(0) 242 | 243 | kl_div_qs_pb = ops.kl_div_gaussian(qs_z1_z2_b1_mu, qs_z1_z2_b1_logvar, pb_z1_b1_mu, pb_z1_b1_logvar).mean() 244 | 245 | kl_shift_qb_pt = (ops.gaussian_log_prob(qb_z2_b2_mu, qb_z2_b2_logvar, qb_z2_b2) - 246 | ops.gaussian_log_prob(pt_z2_z1_mu, pt_z2_z1_logvar, qb_z2_b2)).mean() 247 | 248 | bce = F.binary_cross_entropy(pd_x2_z2, x2, reduction='sum') / batch_size 249 | bce_optimal = F.binary_cross_entropy(x2, x2, reduction='sum').detach() / batch_size 250 | bce_diff = bce - bce_optimal 251 | 252 | loss = bce_diff + kl_div_qs_pb + kl_shift_qb_pt 253 | 254 | return loss, bce_diff, kl_div_qs_pb, kl_shift_qb_pt, bce_optimal 255 | -------------------------------------------------------------------------------- /readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/readers/__init__.py -------------------------------------------------------------------------------- /readers/moving_mnist.py: -------------------------------------------------------------------------------- 1 | """Some parts adapted from the TD-VAE code by Xinqiang Ding .""" 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils import data 6 | from torchvision import datasets 7 | 8 | from pylego.reader import DatasetReader 9 | 10 | 11 | class MovingMNISTDataset(datasets.MNIST): 12 | 13 | def __init__(self, data_path, train, seq_len, binarize): 14 | super().__init__(data_path, train=train, download=True) 15 | self.seq_len = seq_len 16 | self.binarize = binarize 17 | 18 | def __getitem__(self, index): 19 | image, _ = super().__getitem__(index) 20 | image = np.array(image) 21 | 22 | if self.binarize: 23 | tmp = np.random.rand(28, 28) * 255 24 | image = tmp <= image 25 | image = image.astype(np.float32) 26 | else: 27 | image = image.astype(np.float32) / 255.0 28 | 29 | # randomly choose a direction and generate a sequence of images that move in the chosen direction 30 | direction = np.random.choice(2) 31 | image = np.roll(image, np.random.randint(28), 1) # start with a random roll 32 | image_list = [image.reshape(-1)] 33 | for _ in range(1, self.seq_len): 34 | if direction: 35 | image = np.roll(image, -1, 1) 36 | image_list.append(image.reshape(-1)) 37 | else: 38 | image = np.roll(image, 1, 1) 39 | image_list.append(image.reshape(-1)) 40 | 41 | return np.array(image_list) 42 | 43 | 44 | class MovingMNISTReader(DatasetReader): 45 | 46 | def __init__(self, data_path, seq_len=20, binarize=False): 47 | train_dataset = MovingMNISTDataset(data_path, True, seq_len, binarize) 48 | test_dataset = MovingMNISTDataset(data_path, False, seq_len, binarize) 49 | 50 | val_size = int(0.1 * len(train_dataset)) 51 | train_size = len(train_dataset) - val_size 52 | torch.manual_seed(0) 53 | train_dataset, val_dataset = data.random_split(train_dataset, [train_size, val_size]) 54 | super().__init__({'train': train_dataset, 'val': val_dataset, 'test': test_dataset}) 55 | -------------------------------------------------------------------------------- /results/11_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/results/11_5.png -------------------------------------------------------------------------------- /results/1_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/results/1_15.png -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankitkv/TD-VAE/36eb6c759fd22f455312f84debd95dfab5cf1754/runners/__init__.py -------------------------------------------------------------------------------- /runners/basemnist.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from readers.moving_mnist import MovingMNISTReader 4 | from pylego import misc, runner 5 | 6 | 7 | class MovingMNISTBaseRunner(runner.Runner): 8 | 9 | def __init__(self, flags, model_class, log_keys, *args, **kwargs): 10 | self.flags = flags 11 | reader = MovingMNISTReader(flags.data_path, seq_len=flags.seq_len) 12 | summary_dir = flags.log_dir + '/summary' 13 | super().__init__(reader, flags.batch_size, flags.epochs, summary_dir, log_keys=log_keys, 14 | threads=flags.threads, print_every=flags.print_every, visualize_every=flags.visualize_every, 15 | max_batches=flags.max_batches, *args, **kwargs) 16 | model_class = misc.get_subclass(importlib.import_module('models.' + self.flags.model), model_class) 17 | self.model = model_class(self.flags, optimizer=flags.optimizer, learning_rate=flags.learning_rate, 18 | cuda=flags.cuda, load_file=flags.load_file, save_every=flags.save_every, 19 | save_file=flags.save_file, debug=flags.debug) 20 | -------------------------------------------------------------------------------- /runners/tdvaerunner.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | from pylego import misc 6 | 7 | from models.basetdvae import BaseTDVAE 8 | from .basemnist import MovingMNISTBaseRunner 9 | 10 | 11 | class TDVAERunner(MovingMNISTBaseRunner): 12 | 13 | def __init__(self, flags, *args, **kwargs): 14 | super().__init__(flags, BaseTDVAE, ['loss', 'bce_diff', 'kl_div_qs_pb', 'kl_shift_qb_pt']) 15 | 16 | def run_batch(self, batch, train=False): 17 | batch = self.model.prepare_batch(batch) 18 | loss, bce_diff, kl_div_qs_pb, kl_shift_qb_pt, bce_optimal = self.model.run_loss(batch) 19 | if train: 20 | self.model.train(loss, clip_grad_norm=self.flags.grad_norm) 21 | 22 | return collections.OrderedDict([('loss', loss.item()), 23 | ('bce_diff', bce_diff.item()), 24 | ('kl_div_qs_pb', kl_div_qs_pb.item()), 25 | ('kl_shift_qb_pt', kl_shift_qb_pt.item()), 26 | ('bce_optimal', bce_optimal.item())]) 27 | 28 | def _visualize_split(self, split, t, n): 29 | bs = min(self.batch_size, 16) 30 | batch = next(self.reader.iter_batches(split, bs, shuffle=True, partial_batching=True, threads=self.threads, 31 | max_batches=1)) 32 | batch = self.model.prepare_batch(batch[:, :t + 1]) 33 | out = self.model.run_batch([batch, t, n], visualize=True) 34 | 35 | batch = batch.cpu().numpy() 36 | out = out.cpu().numpy() 37 | vis_data = np.concatenate([batch, out], axis=1) 38 | bs, seq_len = vis_data.shape[:2] 39 | return vis_data.reshape([bs * seq_len, 1, 28, 28]), (bs, seq_len) 40 | 41 | def post_epoch_visualize(self, epoch, split): 42 | if split != 'train': 43 | print('* Visualizing', split) 44 | vis_data, rows_cols = self._visualize_split(split, min(self.flags.seq_len - 1, 10), 5) 45 | if split == 'test': 46 | fname = self.flags.log_dir + '/test.png' 47 | else: 48 | fname = self.flags.log_dir + '/val%03d.png' % epoch 49 | misc.save_comparison_grid(fname, vis_data, border_shade=1.0, rows_cols=rows_cols, retain_sequence=True) 50 | print('* Visualizations saved to', fname) 51 | 52 | if split == 'test': 53 | print('* Generating more visualizations for', split) 54 | vis_data, rows_cols = self._visualize_split(split, 0, 15) 55 | fname = self.flags.log_dir + '/test_more.png' 56 | misc.save_comparison_grid(fname, vis_data, border_shade=1.0, rows_cols=rows_cols, retain_sequence=True) 57 | print('* More visualizations saved to', fname) 58 | -------------------------------------------------------------------------------- /test_reader.py: -------------------------------------------------------------------------------- 1 | from pylego.misc import save_comparison_grid 2 | 3 | from readers.moving_mnist import MovingMNISTReader 4 | 5 | 6 | if __name__ == '__main__': 7 | reader = MovingMNISTReader('data/MNIST') 8 | for i, batch in enumerate(reader.iter_batches('train', 4, max_batches=5)): 9 | print(batch.size()) 10 | if i < 3: 11 | batch = batch[0].numpy().reshape(20, 1, 28, 28) 12 | save_comparison_grid('seq%d.png' % i, batch) 13 | --------------------------------------------------------------------------------