├── .gitignore
├── LICENSE
├── README.md
├── _readme_imgs
├── multi_dsprites_recons.png
├── multi_mnist_recons.png
└── original_multimnist_recons.png
├── data
├── multi-dsprites-binary-rgb
│ └── multi_dsprites_color_012.npz
└── multi_mnist
│ ├── multi_binary_mnist_012.npz
│ └── multi_mnist_pyro.npz
├── datasets.py
├── experiments
├── __init__.py
└── air_experiment
│ ├── __init__.py
│ ├── data.py
│ └── experiment_manager.py
├── main.py
├── models
├── __init__.py
└── air.py
├── requirements.txt
└── utils
├── __init__.py
├── misc.py
└── spatial_transform.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | __pycache__
4 | *.pyc
5 | checkpoints
6 | data
7 | results
8 | tensorboard_logs
9 |
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Andrea Dittadi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attend Infer Repeat
2 |
3 | Attend Infer Repeat (AIR) [1] in PyTorch. Parts of this implementation are
4 | inspired by [this Pyro tutorial](https://pyro.ai/examples/air.html) [2] and
5 | [this blog post](http://akosiorek.github.io/ml/2017/09/03/implementing-air.html) [3].
6 | See [below](#high-level-model-description) for a description of the model.
7 |
8 | Install requirements and run:
9 | ```
10 | pip install -r requirements.txt
11 | CUDA_VISIBLE_DEVICES=0 python main.py
12 | ```
13 |
14 |
15 |
16 | ## Results
17 |
18 | ### Original multi-MNIST dataset
19 |
20 | Multi-MNIST results are often (about 80% of the time) reproduced with this implementation.
21 | In the other cases, the model converges to a local maximum of the ELBO where the
22 | recurrent attention often predicts more objects than the ground truth (either
23 | using more objects to model one digit, or inferring blank objects).
24 |
25 | | dataset | likelihood | accuracy | ELBO | log _p(x)_ ≥
[100 iws] |
26 | | -------------------- |:---------------------------:|:------------:|:-----------:|:-------------------:|
27 | | original multi-MNIST | N(_f(z)_, 0.32) | 98.3 ± 0.15 % | 627.4 ± 0.7 | 636.8 ± 0.8 |
28 |
29 | where:
30 | - mean and std do not include the bad runs
31 | - the likelihood is a Gaussian with fixed (and large!) variance [1]
32 | - the last column is a tighter log likelihood lower bound than the ELBO, and iws
33 | stands for importance-weighted samples [4]
34 |
35 |
36 | Reconstructions with inferred bounding boxes for one of the good runs:
37 |
38 | 
39 |
40 |
41 |
42 | ### Smaller objects
43 |
44 | Preliminary results on multi-MNIST and multi-dSprites, with larger images (64x64)
45 | and smaller objects (object patches range from 9x9 to 18x18). The prior on object
46 | scale was updated to reflect the smaller patch size. This is harder
47 | for the model to learn, especially correct inference of z_pres. On the
48 | dSprites dataset only 1/4 runs are successful. Examples of successful runs below.
49 |
50 | 
51 |
52 | 
53 |
54 |
55 |
56 | ## Implementation notes
57 |
58 | - In [1] the prior probability for presence is annealed from almost 1 to 1e-5
59 | or less [3], whereas here it is fixed to 0.01 as suggested in [2].
60 | - This means that the generative model does not make much sense, since an image
61 | sampled from the model will be empty with probability 0.99.
62 | - We can still learn the presence prior (after convergence, otherwise it doesn't
63 | work), which in this case converges to approximately 0.5. Results soon.
64 | - The other defaults are as in the paper [1].
65 | - In [1] the likelihood _p(x | z)_ is a Gaussian with fixed variance, but it has
66 | to be Bernoulli for binary data.
67 | - This means that, unlike in [1], the decoder output must be in the interval
68 | [0, 1]. Clamping is the simplest solution but it doesn't work quite well.
69 |
70 |
71 |
72 | ## High-level model description
73 |
74 | Attend Infer Repeat (AIR) [1] is a structured generative model of visual scenes,
75 | that attempts to explain such scenes as compositions of discrete objects. Each
76 | object is described by its (binary) presence, location/scale, and appearance.
77 | The model is invariant to object permutations.
78 |
79 | The intractable inference of the posterior _p(z | x)_ is approximated through
80 | (stochastic, amortized) variational inference. An iterative (recurrent)
81 | algorithm infers whether there is another object to be explained (i.e. the next
82 | object has presence=1) and, if so, it infers its location/scale and appearance.
83 | Appearance is inferred from a patch of the input image given by the inferred
84 | location/scale.
85 |
86 | ## Requirements
87 | ```
88 | python 3.7.6
89 | numpy 1.18.1
90 | torch 1.4.0
91 | torchvision 0.5.0
92 | matplotlib 3.1.2
93 | tqdm 4.41.1
94 | boilr 0.6.0
95 | multiobject 0.0.3
96 | ```
97 |
98 | ## References
99 |
100 | [1] A Eslami,
101 | N Heess,
102 | T Weber,
103 | Y Tassa,
104 | D Szepesvari,
105 | K Kavukcuoglu,
106 | G Hinton.
107 | _Attend, Infer, Repeat: Fast Scene Understanding with Generative Models_, NIPS 2016
108 |
109 | [2] [https://pyro.ai/examples/air.html](https://pyro.ai/examples/air.html)
110 |
111 | [3] [http://akosiorek.github.io/ml/2017/09/03/implementing-air.html](http://akosiorek.github.io/ml/2017/09/03/implementing-air.html)
112 |
113 | [4] Y Burda, RB Grosse, R Salakhutdinov.
114 | _Importance Weighted Autoencoders_,
115 | ICLR 2016
116 |
--------------------------------------------------------------------------------
/_readme_imgs/multi_dsprites_recons.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/multi_dsprites_recons.png
--------------------------------------------------------------------------------
/_readme_imgs/multi_mnist_recons.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/multi_mnist_recons.png
--------------------------------------------------------------------------------
/_readme_imgs/original_multimnist_recons.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/_readme_imgs/original_multimnist_recons.png
--------------------------------------------------------------------------------
/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz
--------------------------------------------------------------------------------
/data/multi_mnist/multi_binary_mnist_012.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi_mnist/multi_binary_mnist_012.npz
--------------------------------------------------------------------------------
/data/multi_mnist/multi_mnist_pyro.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/data/multi_mnist/multi_mnist_pyro.npz
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class PyroMultiMNIST(Dataset):
7 | def __init__(self, path, train):
8 | self.path = path
9 | self.train = train
10 | data = np.load(path, allow_pickle=True)
11 | x = data['x']
12 | y = data['y']
13 | split = 50000
14 | if train:
15 | self.x, self.y = x[:split], y[:split]
16 | else:
17 | self.x, self.y = x[split:], y[split:]
18 |
19 | def __getitem__(self, index):
20 | """
21 | Returns (x, y), where x is (1, H, W) in range (0, 1),
22 | y is a label dict with only a 'n_obj' key.
23 | """
24 | # x: uint8, (1, H, W)
25 | # y: label dict
26 | x, y = self.x[index], self.y[index]
27 | y = np.array(len(y))
28 | x = x / 255.0
29 |
30 | x, y = torch.from_numpy(x).float(), torch.from_numpy(y).float()
31 | x = x[None]
32 | y = {'n_obj': y} # label dict: compatible with multiobject dataloader
33 |
34 | return x, y
35 |
36 |
37 | def __len__(self):
38 | return len(self.x)
39 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from .air_experiment.experiment_manager import AIRExperiment
2 |
--------------------------------------------------------------------------------
/experiments/air_experiment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/experiments/air_experiment/__init__.py
--------------------------------------------------------------------------------
/experiments/air_experiment/data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from multiobject.pytorch import MultiObjectDataset, MultiObjectDataLoader
4 |
5 | from datasets import PyroMultiMNIST
6 |
7 | multiobject_paths = {
8 | 'multi_mnist_binary': './data/multi_mnist/multi_binary_mnist_012.npz',
9 | 'multi_dsprites_binary_rgb': './data/multi-dsprites-binary-rgb/multi_dsprites_color_012.npz',
10 | }
11 | multiobject_datasets = multiobject_paths.keys()
12 | pyro_mnist_path = os.path.join('data', 'multi_mnist', 'multi_mnist_pyro.npz')
13 |
14 |
15 | class DatasetLoader:
16 | """
17 | Wrapper for DataLoaders. Data attributes:
18 | - train: DataLoader object for training set
19 | - test: DataLoader object for test set
20 | - data_shape: shape of each data point (channels, height, width)
21 | - img_size: spatial dimensions of each data point (height, width)
22 | - color_ch: number of color channels
23 | """
24 |
25 | def __init__(self, args, cuda):
26 |
27 | # Default arguments for dataloaders
28 | kwargs = {'num_workers': 1, 'pin_memory': False} if cuda else {}
29 |
30 | # Define training and test set
31 | if args.dataset_name == 'pyro_multi_mnist':
32 | train_set = PyroMultiMNIST(pyro_mnist_path, train=True)
33 | test_set = PyroMultiMNIST(pyro_mnist_path, train=False)
34 | elif args.dataset_name in multiobject_datasets:
35 | data_path = multiobject_paths[args.dataset_name]
36 | train_set = MultiObjectDataset(data_path, train=True)
37 | test_set = MultiObjectDataset(data_path, train=False)
38 | else:
39 | raise RuntimeError("Unrecognized data set '{}'".format(args.dataset_name))
40 |
41 | # Dataloaders
42 | self.train = MultiObjectDataLoader(
43 | train_set,
44 | batch_size=args.batch_size,
45 | shuffle=True,
46 | drop_last=True,
47 | **kwargs
48 | )
49 | self.test = MultiObjectDataLoader(
50 | test_set,
51 | batch_size=args.test_batch_size,
52 | shuffle=False,
53 | **kwargs
54 | )
55 |
56 | self.data_shape = self.train.dataset[0][0].size()
57 | self.img_size = self.data_shape[1:]
58 | self.color_ch = self.data_shape[0]
59 |
--------------------------------------------------------------------------------
/experiments/air_experiment/experiment_manager.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from boilr import VIExperimentManager
7 | from boilr.viz import img_grid_pad_value
8 | from torch import optim
9 | from torchvision.utils import save_image
10 |
11 | from models.air import AIR
12 | from utils.spatial_transform import batch_add_bounding_boxes
13 | from .data import DatasetLoader
14 |
15 |
16 | class AIRExperiment(VIExperimentManager):
17 | """
18 | Experiment manager.
19 |
20 | Data attributes:
21 | - 'args': argparse.Namespace containing all config parameters. When
22 | initializing this object, if 'args' is not given, all config
23 | parameters are set based on experiment defaults and user input, using
24 | argparse.
25 | - 'run_description': string description of the run that includes a timestamp
26 | and can be used e.g. as folder name for logging.
27 | - 'model': the model.
28 | - 'device': torch.device that is being used
29 | - 'dataloaders': DataLoaders, with attributes 'train' and 'test'
30 | - 'optimizer': the optimizer
31 | """
32 |
33 | def make_datamanager(self):
34 | cuda = self.device.type == 'cuda'
35 | return DatasetLoader(self.args, cuda)
36 |
37 | def make_model(self):
38 | args = self.args
39 |
40 | obj_size = {
41 | 'pyro_multi_mnist': 28,
42 | 'multi_mnist_binary': 18,
43 | 'multi_dsprites_binary_rgb': 18,
44 | }[args.dataset_name]
45 |
46 | scale_prior_mean = {
47 | 'pyro_multi_mnist': 3.,
48 | 'multi_mnist_binary': 4.5,
49 | 'multi_dsprites_binary_rgb': 4.5,
50 | }[args.dataset_name]
51 |
52 | model = AIR(
53 | img_size=self.dataloaders.img_size[0], # assume h=w
54 | color_channels=self.dataloaders.color_ch,
55 | object_size=obj_size,
56 | max_steps=3,
57 | likelihood=args.likelihood,
58 | scale_prior_mean=scale_prior_mean,
59 | )
60 | return model
61 |
62 | def make_optimizer(self):
63 | args = self.args
64 | optimizer = optim.Adam([
65 | {
66 | 'params': self.model.air_params(),
67 | 'lr': args.lr,
68 | 'weight_decay': args.weight_decay,
69 | },
70 | {
71 | 'params': self.model.baseline_params(),
72 | 'lr': args.bl_lr,
73 | 'weight_decay': args.weight_decay,
74 | },
75 | ])
76 | return optimizer
77 |
78 |
79 |
80 | def forward_pass(self, x, y=None):
81 | """
82 | Simple single-pass model evaluation. It consists of a forward pass
83 | and computation of all necessary losses and metrics.
84 | """
85 |
86 | # Forward pass
87 | x = x.to(self.device, non_blocking=True)
88 | out = self.model(x)
89 |
90 | elbo_sep = out['elbo_sep']
91 | bl_target = out['baseline_target']
92 | bl_value = out['baseline_value']
93 | data_likelihood_sep = out['data_likelihood']
94 | z_pres_likelihood = out['z_pres_likelihood']
95 | mask = out['mask_prev']
96 |
97 | # The baseline target is:
98 | # sum_{i=t}^T KL[i] - log p(x | z)
99 | # for all steps up to (and including) the first z_pres=0
100 | bl_target = bl_target - data_likelihood_sep[:, None]
101 | bl_target = bl_target * mask # (B, T)
102 |
103 | # The "REINFORCE" term in the gradient is:
104 | # (baseline_target - baseline_value) * gradient[z_pres_likelihood]
105 | reinforce_term = ((bl_target - bl_value).detach() * z_pres_likelihood)
106 | reinforce_term = reinforce_term * mask
107 | reinforce_term = reinforce_term.sum(1) # (B, )
108 |
109 | # Maximize ELBO with additional REINFORCE term for discrete variables
110 | model_loss = reinforce_term - elbo_sep # (B, )
111 | model_loss = model_loss.mean() # mean over batch
112 |
113 | # MSE as baseline loss
114 | baseline_loss = F.mse_loss(bl_value, bl_target.detach(), reduction='none')
115 | baseline_loss = baseline_loss * mask
116 | baseline_loss = baseline_loss.sum(1).mean() # mean over batch
117 |
118 | loss = model_loss + baseline_loss
119 | out['loss'] = loss
120 |
121 | # L2
122 | l2 = 0.0
123 | for p in self.model.parameters():
124 | l2 = l2 + torch.sum(p ** 2)
125 | l2 = l2.sqrt()
126 | out['l2'] = l2
127 |
128 | # Accuracy
129 | out['accuracy'] = None
130 | if y is not None:
131 | n_obj = y['n_obj'].to(self.device)
132 | n_pred = out['inferred_n'] # (B, )
133 | correct = (n_pred == n_obj).float().sum()
134 | acc = correct / n_pred.size(0)
135 | out['accuracy'] = acc
136 |
137 | # TODO Only for viz, as std=0.3 is pretty high so samples are not good
138 | out['out_sample'] = out['out_mean'] # this is actually NOT a sample!
139 |
140 | return out
141 |
142 |
143 | @staticmethod
144 | def print_train_log(step, epoch, summaries):
145 | s = (" [step {step}] loss: {loss:.5g} ELBO: {elbo:.5g} "
146 | "recons: {recons:.3g} KL: {kl:.3g} acc: {acc:.3g}")
147 | s = s.format(
148 | step=step,
149 | loss=summaries['loss/loss'],
150 | elbo=summaries['elbo/elbo'],
151 | recons=summaries['elbo/recons'],
152 | kl=summaries['elbo/kl'],
153 | acc=summaries['accuracy'])
154 | print(s)
155 |
156 |
157 | @staticmethod
158 | def print_test_log(summaries, step=None, epoch=None):
159 | log_string = " "
160 | if epoch is not None:
161 | log_string += "[step {}, epoch {}] ".format(step, epoch)
162 | s = "ELBO {elbo:.5g} recons: {recons:.3g} KL: {kl:.3g} acc: {acc:.3g}"
163 | log_string += s.format(
164 | elbo=summaries['elbo/elbo'],
165 | recons=summaries['elbo/recons'],
166 | kl=summaries['elbo/kl'],
167 | acc=summaries['accuracy'])
168 | ll_key = None
169 | for k in summaries.keys():
170 | if k.find('elbo_IW') > -1:
171 | ll_key = k
172 | iw_samples = k.split('_')[-1]
173 | break
174 | if ll_key is not None:
175 | log_string += " marginal log-likelihood ({}) {:.5g}".format(
176 | iw_samples, summaries[ll_key])
177 |
178 | print(log_string)
179 |
180 |
181 | @staticmethod
182 | def get_metrics_dict(results):
183 | metrics_dict = {
184 | 'loss/loss': results['loss'].item(),
185 | 'elbo/elbo': results['elbo'].item(),
186 | 'elbo/recons': results['recons'].item(),
187 | 'elbo/kl': results['kl'].item(),
188 | 'l2/l2': results['l2'].item(),
189 | 'accuracy': results['accuracy'].item(),
190 |
191 | 'kl/pres': results['kl_pres'].item(),
192 | 'kl/what': results['kl_what'].item(),
193 | 'kl/where': results['kl_where'].item(),
194 | }
195 | return metrics_dict
196 |
197 |
198 |
199 | def additional_testing(self, img_folder):
200 | """
201 | Perform additional testing, including possibly generating images.
202 |
203 | In this case, save samples from the generative model, and pairs
204 | input/reconstruction from the test set.
205 |
206 | :param img_folder: folder to store images
207 | """
208 |
209 | step = self.model.global_step
210 |
211 | if not self.args.dry_run:
212 |
213 | # Saved images will have n**2 sub-images
214 | n = 8
215 |
216 | # Save model samples
217 | sample, zwhere, n_obj = self.model.sample_prior(n ** 2)
218 | annotated_sample = batch_add_bounding_boxes(sample, zwhere, n_obj)
219 | fname = os.path.join(img_folder, 'sample_' + str(step) + '.png')
220 | pad = img_grid_pad_value(annotated_sample)
221 | save_image(annotated_sample, fname, nrow=n, pad_value=pad)
222 |
223 | # Get first test batch
224 | (x, _) = next(iter(self.dataloaders.test))
225 | fname = os.path.join(img_folder, 'reconstruction_' + str(step) + '.png')
226 |
227 | # Save model original/reconstructions
228 | self.save_input_and_recons(x, fname, n)
229 |
230 |
231 | def save_input_and_recons(self, x, fname, n):
232 | n_img = n ** 2 // 2
233 | if x.shape[0] < n_img:
234 | msg = ("{} data points required, but given batch has size {}. "
235 | "Please use a larger batch.".format(n_img, x.shape[0]))
236 | raise RuntimeError(msg)
237 | x = x.to(self.device)
238 | outputs = self.forward_pass(x)
239 | x = x[:n_img]
240 | if x.shape[1] == 1:
241 | x = x.expand(-1, 3, -1, -1)
242 | recons = outputs['out_sample'][:n_img]
243 | z_where = outputs['all_z_where'][:n_img]
244 | pred_count = outputs['inferred_n']
245 | recons = batch_add_bounding_boxes(recons, z_where, pred_count)
246 | imgs = torch.stack([x.cpu(), recons.cpu()])
247 | imgs = imgs.permute(1, 0, 2, 3, 4)
248 | imgs = imgs.reshape(n ** 2, x.size(1), x.size(2), x.size(3))
249 | pad = img_grid_pad_value(imgs)
250 | save_image(imgs, fname, nrow=n, pad_value=pad)
251 |
252 |
253 | def _parse_args(self):
254 | """
255 | Parse command-line arguments defining experiment settings.
256 |
257 | :return: args: argparse.Namespace with experiment settings
258 | """
259 |
260 | def list_options(lst):
261 | if lst:
262 | return "'" + "' | '".join(lst) + "'"
263 | return ""
264 |
265 | legal_datasets = [
266 | 'pyro_multi_mnist',
267 | 'multi_mnist_binary',
268 | 'multi_dsprites_binary_rgb']
269 | legal_likelihoods = ['bernoulli', 'original']
270 |
271 | parser = argparse.ArgumentParser(
272 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
273 | allow_abbrev=False)
274 |
275 | self.add_required_args(parser,
276 |
277 | # General
278 | batch_size=64,
279 | test_batch_size=2000,
280 | lr=1e-4,
281 | seed=54321,
282 | log_interval=50000,
283 | test_log_interval=50000,
284 | checkpoint_interval=100000,
285 | resume="",
286 |
287 | # VI-specific
288 | ll_every=50000,
289 | loglik_samples=100,)
290 |
291 | parser.add_argument('-d', '--dataset',
292 | type=str,
293 | choices=legal_datasets,
294 | default='pyro_multi_mnist',
295 | metavar='NAME',
296 | dest='dataset_name',
297 | help="dataset: " + list_options(legal_datasets))
298 |
299 | parser.add_argument('--likelihood',
300 | type=str,
301 | choices=legal_likelihoods,
302 | metavar='NAME',
303 | dest='likelihood',
304 | help="likelihood: {}; default depends on dataset".format(
305 | list_options(legal_likelihoods)))
306 |
307 | parser.add_argument('--bl-lr',
308 | type=float,
309 | default=1e-1,
310 | metavar='LR',
311 | help="baseline's learning rate")
312 |
313 | parser.add_argument('--wd',
314 | type=float,
315 | default=0.0,
316 | dest='weight_decay',
317 | help='weight decay')
318 |
319 | args = parser.parse_args()
320 |
321 | assert args.loglik_interval % args.test_log_interval == 0
322 |
323 | if args.likelihood is None: # defaults
324 | args.likelihood = {
325 | 'pyro_multi_mnist': 'original',
326 | 'multi_mnist_binary': 'original', # 'bernoulli',
327 | 'multi_dsprites_binary_rgb': 'original', # 'bernoulli',
328 | }[args.dataset_name]
329 |
330 | return args
331 |
332 | @staticmethod
333 | def _make_run_description(args):
334 | """
335 | Create a string description of the run. It is used in the names of the
336 | logging folders.
337 |
338 | :param args: experiment config
339 | :return: the run description
340 | """
341 | s = ''
342 | s += args.dataset_name
343 | s += ',seed{}'.format(args.seed)
344 | if len(args.additional_descr) > 0:
345 | s += ',' + args.additional_descr
346 | return s
347 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import boilr
4 | import multiobject
5 | from boilr import Trainer
6 |
7 | from experiments import AIRExperiment
8 |
9 |
10 | def _check_version(pkg, pkg_str, version_info):
11 | def to_str(v):
12 | return ".".join(str(x) for x in v)
13 | if pkg.__version_info__[:2] != version_info[:2]:
14 | msg = "This was last tested with {} {}, but the current version is {}"
15 | msg = msg.format(pkg_str, to_str(version_info), pkg.__version__)
16 | warnings.warn(msg)
17 |
18 | BOILR_VERSION = (0, 5, 1)
19 | MULTIOBJ_VERSION = (0, 0, 3)
20 | _check_version(boilr, 'boilr', BOILR_VERSION)
21 | _check_version(multiobject, 'multiobject', MULTIOBJ_VERSION)
22 |
23 | def main():
24 | experiment = AIRExperiment()
25 | trainer = Trainer(experiment)
26 | trainer.run()
27 |
28 | if __name__ == "__main__":
29 | main()
30 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/addtt/attend-infer-repeat-pytorch/ef7aa3cc7b51460a501b10b1758b240ed9594148/models/__init__.py
--------------------------------------------------------------------------------
/models/air.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from boilr import BaseGenerativeModel
6 | from torch import nn
7 | from torch.distributions.bernoulli import Bernoulli
8 | from torch.distributions.kl import kl_divergence
9 | from torch.distributions.normal import Normal
10 | from torch.nn import LSTMCell
11 |
12 | from utils import nograd_param
13 | from utils.spatial_transform import SpatialTransformer
14 |
15 | State = namedtuple(
16 | 'State',
17 | ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what'])
18 |
19 |
20 | class Predictor(nn.Module):
21 | """
22 | Infer presence and location from LSTM hidden state
23 | """
24 | def __init__(self, lstm_hidden_dim):
25 | nn.Module.__init__(self)
26 | self.seq = nn.Sequential(
27 | nn.Linear(lstm_hidden_dim, 200),
28 | nn.ReLU(),
29 | nn.Linear(200, 7),
30 | )
31 |
32 | def forward(self, h):
33 | z = self.seq(h)
34 | z_pres_p = torch.sigmoid(z[:, :1])
35 | z_where_loc = z[:, 1:4]
36 | z_where_scale = F.softplus(z[:, 4:])
37 | return z_pres_p, z_where_loc, z_where_scale
38 |
39 |
40 | class AppearanceEncoder(nn.Module):
41 | """
42 | Infer object appearance latent z_what given an image crop around the object
43 | """
44 | def __init__(self, object_size, color_channels, encoder_hidden_dim, z_what_dim):
45 | super().__init__()
46 | object_numel = color_channels * (object_size ** 2)
47 | self.net = nn.Sequential(
48 | nn.Linear(object_numel, encoder_hidden_dim),
49 | nn.ReLU(),
50 | nn.Linear(encoder_hidden_dim, z_what_dim * 2)
51 | )
52 |
53 | def forward(self, crop):
54 | """
55 | :param crop: (B, C, H, W)
56 | :return: z_what_loc, z_what_scale
57 | """
58 | bs = crop.size(0)
59 | crop_flat = crop.view(bs, -1)
60 | x = self.net(crop_flat)
61 | z_what_loc, z_what_scale = x.chunk(2, dim=1)
62 | z_what_scale = F.softplus(z_what_scale)
63 |
64 | return z_what_loc, z_what_scale
65 |
66 | class AppearanceDecoder(nn.Module):
67 | """
68 | Generate pixel representation of an object given its latent code z_what
69 | """
70 | def __init__(self, z_what_dim, decoder_hidden_dim,
71 | object_size, color_channels, bias=-2.0):
72 | super().__init__()
73 | object_numel = color_channels * (object_size ** 2)
74 | self.net = nn.Sequential(
75 | nn.Linear(z_what_dim, decoder_hidden_dim),
76 | nn.ReLU(),
77 | nn.Linear(decoder_hidden_dim, object_numel),
78 | )
79 | self.sz = object_size
80 | self.ch = color_channels
81 | self.bias = bias
82 |
83 | def forward(self, z_what):
84 | x = self.net(z_what)
85 | x = x.view(-1, self.ch, self.sz, self.sz)
86 | x = torch.sigmoid(x + self.bias)
87 | return x
88 |
89 |
90 | class AIR(BaseGenerativeModel):
91 | """
92 | AIR model. Default settings are from the pyro tutorial. With those settings
93 | we can reproduce results from the original paper (although about 1/10 times
94 | it doesn't converge to a good solution).
95 | """
96 |
97 | z_where_dim = 3
98 | z_pres_dim = 1
99 |
100 | def __init__(self,
101 | img_size,
102 | object_size,
103 | max_steps,
104 | color_channels,
105 | likelihood=None,
106 | z_what_dim=50,
107 | lstm_hidden_dim=256,
108 | baseline_hidden_dim=256,
109 | encoder_hidden_dim=200,
110 | decoder_hidden_dim=200,
111 | scale_prior_mean=3.0,
112 | scale_prior_std=0.2,
113 | pos_prior_mean=0.0,
114 | pos_prior_std=1.0,
115 | ):
116 | super().__init__()
117 |
118 | #### Settings
119 |
120 | self.max_steps = max_steps
121 |
122 | self.img_size = img_size
123 | self.object_size = object_size
124 | self.color_channels = color_channels
125 | self.z_what_dim = z_what_dim
126 | self.lstm_hidden_dim = lstm_hidden_dim
127 | self.baseline_hidden_dim = baseline_hidden_dim
128 | self.encoder_hidden_dim = encoder_hidden_dim
129 | self.decoder_hidden_dim = decoder_hidden_dim
130 |
131 | self.z_pres_prob_prior = nograd_param(0.01)
132 | self.z_where_loc_prior = nograd_param(
133 | [scale_prior_mean, pos_prior_mean, pos_prior_mean])
134 | self.z_where_scale_prior = nograd_param(
135 | [scale_prior_std, pos_prior_std, pos_prior_std])
136 | self.z_what_loc_prior = nograd_param(0.0)
137 | self.z_what_scale_prior = nograd_param(1.0)
138 |
139 | ####
140 |
141 | self.img_numel = color_channels * (img_size ** 2)
142 |
143 | lstm_input_size = (self.img_numel + self.z_what_dim
144 | + self.z_where_dim + self.z_pres_dim)
145 | self.lstm = LSTMCell(lstm_input_size, self.lstm_hidden_dim)
146 |
147 | # Infer presence and location from LSTM hidden state
148 | self.predictor = Predictor(self.lstm_hidden_dim)
149 |
150 | # Infer z_what given an image crop around the object
151 | self.encoder = AppearanceEncoder(object_size, color_channels,
152 | encoder_hidden_dim, z_what_dim)
153 |
154 | # Generate pixel representation of an object given its z_what
155 | self.decoder = AppearanceDecoder(z_what_dim, decoder_hidden_dim,
156 | object_size, color_channels)
157 |
158 | # Spatial transformer (does both forward and inverse)
159 | self.spatial_transf = SpatialTransformer(
160 | (self.object_size, self.object_size),
161 | (self.img_size, self.img_size))
162 |
163 | # Baseline LSTM
164 | self.bl_lstm = LSTMCell(lstm_input_size, self.baseline_hidden_dim)
165 |
166 | # Baseline regressor
167 | self.bl_regressor = nn.Sequential(
168 | nn.Linear(self.baseline_hidden_dim, 200),
169 | nn.ReLU(),
170 | nn.Linear(200, 1)
171 | )
172 |
173 | # Prior distributions
174 | self.pres_prior = Bernoulli(probs=self.z_pres_prob_prior)
175 | self.where_prior = Normal(loc=self.z_where_loc_prior,
176 | scale=self.z_where_scale_prior)
177 | self.what_prior = Normal(loc=self.z_what_loc_prior,
178 | scale=self.z_what_scale_prior)
179 |
180 | # Data likelihood
181 | self.likelihood = likelihood
182 |
183 | @staticmethod
184 | def _module_list_to_params(modules):
185 | params = []
186 | for module in modules:
187 | params.extend(module.parameters())
188 | return params
189 |
190 | def air_params(self):
191 | air_modules = [self.predictor, self.lstm, self.encoder, self.decoder]
192 | return self._module_list_to_params(air_modules) + [self.z_pres_prob_prior]
193 |
194 | def baseline_params(self):
195 | baseline_modules = [self.bl_regressor, self.bl_lstm]
196 | return self._module_list_to_params(baseline_modules)
197 |
198 | def get_output_dist(self, mean):
199 | if self.likelihood == 'original':
200 | std = torch.tensor(0.3).to(self.get_device())
201 | dist = Normal(mean, std.expand_as(mean))
202 | elif self.likelihood == 'bernoulli':
203 | dist = Bernoulli(probs=mean)
204 | else:
205 | msg = "Unrecognized likelihood '{}'".format(self.likelihood)
206 | raise RuntimeError(msg)
207 | return dist
208 |
209 | def forward(self, x):
210 | bs = x.size(0)
211 |
212 | # Init model state
213 | state = State(
214 | h=torch.zeros(bs, self.lstm_hidden_dim, device=x.device),
215 | c=torch.zeros(bs, self.lstm_hidden_dim, device=x.device),
216 | bl_h=torch.zeros(bs, self.baseline_hidden_dim, device=x.device),
217 | bl_c=torch.zeros(bs, self.baseline_hidden_dim, device=x.device),
218 | z_pres=torch.ones(bs, 1, device=x.device),
219 | z_where=torch.zeros(bs, 3, device=x.device),
220 | z_what=torch.zeros(bs, self.z_what_dim, device=x.device),
221 | )
222 |
223 | # KL divergence for each step
224 | kl = torch.zeros(bs, self.max_steps, device=x.device)
225 |
226 | # Store KL for pres, where, and what separately
227 | kl_pres = torch.zeros(bs, self.max_steps, device=x.device)
228 | kl_where = torch.zeros(bs, self.max_steps, device=x.device)
229 | kl_what = torch.zeros(bs, self.max_steps, device=x.device)
230 |
231 | # Baseline value for each step
232 | baseline_value = torch.zeros(bs, self.max_steps, device=x.device)
233 |
234 | # Log likelihood for each step, with shape (B, T):
235 | # log q(z_pres[t] | x, z_{ set KL=0
414 | kl_where = kl_where * z_pres.squeeze()
415 | kl_what = kl_what * z_pres.squeeze()
416 |
417 | # When z_pres[i-1] is 0, zpres is not used -> set KL=0
418 | kl_pres = kl_pres * prev.z_pres.squeeze()
419 |
420 | kl = (kl_pres + kl_where + kl_what)
421 |
422 | # New state
423 | new_state = State(
424 | z_pres=z_pres,
425 | z_where=z_where,
426 | z_what=z_what,
427 | h=h,
428 | c=c,
429 | bl_c=bl_c,
430 | bl_h=bl_h,
431 | )
432 |
433 | out = {
434 | 'state': new_state,
435 | 'kl': kl,
436 | 'kl_pres': kl_pres,
437 | 'kl_where': kl_where,
438 | 'kl_what': kl_what,
439 | 'baseline_value': baseline_value,
440 | 'z_pres_likelihood': z_pres_likelihood,
441 | }
442 | return out
443 |
444 |
445 | def sample_prior(self, n_imgs, **kwargs):
446 |
447 | # Sample from prior. Shapes:
448 | # z_pres: (B, T)
449 | # z_what: (B, T, z_what_dim)
450 | # z_where: (B, T, 3)
451 | z_pres = self.pres_prior.sample((n_imgs, self.max_steps))
452 | z_what = self.what_prior.sample((n_imgs, self.max_steps, self.z_what_dim))
453 | z_where = self.where_prior.sample((n_imgs, self.max_steps))
454 |
455 | # TODO This is only for visualization! Not real model samples
456 | # The prior of z_pres puts a lot of probability on n=0, which doesn't
457 | # lead to informative samples. Instead, generate half images with 1
458 | # object and half with 2.
459 | # z_pres.fill_(0.)
460 | # z_pres[:, 0].fill_(1.)
461 | # z_pres[n_imgs//2:, 1].fill_(1.)
462 |
463 | # If z_pres is sampled from the prior, make sure there are no ones
464 | # after a zero.
465 | for t in range(1, self.max_steps):
466 | z_pres[:, t] *= z_pres[:, t-1] # if previous=0, this is also 0
467 |
468 | n_obj = z_pres.sum(1)
469 |
470 | # Decode z_what to object appearance
471 | sprites = self.decoder(z_what)
472 |
473 | # Spatial-transform them to images with shape (B*T, 1, H, W)
474 | z_where_ = z_where.view(n_imgs * self.max_steps, 3) # shape (B*T, 3)
475 | imgs = self.spatial_transf.forward(sprites, z_where_)
476 |
477 | # Reshape images to (B, T, 1, H, W)
478 | h = w = self.img_size
479 | ch = self.color_channels
480 | imgs = imgs.view(n_imgs, self.max_steps, ch, h, w)
481 |
482 | # Make canvas by masking and summing over timesteps
483 | canvas = imgs * z_pres[:, :, None, None, None]
484 | canvas = canvas.sum(1)
485 |
486 | return canvas, z_where, n_obj
487 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.1
2 | torch==1.4.0
3 | torchvision==0.5.0
4 | matplotlib==3.1.2
5 | tqdm==4.41.1
6 | boilr>=0.6.0,<0.7
7 | multiobject
8 | tensorboard
9 |
10 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def nograd_param(x):
6 | """
7 | Naively make tensor from x, then wrap with nn.Parameter without gradient.
8 | """
9 | return nn.Parameter(torch.tensor(x), requires_grad=False)
10 |
--------------------------------------------------------------------------------
/utils/spatial_transform.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from boilr.utils import to_np
7 |
8 |
9 | class SpatialTransformer:
10 | def __init__(self, input_shape, output_shape):
11 | """
12 | :param input_shape: (H, W)
13 | :param output_shape: (H, W)
14 | """
15 | self.input_shape = input_shape
16 | self.output_shape = output_shape
17 |
18 | def _transform(self, x, z_where, inverse):
19 | """
20 | :param x: (B, 1, Hin, Win)
21 | :param z_where: [s, x, y]
22 | :param inverse: inverse z_where
23 | :return: y of output_size
24 | """
25 | if inverse:
26 | z_where = invert_z_where(z_where)
27 | out_shp = self.input_shape
28 | else:
29 | out_shp = self.output_shape
30 |
31 | out = spatial_transformer(x, z_where, out_shp)
32 | return out
33 |
34 | def forward(self, x, z_where):
35 | return self._transform(x, z_where, inverse=False)
36 |
37 | def inverse(self, x, z_where):
38 | return self._transform(x, z_where, inverse=True)
39 |
40 |
41 | def spatial_transformer(x, z_where, out_shape):
42 | """
43 | Resamples x on a grid of shape out_shape based on an affine transform
44 | parameterized by z_where.
45 | The output image has shape out_shape.
46 |
47 | :param x:
48 | :param z_where:
49 | :param out_shape:
50 | :return:
51 | """
52 | batch_sz = x.size(0)
53 | theta = expand_z_where(z_where)
54 | grid_shape = torch.Size((batch_sz, 1) + out_shape)
55 | grid = F.affine_grid(theta, grid_shape, align_corners=False)
56 | out = F.grid_sample(x, grid, align_corners=False)
57 | return out
58 |
59 | def expand_z_where(z_where):
60 | """
61 | :param z_where: batch. [s, x, y]
62 | :return: [[s, 0, x], [0, s, y]]
63 | """
64 | bs = z_where.size(0)
65 | dev = z_where.device
66 |
67 | # [s, x, y] -> [s, 0, x, 0, s, y]
68 | z_where = torch.cat((torch.zeros(bs, 1, device=dev), z_where), dim=1)
69 | expansion_indices = torch.tensor([1, 0, 2, 0, 1, 3], device=dev)
70 | matrix = torch.index_select(z_where, dim=1, index=expansion_indices)
71 | matrix = matrix.view(bs, 2, 3)
72 |
73 | return matrix
74 |
75 | def invert_z_where(z_where):
76 | z_where_inv = torch.zeros_like(z_where)
77 | scale = z_where[:, 0:1] # (batch, 1)
78 | z_where_inv[:, 1:3] = -z_where[:, 1:3] / scale # (batch, 2)
79 | z_where_inv[:, 0:1] = 1 / scale # (batch, 1)
80 | return z_where_inv
81 |
82 |
83 | def batch_add_bounding_boxes(imgs, z_wheres, n_obj, color=None, n_img=None):
84 | """
85 |
86 | :param imgs: 4d tensor of numpy array, channel dim either 1 or 3
87 | :param z_wheres: tensor or numpy of shape (n_imgs, max_n_objects, 3)
88 | :param n_obj:
89 | :param color:
90 | :param n_img:
91 | :return:
92 | """
93 |
94 | # Check arguments
95 | assert len(imgs.shape) == 4
96 | assert imgs.shape[1] in [1, 3]
97 | assert len(z_wheres.shape) == 3
98 | assert z_wheres.shape[0] == imgs.shape[0]
99 | assert z_wheres.shape[2] == 3
100 |
101 | target_shape = list(imgs.shape)
102 | target_shape[1] = 3
103 |
104 | if n_img is None:
105 | n_img = len(imgs)
106 | if color is None:
107 | color = np.array([1., 0., 0.])
108 | out = torch.stack([
109 | add_bounding_boxes(imgs[j], z_wheres[j], color, n_obj[j])
110 | for j in range(n_img)
111 | ])
112 |
113 | out_shape = tuple(out.shape)
114 | target_shape = tuple(target_shape)
115 | assert out_shape == target_shape, "{}, {}".format(out_shape, target_shape)
116 | return out
117 |
118 |
119 | def add_bounding_boxes(img, z_wheres, color, n_obj):
120 | """
121 | Adds bounding boxes to the n_obj objects in img, according to z_wheres.
122 | The output is never on cuda.
123 |
124 | :param img: image in 3d or 4d shape, either Tensor or numpy. If 4d, the
125 | first dimension must be 1. The channel dimension must be
126 | either 1 or 3.
127 | :param z_wheres: tensor or numpy of shape (1, max_n_objects, 3) or
128 | (max_n_objects, 3)
129 | :param color: color of all bounding boxes (RGB)
130 | :param n_obj: number of objects in the scene. This controls the number of
131 | bounding boxes to be drawn, and cannot be greater than the
132 | max number of objects supported by z_where (dim=1). Has to be
133 | a scalar or a single-element Tensor/array.
134 | :return: image with required bounding boxes, with same type and dimension
135 | as the original image input, except 3 color channels.
136 | """
137 |
138 | try:
139 | n_obj = n_obj.item()
140 | except AttributeError:
141 | pass
142 | n_obj = int(round(n_obj))
143 | assert n_obj <= z_wheres.shape[1]
144 |
145 | try:
146 | img = img.cpu()
147 | except AttributeError:
148 | pass
149 |
150 | if len(img.shape) == 3:
151 | color_dim = 0
152 | else:
153 | color_dim = 1
154 |
155 | if len(z_wheres.shape) == 3:
156 | assert z_wheres.shape[0] == 1
157 | z_wheres = z_wheres[0]
158 |
159 | target_shape = list(img.shape)
160 | target_shape[color_dim] = 3
161 |
162 | for i in range(n_obj):
163 | img = add_bounding_box(img, z_wheres[i:i+1], color)
164 | if img.shape[color_dim] == 1: # this might happen if n_obj==0
165 | reps = [3, 1, 1]
166 | if color_dim == 1:
167 | reps = [1] + reps
168 | reps = tuple(reps)
169 | if isinstance(img, torch.Tensor):
170 | img = img.repeat(*reps)
171 | else:
172 | img = np.tile(img, reps)
173 |
174 | target_shape = tuple(target_shape)
175 | img_shape = tuple(img.shape)
176 | assert img_shape == target_shape, "{}, {}".format(img_shape, target_shape)
177 | return img
178 |
179 |
180 | def add_bounding_box(img, z_where, color):
181 | """
182 | Adds a bounding box to img with parameters z_where and the given color.
183 | Makes a copy of the input image, which is left unaltered. The output is
184 | never on cuda.
185 |
186 | :param img: image in 3d or 4d shape, either Tensor or numpy. If 4d, the
187 | first dimension must be 1. The channel dimension must be
188 | either 1 or 3.
189 | :param z_where: tensor or numpy with 3 elements, and shape (1, ..., 1, 3)
190 | :param color:
191 | :return: image with required bounding box in the specified color, with same
192 | type and dimension as the original image input, except 3 color
193 | channels.
194 | """
195 | def _bounding_box(z_where, x_size, rounded=True, margin=1):
196 | z_where = to_np(z_where).flatten()
197 | assert z_where.shape[0] == z_where.size == 3
198 | s, x, y = tuple(z_where)
199 | w = x_size / s
200 | h = x_size / s
201 | xtrans = -x / s * x_size / 2
202 | ytrans = -y / s * x_size / 2
203 | x1 = (x_size - w) / 2 + xtrans - margin
204 | y1 = (x_size - h) / 2 + ytrans - margin
205 | x2 = x1 + w + 2 * margin
206 | y2 = y1 + h + 2 * margin
207 | x1, x2 = sorted((x1, x2))
208 | y1, y2 = sorted((y1, y2))
209 | coords = (x1, x2, y1, y2)
210 | if rounded:
211 | coords = (int(round(t)) for t in coords)
212 | return coords
213 |
214 | target_shape = list(img.shape)
215 | collapse_first = False
216 | torch_tensor = isinstance(img, torch.Tensor)
217 | img = to_np(img).copy()
218 | if len(img.shape) == 3:
219 | collapse_first = True
220 | img = np.expand_dims(img, 0)
221 | target_shape[0] = 3
222 | else:
223 | target_shape[1] = 3
224 | assert len(img.shape) == 4 and img.shape[0] == 1
225 | if img.shape[1] == 1:
226 | img = np.tile(img, (1, 3, 1, 1))
227 | assert img.shape[1] == 3
228 | color = color[:, None]
229 |
230 | x1, x2, y1, y2 = _bounding_box(z_where, img.shape[2])
231 | x_max = y_max = img.shape[2] - 1
232 | if 0 <= y1 <= y_max:
233 | img[0, :, y1, max(x1, 0):min(x2, x_max)] = color
234 | if 0 <= y2 - 1 <= y_max:
235 | img[0, :, y2 - 1, max(x1, 0):min(x2, x_max)] = color
236 | if 0 <= x1 <= x_max:
237 | img[0, :, max(y1, 0):min(y2, y_max), x1] = color
238 | if 0 <= x2 - 1 <= x_max:
239 | img[0, :, max(y1, 0):min(y2, y_max), x2 - 1] = color
240 |
241 | if collapse_first:
242 | img = img[0]
243 | if torch_tensor:
244 | img = torch.from_numpy(img)
245 |
246 | target_shape = tuple(target_shape)
247 | img_shape = tuple(img.shape)
248 | assert img_shape == target_shape, "{}, {}".format(img_shape, target_shape)
249 | return img
250 |
251 | def _test(obj_size, canvas_size, color_ch):
252 |
253 | # Object to image. Meaningful scenario:
254 | # scale > 1, x and y in [-scale, +scale]
255 | # Perfect copy (no interpolation) when scale == canvas_size / obj_size
256 | obj = (torch.rand(1, color_ch, obj_size, obj_size) < 0.8).float()
257 | z_where = torch.tensor([[6., 2., 4.]])
258 | spatial_transf = SpatialTransformer(
259 | (obj_size, obj_size), (canvas_size, canvas_size))
260 | out = spatial_transf.forward(obj, z_where)
261 |
262 | plt.figure()
263 | plt.imshow(out[0].permute(1, 2, 0).squeeze(), vmin=0., vmax=1.)
264 | plt.show()
265 |
266 | # Image to object.
267 | # Here we retrieve the same object we initially drew on the canvas.
268 | img = out
269 | out = spatial_transf.inverse(img, z_where)
270 |
271 | plt.figure()
272 | plt.imshow(out[0].permute(1, 2, 0).squeeze(), vmin=0., vmax=1.)
273 | plt.show()
274 |
275 | # show bounding box
276 | img_np = to_np(img)
277 | color = np.array([1., 0., 0.])
278 | img_np = add_bounding_box(img_np, z_where, color)
279 | plt.figure()
280 | plt.imshow(img_np[0].transpose(1, 2, 0), vmin=0., vmax=1.)
281 | plt.show()
282 |
283 | # test bounding box methods
284 | add_bounding_box(img[0], z_where, color) # 3d tensor
285 | add_bounding_box(img_np[0], z_where, color) # 3d numpy
286 | add_bounding_box(img, z_where, color) # 4d tensor
287 | add_bounding_box(img_np, z_where, color) # 4d numpy
288 | z_wheres = z_where.repeat(1, 4, 1)
289 | z_wheres += torch.randn_like(z_wheres) * 2
290 | add_bounding_boxes(img[0], z_wheres, color, 4) # 3d tensor
291 | add_bounding_boxes(img_np[0], z_wheres, color, 4) # 3d numpy
292 | add_bounding_boxes(img, z_wheres, color, 4) # 4d tensor
293 | img_np = add_bounding_boxes(img_np, z_wheres, color, 4) # 4d numpy
294 | plt.figure()
295 | plt.imshow(img_np[0].transpose(1, 2, 0), vmin=0., vmax=1.)
296 | plt.show()
297 |
298 | # case nobj = 0 missing
299 |
300 |
301 | if __name__ == '__main__':
302 | obj_size = 8
303 | canvas_size = 48
304 | color_ch = 3
305 | _test(obj_size, canvas_size, color_ch)
306 | color_ch = 1
307 | _test(obj_size, canvas_size, color_ch)
308 |
--------------------------------------------------------------------------------