├── .gitignore ├── LICENSE ├── README.md ├── _readme_imgs ├── multi_dsprites_recons.png ├── multi_mnist_recons.png └── original_multimnist_recons.png ├── data ├── multi-dsprites-binary-rgb │ └── multi_dsprites_color_012.npz └── multi_mnist │ ├── multi_binary_mnist_012.npz │ └── multi_mnist_pyro.npz ├── datasets.py ├── experiments ├── __init__.py └── air_experiment │ ├── __init__.py │ ├── data.py │ └── experiment_manager.py ├── main.py ├── models ├── __init__.py └── air.py ├── requirements.txt └── utils ├── __init__.py ├── misc.py └── spatial_transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | __pycache__ 4 | *.pyc 5 | checkpoints 6 | data 7 | results 8 | tensorboard_logs 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Andrea Dittadi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attend Infer Repeat 2 | 3 | Attend Infer Repeat (AIR) [1] in PyTorch. Parts of this implementation are 4 | inspired by [this Pyro tutorial](https://pyro.ai/examples/air.html) [2] and 5 | [this blog post](http://akosiorek.github.io/ml/2017/09/03/implementing-air.html) [3]. 6 | See [below](#high-level-model-description) for a description of the model. 7 | 8 | Install requirements and run: 9 | ``` 10 | pip install -r requirements.txt 11 | CUDA_VISIBLE_DEVICES=0 python main.py 12 | ``` 13 | 14 | 15 | 16 | ## Results 17 | 18 | ### Original multi-MNIST dataset 19 | 20 | Multi-MNIST results are often (about 80% of the time) reproduced with this implementation. 21 | In the other cases, the model converges to a local maximum of the ELBO where the 22 | recurrent attention often predicts more objects than the ground truth (either 23 | using more objects to model one digit, or inferring blank objects). 24 | 25 | | dataset | likelihood | accuracy | ELBO | log _p(x)_ ≥
[100 iws] | 26 | | -------------------- |:---------------------------:|:------------:|:-----------:|:-------------------:| 27 | | original multi-MNIST | N(_f(z)_, 0.32) | 98.3 ± 0.15 % | 627.4 ± 0.7 | 636.8 ± 0.8 | 28 | 29 | where: 30 | - mean and std do not include the bad runs 31 | - the likelihood is a Gaussian with fixed (and large!) variance [1] 32 | - the last column is a tighter log likelihood lower bound than the ELBO, and iws 33 | stands for importance-weighted samples [4] 34 | 35 | 36 | Reconstructions with inferred bounding boxes for one of the good runs: 37 | 38 | ![Reconstruction on original multi-MNIST](_readme_imgs/original_multimnist_recons.png) 39 | 40 | 41 | 42 | ### Smaller objects 43 | 44 | Preliminary results on multi-MNIST and multi-dSprites, with larger images (64x64) 45 | and smaller objects (object patches range from 9x9 to 18x18). The prior on object 46 | scale was updated to reflect the smaller patch size. This is harder 47 | for the model to learn, especially correct inference of z_pres. On the 48 | dSprites dataset only 1/4 runs are successful. Examples of successful runs below. 49 | 50 | ![Reconstruction on multi-MNIST](_readme_imgs/multi_mnist_recons.png) 51 | 52 | ![Reconstruction on multi-dSprites](_readme_imgs/multi_dsprites_recons.png) 53 | 54 | 55 | 56 | ## Implementation notes 57 | 58 | - In [1] the prior probability for presence is annealed from almost 1 to 1e-5 59 | or less [3], whereas here it is fixed to 0.01 as suggested in [2]. 60 | - This means that the generative model does not make much sense, since an image 61 | sampled from the model will be empty with probability 0.99. 62 | - We can still learn the presence prior (after convergence, otherwise it doesn't 63 | work), which in this case converges to approximately 0.5. Results soon. 64 | - The other defaults are as in the paper [1]. 65 | - In [1] the likelihood _p(x | z)_ is a Gaussian with fixed variance, but it has 66 | to be Bernoulli for binary data. 67 | - This means that, unlike in [1], the decoder output must be in the interval 68 | [0, 1]. Clamping is the simplest solution but it doesn't work quite well. 69 | 70 | 71 | 72 | ## High-level model description 73 | 74 | Attend Infer Repeat (AIR) [1] is a structured generative model of visual scenes, 75 | that attempts to explain such scenes as compositions of discrete objects. Each 76 | object is described by its (binary) presence, location/scale, and appearance. 77 | The model is invariant to object permutations. 78 | 79 | The intractable inference of the posterior _p(z | x)_ is approximated through 80 | (stochastic, amortized) variational inference. An iterative (recurrent) 81 | algorithm infers whether there is another object to be explained (i.e. the next 82 | object has presence=1) and, if so, it infers its location/scale and appearance. 83 | Appearance is inferred from a patch of the input image given by the inferred 84 | location/scale. 85 | 86 | ## Requirements 87 | ``` 88 | python 3.7.6 89 | numpy 1.18.1 90 | torch 1.4.0 91 | torchvision 0.5.0 92 | matplotlib 3.1.2 93 | tqdm 4.41.1 94 | boilr 0.6.0 95 | multiobject 0.0.3 96 | ``` 97 | 98 | ## References 99 | 100 | [1] A Eslami, 101 | N Heess, 102 | T Weber, 103 | Y Tassa, 104 | D Szepesvari, 105 | K Kavukcuoglu, 106 | G Hinton. 107 | _Attend, Infer, Repeat: Fast Scene Understanding with Generative Models_, NIPS 2016 108 | 109 | [2] [https://pyro.ai/examples/air.html](https://pyro.ai/examples/air.html) 110 | 111 | [3] [http://akosiorek.github.io/ml/2017/09/03/implementing-air.html](http://akosiorek.github.io/ml/2017/09/03/implementing-air.html) 112 | 113 | [4] Y Burda, RB Grosse, R Salakhutdinov. 114 | _Importance Weighted Autoencoders_, 115 | ICLR 2016 116 | -------------------------------------------------------------------------------- /_readme_imgs/multi_dsprites_recons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/multi_dsprites_recons.png -------------------------------------------------------------------------------- /_readme_imgs/multi_mnist_recons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/multi_mnist_recons.png -------------------------------------------------------------------------------- /_readme_imgs/original_multimnist_recons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/original_multimnist_recons.png -------------------------------------------------------------------------------- /data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz -------------------------------------------------------------------------------- /data/multi_mnist/multi_binary_mnist_012.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi_mnist/multi_binary_mnist_012.npz -------------------------------------------------------------------------------- /data/multi_mnist/multi_mnist_pyro.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi_mnist/multi_mnist_pyro.npz -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class PyroMultiMNIST(Dataset): 7 | def __init__(self, path, train): 8 | self.path = path 9 | self.train = train 10 | data = np.load(path, allow_pickle=True) 11 | x = data['x'] 12 | y = data['y'] 13 | split = 50000 14 | if train: 15 | self.x, self.y = x[:split], y[:split] 16 | else: 17 | self.x, self.y = x[split:], y[split:] 18 | 19 | def __getitem__(self, index): 20 | """ 21 | Returns (x, y), where x is (1, H, W) in range (0, 1), 22 | y is a label dict with only a 'n_obj' key. 23 | """ 24 | # x: uint8, (1, H, W) 25 | # y: label dict 26 | x, y = self.x[index], self.y[index] 27 | y = np.array(len(y)) 28 | x = x / 255.0 29 | 30 | x, y = torch.from_numpy(x).float(), torch.from_numpy(y).float() 31 | x = x[None] 32 | y = {'n_obj': y} # label dict: compatible with multiobject dataloader 33 | 34 | return x, y 35 | 36 | 37 | def __len__(self): 38 | return len(self.x) 39 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .air_experiment.experiment_manager import AIRExperiment 2 | -------------------------------------------------------------------------------- /experiments/air_experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/experiments/air_experiment/__init__.py -------------------------------------------------------------------------------- /experiments/air_experiment/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from multiobject.pytorch import MultiObjectDataset, MultiObjectDataLoader 4 | 5 | from datasets import PyroMultiMNIST 6 | 7 | multiobject_paths = { 8 | 'multi_mnist_binary': './data/multi_mnist/multi_binary_mnist_012.npz', 9 | 'multi_dsprites_binary_rgb': './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz', 10 | } 11 | multiobject_datasets = multiobject_paths.keys() 12 | pyro_mnist_path = os.path.join('data', 'multi_mnist', 'multi_mnist_pyro.npz') 13 | 14 | 15 | class DatasetLoader: 16 | """ 17 | Wrapper for DataLoaders. Data attributes: 18 | - train: DataLoader object for training set 19 | - test: DataLoader object for test set 20 | - data_shape: shape of each data point (channels, height, width) 21 | - img_size: spatial dimensions of each data point (height, width) 22 | - color_ch: number of color channels 23 | """ 24 | 25 | def __init__(self, args, cuda): 26 | 27 | # Default arguments for dataloaders 28 | kwargs = {'num_workers': 1, 'pin_memory': False} if cuda else {} 29 | 30 | # Define training and test set 31 | if args.dataset_name == 'pyro_multi_mnist': 32 | train_set = PyroMultiMNIST(pyro_mnist_path, train=True) 33 | test_set = PyroMultiMNIST(pyro_mnist_path, train=False) 34 | elif args.dataset_name in multiobject_datasets: 35 | data_path = multiobject_paths[args.dataset_name] 36 | train_set = MultiObjectDataset(data_path, train=True) 37 | test_set = MultiObjectDataset(data_path, train=False) 38 | else: 39 | raise RuntimeError("Unrecognized data set '{}'".format(args.dataset_name)) 40 | 41 | # Dataloaders 42 | self.train = MultiObjectDataLoader( 43 | train_set, 44 | batch_size=args.batch_size, 45 | shuffle=True, 46 | drop_last=True, 47 | **kwargs 48 | ) 49 | self.test = MultiObjectDataLoader( 50 | test_set, 51 | batch_size=args.test_batch_size, 52 | shuffle=False, 53 | **kwargs 54 | ) 55 | 56 | self.data_shape = self.train.dataset[0][0].size() 57 | self.img_size = self.data_shape[1:] 58 | self.color_ch = self.data_shape[0] 59 | -------------------------------------------------------------------------------- /experiments/air_experiment/experiment_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from boilr import VIExperimentManager 7 | from boilr.viz import img_grid_pad_value 8 | from torch import optim 9 | from torchvision.utils import save_image 10 | 11 | from models.air import AIR 12 | from utils.spatial_transform import batch_add_bounding_boxes 13 | from .data import DatasetLoader 14 | 15 | 16 | class AIRExperiment(VIExperimentManager): 17 | """ 18 | Experiment manager. 19 | 20 | Data attributes: 21 | - 'args': argparse.Namespace containing all config parameters. When 22 | initializing this object, if 'args' is not given, all config 23 | parameters are set based on experiment defaults and user input, using 24 | argparse. 25 | - 'run_description': string description of the run that includes a timestamp 26 | and can be used e.g. as folder name for logging. 27 | - 'model': the model. 28 | - 'device': torch.device that is being used 29 | - 'dataloaders': DataLoaders, with attributes 'train' and 'test' 30 | - 'optimizer': the optimizer 31 | """ 32 | 33 | def make_datamanager(self): 34 | cuda = self.device.type == 'cuda' 35 | return DatasetLoader(self.args, cuda) 36 | 37 | def make_model(self): 38 | args = self.args 39 | 40 | obj_size = { 41 | 'pyro_multi_mnist': 28, 42 | 'multi_mnist_binary': 18, 43 | 'multi_dsprites_binary_rgb': 18, 44 | }[args.dataset_name] 45 | 46 | scale_prior_mean = { 47 | 'pyro_multi_mnist': 3., 48 | 'multi_mnist_binary': 4.5, 49 | 'multi_dsprites_binary_rgb': 4.5, 50 | }[args.dataset_name] 51 | 52 | model = AIR( 53 | img_size=self.dataloaders.img_size[0], # assume h=w 54 | color_channels=self.dataloaders.color_ch, 55 | object_size=obj_size, 56 | max_steps=3, 57 | likelihood=args.likelihood, 58 | scale_prior_mean=scale_prior_mean, 59 | ) 60 | return model 61 | 62 | def make_optimizer(self): 63 | args = self.args 64 | optimizer = optim.Adam([ 65 | { 66 | 'params': self.model.air_params(), 67 | 'lr': args.lr, 68 | 'weight_decay': args.weight_decay, 69 | }, 70 | { 71 | 'params': self.model.baseline_params(), 72 | 'lr': args.bl_lr, 73 | 'weight_decay': args.weight_decay, 74 | }, 75 | ]) 76 | return optimizer 77 | 78 | 79 | 80 | def forward_pass(self, x, y=None): 81 | """ 82 | Simple single-pass model evaluation. It consists of a forward pass 83 | and computation of all necessary losses and metrics. 84 | """ 85 | 86 | # Forward pass 87 | x = x.to(self.device, non_blocking=True) 88 | out = self.model(x) 89 | 90 | elbo_sep = out['elbo_sep'] 91 | bl_target = out['baseline_target'] 92 | bl_value = out['baseline_value'] 93 | data_likelihood_sep = out['data_likelihood'] 94 | z_pres_likelihood = out['z_pres_likelihood'] 95 | mask = out['mask_prev'] 96 | 97 | # The baseline target is: 98 | # sum_{i=t}^T KL[i] - log p(x | z) 99 | # for all steps up to (and including) the first z_pres=0 100 | bl_target = bl_target - data_likelihood_sep[:, None] 101 | bl_target = bl_target * mask # (B, T) 102 | 103 | # The "REINFORCE" term in the gradient is: 104 | # (baseline_target - baseline_value) * gradient[z_pres_likelihood] 105 | reinforce_term = ((bl_target - bl_value).detach() * z_pres_likelihood) 106 | reinforce_term = reinforce_term * mask 107 | reinforce_term = reinforce_term.sum(1) # (B, ) 108 | 109 | # Maximize ELBO with additional REINFORCE term for discrete variables 110 | model_loss = reinforce_term - elbo_sep # (B, ) 111 | model_loss = model_loss.mean() # mean over batch 112 | 113 | # MSE as baseline loss 114 | baseline_loss = F.mse_loss(bl_value, bl_target.detach(), reduction='none') 115 | baseline_loss = baseline_loss * mask 116 | baseline_loss = baseline_loss.sum(1).mean() # mean over batch 117 | 118 | loss = model_loss + baseline_loss 119 | out['loss'] = loss 120 | 121 | # L2 122 | l2 = 0.0 123 | for p in self.model.parameters(): 124 | l2 = l2 + torch.sum(p ** 2) 125 | l2 = l2.sqrt() 126 | out['l2'] = l2 127 | 128 | # Accuracy 129 | out['accuracy'] = None 130 | if y is not None: 131 | n_obj = y['n_obj'].to(self.device) 132 | n_pred = out['inferred_n'] # (B, ) 133 | correct = (n_pred == n_obj).float().sum() 134 | acc = correct / n_pred.size(0) 135 | out['accuracy'] = acc 136 | 137 | # TODO Only for viz, as std=0.3 is pretty high so samples are not good 138 | out['out_sample'] = out['out_mean'] # this is actually NOT a sample! 139 | 140 | return out 141 | 142 | 143 | @staticmethod 144 | def print_train_log(step, epoch, summaries): 145 | s = (" [step {step}] loss: {loss:.5g} ELBO: {elbo:.5g} " 146 | "recons: {recons:.3g} KL: {kl:.3g} acc: {acc:.3g}") 147 | s = s.format( 148 | step=step, 149 | loss=summaries['loss/loss'], 150 | elbo=summaries['elbo/elbo'], 151 | recons=summaries['elbo/recons'], 152 | kl=summaries['elbo/kl'], 153 | acc=summaries['accuracy']) 154 | print(s) 155 | 156 | 157 | @staticmethod 158 | def print_test_log(summaries, step=None, epoch=None): 159 | log_string = " " 160 | if epoch is not None: 161 | log_string += "[step {}, epoch {}] ".format(step, epoch) 162 | s = "ELBO {elbo:.5g} recons: {recons:.3g} KL: {kl:.3g} acc: {acc:.3g}" 163 | log_string += s.format( 164 | elbo=summaries['elbo/elbo'], 165 | recons=summaries['elbo/recons'], 166 | kl=summaries['elbo/kl'], 167 | acc=summaries['accuracy']) 168 | ll_key = None 169 | for k in summaries.keys(): 170 | if k.find('elbo_IW') > -1: 171 | ll_key = k 172 | iw_samples = k.split('_')[-1] 173 | break 174 | if ll_key is not None: 175 | log_string += " marginal log-likelihood ({}) {:.5g}".format( 176 | iw_samples, summaries[ll_key]) 177 | 178 | print(log_string) 179 | 180 | 181 | @staticmethod 182 | def get_metrics_dict(results): 183 | metrics_dict = { 184 | 'loss/loss': results['loss'].item(), 185 | 'elbo/elbo': results['elbo'].item(), 186 | 'elbo/recons': results['recons'].item(), 187 | 'elbo/kl': results['kl'].item(), 188 | 'l2/l2': results['l2'].item(), 189 | 'accuracy': results['accuracy'].item(), 190 | 191 | 'kl/pres': results['kl_pres'].item(), 192 | 'kl/what': results['kl_what'].item(), 193 | 'kl/where': results['kl_where'].item(), 194 | } 195 | return metrics_dict 196 | 197 | 198 | 199 | def additional_testing(self, img_folder): 200 | """ 201 | Perform additional testing, including possibly generating images. 202 | 203 | In this case, save samples from the generative model, and pairs 204 | input/reconstruction from the test set. 205 | 206 | :param img_folder: folder to store images 207 | """ 208 | 209 | step = self.model.global_step 210 | 211 | if not self.args.dry_run: 212 | 213 | # Saved images will have n**2 sub-images 214 | n = 8 215 | 216 | # Save model samples 217 | sample, zwhere, n_obj = self.model.sample_prior(n ** 2) 218 | annotated_sample = batch_add_bounding_boxes(sample, zwhere, n_obj) 219 | fname = os.path.join(img_folder, 'sample_' + str(step) + '.png') 220 | pad = img_grid_pad_value(annotated_sample) 221 | save_image(annotated_sample, fname, nrow=n, pad_value=pad) 222 | 223 | # Get first test batch 224 | (x, _) = next(iter(self.dataloaders.test)) 225 | fname = os.path.join(img_folder, 'reconstruction_' + str(step) + '.png') 226 | 227 | # Save model original/reconstructions 228 | self.save_input_and_recons(x, fname, n) 229 | 230 | 231 | def save_input_and_recons(self, x, fname, n): 232 | n_img = n ** 2 // 2 233 | if x.shape[0] < n_img: 234 | msg = ("{} data points required, but given batch has size {}. " 235 | "Please use a larger batch.".format(n_img, x.shape[0])) 236 | raise RuntimeError(msg) 237 | x = x.to(self.device) 238 | outputs = self.forward_pass(x) 239 | x = x[:n_img] 240 | if x.shape[1] == 1: 241 | x = x.expand(-1, 3, -1, -1) 242 | recons = outputs['out_sample'][:n_img] 243 | z_where = outputs['all_z_where'][:n_img] 244 | pred_count = outputs['inferred_n'] 245 | recons = batch_add_bounding_boxes(recons, z_where, pred_count) 246 | imgs = torch.stack([x.cpu(), recons.cpu()]) 247 | imgs = imgs.permute(1, 0, 2, 3, 4) 248 | imgs = imgs.reshape(n ** 2, x.size(1), x.size(2), x.size(3)) 249 | pad = img_grid_pad_value(imgs) 250 | save_image(imgs, fname, nrow=n, pad_value=pad) 251 | 252 | 253 | def _parse_args(self): 254 | """ 255 | Parse command-line arguments defining experiment settings. 256 | 257 | :return: args: argparse.Namespace with experiment settings 258 | """ 259 | 260 | def list_options(lst): 261 | if lst: 262 | return "'" + "' | '".join(lst) + "'" 263 | return "" 264 | 265 | legal_datasets = [ 266 | 'pyro_multi_mnist', 267 | 'multi_mnist_binary', 268 | 'multi_dsprites_binary_rgb'] 269 | legal_likelihoods = ['bernoulli', 'original'] 270 | 271 | parser = argparse.ArgumentParser( 272 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 273 | allow_abbrev=False) 274 | 275 | self.add_required_args(parser, 276 | 277 | # General 278 | batch_size=64, 279 | test_batch_size=2000, 280 | lr=1e-4, 281 | seed=54321, 282 | log_interval=50000, 283 | test_log_interval=50000, 284 | checkpoint_interval=100000, 285 | resume="", 286 | 287 | # VI-specific 288 | ll_every=50000, 289 | loglik_samples=100,) 290 | 291 | parser.add_argument('-d', '--dataset', 292 | type=str, 293 | choices=legal_datasets, 294 | default='pyro_multi_mnist', 295 | metavar='NAME', 296 | dest='dataset_name', 297 | help="dataset: " + list_options(legal_datasets)) 298 | 299 | parser.add_argument('--likelihood', 300 | type=str, 301 | choices=legal_likelihoods, 302 | metavar='NAME', 303 | dest='likelihood', 304 | help="likelihood: {}; default depends on dataset".format( 305 | list_options(legal_likelihoods))) 306 | 307 | parser.add_argument('--bl-lr', 308 | type=float, 309 | default=1e-1, 310 | metavar='LR', 311 | help="baseline's learning rate") 312 | 313 | parser.add_argument('--wd', 314 | type=float, 315 | default=0.0, 316 | dest='weight_decay', 317 | help='weight decay') 318 | 319 | args = parser.parse_args() 320 | 321 | assert args.loglik_interval % args.test_log_interval == 0 322 | 323 | if args.likelihood is None: # defaults 324 | args.likelihood = { 325 | 'pyro_multi_mnist': 'original', 326 | 'multi_mnist_binary': 'original', # 'bernoulli', 327 | 'multi_dsprites_binary_rgb': 'original', # 'bernoulli', 328 | }[args.dataset_name] 329 | 330 | return args 331 | 332 | @staticmethod 333 | def _make_run_description(args): 334 | """ 335 | Create a string description of the run. It is used in the names of the 336 | logging folders. 337 | 338 | :param args: experiment config 339 | :return: the run description 340 | """ 341 | s = '' 342 | s += args.dataset_name 343 | s += ',seed{}'.format(args.seed) 344 | if len(args.additional_descr) > 0: 345 | s += ',' + args.additional_descr 346 | return s 347 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import boilr 4 | import multiobject 5 | from boilr import Trainer 6 | 7 | from experiments import AIRExperiment 8 | 9 | 10 | def _check_version(pkg, pkg_str, version_info): 11 | def to_str(v): 12 | return ".".join(str(x) for x in v) 13 | if pkg.__version_info__[:2] != version_info[:2]: 14 | msg = "This was last tested with {} {}, but the current version is {}" 15 | msg = msg.format(pkg_str, to_str(version_info), pkg.__version__) 16 | warnings.warn(msg) 17 | 18 | BOILR_VERSION = (0, 5, 1) 19 | MULTIOBJ_VERSION = (0, 0, 3) 20 | _check_version(boilr, 'boilr', BOILR_VERSION) 21 | _check_version(multiobject, 'multiobject', MULTIOBJ_VERSION) 22 | 23 | def main(): 24 | experiment = AIRExperiment() 25 | trainer = Trainer(experiment) 26 | trainer.run() 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/models/__init__.py -------------------------------------------------------------------------------- /models/air.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from boilr import BaseGenerativeModel 6 | from torch import nn 7 | from torch.distributions.bernoulli import Bernoulli 8 | from torch.distributions.kl import kl_divergence 9 | from torch.distributions.normal import Normal 10 | from torch.nn import LSTMCell 11 | 12 | from utils import nograd_param 13 | from utils.spatial_transform import SpatialTransformer 14 | 15 | State = namedtuple( 16 | 'State', 17 | ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what']) 18 | 19 | 20 | class Predictor(nn.Module): 21 | """ 22 | Infer presence and location from LSTM hidden state 23 | """ 24 | def __init__(self, lstm_hidden_dim): 25 | nn.Module.__init__(self) 26 | self.seq = nn.Sequential( 27 | nn.Linear(lstm_hidden_dim, 200), 28 | nn.ReLU(), 29 | nn.Linear(200, 7), 30 | ) 31 | 32 | def forward(self, h): 33 | z = self.seq(h) 34 | z_pres_p = torch.sigmoid(z[:, :1]) 35 | z_where_loc = z[:, 1:4] 36 | z_where_scale = F.softplus(z[:, 4:]) 37 | return z_pres_p, z_where_loc, z_where_scale 38 | 39 | 40 | class AppearanceEncoder(nn.Module): 41 | """ 42 | Infer object appearance latent z_what given an image crop around the object 43 | """ 44 | def __init__(self, object_size, color_channels, encoder_hidden_dim, z_what_dim): 45 | super().__init__() 46 | object_numel = color_channels * (object_size ** 2) 47 | self.net = nn.Sequential( 48 | nn.Linear(object_numel, encoder_hidden_dim), 49 | nn.ReLU(), 50 | nn.Linear(encoder_hidden_dim, z_what_dim * 2) 51 | ) 52 | 53 | def forward(self, crop): 54 | """ 55 | :param crop: (B, C, H, W) 56 | :return: z_what_loc, z_what_scale 57 | """ 58 | bs = crop.size(0) 59 | crop_flat = crop.view(bs, -1) 60 | x = self.net(crop_flat) 61 | z_what_loc, z_what_scale = x.chunk(2, dim=1) 62 | z_what_scale = F.softplus(z_what_scale) 63 | 64 | return z_what_loc, z_what_scale 65 | 66 | class AppearanceDecoder(nn.Module): 67 | """ 68 | Generate pixel representation of an object given its latent code z_what 69 | """ 70 | def __init__(self, z_what_dim, decoder_hidden_dim, 71 | object_size, color_channels, bias=-2.0): 72 | super().__init__() 73 | object_numel = color_channels * (object_size ** 2) 74 | self.net = nn.Sequential( 75 | nn.Linear(z_what_dim, decoder_hidden_dim), 76 | nn.ReLU(), 77 | nn.Linear(decoder_hidden_dim, object_numel), 78 | ) 79 | self.sz = object_size 80 | self.ch = color_channels 81 | self.bias = bias 82 | 83 | def forward(self, z_what): 84 | x = self.net(z_what) 85 | x = x.view(-1, self.ch, self.sz, self.sz) 86 | x = torch.sigmoid(x + self.bias) 87 | return x 88 | 89 | 90 | class AIR(BaseGenerativeModel): 91 | """ 92 | AIR model. Default settings are from the pyro tutorial. With those settings 93 | we can reproduce results from the original paper (although about 1/10 times 94 | it doesn't converge to a good solution). 95 | """ 96 | 97 | z_where_dim = 3 98 | z_pres_dim = 1 99 | 100 | def __init__(self, 101 | img_size, 102 | object_size, 103 | max_steps, 104 | color_channels, 105 | likelihood=None, 106 | z_what_dim=50, 107 | lstm_hidden_dim=256, 108 | baseline_hidden_dim=256, 109 | encoder_hidden_dim=200, 110 | decoder_hidden_dim=200, 111 | scale_prior_mean=3.0, 112 | scale_prior_std=0.2, 113 | pos_prior_mean=0.0, 114 | pos_prior_std=1.0, 115 | ): 116 | super().__init__() 117 | 118 | #### Settings 119 | 120 | self.max_steps = max_steps 121 | 122 | self.img_size = img_size 123 | self.object_size = object_size 124 | self.color_channels = color_channels 125 | self.z_what_dim = z_what_dim 126 | self.lstm_hidden_dim = lstm_hidden_dim 127 | self.baseline_hidden_dim = baseline_hidden_dim 128 | self.encoder_hidden_dim = encoder_hidden_dim 129 | self.decoder_hidden_dim = decoder_hidden_dim 130 | 131 | self.z_pres_prob_prior = nograd_param(0.01) 132 | self.z_where_loc_prior = nograd_param( 133 | [scale_prior_mean, pos_prior_mean, pos_prior_mean]) 134 | self.z_where_scale_prior = nograd_param( 135 | [scale_prior_std, pos_prior_std, pos_prior_std]) 136 | self.z_what_loc_prior = nograd_param(0.0) 137 | self.z_what_scale_prior = nograd_param(1.0) 138 | 139 | #### 140 | 141 | self.img_numel = color_channels * (img_size ** 2) 142 | 143 | lstm_input_size = (self.img_numel + self.z_what_dim 144 | + self.z_where_dim + self.z_pres_dim) 145 | self.lstm = LSTMCell(lstm_input_size, self.lstm_hidden_dim) 146 | 147 | # Infer presence and location from LSTM hidden state 148 | self.predictor = Predictor(self.lstm_hidden_dim) 149 | 150 | # Infer z_what given an image crop around the object 151 | self.encoder = AppearanceEncoder(object_size, color_channels, 152 | encoder_hidden_dim, z_what_dim) 153 | 154 | # Generate pixel representation of an object given its z_what 155 | self.decoder = AppearanceDecoder(z_what_dim, decoder_hidden_dim, 156 | object_size, color_channels) 157 | 158 | # Spatial transformer (does both forward and inverse) 159 | self.spatial_transf = SpatialTransformer( 160 | (self.object_size, self.object_size), 161 | (self.img_size, self.img_size)) 162 | 163 | # Baseline LSTM 164 | self.bl_lstm = LSTMCell(lstm_input_size, self.baseline_hidden_dim) 165 | 166 | # Baseline regressor 167 | self.bl_regressor = nn.Sequential( 168 | nn.Linear(self.baseline_hidden_dim, 200), 169 | nn.ReLU(), 170 | nn.Linear(200, 1) 171 | ) 172 | 173 | # Prior distributions 174 | self.pres_prior = Bernoulli(probs=self.z_pres_prob_prior) 175 | self.where_prior = Normal(loc=self.z_where_loc_prior, 176 | scale=self.z_where_scale_prior) 177 | self.what_prior = Normal(loc=self.z_what_loc_prior, 178 | scale=self.z_what_scale_prior) 179 | 180 | # Data likelihood 181 | self.likelihood = likelihood 182 | 183 | @staticmethod 184 | def _module_list_to_params(modules): 185 | params = [] 186 | for module in modules: 187 | params.extend(module.parameters()) 188 | return params 189 | 190 | def air_params(self): 191 | air_modules = [self.predictor, self.lstm, self.encoder, self.decoder] 192 | return self._module_list_to_params(air_modules) + [self.z_pres_prob_prior] 193 | 194 | def baseline_params(self): 195 | baseline_modules = [self.bl_regressor, self.bl_lstm] 196 | return self._module_list_to_params(baseline_modules) 197 | 198 | def get_output_dist(self, mean): 199 | if self.likelihood == 'original': 200 | std = torch.tensor(0.3).to(self.get_device()) 201 | dist = Normal(mean, std.expand_as(mean)) 202 | elif self.likelihood == 'bernoulli': 203 | dist = Bernoulli(probs=mean) 204 | else: 205 | msg = "Unrecognized likelihood '{}'".format(self.likelihood) 206 | raise RuntimeError(msg) 207 | return dist 208 | 209 | def forward(self, x): 210 | bs = x.size(0) 211 | 212 | # Init model state 213 | state = State( 214 | h=torch.zeros(bs, self.lstm_hidden_dim, device=x.device), 215 | c=torch.zeros(bs, self.lstm_hidden_dim, device=x.device), 216 | bl_h=torch.zeros(bs, self.baseline_hidden_dim, device=x.device), 217 | bl_c=torch.zeros(bs, self.baseline_hidden_dim, device=x.device), 218 | z_pres=torch.ones(bs, 1, device=x.device), 219 | z_where=torch.zeros(bs, 3, device=x.device), 220 | z_what=torch.zeros(bs, self.z_what_dim, device=x.device), 221 | ) 222 | 223 | # KL divergence for each step 224 | kl = torch.zeros(bs, self.max_steps, device=x.device) 225 | 226 | # Store KL for pres, where, and what separately 227 | kl_pres = torch.zeros(bs, self.max_steps, device=x.device) 228 | kl_where = torch.zeros(bs, self.max_steps, device=x.device) 229 | kl_what = torch.zeros(bs, self.max_steps, device=x.device) 230 | 231 | # Baseline value for each step 232 | baseline_value = torch.zeros(bs, self.max_steps, device=x.device) 233 | 234 | # Log likelihood for each step, with shape (B, T): 235 | # log q(z_pres[t] | x, z_{ set KL=0 414 | kl_where = kl_where * z_pres.squeeze() 415 | kl_what = kl_what * z_pres.squeeze() 416 | 417 | # When z_pres[i-1] is 0, zpres is not used -> set KL=0 418 | kl_pres = kl_pres * prev.z_pres.squeeze() 419 | 420 | kl = (kl_pres + kl_where + kl_what) 421 | 422 | # New state 423 | new_state = State( 424 | z_pres=z_pres, 425 | z_where=z_where, 426 | z_what=z_what, 427 | h=h, 428 | c=c, 429 | bl_c=bl_c, 430 | bl_h=bl_h, 431 | ) 432 | 433 | out = { 434 | 'state': new_state, 435 | 'kl': kl, 436 | 'kl_pres': kl_pres, 437 | 'kl_where': kl_where, 438 | 'kl_what': kl_what, 439 | 'baseline_value': baseline_value, 440 | 'z_pres_likelihood': z_pres_likelihood, 441 | } 442 | return out 443 | 444 | 445 | def sample_prior(self, n_imgs, **kwargs): 446 | 447 | # Sample from prior. Shapes: 448 | # z_pres: (B, T) 449 | # z_what: (B, T, z_what_dim) 450 | # z_where: (B, T, 3) 451 | z_pres = self.pres_prior.sample((n_imgs, self.max_steps)) 452 | z_what = self.what_prior.sample((n_imgs, self.max_steps, self.z_what_dim)) 453 | z_where = self.where_prior.sample((n_imgs, self.max_steps)) 454 | 455 | # TODO This is only for visualization! Not real model samples 456 | # The prior of z_pres puts a lot of probability on n=0, which doesn't 457 | # lead to informative samples. Instead, generate half images with 1 458 | # object and half with 2. 459 | # z_pres.fill_(0.) 460 | # z_pres[:, 0].fill_(1.) 461 | # z_pres[n_imgs//2:, 1].fill_(1.) 462 | 463 | # If z_pres is sampled from the prior, make sure there are no ones 464 | # after a zero. 465 | for t in range(1, self.max_steps): 466 | z_pres[:, t] *= z_pres[:, t-1] # if previous=0, this is also 0 467 | 468 | n_obj = z_pres.sum(1) 469 | 470 | # Decode z_what to object appearance 471 | sprites = self.decoder(z_what) 472 | 473 | # Spatial-transform them to images with shape (B*T, 1, H, W) 474 | z_where_ = z_where.view(n_imgs * self.max_steps, 3) # shape (B*T, 3) 475 | imgs = self.spatial_transf.forward(sprites, z_where_) 476 | 477 | # Reshape images to (B, T, 1, H, W) 478 | h = w = self.img_size 479 | ch = self.color_channels 480 | imgs = imgs.view(n_imgs, self.max_steps, ch, h, w) 481 | 482 | # Make canvas by masking and summing over timesteps 483 | canvas = imgs * z_pres[:, :, None, None, None] 484 | canvas = canvas.sum(1) 485 | 486 | return canvas, z_where, n_obj 487 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.1 2 | torch==1.4.0 3 | torchvision==0.5.0 4 | matplotlib==3.1.2 5 | tqdm==4.41.1 6 | boilr>=0.6.0,<0.7 7 | multiobject 8 | tensorboard 9 | 10 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def nograd_param(x): 6 | """ 7 | Naively make tensor from x, then wrap with nn.Parameter without gradient. 8 | """ 9 | return nn.Parameter(torch.tensor(x), requires_grad=False) 10 | -------------------------------------------------------------------------------- /utils/spatial_transform.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from boilr.utils import to_np 7 | 8 | 9 | class SpatialTransformer: 10 | def __init__(self, input_shape, output_shape): 11 | """ 12 | :param input_shape: (H, W) 13 | :param output_shape: (H, W) 14 | """ 15 | self.input_shape = input_shape 16 | self.output_shape = output_shape 17 | 18 | def _transform(self, x, z_where, inverse): 19 | """ 20 | :param x: (B, 1, Hin, Win) 21 | :param z_where: [s, x, y] 22 | :param inverse: inverse z_where 23 | :return: y of output_size 24 | """ 25 | if inverse: 26 | z_where = invert_z_where(z_where) 27 | out_shp = self.input_shape 28 | else: 29 | out_shp = self.output_shape 30 | 31 | out = spatial_transformer(x, z_where, out_shp) 32 | return out 33 | 34 | def forward(self, x, z_where): 35 | return self._transform(x, z_where, inverse=False) 36 | 37 | def inverse(self, x, z_where): 38 | return self._transform(x, z_where, inverse=True) 39 | 40 | 41 | def spatial_transformer(x, z_where, out_shape): 42 | """ 43 | Resamples x on a grid of shape out_shape based on an affine transform 44 | parameterized by z_where. 45 | The output image has shape out_shape. 46 | 47 | :param x: 48 | :param z_where: 49 | :param out_shape: 50 | :return: 51 | """ 52 | batch_sz = x.size(0) 53 | theta = expand_z_where(z_where) 54 | grid_shape = torch.Size((batch_sz, 1) + out_shape) 55 | grid = F.affine_grid(theta, grid_shape, align_corners=False) 56 | out = F.grid_sample(x, grid, align_corners=False) 57 | return out 58 | 59 | def expand_z_where(z_where): 60 | """ 61 | :param z_where: batch. [s, x, y] 62 | :return: [[s, 0, x], [0, s, y]] 63 | """ 64 | bs = z_where.size(0) 65 | dev = z_where.device 66 | 67 | # [s, x, y] -> [s, 0, x, 0, s, y] 68 | z_where = torch.cat((torch.zeros(bs, 1, device=dev), z_where), dim=1) 69 | expansion_indices = torch.tensor([1, 0, 2, 0, 1, 3], device=dev) 70 | matrix = torch.index_select(z_where, dim=1, index=expansion_indices) 71 | matrix = matrix.view(bs, 2, 3) 72 | 73 | return matrix 74 | 75 | def invert_z_where(z_where): 76 | z_where_inv = torch.zeros_like(z_where) 77 | scale = z_where[:, 0:1] # (batch, 1) 78 | z_where_inv[:, 1:3] = -z_where[:, 1:3] / scale # (batch, 2) 79 | z_where_inv[:, 0:1] = 1 / scale # (batch, 1) 80 | return z_where_inv 81 | 82 | 83 | def batch_add_bounding_boxes(imgs, z_wheres, n_obj, color=None, n_img=None): 84 | """ 85 | 86 | :param imgs: 4d tensor of numpy array, channel dim either 1 or 3 87 | :param z_wheres: tensor or numpy of shape (n_imgs, max_n_objects, 3) 88 | :param n_obj: 89 | :param color: 90 | :param n_img: 91 | :return: 92 | """ 93 | 94 | # Check arguments 95 | assert len(imgs.shape) == 4 96 | assert imgs.shape[1] in [1, 3] 97 | assert len(z_wheres.shape) == 3 98 | assert z_wheres.shape[0] == imgs.shape[0] 99 | assert z_wheres.shape[2] == 3 100 | 101 | target_shape = list(imgs.shape) 102 | target_shape[1] = 3 103 | 104 | if n_img is None: 105 | n_img = len(imgs) 106 | if color is None: 107 | color = np.array([1., 0., 0.]) 108 | out = torch.stack([ 109 | add_bounding_boxes(imgs[j], z_wheres[j], color, n_obj[j]) 110 | for j in range(n_img) 111 | ]) 112 | 113 | out_shape = tuple(out.shape) 114 | target_shape = tuple(target_shape) 115 | assert out_shape == target_shape, "{}, {}".format(out_shape, target_shape) 116 | return out 117 | 118 | 119 | def add_bounding_boxes(img, z_wheres, color, n_obj): 120 | """ 121 | Adds bounding boxes to the n_obj objects in img, according to z_wheres. 122 | The output is never on cuda. 123 | 124 | :param img: image in 3d or 4d shape, either Tensor or numpy. If 4d, the 125 | first dimension must be 1. The channel dimension must be 126 | either 1 or 3. 127 | :param z_wheres: tensor or numpy of shape (1, max_n_objects, 3) or 128 | (max_n_objects, 3) 129 | :param color: color of all bounding boxes (RGB) 130 | :param n_obj: number of objects in the scene. This controls the number of 131 | bounding boxes to be drawn, and cannot be greater than the 132 | max number of objects supported by z_where (dim=1). Has to be 133 | a scalar or a single-element Tensor/array. 134 | :return: image with required bounding boxes, with same type and dimension 135 | as the original image input, except 3 color channels. 136 | """ 137 | 138 | try: 139 | n_obj = n_obj.item() 140 | except AttributeError: 141 | pass 142 | n_obj = int(round(n_obj)) 143 | assert n_obj <= z_wheres.shape[1] 144 | 145 | try: 146 | img = img.cpu() 147 | except AttributeError: 148 | pass 149 | 150 | if len(img.shape) == 3: 151 | color_dim = 0 152 | else: 153 | color_dim = 1 154 | 155 | if len(z_wheres.shape) == 3: 156 | assert z_wheres.shape[0] == 1 157 | z_wheres = z_wheres[0] 158 | 159 | target_shape = list(img.shape) 160 | target_shape[color_dim] = 3 161 | 162 | for i in range(n_obj): 163 | img = add_bounding_box(img, z_wheres[i:i+1], color) 164 | if img.shape[color_dim] == 1: # this might happen if n_obj==0 165 | reps = [3, 1, 1] 166 | if color_dim == 1: 167 | reps = [1] + reps 168 | reps = tuple(reps) 169 | if isinstance(img, torch.Tensor): 170 | img = img.repeat(*reps) 171 | else: 172 | img = np.tile(img, reps) 173 | 174 | target_shape = tuple(target_shape) 175 | img_shape = tuple(img.shape) 176 | assert img_shape == target_shape, "{}, {}".format(img_shape, target_shape) 177 | return img 178 | 179 | 180 | def add_bounding_box(img, z_where, color): 181 | """ 182 | Adds a bounding box to img with parameters z_where and the given color. 183 | Makes a copy of the input image, which is left unaltered. The output is 184 | never on cuda. 185 | 186 | :param img: image in 3d or 4d shape, either Tensor or numpy. If 4d, the 187 | first dimension must be 1. The channel dimension must be 188 | either 1 or 3. 189 | :param z_where: tensor or numpy with 3 elements, and shape (1, ..., 1, 3) 190 | :param color: 191 | :return: image with required bounding box in the specified color, with same 192 | type and dimension as the original image input, except 3 color 193 | channels. 194 | """ 195 | def _bounding_box(z_where, x_size, rounded=True, margin=1): 196 | z_where = to_np(z_where).flatten() 197 | assert z_where.shape[0] == z_where.size == 3 198 | s, x, y = tuple(z_where) 199 | w = x_size / s 200 | h = x_size / s 201 | xtrans = -x / s * x_size / 2 202 | ytrans = -y / s * x_size / 2 203 | x1 = (x_size - w) / 2 + xtrans - margin 204 | y1 = (x_size - h) / 2 + ytrans - margin 205 | x2 = x1 + w + 2 * margin 206 | y2 = y1 + h + 2 * margin 207 | x1, x2 = sorted((x1, x2)) 208 | y1, y2 = sorted((y1, y2)) 209 | coords = (x1, x2, y1, y2) 210 | if rounded: 211 | coords = (int(round(t)) for t in coords) 212 | return coords 213 | 214 | target_shape = list(img.shape) 215 | collapse_first = False 216 | torch_tensor = isinstance(img, torch.Tensor) 217 | img = to_np(img).copy() 218 | if len(img.shape) == 3: 219 | collapse_first = True 220 | img = np.expand_dims(img, 0) 221 | target_shape[0] = 3 222 | else: 223 | target_shape[1] = 3 224 | assert len(img.shape) == 4 and img.shape[0] == 1 225 | if img.shape[1] == 1: 226 | img = np.tile(img, (1, 3, 1, 1)) 227 | assert img.shape[1] == 3 228 | color = color[:, None] 229 | 230 | x1, x2, y1, y2 = _bounding_box(z_where, img.shape[2]) 231 | x_max = y_max = img.shape[2] - 1 232 | if 0 <= y1 <= y_max: 233 | img[0, :, y1, max(x1, 0):min(x2, x_max)] = color 234 | if 0 <= y2 - 1 <= y_max: 235 | img[0, :, y2 - 1, max(x1, 0):min(x2, x_max)] = color 236 | if 0 <= x1 <= x_max: 237 | img[0, :, max(y1, 0):min(y2, y_max), x1] = color 238 | if 0 <= x2 - 1 <= x_max: 239 | img[0, :, max(y1, 0):min(y2, y_max), x2 - 1] = color 240 | 241 | if collapse_first: 242 | img = img[0] 243 | if torch_tensor: 244 | img = torch.from_numpy(img) 245 | 246 | target_shape = tuple(target_shape) 247 | img_shape = tuple(img.shape) 248 | assert img_shape == target_shape, "{}, {}".format(img_shape, target_shape) 249 | return img 250 | 251 | def _test(obj_size, canvas_size, color_ch): 252 | 253 | # Object to image. Meaningful scenario: 254 | # scale > 1, x and y in [-scale, +scale] 255 | # Perfect copy (no interpolation) when scale == canvas_size / obj_size 256 | obj = (torch.rand(1, color_ch, obj_size, obj_size) < 0.8).float() 257 | z_where = torch.tensor([[6., 2., 4.]]) 258 | spatial_transf = SpatialTransformer( 259 | (obj_size, obj_size), (canvas_size, canvas_size)) 260 | out = spatial_transf.forward(obj, z_where) 261 | 262 | plt.figure() 263 | plt.imshow(out[0].permute(1, 2, 0).squeeze(), vmin=0., vmax=1.) 264 | plt.show() 265 | 266 | # Image to object. 267 | # Here we retrieve the same object we initially drew on the canvas. 268 | img = out 269 | out = spatial_transf.inverse(img, z_where) 270 | 271 | plt.figure() 272 | plt.imshow(out[0].permute(1, 2, 0).squeeze(), vmin=0., vmax=1.) 273 | plt.show() 274 | 275 | # show bounding box 276 | img_np = to_np(img) 277 | color = np.array([1., 0., 0.]) 278 | img_np = add_bounding_box(img_np, z_where, color) 279 | plt.figure() 280 | plt.imshow(img_np[0].transpose(1, 2, 0), vmin=0., vmax=1.) 281 | plt.show() 282 | 283 | # test bounding box methods 284 | add_bounding_box(img[0], z_where, color) # 3d tensor 285 | add_bounding_box(img_np[0], z_where, color) # 3d numpy 286 | add_bounding_box(img, z_where, color) # 4d tensor 287 | add_bounding_box(img_np, z_where, color) # 4d numpy 288 | z_wheres = z_where.repeat(1, 4, 1) 289 | z_wheres += torch.randn_like(z_wheres) * 2 290 | add_bounding_boxes(img[0], z_wheres, color, 4) # 3d tensor 291 | add_bounding_boxes(img_np[0], z_wheres, color, 4) # 3d numpy 292 | add_bounding_boxes(img, z_wheres, color, 4) # 4d tensor 293 | img_np = add_bounding_boxes(img_np, z_wheres, color, 4) # 4d numpy 294 | plt.figure() 295 | plt.imshow(img_np[0].transpose(1, 2, 0), vmin=0., vmax=1.) 296 | plt.show() 297 | 298 | # case nobj = 0 missing 299 | 300 | 301 | if __name__ == '__main__': 302 | obj_size = 8 303 | canvas_size = 48 304 | color_ch = 3 305 | _test(obj_size, canvas_size, color_ch) 306 | color_ch = 1 307 | _test(obj_size, canvas_size, color_ch) 308 | --------------------------------------------------------------------------------