├── draw
├── __init__.py
└── draw.py
├── gqn
├── __init__.py
├── training.py
├── gqn.py
├── representation.py
└── generator.py
├── scripts
├── gpu.sh
├── data.sh
└── tfrecord-converter.py
├── placeholder.py
├── LICENSE.md
├── environment.yml
├── README.md
├── shepardmetzler.py
├── run-draw.py
├── run-convdraw.py
├── run-gqn.py
└── mental-rotation.ipynb
/draw/__init__.py:
--------------------------------------------------------------------------------
1 | from .draw import DRAW, ConvolutionalDRAW
2 |
--------------------------------------------------------------------------------
/gqn/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import GeneratorNetwork
2 | from .representation import TowerRepresentation, PyramidRepresentation
3 | from .gqn import GenerativeQueryNetwork
4 | from .training import partition, Annealer
--------------------------------------------------------------------------------
/scripts/gpu.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | export CUDA_VISIBLE_DEVICES=0,1
3 |
4 | DATA_DIR=$1
5 |
6 | # Start TensorBoard in background
7 | tensorboard --logdir "../logs" &
8 | TENSORBOARD_PID=$!
9 | echo "Started Tensorboard with PID: $TENSORBOARD_PID"
10 |
11 | # Start training script
12 | python ../run-gqn.py \
13 | --data_dir $DATA_DIR \
14 | --log_dir "../logs" \
15 | --data_parallel "True" \
16 | --batch_size 1 \
17 | --workers 6
18 |
--------------------------------------------------------------------------------
/placeholder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class PlaceholderData(Dataset):
6 | """
7 | Random placeholder dataset for testing
8 | training loop without loading actual data.
9 | """
10 | def __init__(self, *args, **kwargs):
11 | super(PlaceholderData, self).__init__()
12 |
13 | def __len__(self):
14 | return 2000
15 |
16 | def __getitem__(self, idx):
17 | # (b, m, c, h, w)
18 | images = torch.randn(64, 15, 3, 64, 64)
19 |
20 | # (b, m, 5)
21 | viewpoints = torch.randn(64, 15, 7)
22 |
23 | return images, viewpoints
24 |
--------------------------------------------------------------------------------
/scripts/data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | LOCATION=$1 # example: /tmp/data
4 | BATCH_SIZE=$2 # example: 64
5 |
6 | echo "Downloading data"
7 | gsutil -m cp -R gs://gqn-dataset/shepard_metzler_5_parts $LOCATION
8 |
9 | echo "Deleting small records" # less than 10MB
10 | DATA_PATH="$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"
11 | find $DATA_PATH -type f -size -10M | xargs rm
12 |
13 | echo "Converting data"
14 | python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "train"
15 | echo "Training data: done"
16 | python tfrecord-converter.py $LOCATION shepard_metzler_5_parts -b $BATCH_SIZE -m "test"
17 | echo "Testing data: done"
18 |
19 | echo "Removing original records"
20 | rm -rf "$LOCATION/shepard_metzler_5_parts/**/*.tfrecord"
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License (MIT)
2 | =====================
3 |
4 | Copyright © 2018 Jesper Wohlert
5 |
6 | Permission is hereby granted, free of charge, to any person
7 | obtaining a copy of this software and associated documentation
8 | files (the “Software”), to deal in the Software without
9 | restriction, including without limitation the rights to use,
10 | copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the
12 | Software is furnished to do so, subject to the following
13 | conditions:
14 |
15 | The above copyright notice and this permission notice shall be
16 | included in all copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND,
19 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
20 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
22 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
23 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
24 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
25 | OTHER DEALINGS IN THE SOFTWARE.
26 |
27 | Clauses:
28 | Permission is NOT granted to individuals who copy or clone this
29 | repository and then redistribute the code. In that case the repository
30 | should instead be forked such that the original ownership is clear.
31 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gqn
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _tflow_1100_select=0.0.1=gpu
7 | - _tflow_190_select=0.0.3=mkl
8 | - blas=1.0=mkl
9 | - cffi=1.11.5=py35he75722e_1
10 | - cudatoolkit=9.0=h13b8566_0
11 | - cudnn=7.1.2=cuda9.0_0
12 | - cupti=9.0.176=0
13 | - glib=2.56.2=hd408876_0
14 | - intel-openmp=2019.0=117
15 | - libffi=3.2.1=hd88cf55_4
16 | - libgcc-ng=8.2.0=hdf63c60_1
17 | - libgfortran-ng=7.3.0=hdf63c60_0
18 | - libprotobuf=3.6.0=hdbcaa40_0
19 | - mkl=2018.0.3=1
20 | - mkl_fft=1.0.6=py35h7dd41cf_0
21 | - mkl_random=1.0.1=py35h4414c95_1
22 | - nccl=1.3.5=cuda9.0_0
23 | - numpy=1.15.2=py35h1d66e8a_0
24 | - numpy-base=1.15.2=py35h81de0dd_0
25 | - pip=10.0.1=py35_0
26 | - protobuf=3.6.0=py35hf484d3e_0
27 | - python=3.5.6=hc3d631a_0
28 | - setuptools=40.2.0=py35_0
29 | - six=1.11.0=py35_1
30 | - tensorboard=1.10.0=py35hf484d3e_0
31 | - tensorflow=1.10.0=gpu_py35h566a776_0
32 | - tensorflow-base=1.10.0=gpu_py35h6ecc378_0
33 | - tensorflow-gpu=1.10.0=hf154084_0
34 | - ignite=0.1.1=py35_0
35 | - pytorch=1.0.0=py3.5_cuda9.0.176_cudnn7.4.1_1
36 | - pytorch-nightly=1.0.0.dev20181010=py3.5_cuda9.0.176_cudnn7.1.2_0
37 | - torchvision=0.2.1=py_2
38 | - pip:
39 | - nvidia-ml-py==375.53.1
40 | - nvidia-ml-py3==7.352.0
41 | - pytorch-ignite==0.1.1
42 | - tensorboardx==1.6
43 | - torch==1.0.0
44 | prefix: ~/.conda/envs/gqn
45 |
46 |
--------------------------------------------------------------------------------
/gqn/training.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | class Annealer(object):
5 | def __init__(self, init, delta, steps):
6 | self.init = init
7 | self.delta = delta
8 | self.steps = steps
9 | self.s = 0
10 | self.data = self.__repr__()
11 | self.recent = init
12 |
13 | def __repr__(self):
14 | return {"init": self.init, "delta": self.delta, "steps": self.steps, "s": self.s}
15 |
16 | def __iter__(self):
17 | return self
18 |
19 | def __next__(self):
20 | self.s += 1
21 | value = max(self.delta + (self.init - self.delta) * (1 - self.s / self.steps), self.delta)
22 | self.recent = value
23 | return value
24 |
25 |
26 | def partition(images, viewpoints):
27 | """
28 | Partition batch into context and query sets.
29 | :param images
30 | :param viewpoints
31 | :return: context images, context viewpoint, query image, query viewpoint
32 | """
33 | # Maximum number of context points to use
34 | _, b, m, *x_dims = images.shape
35 | _, b, m, *v_dims = viewpoints.shape
36 |
37 | # "Squeeze" the batch dimension
38 | images = images.view((-1, m, *x_dims))
39 | viewpoints = viewpoints.view((-1, m, *v_dims))
40 |
41 | # Sample random number of views
42 | n_context = random.randint(2, m - 1)
43 | indices = random.sample([i for i in range(m)], n_context)
44 |
45 | # Partition into context and query sets
46 | context_idx, query_idx = indices[:-1], indices[-1]
47 |
48 | x, v = images[:, context_idx], viewpoints[:, context_idx]
49 | x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]
50 |
51 | return x, v, x_q, v_q
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Update 2019/06/24**: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: [nbviewer](https://nbviewer.jupyter.org/github/wohlert/generative-query-network-pytorch/blob/master/mental-rotation.ipynb)
2 |
3 | # Generative Query Network
4 |
5 | This is a PyTorch implementation of the Generative Query Network (GQN)
6 | described in the DeepMind paper "Neural scene representation and
7 | rendering" by Eslami et al. For an introduction to the model and problem
8 | described in the paper look at the article by [DeepMind](https://deepmind.com/blog/neural-scene-representation-and-rendering/).
9 |
10 | 
11 |
12 | The current implementation generalises to any of the datasets described
13 | in the paper. However, currently, *only the Shepard-Metzler dataset* has
14 | been implemented. To use this dataset you can use the provided script in
15 | ```
16 | sh scripts/data.sh data-dir batch-size
17 | ```
18 |
19 | The model can be trained in full by in accordance to the paper by running the
20 | file `run-gqn.py` or by using the provided training script
21 | ```
22 | sh scripts/gpu.sh data-dir
23 | ```
24 |
25 | ## Implementation
26 |
27 | The implementation shown in this repository consists of all of the
28 | representation architectures described in the paper along with the
29 | generative model that is similar to the one described in
30 | "Towards conceptual compression" by Gregor et al.
31 |
32 | Additionally, this repository also contains implementations of the **DRAW
33 | model and the ConvolutionalDRAW** model both described by Gregor et al.
34 |
35 |
--------------------------------------------------------------------------------
/shepardmetzler.py:
--------------------------------------------------------------------------------
1 | import os, gzip
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | def transform_viewpoint(v):
8 | """
9 | Transforms the viewpoint vector into a consistent
10 | representation
11 | """
12 | w, z = torch.split(v, 3, dim=-1)
13 | y, p = torch.split(z, 1, dim=-1)
14 |
15 | # position, [yaw, pitch]
16 | view_vector = [w, torch.cos(y), torch.sin(y), torch.cos(p), torch.sin(p)]
17 | v_hat = torch.cat(view_vector, dim=-1)
18 |
19 | return v_hat
20 |
21 |
22 | class ShepardMetzler(Dataset):
23 | """
24 | Shepart Metzler mental rotation task
25 | dataset. Based on the dataset provided
26 | in the GQN paper. Either 5-parts or
27 | 7-parts.
28 | :param root_dir: location of data on disc
29 | :param train: whether to use train of test set
30 | :param transform: transform on images
31 | :param fraction: fraction of dataset to use
32 | :param target_transform: transform on viewpoints
33 | """
34 | def __init__(self, root_dir, train=True, transform=None, fraction=1.0, target_transform=transform_viewpoint):
35 | super(ShepardMetzler, self).__init__()
36 | assert fraction > 0.0 and fraction <= 1.0
37 | prefix = "train" if train else "test"
38 | self.root_dir = os.path.join(root_dir, prefix)
39 | self.records = sorted([p for p in os.listdir(self.root_dir) if "pt" in p])
40 | self.records = self.records[:int(len(self.records)*fraction)]
41 | self.transform = transform
42 | self.target_transform = target_transform
43 |
44 | def __len__(self):
45 | return len(self.records)
46 |
47 | def __getitem__(self, idx):
48 | scene_path = os.path.join(self.root_dir, self.records[idx])
49 | with gzip.open(scene_path, "r") as f:
50 | data = torch.load(f)
51 | images, viewpoints = list(zip(*data))
52 |
53 | images = np.stack(images)
54 | viewpoints = np.stack(viewpoints)
55 |
56 | # uint8 -> float32
57 | images = images.transpose(0, 1, 4, 2, 3)
58 | images = torch.FloatTensor(images)/255
59 |
60 | if self.transform:
61 | images = self.transform(images)
62 |
63 | viewpoints = torch.FloatTensor(viewpoints)
64 | if self.target_transform:
65 | viewpoints = self.target_transform(viewpoints)
66 |
67 | return images, viewpoints
68 |
69 |
--------------------------------------------------------------------------------
/gqn/gqn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import Normal
4 |
5 | from .representation import TowerRepresentation
6 | from .generator import GeneratorNetwork
7 |
8 |
9 | class GenerativeQueryNetwork(nn.Module):
10 | """
11 | Generative Query Network (GQN) as described
12 | in "Neural scene representation and rendering"
13 | [Eslami 2018].
14 |
15 | :param x_dim: number of channels in input
16 | :param v_dim: dimensions of viewpoint
17 | :param r_dim: dimensions of representation
18 | :param z_dim: latent channels
19 | :param h_dim: hidden channels in LSTM
20 | :param L: Number of refinements of density
21 | """
22 | def __init__(self, x_dim, v_dim, r_dim, h_dim, z_dim, L=12):
23 | super(GenerativeQueryNetwork, self).__init__()
24 | self.r_dim = r_dim
25 |
26 | self.generator = GeneratorNetwork(x_dim, v_dim, r_dim, z_dim, h_dim, L)
27 | self.representation = TowerRepresentation(x_dim, v_dim, r_dim, pool=True)
28 |
29 | def forward(self, context_x, context_v, query_x, query_v):
30 | """
31 | Forward through the GQN.
32 |
33 | :param x: batch of context images [b, m, c, h, w]
34 | :param v: batch of context viewpoints for image [b, m, k]
35 | :param x_q: batch of query images [b, c, h, w]
36 | :param v_q: batch of query viewpoints [b, k]
37 | """
38 | # Merge batch and view dimensions.
39 | b, m, *x_dims = context_x.shape
40 | _, _, *v_dims = context_v.shape
41 |
42 | x = context_x.view((-1, *x_dims))
43 | v = context_v.view((-1, *v_dims))
44 |
45 | # representation generated from input images
46 | # and corresponding viewpoints
47 | phi = self.representation(x, v)
48 |
49 | # Seperate batch and view dimensions
50 | _, *phi_dims = phi.shape
51 | phi = phi.view((b, m, *phi_dims))
52 |
53 | # sum over view representations
54 | r = torch.sum(phi, dim=1)
55 |
56 | # Use random (image, viewpoint) pair in batch as query
57 | x_mu, kl = self.generator(query_x, query_v, r)
58 |
59 | # Return reconstruction and query viewpoint
60 | # for computing error
61 | return (x_mu, r, kl)
62 |
63 | def sample(self, context_x, context_v, query_v, sigma):
64 | """
65 | Sample from the network given some context and viewpoint.
66 |
67 | :param context_x: set of context images to generate representation
68 | :param context_v: viewpoints of `context_x`
69 | :param viewpoint: viewpoint to generate image from
70 | :param sigma: pixel variance
71 | """
72 | batch_size, n_views, _, h, w = context_x.shape
73 |
74 | _, _, *x_dims = context_x.shape
75 | _, _, *v_dims = context_v.shape
76 |
77 | x = context_x.view((-1, *x_dims))
78 | v = context_v.view((-1, *v_dims))
79 |
80 | phi = self.representation(x, v)
81 |
82 | _, *phi_dims = phi.shape
83 | phi = phi.view((batch_size, n_views, *phi_dims))
84 |
85 | r = torch.sum(phi, dim=1)
86 |
87 | x_mu = self.generator.sample((h, w), query_v, r)
88 | return x_mu
89 |
--------------------------------------------------------------------------------
/run-draw.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import argparse
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 | import torchvision.transforms as transforms
9 | from torchvision.datasets import MNIST
10 | from torchvision.utils import save_image
11 |
12 | from draw import DRAW
13 | cuda = torch.cuda.is_available()
14 | device = torch.device("cuda:0" if cuda else "cpu")
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser(description='DRAW with MNIST Example')
19 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
20 | parser.add_argument('--batch_size', type=int, default=64, help='size of batch (default: 64)')
21 | parser.add_argument('--data_dir', type=str, help='location of training data', default="./train")
22 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
23 | parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
24 |
25 | args = parser.parse_args()
26 |
27 | # Define dataset
28 | transform = transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Lambda(lambda x: torch.bernoulli(x))
31 | ])
32 | dataset = MNIST(root=args.data_dir, train=True, download=True, transform=transform)
33 |
34 | # Create model and optimizer
35 | model = DRAW(x_dim=784, h_dim=256, z_dim=16, T=10).to(device)
36 | model = nn.DataParallel(model) if args.data_parallel else model
37 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.5, 0.999))
38 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, 0.5)
39 |
40 | # Load the dataset
41 | kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
42 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
43 |
44 | loss = nn.BCELoss(reduce=False).to(device)
45 |
46 | for epoch in range(args.epochs):
47 | for x, _ in tqdm(loader):
48 | batch_size = x.size(0)
49 |
50 | x = x.view(batch_size, -1).to(device)
51 |
52 | x_hat, kl_divergence = model(x)
53 | x_hat = torch.sigmoid(x_hat)
54 |
55 | reconstruction = loss(x_hat, x).sum(1)
56 | kl = kl_divergence.sum(1)
57 | elbo = torch.mean(reconstruction + kl)
58 |
59 | elbo.backward()
60 | optimizer.step()
61 | optimizer.zero_grad()
62 |
63 | with torch.no_grad():
64 | scheduler.step()
65 |
66 | if epoch % 1 == 0:
67 | print("Loss at step {}: {}".format(epoch, elbo.item()))
68 |
69 | # Not sustainable if not dataparallel
70 | if type(model) is nn.DataParallel:
71 | x_sample = model.module.sample(args.batch_size)
72 | else:
73 | x_sample = model.sample(args.batch_size)
74 |
75 | save_image(x_hat, "reconstruction-{}.jpg".format(epoch))
76 | save_image(x_sample, "sample-{}.jpg".format(epoch))
77 |
78 | if epoch % 10 == 0:
79 | torch.save(model, "model-{}.pt".format(epoch))
80 |
--------------------------------------------------------------------------------
/scripts/tfrecord-converter.py:
--------------------------------------------------------------------------------
1 | """
2 | tfrecord-converter
3 |
4 | Takes a directory of tf-records with Shepard-Metzler data
5 | and converts it into a number of gzipped PyTorch records
6 | with a fixed batch size.
7 |
8 | Thanks to l3robot and versatran01 for providing initial
9 | scripts.
10 | """
11 | import os, gzip, torch
12 | import tensorflow as tf, numpy as np, multiprocessing as mp
13 | from functools import partial
14 | from itertools import islice, chain
15 | from argparse import ArgumentParser
16 |
17 | # disable logging and gpu
18 | tf.logging.set_verbosity(tf.logging.ERROR)
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
20 |
21 | POSE_DIM, IMG_DIM, SEQ_DIM = 5, 64, 15
22 |
23 | def chunk(iterable, size=10):
24 | """
25 | Chunks an iterator into subsets of
26 | a given size.
27 | """
28 | iterator = iter(iterable)
29 | for first in iterator:
30 | yield chain([first], islice(iterator, size - 1))
31 |
32 | def process(record):
33 | """
34 | Processes a tf-record into a numpy (image, pose) tuple.
35 | """
36 | kwargs = dict(dtype=tf.uint8, back_prop=False)
37 | for data in tf.python_io.tf_record_iterator(record):
38 | instance = tf.parse_single_example(data, {
39 | 'frames': tf.FixedLenFeature(shape=SEQ_DIM, dtype=tf.string),
40 | 'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
41 | })
42 |
43 | # Get data
44 | images = tf.concat(instance['frames'], axis=0)
45 | poses = instance['cameras']
46 |
47 | # Convert
48 | images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]), **kwargs)
49 | images = tf.reshape(images, (-1, SEQ_DIM, IMG_DIM, IMG_DIM, 3))
50 | poses = tf.reshape(poses, (-1, SEQ_DIM, POSE_DIM))
51 |
52 | # Numpy conversion
53 | images, poses = images.numpy(), poses.numpy()
54 | yield np.squeeze(images), np.squeeze(poses)
55 |
56 | def convert(record, batch_size):
57 | """
58 | Processes and saves a tf-record.
59 | """
60 | path, filename = os.path.split(record)
61 | basename, *_ = os.path.splitext(filename)
62 | print(basename)
63 |
64 | batch_process = lambda r: chunk(process(r), batch_size)
65 |
66 | for i, batch in enumerate(batch_process(record)):
67 | p = os.path.join(path, "{0:}-{1:02}.pt.gz".format(basename, i))
68 | with gzip.open(p, 'wb') as f:
69 | torch.save(list(batch), f)
70 |
71 | if __name__ == '__main__':
72 | tf.enable_eager_execution()
73 | parser = ArgumentParser(description='Convert gqn tfrecords to gzip files.')
74 | parser.add_argument('base_dir', nargs=1,
75 | help='base directory of gqn dataset')
76 | parser.add_argument('dataset', type=str, default="shepard_metzler_5_parts",
77 | help='datasets to convert, eg. shepard_metzler_5_parts')
78 | parser.add_argument('-b', '--batch-size', type=int, default=64,
79 | help='number of sequences in each output file')
80 | parser.add_argument('-m', '--mode', type=str, default='train',
81 | help='whether to convert train or test')
82 | args = parser.parse_args()
83 |
84 | # Find path
85 | base_dir = os.path.expanduser(args.base_dir[0])
86 | data_dir = os.path.join(base_dir, args.dataset, args.mode)
87 |
88 | # Find all records
89 | records = [os.path.join(data_dir, f) for f in sorted(os.listdir(data_dir))]
90 | records = [f for f in records if "tfrecord" in f]
91 |
92 | with mp.Pool(processes=mp.cpu_count()) as pool:
93 | f = partial(convert, batch_size=args.batch_size)
94 | pool.map(f, records)
95 |
--------------------------------------------------------------------------------
/run-convdraw.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchvision.transforms as transforms
8 | from torch.utils.data import DataLoader
9 | from torchvision.datasets import SVHN, MNIST
10 | from torchvision.utils import save_image
11 |
12 | from draw import ConvolutionalDRAW
13 | cuda = torch.cuda.is_available()
14 | device = torch.device("cuda:0" if cuda else "cpu")
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser(description='ConvolutionalDRAW with MNIST/SVHN Example')
19 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
20 | parser.add_argument('--data_dir', type=str, help='location of training data', default="./train")
21 | parser.add_argument('--batch_size', type=int, default=128, help='size of batch (default: 128)')
22 | parser.add_argument('--dataset', type=str, default="MNIST", help='dataset to use (default: MNIST)')
23 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
24 | parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
25 |
26 | args = parser.parse_args()
27 |
28 | if args.dataset == "MNIST":
29 | mean, std = 0, 1
30 | transform = transforms.Compose([
31 | transforms.ToTensor(),
32 | #transforms.Normalize(mean=(0.1307,), std=(0.3081,)
33 | transforms.Lambda(lambda x: torch.bernoulli(x))
34 | ])
35 | dataset = MNIST(root=args.data_dir, train=True, download=True, transform=transform)
36 | loss = nn.BCELoss(reduce=False)
37 | output_activation = torch.sigmoid
38 | x_dim, x_shape = 1, (28, 28)
39 |
40 | elif args.dataset == "SVHN":
41 | mean, std = (0.4376, 0.4437, 0.4728), (0.198, 0.201, 0.197)
42 | transform = transforms.Compose([
43 | transforms.ToTensor(),
44 | transforms.Normalize(mean=mean, std=std)
45 | ])
46 | dataset = SVHN(root=args.data_dir, split="train", download=True, transform=transform)
47 | loss = nn.MSELoss(reduce=False)
48 | output_activation = lambda x: x
49 | x_dim, x_shape = 3, (32, 32)
50 |
51 | # Create model and optimizer
52 | model = ConvolutionalDRAW(x_dim=x_dim, x_shape=x_shape, h_dim=160, z_dim=12, T=16).to(device)
53 | model = nn.DataParallel(model) if args.data_parallel else model
54 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
55 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, 0.5)
56 |
57 | # Load the dataset
58 | kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
59 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
60 |
61 | for epoch in range(args.epochs):
62 | for x, _ in tqdm(loader):
63 | batch_size = x.size(0)
64 | x = x.to(device)
65 |
66 | x_hat, kl = model(x)
67 | x_hat = output_activation(x_hat)
68 |
69 | reconstruction = torch.sum(loss(x_hat, x).view(batch_size, -1), dim=1)
70 | kl_divergence = torch.sum(kl.view(batch_size, -1), dim=1)
71 | elbo = torch.mean(reconstruction + kl_divergence)
72 |
73 | elbo.backward()
74 | optimizer.step()
75 | optimizer.zero_grad()
76 |
77 | with torch.no_grad():
78 | scheduler.step()
79 |
80 | if epoch % 1 == 0:
81 | print("Loss at step {}: {}".format(epoch, elbo.item()))
82 |
83 | # Not sustainable if not dataparallel
84 | x_sample = model.module.sample(args.batch_size)
85 |
86 | # Renormalize to visualise
87 | x_sample = (x_sample - mean)/std
88 | x_hat = (x_hat - mean)/std
89 |
90 | save_image(x_hat, "reconstruction-{}.jpg".format(epoch))
91 | save_image(x_sample, "sample-{}.jpg".format(epoch))
92 |
93 | if epoch % 10 == 0:
94 | torch.save(model, "model-{}.pt".format(epoch))
--------------------------------------------------------------------------------
/gqn/representation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class TowerRepresentation(nn.Module):
7 | def __init__(self, n_channels, v_dim, r_dim=256, pool=True):
8 | """
9 | Network that generates a condensed representation
10 | vector from a joint input of image and viewpoint.
11 |
12 | Employs the tower/pool architecture described in the paper.
13 |
14 | :param n_channels: number of color channels in input image
15 | :param v_dim: dimensions of the viewpoint vector
16 | :param r_dim: dimensions of representation
17 | :param pool: whether to pool representation
18 | """
19 | super(TowerRepresentation, self).__init__()
20 | # Final representation size
21 | self.r_dim = k = r_dim
22 | self.pool = pool
23 |
24 | self.conv1 = nn.Conv2d(n_channels, k, kernel_size=2, stride=2)
25 | self.conv2 = nn.Conv2d(k, k, kernel_size=2, stride=2)
26 | self.conv3 = nn.Conv2d(k, k//2, kernel_size=3, stride=1, padding=1)
27 | self.conv4 = nn.Conv2d(k//2, k, kernel_size=2, stride=2)
28 |
29 | self.conv5 = nn.Conv2d(k + v_dim, k, kernel_size=3, stride=1, padding=1)
30 | self.conv6 = nn.Conv2d(k + v_dim, k//2, kernel_size=3, stride=1, padding=1)
31 | self.conv7 = nn.Conv2d(k//2, k, kernel_size=3, stride=1, padding=1)
32 | self.conv8 = nn.Conv2d(k, k, kernel_size=1, stride=1)
33 |
34 | self.avgpool = nn.AvgPool2d(k//16)
35 |
36 | def forward(self, x, v):
37 | """
38 | Send an (image, viewpoint) pair into the
39 | network to generate a representation
40 | :param x: image
41 | :param v: viewpoint (x, y, z, cos(yaw), sin(yaw), cos(pitch), sin(pitch))
42 | :return: representation
43 | """
44 | # Increase dimensions
45 | v = v.view(v.size(0), -1, 1, 1)
46 | v = v.repeat(1, 1, self.r_dim // 16, self.r_dim // 16)
47 |
48 | # First skip-connected conv block
49 | skip_in = F.relu(self.conv1(x))
50 | skip_out = F.relu(self.conv2(skip_in))
51 |
52 | x = F.relu(self.conv3(skip_in))
53 | x = F.relu(self.conv4(x)) + skip_out
54 |
55 | # Second skip-connected conv block (merged)
56 | skip_in = torch.cat([x, v], dim=1)
57 | skip_out = F.relu(self.conv5(skip_in))
58 |
59 | x = F.relu(self.conv6(skip_in))
60 | x = F.relu(self.conv7(x)) + skip_out
61 |
62 | r = F.relu(self.conv8(x))
63 |
64 | if self.pool:
65 | r = self.avgpool(r)
66 |
67 | return r
68 |
69 |
70 | class PyramidRepresentation(nn.Module):
71 | def __init__(self, n_channels, v_dim, r_dim=256):
72 | """
73 | Network that generates a condensed representation
74 | vector from a joint input of image and viewpoint.
75 |
76 | Employs the pyramid architecture described in the paper.
77 |
78 | :param n_channels: number of color channels in input image
79 | :param v_dim: dimensions of the viewpoint vector
80 | :param r_dim: dimensions of representation
81 | """
82 | super(PyramidRepresentation, self).__init__()
83 | # Final representation size
84 | self.r_dim = k = r_dim
85 |
86 | self.conv1 = nn.Conv2d(n_channels + v_dim, k//8, kernel_size=2, stride=2)
87 | self.conv2 = nn.Conv2d(k//8, k//4, kernel_size=2, stride=2)
88 | self.conv3 = nn.Conv2d(k//4, k//2, kernel_size=2, stride=2)
89 | self.conv4 = nn.Conv2d(k//2, k, kernel_size=8, stride=8)
90 |
91 | def forward(self, x, v):
92 | """
93 | Send an (image, viewpoint) pair into the
94 | network to generate a representation
95 | :param x: image
96 | :param v: viewpoint (x, y, z, cos(yaw), sin(yaw), cos(pitch), sin(pitch))
97 | :return: representation
98 | """
99 | # Increase dimensions
100 | batch_size, _, h, w = x.shape
101 |
102 | v = v.view(batch_size, -1, 1, 1)
103 | v = v.repeat(1, 1, h, w)
104 |
105 | # Merge representation
106 | r = torch.cat([x, v], dim=1)
107 |
108 | r = F.relu(self.conv1(r))
109 | r = F.relu(self.conv2(r))
110 | r = F.relu(self.conv3(r))
111 | r = F.relu(self.conv4(r))
112 |
113 | return r
--------------------------------------------------------------------------------
/run-gqn.py:
--------------------------------------------------------------------------------
1 | """
2 | run-gqn.py
3 |
4 | Script to train the a GQN on the Shepard-Metzler dataset
5 | in accordance to the hyperparameter settings described in
6 | the supplementary materials of the paper.
7 | """
8 | import random
9 | import math
10 | from argparse import ArgumentParser
11 |
12 | # Torch
13 | import torch
14 | import torch.nn as nn
15 | from torch.distributions import Normal
16 | from torch.utils.data import DataLoader
17 | from torchvision.utils import make_grid
18 |
19 | # TensorboardX
20 | from tensorboardX import SummaryWriter
21 |
22 | # Ignite
23 | from ignite.contrib.handlers import ProgressBar
24 | from ignite.engine import Engine, Events
25 | from ignite.handlers import ModelCheckpoint, Timer
26 | from ignite.metrics import RunningAverage
27 |
28 | from gqn import GenerativeQueryNetwork, partition, Annealer
29 | from shepardmetzler import ShepardMetzler
30 | #from placeholder import PlaceholderData as ShepardMetzler
31 |
32 | cuda = torch.cuda.is_available()
33 | device = torch.device("cuda:0" if cuda else "cpu")
34 |
35 | # Random seeding
36 | random.seed(99)
37 | torch.manual_seed(99)
38 | if cuda: torch.cuda.manual_seed(99)
39 | torch.backends.cudnn.deterministic = True
40 | torch.backends.cudnn.benchmark = False
41 |
42 | if __name__ == '__main__':
43 | parser = ArgumentParser(description='Generative Query Network on Shepard Metzler Example')
44 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs run (default: 200)')
45 | parser.add_argument('--batch_size', type=int, default=1, help='multiple of batch size (default: 1)')
46 | parser.add_argument('--data_dir', type=str, help='location of data', default="train")
47 | parser.add_argument('--log_dir', type=str, help='location of logging', default="log")
48 | parser.add_argument('--fraction', type=float, help='how much of the data to use', default=1.0)
49 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
50 | parser.add_argument('--data_parallel', type=bool, help='whether to parallelise based on data (default: False)', default=False)
51 | args = parser.parse_args()
52 |
53 | # Create model and optimizer
54 | model = GenerativeQueryNetwork(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8).to(device)
55 | model = nn.DataParallel(model) if args.data_parallel else model
56 |
57 | optimizer = torch.optim.Adam(model.parameters(), lr=5 * 10 ** (-5))
58 |
59 | # Rate annealing schemes
60 | sigma_scheme = Annealer(2.0, 0.7, 80000)
61 | mu_scheme = Annealer(5 * 10 ** (-6), 5 * 10 ** (-6), 1.6 * 10 ** 5)
62 |
63 | # Load the dataset
64 | train_dataset = ShepardMetzler(root_dir=args.data_dir, fraction=args.fraction)
65 | valid_dataset = ShepardMetzler(root_dir=args.data_dir, fraction=args.fraction, train=False)
66 |
67 | kwargs = {'num_workers': args.workers, 'pin_memory': True} if cuda else {}
68 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
69 | valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
70 |
71 | def step(engine, batch):
72 | model.train()
73 |
74 | x, v = batch
75 | x, v = x.to(device), v.to(device)
76 | x, v, x_q, v_q = partition(x, v)
77 |
78 | # Reconstruction, representation and divergence
79 | x_mu, _, kl = model(x, v, x_q, v_q)
80 |
81 | # Log likelihood
82 | sigma = next(sigma_scheme)
83 | ll = Normal(x_mu, sigma).log_prob(x_q)
84 |
85 | likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
86 | kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))
87 |
88 | # Evidence lower bound
89 | elbo = likelihood - kl_divergence
90 | loss = -elbo
91 | loss.backward()
92 |
93 | optimizer.step()
94 | optimizer.zero_grad()
95 |
96 | with torch.no_grad():
97 | # Anneal learning rate
98 | mu = next(mu_scheme)
99 | i = engine.state.iteration
100 | for group in optimizer.param_groups:
101 | group["lr"] = mu * math.sqrt(1 - 0.999 ** i) / (1 - 0.9 ** i)
102 |
103 | return {"elbo": elbo.item(), "kl": kl_divergence.item(), "sigma": sigma, "mu": mu}
104 |
105 | # Trainer and metrics
106 | trainer = Engine(step)
107 | metric_names = ["elbo", "kl", "sigma", "mu"]
108 | RunningAverage(output_transform=lambda x: x["elbo"]).attach(trainer, "elbo")
109 | RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
110 | RunningAverage(output_transform=lambda x: x["sigma"]).attach(trainer, "sigma")
111 | RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
112 | ProgressBar().attach(trainer, metric_names=metric_names)
113 |
114 | # Model checkpointing
115 | checkpoint_handler = ModelCheckpoint("./", "checkpoint", save_interval=1, n_saved=3,
116 | require_empty=False)
117 | trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
118 | to_save={'model': model.state_dict, 'optimizer': optimizer.state_dict,
119 | 'annealers': (sigma_scheme.data, mu_scheme.data)})
120 |
121 | timer = Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
122 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
123 |
124 | # Tensorbard writer
125 | writer = SummaryWriter(log_dir=args.log_dir)
126 |
127 | @trainer.on(Events.ITERATION_COMPLETED)
128 | def log_metrics(engine):
129 | for key, value in engine.state.metrics.items():
130 | writer.add_scalar("training/{}".format(key), value, engine.state.iteration)
131 |
132 | @trainer.on(Events.EPOCH_COMPLETED)
133 | def save_images(engine):
134 | with torch.no_grad():
135 | x, v = engine.state.batch
136 | x, v = x.to(device), v.to(device)
137 | x, v, x_q, v_q = partition(x, v)
138 |
139 | x_mu, r, _ = model(x, v, x_q, v_q)
140 |
141 | r = r.view(-1, 1, 16, 16)
142 |
143 | # Send to CPU
144 | x_mu = x_mu.detach().cpu().float()
145 | r = r.detach().cpu().float()
146 |
147 | writer.add_image("representation", make_grid(r), engine.state.epoch)
148 | writer.add_image("reconstruction", make_grid(x_mu), engine.state.epoch)
149 |
150 | @trainer.on(Events.EPOCH_COMPLETED)
151 | def validate(engine):
152 | model.eval()
153 | with torch.no_grad():
154 | x, v = next(iter(valid_loader))
155 | x, v = x.to(device), v.to(device)
156 | x, v, x_q, v_q = partition(x, v)
157 |
158 | # Reconstruction, representation and divergence
159 | x_mu, _, kl = model(x, v, x_q, v_q)
160 |
161 | # Validate at last sigma
162 | ll = Normal(x_mu, sigma_scheme.recent).log_prob(x_q)
163 |
164 | likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
165 | kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))
166 |
167 | # Evidence lower bound
168 | elbo = likelihood - kl_divergence
169 |
170 | writer.add_scalar("validation/elbo", elbo.item(), engine.state.epoch)
171 | writer.add_scalar("validation/kl", kl_divergence.item(), engine.state.epoch)
172 |
173 | @trainer.on(Events.EXCEPTION_RAISED)
174 | def handle_exception(engine, e):
175 | writer.close()
176 | engine.terminate()
177 | if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
178 | import warnings
179 | warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
180 | checkpoint_handler(engine, { 'model_exception': model })
181 | else: raise e
182 |
183 | trainer.run(train_loader, args.n_epochs)
184 | writer.close()
185 |
--------------------------------------------------------------------------------
/gqn/generator.py:
--------------------------------------------------------------------------------
1 | """
2 | The inference-generator architecture is conceptually
3 | similar to the encoder-decoder pair seen in variational
4 | autoencoders. The difference here is that the model
5 | must infer latents from a cascade of time-dependent inputs
6 | using convolutional and recurrent networks.
7 |
8 | Additionally, a representation vector is shared between
9 | the networks.
10 | """
11 | SCALE = 4 # Scale of image generation process
12 |
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torch.distributions import Normal, kl_divergence
18 |
19 |
20 | class Conv2dLSTMCell(nn.Module):
21 | """
22 | 2d convolutional long short-term memory (LSTM) cell.
23 | Functionally equivalent to nn.LSTMCell with the
24 | difference being that nn.Kinear layers are replaced
25 | by nn.Conv2D layers.
26 |
27 | :param in_channels: number of input channels
28 | :param out_channels: number of output channels
29 | :param kernel_size: size of image kernel
30 | :param stride: length of kernel stride
31 | :param padding: number of pixels to pad with
32 | """
33 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
34 | super(Conv2dLSTMCell, self).__init__()
35 | self.in_channels = in_channels
36 | self.out_channels = out_channels
37 |
38 | kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)
39 |
40 | self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
41 | self.input = nn.Conv2d(in_channels, out_channels, **kwargs)
42 | self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
43 | self.state = nn.Conv2d(in_channels, out_channels, **kwargs)
44 |
45 | self.transform = nn.Conv2d(out_channels, in_channels, **kwargs)
46 |
47 | def forward(self, input, states):
48 | """
49 | Send input through the cell.
50 |
51 | :param input: input to send through
52 | :param states: (hidden, cell) pair of internal state
53 | :return new (hidden, cell) pair
54 | """
55 | (hidden, cell) = states
56 |
57 | input = input + self.transform(hidden)
58 |
59 | forget_gate = torch.sigmoid(self.forget(input))
60 | input_gate = torch.sigmoid(self.input(input))
61 | output_gate = torch.sigmoid(self.output(input))
62 | state_gate = torch.tanh(self.state(input))
63 |
64 | # Update internal cell state
65 | cell = forget_gate * cell + input_gate * state_gate
66 | hidden = output_gate * torch.tanh(cell)
67 |
68 | return hidden, cell
69 |
70 |
71 | class GeneratorNetwork(nn.Module):
72 | """
73 | Network similar to a convolutional variational
74 | autoencoder that refines the generated image
75 | over a number of iterations.
76 |
77 | :param x_dim: number of channels in input
78 | :param v_dim: dimensions of viewpoint
79 | :param r_dim: dimensions of representation
80 | :param z_dim: latent channels
81 | :param h_dim: hidden channels in LSTM
82 | :param L: number of density refinements
83 | :param share: whether to share cores across refinements
84 | """
85 | def __init__(self, x_dim, v_dim, r_dim, z_dim=64, h_dim=128, L=12, share=True):
86 | super(GeneratorNetwork, self).__init__()
87 | self.L = L
88 | self.z_dim = z_dim
89 | self.h_dim = h_dim
90 | self.share = share
91 |
92 | # Core computational units
93 | kwargs = dict(kernel_size=5, stride=1, padding=2)
94 | inference_args = dict(in_channels=v_dim + r_dim + x_dim + h_dim, out_channels=h_dim, **kwargs)
95 | generator_args = dict(in_channels=v_dim + r_dim + z_dim, out_channels=h_dim, **kwargs)
96 | if self.share:
97 | self.inference_core = Conv2dLSTMCell(**inference_args)
98 | self.generator_core = Conv2dLSTMCell(**generator_args)
99 | else:
100 | self.inference_core = nn.ModuleList([Conv2dLSTMCell(**inference_args) for _ in range(L)])
101 | self.generator_core = nn.ModuleList([Conv2dLSTMCell(**generator_args) for _ in range(L)])
102 |
103 | # Inference, prior
104 | self.posterior_density = nn.Conv2d(h_dim, 2*z_dim, **kwargs)
105 | self.prior_density = nn.Conv2d(h_dim, 2*z_dim, **kwargs)
106 |
107 | # Generative density
108 | self.observation_density = nn.Conv2d(h_dim, x_dim, kernel_size=1, stride=1, padding=0)
109 |
110 | # Up/down-sampling primitives
111 | self.upsample = nn.ConvTranspose2d(h_dim, h_dim, kernel_size=SCALE, stride=SCALE, padding=0, bias=False)
112 | self.downsample = nn.Conv2d(x_dim, x_dim, kernel_size=SCALE, stride=SCALE, padding=0, bias=False)
113 |
114 | def forward(self, x, v, r):
115 | """
116 | Attempt to reconstruct x with corresponding
117 | viewpoint v and context representation r.
118 |
119 | :param x: image to send through
120 | :param v: viewpoint of image
121 | :param r: representation for image
122 | :return reconstruction of x and kl-divergence
123 | """
124 | batch_size, _, h, w = x.shape
125 | kl = 0
126 |
127 | # Downsample x, upsample v and r
128 | x = self.downsample(x)
129 | v = v.view(batch_size, -1, 1, 1).repeat(1, 1, h // SCALE, w // SCALE)
130 | if r.size(2) != h // SCALE:
131 | r = r.repeat(1, 1, h // SCALE, w // SCALE)
132 |
133 | # Reset hidden and cell state
134 | hidden_i = x.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
135 | cell_i = x.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
136 |
137 | hidden_g = x.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
138 | cell_g = x.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
139 |
140 | # Canvas for updating
141 | u = x.new_zeros((batch_size, self.h_dim, h, w))
142 |
143 | for l in range(self.L):
144 | # Prior factor (eta π network)
145 | p_mu, p_std = torch.chunk(self.prior_density(hidden_g), 2, dim=1)
146 | prior_distribution = Normal(p_mu, F.softplus(p_std))
147 |
148 | # Inference state update
149 | inference = self.inference_core if self.share else self.inference_core[l]
150 | hidden_i, cell_i = inference(torch.cat([hidden_g, x, v, r], dim=1), [hidden_i, cell_i])
151 |
152 | # Posterior factor (eta e network)
153 | q_mu, q_std = torch.chunk(self.posterior_density(hidden_i), 2, dim=1)
154 | posterior_distribution = Normal(q_mu, F.softplus(q_std))
155 |
156 | # Posterior sample
157 | z = posterior_distribution.rsample()
158 |
159 | # Generator state update
160 | generator = self.generator_core if self.share else self.generator_core[l]
161 | hidden_g, cell_g = generator(torch.cat([z, v, r], dim=1), [hidden_g, cell_g])
162 |
163 | # Calculate u
164 | u = self.upsample(hidden_g) + u
165 |
166 | # Calculate KL-divergence
167 | kl += kl_divergence(posterior_distribution, prior_distribution)
168 |
169 | x_mu = self.observation_density(u)
170 |
171 | return torch.sigmoid(x_mu), kl
172 |
173 | def sample(self, x_shape, v, r):
174 | """
175 | Sample from the prior distribution to generate
176 | a new image given a viewpoint and representation
177 |
178 | :param x_shape: (height, width) of image
179 | :param v: viewpoint
180 | :param r: representation (context)
181 | """
182 | h, w = x_shape
183 | batch_size = v.size(0)
184 |
185 | # Increase dimensions
186 | v = v.view(batch_size, -1, 1, 1).repeat(1, 1, h // SCALE, w // SCALE)
187 | if r.size(2) != h // SCALE:
188 | r = r.repeat(1, 1, h // SCALE, w // SCALE)
189 |
190 | # Reset hidden and cell state for generator
191 | hidden_g = v.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
192 | cell_g = v.new_zeros((batch_size, self.h_dim, h // SCALE, w // SCALE))
193 |
194 | u = v.new_zeros((batch_size, self.h_dim, h, w))
195 |
196 | for _ in range(self.L):
197 | p_mu, p_log_std = torch.chunk(self.prior_density(hidden_g), 2, dim=1)
198 | prior_distribution = Normal(p_mu, F.softplus(p_log_std))
199 |
200 | # Prior sample
201 | z = prior_distribution.sample()
202 |
203 | # Calculate u
204 | hidden_g, cell_g = self.generator_core(torch.cat([z, v, r], dim=1), [hidden_g, cell_g])
205 | u = self.upsample(hidden_g) + u
206 |
207 | x_mu = self.observation_density(u)
208 |
209 | return torch.sigmoid(x_mu)
210 |
--------------------------------------------------------------------------------
/draw/draw.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal, kl_divergence
5 |
6 |
7 | class BaseAttention(nn.Module):
8 | """
9 | No attention module.
10 | """
11 | def __init__(self, h_dim, x_dim):
12 | super(BaseAttention, self).__init__()
13 | self.h_dim = h_dim
14 | self.x_dim = x_dim
15 | self.write_head = nn.Linear(h_dim, x_dim)
16 |
17 | def read(self, x, x_hat, h):
18 | return torch.cat([x, x_hat], dim=1)
19 |
20 | def write(self, x):
21 | return self.write_head(x)
22 |
23 |
24 | class FilterBankAttention(BaseAttention):
25 | def __init__(self, h_dim, x_dim):
26 | """
27 | Filter bank attention mechanism described in the paper.
28 | """
29 | super(FilterBankAttention, self).__init__(h_dim, x_dim)
30 |
31 | def read(self, x, error, h):
32 | return NotImplementedError
33 |
34 | def write(self, x):
35 | return NotImplementedError
36 |
37 |
38 | class DRAW(nn.Module):
39 | """
40 | Deep Recurrent Attentive Writer (DRAW) [Gregor 2015].
41 |
42 | :param x_dim: size of input
43 | :param h_dim: number of hidden neurons
44 | :param z_dim: number of latent neurons
45 | :param T: number of recurrent layers
46 | """
47 | def __init__(self, x_dim, h_dim=256, z_dim=10, T=10, attention_module=BaseAttention):
48 | super(DRAW, self).__init__()
49 | self.x_dim = x_dim
50 | self.z_dim = z_dim
51 | self.h_dim = h_dim
52 | self.T = T
53 |
54 | # Returns the distribution parameters
55 | self.variational = nn.Linear(h_dim, 2*z_dim)
56 | self.observation = nn.Linear(x_dim, x_dim)
57 |
58 | # Recurrent encoder/decoder models
59 | self.encoder = nn.LSTMCell(2*x_dim + h_dim, h_dim)
60 | self.decoder = nn.LSTMCell(z_dim, h_dim)
61 |
62 | # Attention module
63 | self.attention = attention_module(h_dim, x_dim)
64 |
65 | def forward(self, x):
66 | batch_size = x.size(0)
67 |
68 | # Hidden states (allocate on same device as input)
69 | h_enc = x.new_zeros((batch_size, self.h_dim))
70 | h_dec = x.new_zeros((batch_size, self.h_dim))
71 |
72 | # Cell states
73 | c_enc = x.new_zeros((batch_size, self.h_dim))
74 | c_dec = x.new_zeros((batch_size, self.h_dim))
75 |
76 | # Prior distribution
77 | p_mu = x.new_zeros((batch_size, self.z_dim))
78 | p_std = x.new_ones((batch_size, self.z_dim))
79 | self.prior = Normal(p_mu, p_std)
80 |
81 | canvas = x.new_zeros((batch_size, self.x_dim))
82 | kl = 0
83 |
84 | for _ in range(self.T):
85 | x_hat = x - torch.sigmoid(canvas)
86 | att = self.attention.read(x, x_hat, h_dec)
87 |
88 | # Infer posterior density from hidden state
89 | h_enc, c_enc = self.encoder(torch.cat([att, h_dec], dim=1), [h_enc, c_enc])
90 |
91 | # Posterior distribution
92 | q_mu, q_log_std = torch.split(self.variational(h_enc), self.z_dim, dim=1)
93 | q_std = torch.exp(q_log_std)
94 | posterior = Normal(q_mu, q_std)
95 |
96 | # Sample from posterior
97 | z = posterior.rsample()
98 |
99 | # Send representation through decoder
100 | h_dec, c_dec = self.decoder(z, [h_dec, c_dec])
101 |
102 | # Gather representation
103 | canvas += self.attention.write(h_dec)
104 |
105 | kl += kl_divergence(posterior, self.prior)
106 |
107 | # Return the reconstruction
108 | x_mu = self.observation(canvas)
109 | return [x_mu, kl]
110 |
111 | def sample(self, z=None):
112 | """
113 | Generate a sample from the data distribution.
114 |
115 | :param z: latent code, otherwise sample from prior
116 | """
117 | z = self.prior.sample() if z is None else z
118 | batch_size = z.size(0)
119 |
120 | canvas = z.new_zeros((batch_size, self.x_dim))
121 | h_dec = z.new_zeros((batch_size, self.h_dim))
122 | c_dec = z.new_zeros((batch_size, self.h_dim))
123 |
124 | for _ in range(self.T):
125 | h_dec, c_dec = self.decoder(z, [h_dec, c_dec])
126 | canvas = canvas + self.attention.write(h_dec)
127 |
128 | x_mu = self.observation(canvas)
129 | return x_mu
130 |
131 |
132 | class Conv2dLSTMCell(nn.Module):
133 | """
134 | 2d convolutional long short-term memory (LSTM) cell.
135 | Functionally equivalent to nn.LSTMCell with the
136 | difference being that nn.Kinear layers are replaced
137 | by nn.Conv2D layers.
138 |
139 | :param in_channels: number of input channels
140 | :param out_channels: number of output channels
141 | :param kernel_size: size of image kernel
142 | :param stride: length of kernel stride
143 | :param padding: number of pixels to pad with
144 | """
145 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
146 | super(Conv2dLSTMCell, self).__init__()
147 | self.in_channels = in_channels
148 | self.out_channels = out_channels
149 |
150 | kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)
151 |
152 | self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
153 | self.input = nn.Conv2d(in_channels, out_channels, **kwargs)
154 | self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
155 | self.state = nn.Conv2d(in_channels, out_channels, **kwargs)
156 |
157 | self.transform = nn.Conv2d(out_channels, in_channels, **kwargs)
158 |
159 | def forward(self, input, states):
160 | """
161 | Send input through the cell.
162 |
163 | :param input: input to send through
164 | :param states: (hidden, cell) pair of internal state
165 | :return new (hidden, cell) pair
166 | """
167 | (hidden, cell) = states
168 |
169 | input = input + self.transform(hidden)
170 |
171 | forget_gate = torch.sigmoid(self.forget(input))
172 | input_gate = torch.sigmoid(self.input(input))
173 | output_gate = torch.sigmoid(self.output(input))
174 | state_gate = torch.tanh(self.state(input))
175 |
176 | # Update internal cell state
177 | cell = forget_gate * cell + input_gate * state_gate
178 | hidden = output_gate * torch.tanh(cell)
179 |
180 | return hidden, cell
181 |
182 |
183 | class ConvolutionalDRAW(nn.Module):
184 | """
185 | Convolutional DRAW model described in
186 | "Towards Conceptual Compression" [Gregor 2016].
187 | The model consists of a autoregressive density
188 | estimator using a recurrent convolutional network.
189 |
190 | :param x_dim: number of channels in input
191 | :param x_shape: tuple representing input image shape
192 | :param h_dim: number of hidden channels
193 | :param z_dim: number of channels in latent variable
194 | :param T: number of recurrent layers
195 | """
196 | def __init__(self, x_dim, x_shape=(32, 32), h_dim=256, z_dim=10, T=10):
197 | super(ConvolutionalDRAW, self).__init__()
198 | self.x_dim = x_dim
199 | self.x_shape = x_shape
200 | self.z_dim = z_dim
201 | self.h_dim = h_dim
202 | self.T = T
203 |
204 | # Outputs parameters of distributions
205 | self.variational = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)
206 | self.prior = nn.Conv2d(h_dim, 2*z_dim, kernel_size=5, stride=1, padding=2)
207 |
208 | # Analogous to original DRAW model
209 | self.write_head = nn.Conv2d(h_dim, x_dim*4, kernel_size=1, stride=1, padding=0)
210 | self.read_head = nn.Conv2d(x_dim, x_dim, kernel_size=3, stride=2, padding=1)
211 |
212 | # Recurrent encoder/decoder models
213 | self.encoder = Conv2dLSTMCell(2*x_dim, h_dim, kernel_size=5, stride=2, padding=2)
214 | self.decoder = Conv2dLSTMCell(z_dim + x_dim, h_dim, kernel_size=5, stride=1, padding=2)
215 |
216 | def forward(self, x):
217 | h, w = self.x_shape
218 | batch_size = x.size(0)
219 |
220 | # Hidden states (allocate on same device as input)
221 | h_enc = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
222 | h_dec = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
223 |
224 | # Cell states
225 | c_enc = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
226 | c_dec = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
227 |
228 | canvas = x.new_zeros((batch_size, self.x_dim, h, w))
229 | kl = 0
230 |
231 | for _ in range(self.T):
232 | # Reconstruction error
233 | epsilon = x - canvas
234 |
235 | # Infer posterior density from hidden state
236 | h_enc, c_enc = self.encoder(torch.cat([x, epsilon], dim=1), [h_enc, c_enc])
237 |
238 | # Prior distribution
239 | p_mu, p_log_std = torch.split(self.prior(h_dec), self.z_dim, dim=1)
240 | p_std = torch.exp(p_log_std)
241 | prior = Normal(p_mu, p_std)
242 |
243 | # Posterior distribution
244 | q_mu, q_log_std = torch.split(self.variational(h_enc), self.z_dim, dim=1)
245 | q_std = torch.exp(q_log_std)
246 | posterior = Normal(q_mu, q_std)
247 |
248 | # Sample from posterior
249 | z = posterior.rsample()
250 |
251 | canvas_next = self.read_head(canvas)
252 |
253 | # Send representation through decoder
254 | h_dec, c_dec = self.decoder(torch.cat([z, canvas_next], dim=1), [h_dec, c_dec])
255 |
256 | # Refine representation
257 | canvas = canvas + F.pixel_shuffle(self.write_head(h_dec), 2)
258 | kl += kl_divergence(posterior, prior)
259 |
260 | # Return the reconstruction and kl
261 | return [canvas, kl]
262 |
263 | def sample(self, x):
264 | """
265 | Sample from the prior to generate a new
266 | datapoint.
267 |
268 | :param x: tensor representing shape of sample
269 | """
270 | h, w = self.x_shape
271 | batch_size = x.size(0)
272 |
273 | h_dec = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
274 | c_dec = x.new_zeros((batch_size, self.h_dim, h//2, w//2))
275 |
276 | canvas = x.new_zeros((batch_size, self.x_dim, h, w))
277 |
278 | for _ in range(self.T):
279 | p_mu, p_log_std = torch.split(self.prior(h_dec), self.z_dim, dim=1)
280 | p_std = torch.exp(p_log_std)
281 | z = Normal(p_mu, p_std).sample()
282 |
283 | canvas_next = self.read_head(canvas)
284 | h_dec, c_dec = self.decoder(torch.cat([z, canvas_next], dim=1), [h_dec, c_dec])
285 | canvas = canvas + F.pixel_shuffle(self.write_head(h_dec), 2)
286 |
287 | return canvas
288 |
--------------------------------------------------------------------------------
/mental-rotation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "%matplotlib inline"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Mental rotation\n",
19 | "\n",
20 | "In this notebook we will show how we can enable a pretrained Generative Query Network (GQN) for the Shepard-Metzler mental rotation task. The problem is well studied in psychology to asses spatial intelligence. Mental rotation is a congitive hard problem as it typically requires the employment of both the ventral and dorsal visual streams for recognition and spatial reasoning respectively. Additionally, a certain degree of metacognition is required to reason about uncertainty.\n",
21 | "\n",
22 | "It turns out that the GQN is capable of this, as we will see in this notebook.\n",
23 | "\n",
24 | "
\n",
25 | "Note: \n",
26 | "This model has only been trained on around 10% of the data for $2 \\times 10^5$ iterations instead of the $2 \\times 10^6$ described in the original paper. This means that the reconstructions are quite bad and the samples are even worse. Consequently, this notebook is just a proof of concept that the model approximately works. If you have the computational means to fully train the model, then please feel free to make a pull request with the trained model, this will help me a lot.\n",
27 | "
\n",
28 | "\n",
29 | "You can download the pretrained model weights from here: [https://github.com/wohlert/generative-query-network-pytorch/releases/tag/0.1](https://github.com/wohlert/generative-query-network-pytorch/releases/download/0.1/model-checkpoint.pth)."
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 35,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "import torch\n",
39 | "import torch.nn as nn\n",
40 | " \n",
41 | "# Load dataset\n",
42 | "from shepardmetzler import ShepardMetzler\n",
43 | "from torch.utils.data import DataLoader\n",
44 | "\n",
45 | "dataset = ShepardMetzler(\"/data/shepard_metzler_5_parts/\") ## <= Choose your data location\n",
46 | "loader = DataLoader(dataset, batch_size=1, shuffle=True)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 36,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "data": {
56 | "text/plain": [
57 | "GenerativeQueryNetwork(\n",
58 | " (generator): GeneratorNetwork(\n",
59 | " (inference_core): Conv2dLSTMCell(\n",
60 | " (forget): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
61 | " (input): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
62 | " (output): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
63 | " (state): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
64 | " (transform): Conv2d(128, 394, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
65 | " )\n",
66 | " (generator_core): Conv2dLSTMCell(\n",
67 | " (forget): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
68 | " (input): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
69 | " (output): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
70 | " (state): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
71 | " (transform): Conv2d(128, 327, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
72 | " )\n",
73 | " (posterior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
74 | " (prior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
75 | " (observation_density): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))\n",
76 | " (upsample): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)\n",
77 | " (downsample): Conv2d(3, 3, kernel_size=(4, 4), stride=(4, 4), bias=False)\n",
78 | " )\n",
79 | " (representation): TowerRepresentation(\n",
80 | " (conv1): Conv2d(3, 256, kernel_size=(2, 2), stride=(2, 2))\n",
81 | " (conv2): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))\n",
82 | " (conv3): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
83 | " (conv4): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))\n",
84 | " (conv5): Conv2d(263, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
85 | " (conv6): Conv2d(263, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
86 | " (conv7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
87 | " (conv8): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
88 | " (avgpool): AvgPool2d(kernel_size=16, stride=16, padding=0)\n",
89 | " )\n",
90 | ")"
91 | ]
92 | },
93 | "execution_count": 36,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "from gqn import GenerativeQueryNetwork, partition\n",
100 | "\n",
101 | "# Load model parameters onto CPU\n",
102 | "state_dict = torch.load(\"./model-checkpoint.pth\", map_location=\"cpu\") ## <= Choose your model location\n",
103 | "\n",
104 | "# Initialise new model with the settings of the trained one\n",
105 | "model_settings = dict(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8)\n",
106 | "model = GenerativeQueryNetwork(**model_settings)\n",
107 | "\n",
108 | "# Load trained parameters, un-dataparallel if needed\n",
109 | "if True in [\"module\" in m for m in list(state_dict.keys())]:\n",
110 | " model = nn.DataParallel(model)\n",
111 | " model.load_state_dict(state_dict)\n",
112 | " model = model.module\n",
113 | "else:\n",
114 | " model.load_state_dict(state_dict)\n",
115 | " \n",
116 | "model"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "We load a batch of a single image containing a single object seen from 15 different viewpoints. We describe the whole set of image, viewpoint pairs by $\\{x_i, v_i \\}_{i=1}^{n}$. Whereafter we seperate this set into a context set $\\{x_i, v_i \\}_{i=1}^{m}$ of $m$ random elements and a query set $\\{x^q, v^q \\}$, which contains just a single element."
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 184,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "def deterministic_partition(images, viewpoints, indices):\n",
133 | " \"\"\"\n",
134 | " Partition batch into context and query sets.\n",
135 | " :param images\n",
136 | " :param viewpoints\n",
137 | " :return: context images, context viewpoint, query image, query viewpoint\n",
138 | " \"\"\"\n",
139 | " # Maximum number of context points to use\n",
140 | " _, b, m, *x_dims = images.shape\n",
141 | " _, b, m, *v_dims = viewpoints.shape\n",
142 | "\n",
143 | " # \"Squeeze\" the batch dimension\n",
144 | " images = images.view((-1, m, *x_dims))\n",
145 | " viewpoints = viewpoints.view((-1, m, *v_dims))\n",
146 | "\n",
147 | " # Partition into context and query sets\n",
148 | " context_idx, query_idx = indices[:-1], indices[-1]\n",
149 | "\n",
150 | " x, v = images[:, context_idx], viewpoints[:, context_idx]\n",
151 | " x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]\n",
152 | "\n",
153 | " return x, v, x_q, v_q"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 186,
159 | "metadata": {},
160 | "outputs": [
161 | {
162 | "data": {
163 | "image/png": "\n",
164 | "text/plain": [
165 | ""
166 | ]
167 | },
168 | "metadata": {
169 | "needs_background": "light"
170 | },
171 | "output_type": "display_data"
172 | }
173 | ],
174 | "source": [
175 | "import random\n",
176 | "\n",
177 | "# Pick a scene to visualise\n",
178 | "scene_id = 0\n",
179 | "\n",
180 | "# Load data\n",
181 | "x, v = next(iter(loader))\n",
182 | "x_, v_ = x.squeeze(0), v.squeeze(0)\n",
183 | "\n",
184 | "# Sample a set of views\n",
185 | "n_context = 7 + 1\n",
186 | "indices = random.sample([i for i in range(v_.size(1))], n_context)\n",
187 | "\n",
188 | "# Seperate into context and query sets\n",
189 | "x_c, v_c, x_q, v_q = deterministic_partition(x, v, indices)\n",
190 | "\n",
191 | "# Visualise context and query images\n",
192 | "f, axarr = plt.subplots(1, 15, figsize=(20, 7))\n",
193 | "for i, ax in enumerate(axarr.flat):\n",
194 | " # Move channel dimension to end\n",
195 | " ax.imshow(x_[scene_id][i].permute(1, 2, 0))\n",
196 | " \n",
197 | " if i == indices[-1]:\n",
198 | " ax.set_title(\"Query\", color=\"magenta\")\n",
199 | " elif i in indices[:-1]:\n",
200 | " ax.set_title(\"Context\", color=\"green\")\n",
201 | " else:\n",
202 | " ax.set_title(\"Unused\", color=\"grey\")\n",
203 | " \n",
204 | " ax.axis(\"off\")"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": [
211 | "## Reconstruction\n",
212 | "\n",
213 | "Now we feed the whole set into the network and the network will perform the segregration of sets. The query image is then reconstructed in accordance to a given viewpoint and a representation vector that has been generated only by the context set."
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 187,
219 | "metadata": {},
220 | "outputs": [
221 | {
222 | "data": {
223 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAr8AAAD0CAYAAACSGU5oAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xm4ZFdZ7/HfW8OZe57T6SGzJhAiGkJAryBRFPXqRS5iEIWLovdxvKByRZRBBvWKoKIiCiIik8pFhCjDoxGQABIIhEi4JqQ7naQ7nZ5On7FODev+UXVqvbuofapWd5/upPf38zz9PKuq3j3U6b2qVu317ndbCEEAAABAEZTO9Q4AAAAAZwuDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoPfAjOzl5jZn5/r/QDw8GFmd5jZk871fgAYDn02HYPfVWZmzzWz281s3swOmdkfm9m6c71fkhRCeE0I4SfO9X4AZ5uZ7TOzBTOb7fTLt5nZ1Lner15m9nIze8cqrv9tZvYq/1wI4aoQws2rtU0gxSOlr56qzvu7ISGePnsGMPhdRWb2Ikm/LemXJa2T9HhJeyV9xMyqq7C9ypleJ3Ae+/4QwpSkayR9k6RfPcf7k8za+BzH+W7V+6qZlc/0OvHwxYfmKjGztZJeIennQgj/FEKohxD2SXqmpIsl3diJy/yKM7Mnmdl97vEFZvZ3ZvaQmd1jZj/vXnu5mf2tmb3DzE5K+t+dM8ybXMxjO8t+3WDbn1Uys71mFszseWZ2wMyOm9lPm9m1ZvYlMzthZm90y15iZv9sZkfN7IiZ/bWZre/Z7hfMbMbM/sbM3tPzPr/PzG7rrPdTZnb1af7JgVMSQjgk6cNqf7HKzEbN7HfN7F4ze9DM3mRm48vxZvYDnWP3pJndbWbf3Xn+AjP7gJkdM7O7zOwn3TIvN7P3mtnbO33iDjP7Fvf6i83s/s5rXzWzp3TW+xJJP9w56/XFTuzNZvZqM/s3SfOSLu49e9R7xtjMvrXTz050+vdzzewFkp4t6Vc66/+HTmx3XZ2/xRvM7IHOvzeY2WjntSeZ2X1m9iIzO2xmB83seWf6/wdYltJX3fH5ks531D4ze/byujrfvX9iZjeZ2ZykJw9Y32Yz+2CnDx0zs09Y54enDf6e7tv3zeyvJO2W9A+dPvgrnef/xtpnuafN7ONmdlXnefrsGcLgd/U8QdKYpPf5J0MIs5JukvRdg1bQ6Vj/IOmLknZKeoqkXzSzp7qwH5D0t5LWS3qdpJvVHmAve46kd4cQ6kPu93WSLpP0w5LeIOnXJN0g6SpJzzSzb1/ePUmvlXSBpG+UtEvSyzv7PSLp/0p6m6SNkt4l6b+59/VNkt4q6ackbZL0p5I+sNxBgbPJzC6U9D2S7uo89VuSLlf7C/ZStfveb3RiHyfp7WrP5qyX9F8k7ess925J96ndJ54h6TVm9h1uU/+1E7Ne0gckvbGzzisk/ayka0MIayQ9VdK+EMI/SXqNpPeEEKZCCI9x63qOpBdIWiNp/4D3t0fSP0r6Q0lbOu/rthDCmyX9taTf6az/+/ss/mtqz1hdI+kxkh4n6aXu9e1qz2rtlPR8SX9kZhtW2h/gVKX01Y7tkjZ3nv9xSW/u9LdlN0p6tdr96JMD1vcitfv3Fknb1P5hGob8nu7b90MIz5F0rzpntkMIv9OJ/0e1v4e3Svq82v1U9Nkzh8Hv6tks6UgIodHntYNqd6BBrpW0JYTwyhDCUgjha5L+TNKzXMwtIYT3hxBaIYQFSX8p6Uel7jTOj0j6q4T9/s0QwmII4SOS5iS9K4RwOIRwv6RPqD3lpBDCXSGEj4YQaiGEhyT9nqTlgfHjJVUk/UHnjPf7JH3WbeMFkv40hPCZEEIzhPCXkmqd5YCz5f1mNiPpgKTDkl5mZqb28fm/QgjHQggzag9Al/vc8yW9tXPst0II94cQ7jSzXZKeKOnFnf5zm6Q/l/RjbnufDCHcFEJoqt0nlwezTUmjkq40s2oIYV8I4e4B+/62EMIdIYTGED9sb5T0sRDCuzr98Whn/4bxbEmv7HwGPKT2bNZz3Ov1zuv1EMJNkmYlXdFnPcDpOJW+uuzXO99T/yrpQ8qeHPr7EMK/hRBaan8HrbS+uqQdkvZ0jvdPhBCChvuezuv7fYUQ3hpCmAkh1NQ+qfQYG/5aIfrsEMgRXT1HJG02s0qfAfCOzuuD7JF0gZmdcM+V1R6ELjvQs8zfS3qTmV2k9gE9HUL4rIb3oGsv9Hk8JUlmtk3S70v6NrV/NZckHe/EXSDp/s4HQ7/93CPpx83s59xzI53lgLPlB0MIH+vMZrxT7R+sI5ImJN3a/m6V1J7lWM4H3KX2zE2vCyQtf2Eu2y/pW9zjQ649L2ms8/lwl5n9otpfcleZ2YclvTCE8MAK+97b71eyS9KgwXSeC5Q9s7xf2X56tOfzbV6dzwjgDDqVvipJx0MIc+5x7/Hr+9GWAev7P2r30Y90Xn9zCOG3NNz3dF7f/7qTY52TVq+W9N87+9TqvLRZ0nRvfB/02SFw5nf13KL2L8mn+yetfZXq96idniC1z65OuJDtrn1A0j0hhPXu35oQwtNcjB9gKoSwKOm9ap/9fY7SzvqmeE1n248OIaztbG/5E+OgpJ3mPkHU/gJedkDSq3ve10QI4V2rtK9Ars4ZobdJ+l21f5QuSLrKHZvrOhfbSO1j95I+q3lA0kYzW+Oe2y3p/iH34Z0hhG9V+4s0qH2hrNTTv/0iPY8HfY702+eV1r/sgc4+LdvdeQ446xL7qiRtMLNJ97j3+PXH/4rr65yJfVEI4WK10xheaGZP0XDf0yu+rZ7HN6qdzniD2ukJezvPW058L/rsEBj8rpIQwrTa0w1/aGbfbWZVM9ur9sD0iDo5PJJuk/Q0M9toZtsl/aJbzWclzVj7YphxMyub2aPM7NoBm3+7pOeq3UFXa/C7Ru3pkmkz26l2DuSyW9Seyv1ZM6uY2Q+onXe07M8k/bSZXWdtk2b2vT0DB+BseoOk75T0aLWPz9eb2VZJMrOdLn/vLZKeZ+0L0kqd174hhHBA0qckvdbMxqx9AefzJQ0sU2ZmV5jZd3Ry3hfV/gJePtvzoKS9Nriiw22SntX5nPkWtXOOl/21pBvM7Jmd/rjJzK5x6794hfW+S9JLzWyLmW1WO/9x1UqvAUMYtq8ue4WZjZjZt0n6Pkl/02+lndSH3PVZ+yLtSzsndabV/o5r6dS/p5f19sE1ap84O6r2D9rXDIjvRZ8dAoPfVdRJXn+J2r9SZyTdo/bBfIObivkrtRPl90n6iKT3uOWbanfWazrLHlE7j3DF3J8Qwr+p3Sk/H0JY8WKY0/AKSY9V+0PgQ3IX9oUQltQ+4/18SSfUPiv8QbU7tEIIn5P0k2on/R9X++KF567SfgIDdXLj3q72F8WL1T4mP23tKiofUycnrpNC9DxJr1f72P9XxbMsP6L2WZoH1L7g82UhhI8NsflRtS+0OaL29OhWxVJOy1/UR83s8yus49fVPrt7XO2++U733u6V9DS1L9g5pvZAeTnn8C1q5xqfMLP391nvqyR9TtKXJN2u9sU3r+oTB5wVw/bVjkNq94kH1P4R+NMhhDtXWP1K67us83hW7RM8fxxC+JdT/Z52Xqv2YPWEmf1S573tV3vW6D8kfbonnj57Blg2LROrqVNS5JWSntj5QlrNbf2zpHeGEB4Wd3Azs89IelMI4S/O9b4AAM5v1r7j2TtCCBee633Bww8XvJ1FIYS/MLOG2mXQVm3w25lueazaeUPnROfChK+q/Sv42ZKulvRP52p/AAAAJAa/Z10IYbVycCVJZvaXkn5Q0i/0XHl+tl2hdn7zpKSvSXpGCOHgOdwfAAAA0h4AAABQHFzwBgAAgMJg8AsAAIDCOKs5v2ZGjgXghBBscNS5QX8Fsh7O/fW6H31dUn/d+K/p11yHWi0pvnnsxOAgp/Soy5LiJcnqzbQFGmnxjc1p5edn9o4nxUvSyGxrcJAzv6U8OMjZ+tGUG0K2Le3ZnBR/5JqJwUHOjnfckRQvSY0r9ybFf+yTL83tr5z5BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhVM71DgAAgNNjzdXfxv978eVJ8Zf/5h1pG7j3YFq8JJuaSooPayaS4huTacOkjZ84kBQvSc23p8VP/tB8UvxXfjPt/02SvuGldybFt775qqT4gz+aFi9J6+6pJy+ThzO/AAAAKAwGvwAAACgMBr8AAAAoDAa/AAAAKAwGvwAAACgMBr8AAAAoDAa/AAAAKAwGvwAAACgMBr8AAAAoDAa/AAAAKAwGvwAAACgMBr8AAAAojMq53gEAAHB61n/5eFJ86/iJ5G1c+quHk+KbjUZS/KFfeEJSvCTt/NDBtAXufzApfPa6zUnxx67ckxQvSetfV0+KH5v/UlL8ps+nn+dsLSwmxe95+teS4mu/vCUpXpL02dvTl8nBmV8AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFEblXO8AAAA4PYs7ppLiRyt7krdx97PXJ8Xv/JdGUvz2P7glKV6StHlzUnjt2suS4je9+wtJ8TY1mRQvSeHCbWnbWL8uKX7DnfNJ8ZIUarWk+PrPb0iKL8/PJcVLUjN5iXyc+QUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYXCTi7PAXDsME5QnZ+G8RUNeUO96htjBobZxOlZ6/2dsIwAAoOg48wsAAIDCYPALAACAwiDtIVXO9LytMDVfVrnbbrq7Uwe/rqFyI+JvFR9ezcS0uq26xaiw0g5mtt3/91DJrdfL3Gt7qPeQE79i2kP/9x1y9gkAimb0yEJS/ANP3pC8jb0fmE+KH7nrYNoG9uxKi5d08rE7kuIn3v+5pPjGE69Oiq8emU2Kl6TS9FxSfPPYiaR4e/BwUrwkldasSYo/ecW6pPip9346KV6S7vvVJyQvk4czvwAAACgMBr8AAAAoDNIeEvnsAZdVMKBYgUt1yKwsbyOuXYopE9ZyL7T6T/mbXzi4rQ37M+dMpS6EgU/nr+frAklvAAAAZwZnfgEAAFAYDH4BAABQGKQ9DCGTheBfyJnDX2mSvlyOa2v6HAq/Lr+CRv/qEKVSqW98c5idGuaGGj2ag0NO72YUQxaj8H9//5a4DwYAABgGZ34BAABQGAx+AQAAUBikPSTy1RSCm2xfaQreT9s3mzllEHLiS2V/o4r4fKvpbmaRs6/D30TD/QYKOYkFw1SmyGw8Z/3DVm4gjwEAAKwCzvwCAACgMBj8AgAAoDAY/AIAAKAwyPkdgk8/zSsllkmV7QkJOWm+pczzObnELke4VI53e1PmrnFu45VMYnDffc3urOR/AyXfjS1vAYvrDJnt+fxin/+bnwucVwWOtGAAaFvaNJ4Uv+W2xeRtVGaXkuJPfNvepPi1X5lOipek45eXBwc5azauT4p/4Lq0v6u10uIlaesX0v4vyvvvS4q3b74qKV6SdMfdSeEPPGWogqhdV358W1K8JG35Yt4VTuk48wsAAIDCYPALAACAwiDtIVHI+72QqebVM4XvUgMq7s5s1ohxJRdUdivzEwn1VowpVUfj5uo19V2g7FMgVkoSGLL8WD/D5B4MlUuxwjJ+cXIdAADAaeDMLwAAAAqDwS8AAAAKg7SHYeRVN8hMzQ9xNzRJDZfqMOqeH1e8YrXicgMaLqbuqiO0Gu6V/kUgpLy7yZV60hxC36ayv41cu5VXHcL/Dfq3LSdtIazwNxvuLnLA6vKHYdVcf3XtlsvLqYfYGVt5VWI4iAHgrOPMLwAAAAqDwS8AAAAKg7SHVMPc+aG3JIHLABhxT69ziQ9Xb7io2/4f3/WMbvv33vO6bnvG/VZZbMW0h4qrLtFyeQ++HPRJd6OJaatm96+aUyQ8c3cJl0/hthcyN6rIlLzoux4f4aeRG8PO/oacbQCrYMylNGytjHXb10xu6bav3R4LyB+bm+22bzm2v9u+b+lkt73kKr7MNRe67Vor9tjgPkOaYaXi8aF/2/JuJgMA4MwvAAAACoPBLwAAAAqDtIchWGbavv90ft7kvySZe2bCta/fcEW3vWM6Tq8+5v6YHPH+x/9Gt90su5tfuKlQczewqFXinpwcj8+/4KN/0G1/wSazOzjh6k74sgtNlzjRcPdcb867+P7pEHkVIXyChbm/Wqkn7SFTtEI+TcP/dX1iB1O7ODMmS7E/XFie6ra/qbq1237i6KXd9nVrL++2pzav77Z/4uJ13XZtbCK2p7/SbU9b7Ev3j8ZjeN+muA9/8qUvZ/ZvZiGmSjTdR3gYWdNt22xMs2hO3xNjlo7KPXBrpf880s3uHBkc5Kzdv5i8jdLxmaT4hU3rBgc5jcduSIqXpD3vvi8p/sSTLx0c5IzMpFVk2f7B/YODepXSzkPWr390UvwDT5gYHNTjgurlg4Ocb3zpvqT4xkXbk+IlaWS6PjhoSJz5BQAAQGEw+AUAAEBhkPaQKHOviEw7TuiXe0pC+AmH67d9Y7c9eSxWbNjUistvn4tb2bDkEgWW/NXc7r/OpT00q7F9rBqnCF77/S/rtm/+zmsy+3f71rXddrUap1ttMU6dbmjNddtv/Znnx4X91GkjTsdqKU7nlpr9pyoyxSR6pl2zj3JutpFXXQJIUOo5B7C2FNNs9pZjitBOG++217i0nrUuNWKiEtMebHRzt92YiDGjuzd220vl2Dd2TMUesXNDXP++H3p6Zv/+oxE/H466dIqlibh/1X21brt294kY85W7u+36Xbd12zN3/3ncQCYdAgDOP5z5BQAAQGEw+AUAAEBhkPYwhGwZeVdxwT3fcgkR4z1pD49bE29gsXUuTs9XmzHtoezSJkIpToW26q4iQjlOx7aarppCKW7bLLbXNGN7tBH/q+fHd2b278GpTd12oxH3qTQep2cPuyPliW/7eLddb8R0iN3lmOrwtzd+l9vX425r8f00farC15fIcIu4v3SmYD+pDjh91vN4Yyke7Ntce9wfe/WYVmCuj/rSMFbyKUuxXT8Rt9ioxOdd11Pd3VyjumNHZv8m3OdAbSreeGNpND4/tdl9PjwqLjtyQ7whR2vft3fb//5LH4rPL9zrtpZ2pTsAPBJw5hcAAACFweAXAAAAhUHaQ6Jm5lH/3w7lnuc3V+KV3uXjs932WCVWVghuar/ZiNOoDRczW4lTnIuuOkTFXfE9NxfTEEZH45XqR0NsT5djMXxJWhiN1R6WLO5HCHHKc9q1y5PxsGmEOP07OueKjZdiwfWgOJ+b+fv5P1Pv3HOmrIa/eYYP6n/DESBN9uBZ41IdRtyB2HRH72KIqUlLLn2ptOQrJcS+0SjHSgyLJ+LNBWru2D42Hrd1aDamMNxzr6uiIukhV9FlYbNLhVoT99sW47qqLZf+NBH75fq9sTJFZWOsQrN0/wG3NToWgPMPZ34BAABQGAx+AQAAUBikPaSyvKn2+DuitwZBbTFOc641d/91lwNQdv8V1VJMUTjeivH3T8Wi+cfHYurCwlKc1ixNXRDXU477dM/amD4xHbL/7fOzcf+qI/G14K5u9ykQrVac5rWQ8/vJ3PTvSukN3Q2s9JiqDlg9K93kYtTd9CWEmLpwohpjbluK6Q0uA0mtSlzvoqvCUlp3Rbc9X40fAidHYr86Ph77T202fh5IUnNLTHlqrXXPb4rrmm64zxnX6eaqsT0bXOrUtid120sP3BwXDfGzAQ9vm/79SFL84SduHhzUu43ahqT4E1elfXZf8t5TuMHKYm1wjFN/7rGk+DVvXD84yGltWDs4qEdpenZwkNMcLQ8OcrZ/emFwUI/S7XcPDnL2vfAxSfF733c0KV6SdDjt/24lnPkFAABAYTD4BQAAQGGQ9nA6MikQ+XP7ZXf1eCu4XAeLy9QUp1fvc6kHi+5GEy/51Ie77Uc9/ulx0+U4LVNfiOusluN69p1w+6PsNM4mV1y/5Noum0KNWLBC8+7q9kolbqMkPxWTmKpwKheVcyE6zgDr6a/mbiJRt9g+2XKpTe4GMvON2HfHajE1wix2mqV6TFMqt2JMrRQP4iWXqTA/F/vYyBGfwiCNuHQFX+Gh6dKcGu49lUrx+dFRVwWiGfuoNf3fIG1KFQAeaTjzCwAAgMJg8AsAAIDCIO0hkbmpdjebqKarjFDqnTZsxdcqbho1KLYfcu0vbolTpG/84Ee67bvKF8b92HZZt70UYhpDw13lbS6tou6uDn3wn7NXTNbHYpyFON06MzrfbV/6tN3dds2lOvjUg7ExlxsR3FStz4Awf8j5tJFsmkReUQiPrAecCaHnYDN3M5mlUizfsNCK6Q0txQoM+0LsryPNdd12ZSn2y1CKz4/X4zpbro823QXZ9cWYHjV1KHuOYmE+Ph5xKVINt3zNVZ1w2RBa654fW3A3rin7z434WRHoZADOQ5z5BQAAQGEw+AUAAEBhkPZwGlw2Q2aavrfOQdMVyp+vx+nFpRBTHRYnY1WHV930oW774PjWbnt6Ik6jHqzGqdaaq97QqMR5zXLZXRXudnbCsr95GrOxsPioq0xRKcV9HXNvqhZ3W24Tas24lQaf+uG2l5ljdukWPX+10gqPuttzyzA7i1PVW+2hUY3VGGZDTIGom2/HtIeDin2u2nLtZmyXmnGdG0P8PCi7KjHWivvhq8KsP+oqxEhSLfatJZeyVJuNvaC2xaVauc6xNr4FVWfieqsl34M4JwLg/ManHAAAAAqDwS8AAAAKg7SHRCHn90I5MwWfrfaw4K8eH4tToScspivcczLen3xmalNcdtTfJzyud7QaUxLqSzEPoVqJMRWXntB0+72wtJjddzfj2WjE99GsuqL5dbcudyMMfzV4Se6FvHIN5laUWbYnLPNomN9oiTfVADpGy1OZx/sXd3Tbh2xLt11v+gomMe1o4WCsnlItxf46MbKh214zHtMeJtyBX3HlYyYsHvUT7pN5352uRIOkZiv2s7W3x5SlNetidZb918T31FqM27DR2J5a41Ij9n81biDEzyI8ctR2rB0c5Gz91NHkbdji0uAg5/IX3pkUX5qaHBzUozkzMzjIOXnrxUnx5fVpSXXjn7w/KV6S7nzFVUnxl7/ktqR423VBUrwktRqNwUHOnr8/NjjImd+ddrxK0uTs/OCgIXHmFwAAAIXB4BcAAACFweAXAAAAhUHO71BKOW3P5/yOZl45WI9/5unZ+PxsNT4/Mx5znWqVmFvYcKXH5Coe1RZiXt5INeb5Nl26sU/ZcVXLvo5Z//Jja+djrtOjDsWIkHPUrHW5jB/2aYO5+b+x2epJq/J/5bxsXsqb4UxY6rm7YKMcc/SDK0e45A58n9fvsyCb7qD0pctKjfhC3eflu7KDI6X4/Ji7feRkI9uByq7jbHM9Zb3r/OGQu+Oky/mdKLuSh5X4AWG1I3EDoae0GgCcZzjzCwAAgMJg8AsAAIDCIO1hKIPLaLXc74jZnrSH/Y118bVSLJFUL8USRk03/Wn+tmlNXxosTneOuKnTpisfNj8fpzKrk3E//N3oGo2eaU2Lh4G5wAvn4j49+V9i+O75hW675JIPptfHMiRvacVlT+alPYS+TUlS0y9jOX9/P11NDgSSxAOs2TPNXy/F1xoudWEpk78T+0zDlfhrhVhGsKzYLyuuPFnLde/gtuVvfujbkz2fP75U2g7X3tCMcVMH47bLLu1hzO9Hfa7bbtYOuC2Q9gDg/MaZXwAAABQGg18AAAAUBmkPydxV1O6K7FbL/47IllaYVrx6vDUe7xjlKzMoxOnI0PJlGtz8p5vWdLOxWlqM05djY+tdTAyq1+M6S63sf3up7FMo4vuYqsXp0m+Im9DVR9wdo9x6DuyIz/t7t5zMSxvJS4dYKY70BpwRLt2g5xxAy1zHdP0h87xiFYigxb7tVss934ppCKHkK8P49Iu49mYzph6MV7IdZVMlLr/FVW9Y69IVdrgbIZUXYozVYsrS/MIJt3+uTScDcJ7jzC8AAAAKg8EvAAAACoO0h2H0vweEWiGvCkT2N0Vw06VNf+cJd5V5ZSxOozaDK1bv11X1aRZxPeURt3439bngblIxOhFTMZpL2WnNkru03NfTX2zEqdpxV8CiUvLTvzGmZHHb5lMdhkpbWOF3WNNVwshd1eCKHEAUj5dqZTLzylIzHuwli681y+7j0uLRV67HtKaqTXXboy7taKwU++sh9/y6Sjy2x117Mn4c6MFm9mO6VnP9zN1hY3057vfESHxh1OK2JzfFz4GqS8Fa3P+Q2wJ96ZGoXEur0nHPMzYnb2PXR+cGBznVZto+hfnFwUE9Snt3JcVf9NtfSopvLaTt08lnXJsUL0mXvugzSfF3/nHaNrZ/PP08Z7hua1J8eWlwjLfmP2fSFpA0e/WO5GXycOYXAAAAhcHgFwAAAIVB2kMqy2lnbriQnTYMmTg3Wd+K7apLJWgs1mLMhC/rEKecRqvxcu6GX2eIy06Oxptr+HtlKGSncYLbj/JIPCQmxuPzi7OuGoVbmVX6p2jkXjE+bOWGENdVzgnxf2WuT0caX+0he/T4G080QzzWW6H/FG5wR2IIsWqCNWM6hMsO0mg99r9qiGkIcyH2vVmXvlTrSVMK7kY04yPxtSmXNlGqxRgXIl84IizEPm25vQwAzj+c+QUAAEBhMPgFAABAYZD2kMrPteemQLiKDr3KcZrT3NRmw994YnSi/wYtXh352X98i1unuzS8vKbb/Oan3NhtNy1ut1yOhe4lKZgvuu/LWcT9G62Mx90oxX1tWGzXzB9OQ9zBwoe08n+HtdT/70mqA86E3nSGUHJpR63YV4JlDtjYdFVOmjrSbdczIXGdrfqe+EIjVmgYKcV+bNV4zC+1/OeBtOhTF1wnKrmqDhWXClVxMRWXTtRwV2dP2oZueyYcdVujlwE4/3DmFwAAAIXB4BfCnQdNAAAK/klEQVQAAACFQdpDqrxZQH+xdKu3SLx77Col+KvKMzHuBhaquxSFavzvMndld6i7YtGVuP5bP/oOt98xbUHz2bSH0Vacbh1zBfVHm3Gf1lz/C/EtNNw0b4jLloI/nHJ+V+UWgchPFcksMmy1CGBF7vi0bIqOlXwaUKNvnJm/WU1MacgUc3GrbfnqDY173fpjvzzhU5OCv2HM+sz+LbjUqZpi//M15kcqroKLu8lOyX9uuOyOTeULu+0HW3e7/aOTATj/cOYXAAAAhcHgFwAAAIVB2sMw8mb+/POZTIfsbwrzgRbTEqpuCrLsFqn7CUx3E4nsBeZx2tXfIKPVjAX0m6Vj3fbYfNzA7kb26vHtbgb31d/zczFuNk7JbnOZFWV31JSrcdtVN+dr/tDyb9+1Vyqrn7n+PlNVw/9t/Y1FVlgZ8HXiAVOyqcwr9Yo7dl3uQitTxcXdKManKbl0hdDyR3HMMViajdUUKoo3wmjapNts3Ienjl6Z2b9dpVjRZcNSXCa4iiyj62O74j5nfLesbojbLrmbc9CZHpkWto4ODnJ2v+ozydsojVQHBzlh3dqk+H0/dVlSvCTtef0Xk+IfetbVSfHlpcEx3oY7Z9MWkBRCWp+77O21wUFO5WsHk+IlyWyIik1O2LR+cJDT/MpdSfGSNDZ6VfIyeTjzCwAAgMJg8AsAAIDCIO0hkS/QkJ2pcFdR90wbVitxijSE+W670YjT9vU5t2I/tdTyl4+7bbgbTQRXQWLM5U9ULE6NrG/E9f/6jT+W2b/dM3GKdNeJ6W57fG6u226W4xTp4kTcRsntR200vs+8G1N4K030hLwbiABnRDyoytXsdHHJfyq6gzT4tCNXQSGYmxf1MZmbZ8QYc7k/JXegW+YzJLZnmg9m9m+6Ffv1SZeKMeZi5l1WhrX8VLW7yYVb9oKR3d32HY0vuf3I3gAEAM4HnPkFAABAYTD4BQAAQGGQ9pDIZyHk33Ch7h8oNB6KbTdlWSq5qgvVeAV3019dWvZTsi6VYCnOa5YaceOj7grztW6e9or1G7vtj7znTZn9m3IXjk65ovtl9wbHFff1mU95VoypxP07vDZexV6Tu5GGTxVx282t6NAr83fuvYEIcCrikbi09FDmlUr9mHsU+0Nwx56Zq5Qw4Y7qpksJcje7MVeFpVwd6/u8Wi4lqhnXebAR048kqRbiZ0WjtLXbnm1s67YPL8W0jKlSTHuYLMf2rNvekVbsu9zYAsD5jjO/AAAAKAwGvwAAACgM0h5S5U3P+zIQrWxV7MW5A+5RnEZVaV23WVYsEF0ux3amTr67qnykEqcmxytx27vWb+q2p+ZjfKV+vNtu9kxrLrh1zZfjdGkp+CnceKi8/hO/H5ddiiketyoWG8+kPeTNouamjfQIOW3glMUDqVk/mXmlPnt/fGAxrSdb7SGmD1THptzz/qYYcRutRuwnNuKKwYfY35rNuGzN3TjjUCOb6nMyxBSkGcV0pkOlnd32fOtwtz3iOpq5N3EkxMouX2v+Z9wlkVoE4PzGmV8AAAAUBoNfAAAAFAZpD4nMF73PTNu7/ATLVnuwcMI9momLuKnJ5rwvRO/L1fuNxG3UXOWHmpvCvd1dGL7BreeazVfGF3qK+rdqbkrWpVC0SnH6s1GJ2/vykf+Iz7v11FRzbXcVe16FDHfTDuuZavW/yrLF/93ibhmyIZAmHjEhZG/IsnDyq922v5mFlUa67ZKrwlI74DqdxWO1XIoxZRdfXoz9su4+Nhohrr/sjvlq2JzZv+lWrEZxRPEGGOPhvm77WCumbjRcv1wMMR1prhXTHuZdCgQemSbvmx8c5Bz6meuStzEyk/ZJO3407SYpe9/74OCgHs35tPc9d2HaXZN2v/KWpPj9L7s+KV6SLpq7PG2Bh04OjnGOPvWStPVL2vjeLyTFLz5mV1J8ads1SfGS9ODjRgcHDbv9M7YmAAAA4GGOwS8AAAAKg8EvAAAACoOc39Ph05/8z4ieHEL/Rw4uF9aXHMumxfb/b8mUICr1L0dkirnD862Y33frkc+5/an2LBN33ufRNt2d6nze4JLL9G34klGZvN2cckmh/++tcs/jUs6j/CJMlGfCqco/dvxd3UIzHvehtdg3xie1B1easFmPecFmI32ipZq7iMD3ybrN9uxUpv5ht1VuxvUuKl5n0AhuP1yfbrlrCChvBqBIOPMLAACAwmDwCwAAgMIg7SFZzu8FP51v2bSHui8IlqkUU+rbzoT42mo5qQ5+xrLppkQX/NSp4tRn2aUwrLAqtTKFzJo5UTmGqSbj3lpvQRz/Nyhn9sP/zZiqxWrrf3vBbKpDzpK+LGLmBX88D779YStk+2vIlPjLrSPoWr53URQQADjzCwAAgMJg8AsAAIDCIO1hKP3TE3KtNLOYuSvcMNvOuUNO/9nYbIjFqc+6S4dorHR1u3/glrdh9jUv1cHd9So7S+umb3vWT1UHnL+GOYZjh2iskKaUtwwAIB9nfgEAAFAYDH4BAABQGKQ9DMVfXZ0j76LrFRcaYvrTLztUeMs/GBwz5LaHmlDN+xvkXRm/wkqHedtM8qIYONIx2NKG0aT4nR84kLyNxo4NaQt89o6kcNu6OW39kip7diXFb/tcfXCQU77s4qT4scNJ4W0HH0oKX7rmoqT46lx6uqCNpR1PJy4ZGRzkbP2TW5LiJWn0sscnL5OHM78AAAAoDAa/AAAAKAzSHpIlpiqc0/Wew8oIZ3CmlklfAABwpnDmFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFwU0uAAB4hBu/44Gk+DA5nryNyqETaQvs3JEU3ti5MW39kspfPZAUP3pTWnzj+quT4hc3J4VLkprHj6dtY9PlSfHrPncwKV6SmguLSfHbPjOdFH/k+Y9PipekykLyIrk48wsAAIDCYPALAACAwmDwCwAAgMJg8AsAAIDCYPALAACAwmDwCwAAgMJg8AsAAIDCYPALAACAwmDwCwAAgMJg8AsAAIDCYPALAACAwqic6x0AAABnV/M/v5a8TPmSvUnxrQ1rkuIr+w8nxUtSCK2k+PKmjWkb+PI9SeE7Ry9NW7+kh/7n9Unx228+khTfWjuRFC9Jpem0ZZq33pEUv3Xh8qR4STrwvZuTl8nDmV8AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABQGg18AAAAUBoNfAAAAFAaDXwAAABRG5VzvAAAAOD37n7M3KX73+8aTt3HTzX+XFP+93/qDSfFHbrgoKV6S5rdZUvy2WxeT4ktLraT4I48aS4qXpO1vvjUp/it/9Oik+CtffjApXpLmH3dJUvz4PceTt5GqdQZHrJz5BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhMPgFAABAYTD4BQAAQGEw+AUAAEBhWAjhXO8DAAAAcFZw5hcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhcHgFwAAAIXB4BcAAACFweAXAAAAhfH/AfKvvE0J5qnIAAAAAElFTkSuQmCC\n",
224 | "text/plain": [
225 | ""
226 | ]
227 | },
228 | "metadata": {
229 | "needs_background": "light"
230 | },
231 | "output_type": "display_data"
232 | }
233 | ],
234 | "source": [
235 | "f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 7))\n",
236 | "\n",
237 | "x_mu, r, kl = model(x_c[scene_id].unsqueeze(0), \n",
238 | " v_c[scene_id].unsqueeze(0), \n",
239 | " x_q[scene_id].unsqueeze(0),\n",
240 | " v_q[scene_id].unsqueeze(0))\n",
241 | "\n",
242 | "x_mu = x_mu.squeeze(0)\n",
243 | "r = r.squeeze(0)\n",
244 | "\n",
245 | "ax1.imshow(x_q[scene_id].data.permute(1, 2, 0))\n",
246 | "ax1.set_title(\"Query image\")\n",
247 | "ax1.axis(\"off\")\n",
248 | "\n",
249 | "ax2.imshow(x_mu.data.permute(1, 2, 0))\n",
250 | "ax2.set_title(\"Reconstruction\")\n",
251 | "ax2.axis(\"off\")\n",
252 | "\n",
253 | "ax3.imshow(r.data.view(16, 16))\n",
254 | "ax3.set_title(\"Representation\")\n",
255 | "ax3.axis(\"off\")\n",
256 | "\n",
257 | "plt.show()"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "## Visualising representation\n",
265 | "\n",
266 | "We might be interested in visualising the representation as more context points are introduced. The representation network $\\phi(x_i, v_i)$ generates a single representation for a context point $(x_i, v_i)$ which is then aggregated (summed) for each context point to generate the final representation.\n",
267 | "\n",
268 | "Below, we see how adding more context points creates a less sparse representation."
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 188,
274 | "metadata": {},
275 | "outputs": [
276 | {
277 | "data": {
278 | "image/png": "\n",
279 | "text/plain": [
280 | ""
281 | ]
282 | },
283 | "metadata": {
284 | "needs_background": "light"
285 | },
286 | "output_type": "display_data"
287 | }
288 | ],
289 | "source": [
290 | "f, axarr = plt.subplots(1, 7, figsize=(20, 7))\n",
291 | "\n",
292 | "r = torch.zeros(128, 256, 1, 1)\n",
293 | "\n",
294 | "for i, ax in enumerate(axarr.flat):\n",
295 | " phi = model.representation(x_c[:, i], v_c[:, i])\n",
296 | " r += phi\n",
297 | " ax.imshow(r[scene_id].data.view(16, 16))\n",
298 | " ax.axis(\"off\")\n",
299 | " ax.set_title(\"#Context points: {}\".format(i+1))"
300 | ]
301 | },
302 | {
303 | "cell_type": "markdown",
304 | "metadata": {},
305 | "source": [
306 | "## Sample from the prior.\n",
307 | "\n",
308 | "Because we use a conditional prior density $\\pi(z|y)$ that is parametrised by a neural network, we should be able to continuously refine it during training such that if $y = (v, r)$ we can generate a sample from the data distrbution by sampling $z \\sim \\pi(z|v,r)$ and sending it through the generative model $g_{\\theta}(x|z, y)$.\n",
309 | "\n",
310 | "This means that we can give a number of context points along with a query viewpoint and generate a new image."
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 189,
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "data": {
320 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABH4AAACmCAYAAACsl0hIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmwpFl6HvTnfFtm3rVuLV29TvfM9HRrpmf1aJCskTSDPIqRBcgiBEYegbBBhiBCEAF2WNjYZjBjbCtAGAth5ECyzCgCS2Bsh2wLCGGGsCRLsyhGQrN0T6/VS1V1rXfL5dsOf7zP+92lu7qrq2oq8373+UV0ZN+beb/Me/Otc06e8573hBgjRERERERERESkf5J5vwAREREREREREfnm0MSPiIiIiIiIiEhPaeJHRERERERERKSnNPEjIiIiIiIiItJTmvgREREREREREekpTfyIiIiIiIiIiPSUJn4WUAjhKyGEj8/7dcjRo9iRW6G4kVul2JFbobiRW6XYkVuhuJFb1afYWbiJnxDCp0IIXwwh7IQQzocQfjWE8J134Lq/EEL4zB16jX88hPDrd+JaryfG+ESM8XM3+VqeDyF84k48bwjhZAjh74cQdkMIL4QQPnUnrnu3KHbmGjs/zr/9LITwC3fimneL4mY+cRNCGIQQfo5tzXYI4cshhD98u9e9mxQ7c21zfpF/860QwlMhhB+7E9e9GxQ384ubfdd8VwhhGkL4xTt53W82xc5c25zPMWZ2+N+Td+K6d4PiZr5tTgjhh0MIX+Pnq2dCCN91p679zabYmWubs3PovyaE8NO3c82FmvgJIfwnAP46gP8KwFkAbwPwPwD4I/N8XcfIzwAoYX/7HwHwN0MIT8z3Jd0cxc7cvQLgMwB+ft4v5K1Q3MxVBuBFAB8DsA7gzwP45RDCI3N8TTdNsTN3fwXAIzHGNQA/AOAzIYQPz/k1vSnFzcL4GQBfmPeLeCsUOwvhx2OMK/zv8Xm/mJuhuJmvEML3AvhrAP4EgFUA3w3g2bm+qJuk2JmvfW3NCoB7AUwA/K+3e9GF+A828N8B8K+/wWMGsAB8hf/9dQAD3vdxAC8B+FMAXgVwHsCf4H3/HoAKNqmxA+BX+P37Afw9AJcAPAfgP9r3XP8EwH+z7+u/C/tQ+24AUwANr3X9Bq/1c7CB6ecBbAH4hwBO7rv/BwB8BcB1Pvbd++57HsAn+P+fBvDLAP5nANv8mW/lfZ8F0MICYQfAnwEwBPCLAK7w2l8AcPYm/v7L/Ps8tu97nwXwV+cdG4qdxY6dQ6/9MwB+Yd4xobg5WnGz73X8HoAfmndsKHaOVuwAeJx/wz8679hQ3Cx+3AD4YT7fpwH84rzjQrFzNGKHr+PH5h0LipsjFze/CeDfnXcsKHaOXuwcev3/NmzCMNzW+zrvwNr3C30fgBpA9gaP+UsAfgvAPQDO8B/Tf7kvwGo+Jgfw/QDGADZ4/y8A+My+ayUAvgTgLwIoALyDf9BP8v57GajfA8t+eRbAKu/74wB+/U1+n88BeBnAe2GTKn8PHGAAeAzALoDv5Wv9MwCeBlDcIMCm/H1SBu1vvV4w8ut/H8CvAFji4z8MYI33/acA/tENXu+HAIwPfe9Pg/8YF/k/xc58Y+fQaz9KEz+KmwWJGz72LJ/3W+YdG4qdoxE7sJXHMYAI4HcArMw7NhQ3ix03ANYAPAXgQRytiR/Fzvxj53OwD6SXAfwGgI/POy4UN4sdN3xsycc8DZsI+e8BjOYdG4qdxY6d13n9/xTAp2/7fZ13YO37hX4EwIU3ecwzAL5/39efBPD8vgCb7A9QBsi33yDAvg3AuUPX/7MA/va+r38IthXhMoDv3Pf9mw2wv7rv6/fA/vGnAP4CgF8+FOwvg53I6wTYrx26zuQNAuzfgf3De/9b/Pt/1+G/P4A/CeBz844Nxc5ix86h136UJn4UN4sTNzmAXwPws/OOC8XOkYudFMB3wrYK5vOODcXNYscNgP8OwE/se96jMvGj2Jl/7HwbbKvOALb6vg3gnfOODcXN4sYNLIMlAvgigPsAnIZNGv7leceGYmexY+fQa38YltH09tt9Xxepxs8VAKdDCNkbPOZ+AC/s+/oFfq+7Royx3vf1GMDKDa71MID7QwjX/T8Afw626ux+BRYQT8YYb6Vo1IuHXmsO+0d/4PeIMbZ87AM3uM6Fff8/BjB8g7/TZwH8nwD+bgjhlRDCT4YQ8pt4rTuwlbD91mAd26JT7Mw3do4qxc0CxE0IIeE1SgA/frM/N2eKnQWIHb6ehr/vgwD+g7fys3OguJlj3IQQPgjgEwD+2zd77AJS7My5zYkx/naMcTvGOIsx/h3YB/jvv5mfnSPFzXzjZsLbn44xno8xXgbwU1j8uAEUO/OOnf3+LdjE1nNv8edeY5Emfv45gBmAH3yDx7wCCwz3Nn7vZsRDX78I4LkY44l9/63GGPf/Y/zLAL4G4L4Qwh97g2vdyEOHXmsFm6U88HuEEAIf+/JNXne/A68lxljFGP+LGON7AHwHgH8ZwI/exHWeApCFEN6173sfgO1dXHSKnfnGzlGluJlz3PB1/BysY/+hGGN1C69nHhQ7i9fmZADeeYs/e7cobuYbNx8H8AiAcyGEC7Dt7D8UQvidW3hNd5tiZ/HanAgg3OLP3i2KmznGTYzxGmx71/7r3ezvOW+KncVpc34UwN+5hdfyGgsz8RNj3ITt6/uZEMIPhhCWQgh5COEPhxB+kg/7XwD8+RDCmRDCaT7+Zo/ivAjbL+g+D2A7hPATIYRRCCENIbw3hPARAAghfDesAvuPwlI6fzqE8MC+az0YQije5Dn/zRDCe0IIS7A9jv9bjLGBFYX6l0IIf4izfn8K9o/rN2/yd7nh7xVC+BdDCO8LIaSw4lUVrNDUG4ox7gL43wH8pRDCcgjho7Cq7Z+9hdd0Vyl25hs7/NkshDCEzcSnIYQ3mv1eCIqb+ccNgL8JK8z3r8QYJ2/24EWh2Jlv7IQQ7gl2PO4K/xafBPDHAPzft/Ca7hrFzdzbnL8Fmxz8IP/7HwH8Y9j2hIWm2Jl7m3MihPBJH9uEEH4EdjrT/3ELr+muUdzMvc0BgL8N4D9kv7UB4D8G8I9u4TXdVYqdhYgdhBC+A5Z5dHunebl4m3vF7vR/sD2FX4QVWboA65S/g/cNAfwNWGXw8/z/YdzbS/jSoWs9j709ee8C8GVYRe1/wO/dDwvaCwCuwQpUfQK2xel5AD+871p/DcD/BZvdL/i6rgK4fIPf43M4WD38VwCc3nf/vwrgqwA2Afy/AJ64wev+NPbtQYetVkVwzyRscuYcf68/DRv8Psm/30X+jfyxfw7Ar77B3/4kgH/Anz0H4FPzjgfFzpGJnU/z2vv/+/S8Y0Jxs7hxA1tdibAieTv7/vuReceEYmfhY+cMX8d1vt7/D8CfnHc8KG4WO25e5/UfeN6j8J9iZ65tzhdg5Q+u82/xvfOOB8XNYscN789hBxFc59+k+/sehf8UO/PtrwD8LIDP3qn3M/CicoeFED4HC4z/ad6vRY4WxY7cCsWN3CrFjtwKxY3cKsWO3ArFjdwqxY5ZmK1eIiIiIiIiIiJyZ2niR0RERERERESkp7TVS0RERERERESkp5TxIyIiIiIiIiLSU5r4ERERERERERHpqexuPlkIQfvKeirGGL5Z11bc9Nc3M24AxU6fqc2RW6E2R26V2hy5FWpz5FapzZFb8UZxo4wfEREREREREZGe0sSPiIiIiIiIiEhPaeJHRERERERERKSnNPEjIiIiIiIiItJTmvgREREREREREekpTfyIiIiIiIiIiPSUJn5ERERERERERHpKEz8iIiIiIiIiIj2liR8RERERERERkZ7SxI+IiIiIiIiISE9p4kdEREREREREpKc08SMiIiIiIiIi0lOa+BERERERERER6SlN/IiIiIiIiIiI9JQmfkREREREREREekoTPyIiIiIiIiIiPaWJHxERERERERGRntLEj4iIiIiIiIhIT2niR0RERERERESkpzTxIyIiIiIiIiLSU5r4ERERERERERHpKU38iIiIiIiIiIj0lCZ+RERERERERER6ShM/IiIiIiIiIiI9pYkfEREREREREZGe0sSPiIiIiIiIiEhPaeJHRERERERERKSnNPEjIiIiIiIiItJTmvgREREREREREekpTfyIiIiIiIiIiPSUJn5ERERERERERHpKEz8iIiIiIiIiIj2liR8RERERERERkZ7SxI+IiIiIiIiISE9p4kdEREREREREpKc08SMiIiIiIiIi0lOa+BERERERERER6SlN/IiIiIiIiIiI9JQmfkREREREREREekoTPyIiIiIiIiIiPZXN+wWIiAiQIPD/AiIiAHS3IjeSMm6i4kbeIrU58lZ4vCyFAQDgRLqKIuQAgK12BgAYxykAYNaWAIAWDYDDcaUYO248dkaMlzRkXRRMYwUAaGILAGjR8h7FyXHnPVTO6YoQEsRocVEfalvUd90cZfyIiIiIiIiIiPSUMn5ui89FapbxuAuMhSykAIBBVqBsbBWjamsAe7PRoZtvfb1Z6nDg5jWhFXhHbCFHWxYsDgaMmTPZEgBgKVvD5WoLAHCtHgMAmmgrG4E/061wxL0A6f6vi52DwRMSe57YNod/Qo4Qj5flxLrv+xk3y9lJnC+vAwAu1tsAgJpx4zHRxreyMsbHMOb24klxc1RlONjm3JMOAQDr2T24XFubc6mx2zpav+V9TmSfsxc74c3jqIsdreAfZcuJZWm8PT8BAPj4+nsAAN/32A8izy3759lXzwMAfufqkwCAJ0vru16avQQAuFJdBABUcYZJtc0rez92KD7YV0F91ZHlw5C1tAAAPFqsAwC+e+MRAMC33Pft2J7sAAB+/VXGzPgVAMBlZovtsE0qG8sma9Eiep/2ps+smDnKhsHGNz4u/vDoAQDAh9Y/jGulxc1vbz8NAHixugwAmMD6rN3G2p4qWhxFRLRvGjfHhzJ+RERERERERER6KsR492ZFQwhHYArWZosDV7lCFrpV9RDDwYcmNm/W1lwZu4t/y0UT4+E/zp2ziHGTcM604MrUqWwFAPAdpx4DAJxdT9E09rKv1na7zdj62uXnAAA7XNWoU66ojnJMatsf3xZ8ontsRTZ5u62WtF/bBAA0F23Gu92xlRDUNY7iCsc3M26AxYydQbA393RqKxnvy+8BAHz3+jsAAB84dQp5anE1ZlbHVm7x9tULvw0AuBxs9eLqkl1r6+QqfuvVawCAmSWaIS6dBgCEsxaTyWVbca0vvmD379jXsdxG5AoJcHQyyY5bm7OSjgAAj2S26v7t+X0AgD+4+nYAwPvufRtyxtZuY/FzNbVf48vnfhUA8Hy09uJcbt+/tFTg6zsWN3XLdqiw62PNVtiS7QsAgPbaq3b/zFZhYzsDuKJ2lBzHNmfIuDiTWj/13szanI8uWZvzkRPvQMGV+TFrcGwW1uZcPP95uw0TAMBLQ/vzXVwb4TfOWzZH6UOgxPopDBg7sH6q2XrR7mfWR2x2AXjsLNyf64aOXZuT2PjjoXQDAPBE9ggA4P2jxwEAP3Df92FpZO1FHFhszdYtjoa7ltWzmVoMvHSS8fP2FD/15D8DAFye7AIAqqH9TNw4CQAIX30GAFC/8BQAoN1+2e4vtxHbCV/dwv25bug4tTme/b7O8c0jqcXH+/OzAIAPLT8CAPjkox/DYLQGAKgH9tjdZY53XvxNAMDFxsa7T+fWbz2znuOXvvTPAQDTGbOAWs+St8yz4JmJE8sCiY3Fn/VVR2d8445dm8NxzsNscz6cvw0A8JHRuwAAn3jwoxgMra1pM2ufJkvWVzWXfhcAcDVeBQA8M7D3/utrLX7m8/8PAKAsrf3wzHf/2N7liVW7/j/8ztGLGeCN4+bYbvXyiR1PY+4menIbMOer1ogkawmSzL6XpPbnyjhAml63hqe6ZoHSTJhW1rb7Ulff8ivDUerQjh+LkwE/kN/DNMTHl2wg/W1nHwYAfOv7H0daWKO0y39mm629r69sXQIAXElskufKijVA19cSXPigxdr1ey3Gmvs5EFqxzrN8yT6E7fBD/u5XLL169hvnsPtrX7aX2HqDJYvEJwvXE2tbHuUH+Cdye48fTFcBAI8/9ATyJev86qHdTgYWF//CH/wW+9pCC9c27EPa+ZUCa++y2Hs22vV3V+26WLLB1dJF6/B2r9okz+z8FQBA+ZXn8epnfwIAEJlaLYsjZdyc5oDo/bkNiN7L2/uCtUEP3vMOZLn9f53YY+8rrL16/2M26N5mm3Nh3T6pP7sS8Esfvh8A8GJm8TfZsMfGFbt+/nX70DU7z37uBYub6itP49o/+yl7kVFtziLyNueUTzJnNhn8wcze40cSayMeO/Me5MNlAEDD28mI25bf8QQAYDq08c2Vk9aPnTuVIj9j13thag3SztDatHbJBubFRYuZ2cvW9lTnrO+rnv0Grj35nwMAYvQP8rIoUnA7KScBzwaLiRPRxiU5/7nnDZDVFmNhYPcFWP8zYt8UC/vZe+63H4pvC/ieH/wxAMBXuefgChcwSl5j5VWLtdkrFj/1CzbpPPvyUzj/y3/WrtMtVsgi8f7qBCeb35bY+38W1kYsc5J4KR2gyKyfajkBlDEO7vno9wAAHgjWXz28Yj/02HrAq5/6NwAA35hyrJxY21ZxXDV8wSaLqhdt4rE9Z/1X+ezvYfvZn7MnV3+1cLzN8fHxQ4mNR85Ge39HbGdGcYQirvCHbKEhy62tWXni4wCAM7n1KWdXrR1520aNb/yRfw0A8OSUC/LcSla2bL++YAvy9cu2ONFe5OTz5SdRXv8iX+XRnAQ6TFu9RERERERERER66vhm/HCbVlLYnyDl6layYjOHw4dtlnr0nhMoTtisdMrVjwFXNLaet4yf3WdtFWv8tKUWtpMK9cRmqveKH/otZxe96OG+41TtJt0rhBi9KLAKIy4Kf7c2MhasG9qs9EOFxchqYo9YCevIMouhjIURPakyvf8MAOBEZrPRG8uW8fPqoER8wlY+0nWLj9mazWw3mcXcmUcfAgBM32kXqz5k15h84iK++iXbwlNfucRXqXhZRCe5ovEg0+jXor3XaWXvZVJHJFzdyJi941t4kmjtVMY6dUVpsVXMcpy4zzJ+1vnYdGTxV6f2mDNMyT/xNouL9H32+PIDj+HqP/xZAEB19ff4KvuxstEHfgzu2cTe1wd5u9rae5RVtrqVTmskjQUGk1OReT8HZmDwjhVuQ92IAfe+07YCjpkttL1sq2hNbvG5wdXY+H675qC0563OfSt++/M/b4+depsji8S3XdzLNufhYP3UOv95Z3XJWyDl9sCEbU7LOEvYf6VsdDhkwjAJ2HiPbRW7VnJvcm7XrzhWWrvXMori4/Y68vGjdu3n3o8v/eTfAgA0k2f4atVfLZo1vo8nYLcZ90U0DbeptzM0PMQi1BYftW9tbyzIJhwT7bBP26oCZuvWpgQGU8FMj8js+jVmKqZn7VpLT9i4p/7Ao7j4T/6GvYbNZ/kqFTeLxMfIJ5gVv8FMjoKHnIBtTluWaJk6Fituz+JDZjv2MyULh9cDtkl1grXH32nXBzPMhnZfyYL164/atRJmdgx2ecjK09+JL/1nf9+uM7t4R35XuXN8nOOZYie9zeGWrJbj42ZSo239MzLLrXCMO2vsGiXbj6q22GiSiJPfZVvGTrM/S5bsvik/i48etu2r2OVBKtzJU339JXzj5z9lT1dv3qHfdr6U8SMiIiIiIiIi0lPHMOOHtXxYUDdb5er4Guv3nLGvR4/ZCun6+x7EcM1WsTKugA1ZeTfhZtV0xWYOvc5lfXmC5jxXQNuDK+eB+wrzgnsUec0YOQuZrCKw6GFdWh2XprYCVZ75E31aXKvyc3OCGTjrnGnOuepUzizTa3K5QsqaCLvB3qdtFuy9ypns7cxicHPZ5l93lzJcuW7v7RUePVhyxR5Du36RWNZGSO37S6zfsvrACp7h6mvzmxY3sdE+5kXSHW/Kf/MF591rroSNeWTpdGfXF8VQlxYzU76V05dt3/pOYvFwZcWudWFthHOPWebHJcbblG0ayzKg4POCK7BDxmM+GCI7Y4Xzqmtfscfcco0yueMYOBtcqWJTgIYdzqyxNqfa3UWT2PtWZ4wBrnhNWJh5mzUTLq3aRS9tDPDKOVvFusrV1+lJxo0lAGG5ZcYHV8Yy1sFbOr2KbN0KTDfTy3yxWn1fRL6K6mcGlIydMeukTMfXkZXW5zCZA7PWxijjLWtzxsz4uXLN4uCV0wVefIe1OZcnzO4oLDYadltpwWzXmuMuDjmHp5aQr1hR8maizI3FY+/FmtdNYdvj9UJ9ZDFNKgQWjAf7r5Ltxe7U4mXCgwgu7VqMXN4NeO6iPfbSsj3PLFpcNAO7/vUxD89g4V6vVTVYHiHdsOyfZvO5A69VFssqs8UK1k9t+T7V/Nwya0vEkgWaWah5OrA4uH7OPvNMc7Y5zIC/MCnw/Hl77GV+hput2fO0rIM4Ys2WjG3QYJlZ+A+dQHrCspzbi6/yVSp2Foe3Of6Zyt63mp+FpqzLNG2maBk3MbDNYY2frcvWH00zfo66bld+ZTfFN562sc9FeyimJxk3y4wbZuAnHGANT9n9q48nSJfssIJ6y2tgHu24UcaPiIiIiIiIiEhPHbuMHz/MK8mYYWOHUiC9nxk/77BVrvAO7k0/mwAnLcsiy6w2SxrttJON08y6eLfNMg7fYUebjr92EbNf+pz9PGelEx47l7POxsbpj9jzrNiKaRie4QtbAU+iQ8m6CbPtr9vtzpMAgOn49+3at/QXkNvhNRO8yEo5sJnlZmDvxlZhs9JfaC8D4V4AwDar029Hq7kyblgHgaujFXx1fobrlyw+xin3Qa9yv7L9CF7KTwEAIpczCmZtjJZajP7oHwIATH7X4iVuXb8Tv7LcIRn3up9km3CCmRMFm+GS2WNPNS2QWqxsB2trrs3s6+k66/KMWAOKp3uNRw2uje0x5ardF5csJpM1i8nLXNGIzPjJmTU0XBkgnH63Xegb/9huderFwsi5PjPi+9Um1uZM4JmEdv+LaQnwRKbt3HI7rvKI2+0z9v76aXA7SzxJcFhhe2I1fVq2R8nA3vtsjdkcDX+Ined1r2O2MkK2/kEAwOziV/lqPRtVFkHG/mrE24J1MDLm/lSMj+eqCcLQ4mA6svu2GEM7Z60G1JTZGbtrjLulCrPa+rbA7ML0lN3yIB/sVlxe5Sp8zRFnU2fITnzMvrj8G3YbdUrTovCTmZZgA4+VYLGxzFPgIk+g/No0xYgn4LbM7ppyiNSctLHKeM3amu0NaxuuZyXGrOXTeF/FmGL5TGwxM94zfWYJa/+s50iW3s5X+eu8VXbqIvFxzgazxVZ4omACaxQ2U/v+U7s10tze5ymzkbda+3p3zbLXxyPGzqrFztawws6Y42YbKqNZ5bHczGK92rK/ip7ZbF8uP1IgrLzPvrj4Jb7a5nZ/XblDcv6bX408/S2x3Qw5bAw8CfaGf32aIlmyx1Sp3W5zx0xdWB2f3QHjZmRxs5mUqC9aWxbW+Jl8nfWl+Flu815r0xJmGU7ZNw4eGAHhLF/l1+7QbztfyvgREREREREREempY5jxw/2fyzabV2zYbXbSZqPzFZtVzFObZczadyKvbH/fAJadM2LGT13wNJ7aZhXbNVuNz05dxvbaOfuen8IzstSi5ZO2WvHwYx8HACQnLCsk5lx1bQeYvGB7EVtuti93/wAAYPfS0wCAF5/+i/YzWiG76xLGT8LZ6Zozzde5zf38psVCNhqgZZbGOOFKKmwGu2VstVwNDTNWpi9q5Ju28pEXXDlfYdZGypOYUpu1jl6Jnq9j0DbIwdXXjOlB8Ar0yg1bBIEZEw3fs7E3v1ywrCrWRJgOEGGrZOPaYuZ6YpmI7abFTsOVsjjkSvooIueKRrbNE1G4up9WXifK0x3tdsAaU/msRdhlYQ+tBSwczzJseCLXLk9LSb2QChct22oJkatk262ttl/jiv1WbfFT7jJexlzVGrXIr/AkQS7VB55YmDM+M2YV+kmYy7k9brkKCF7DzlNp1dQsFG9zImPGTzC5xvetYX2xMmRIMeT/Wz+yE+12c2yxU/MElYbtRpm0SKasi8ATwTJmjbXVwXjwEy0L1nBJJxGh9axCr34mi8Pjxt7PGeNmi7XfStaZbEODJZ70lfBEnVnJNmbC+nQTZu0wRiZ1iuGWjX28JpSnZfBQQgRLNmTuCFAwSzbMGsSp19mQRRS6bRX2Hpes9dN41h8zKr7e5shK+/xVzXiiJMc9E37thwXW/Kgzm7QYXLc+bcDsoJQnwjW8rvdT3m8uMelwULUAa8OozVlcLfuqMducJnodMNb6STOkzBBruBVih6dVVuWQj7VrlRl/Nm+AmV03+Pj4OsdPHMKk/nmeceTjnGIXew1TFzdHe6CjUb6IiIiIiIiISE8dv4wfzganrK/hp2ulYG2fHZ62dYWb1EePAlPL+Am5ZeckjT2m4WlMnniT7z5oX0+3ceLMBQBAyYX0NLWMn5WBXWvjtO2bbwrLBmmCzXA30wzpuv1Q4otmK1b/ZyWxfYbnn7espLLissgRn308ioa5xUCW2vs3rlhTY8qaGteX0GS2Ylqm9n7V3Cc/HLBOj58I5rU18oicB1UUm8zw4Vub7dqMc9rYSlvgCQZrPAUhLWss13b9lM/b4MId+m3lzuDpFtyXPOEqxYzHbpXRYueF2RIiV9vLhjWfgsXbqLIaC+CqSMJaT7GIGL1gd5UFV76mzPjhXvjipAUTF+7BhTKE7RY5X4OfOqgWZXFEP92CJwnuMm5Kni7ZBIun59s1xNoywsbR2phtHs1Vztb4WGbvMMswTiLy51gvzONmbDFQ8PSmsM4VMnZIJ0asnbC7ty8/cA1JcbNYuveDWVyz4G0NMywYD88mBdLEYqVlTbppYMwwUzXyZxJmEMZpxGDH/n8YfLWUq7RM8fHadL5OusxaeOm1BmnS8L5w8LXK3Pl70XKMMmX2hmftVPx3fyGLGLFd8NoubcM2Ycw6Ulwsr/xUp7rB6EVm/zBZe3+vAAAgAElEQVQDMWGGWMn4yGof/7AOHQdCyWaNEFVHbJHtxY71V1P2Tw37q5KZhS/FJSSNfc7y7PhpY21QPmWGs/dTjI923GLpJftexdOY6pptGU+hZNJrl9WwxnFOttMiZTwr42fxeNw07KvGrEO3y35nxjHv84MBwpAdy5K96bPcvk5ndtswW6f1pOi2RV7yZGQ2H+kV7rTYtpQffpRCzhMoeVgckqsRaWZBVfckbpTxIyIiIiIiIiLSU8cs4ycg4QkoKWcPc66sJ5xhDpt2m7T2/TS7D8nmPfa9zFZT09ZmGcM9PJWn5Mr6xGYOs/IE6o1vBQBMdu2ZE66ajXiNgjOI05rT1JXNjidlhiEzQPhSkfF0jeWGGSK5ZQCVlWd0aK3sm4+r4VwhHRW2oh54ysWMaRQ7jWXdXJ2uoOHJFzVXULulCGZ3Be59DjVX3DNg5SJXtpjMxcMvkHFWOoMFVMoYWR0xVsYRw027Xhq9xk8/9qMedd1JcJxnr7myvu0rYeCqKuuxXGqWgYp1NXgSRs1jck5VfkLGwfo9SRmx/qotrWaswRFYcyG7Yu//yphxwJsR9z+31xoUvj9emRsLo4sb1kwoC3vvr/PUktBYn1EyreJCXEJkNuqMJwiWvA38fvC6PXyf0xixdIEnE6YeS3Y7uMbMnnt5yhdX31kWD8VWiyL42lE/VsL6Yq/NYbZYzkzUaO9/01r7UTKT4/k0RcL+CMz+a/10lQFXU31c4vVW2ogTFfsrLq2WrJdQ8n+yUxZbCfu4JQ+XV8uuL1NrszgOtzkN42aXfVTL+ClZY+NKvld/x08fTFh7ZcDMn8g+qp3a+9xsNlg+x1pBrLMx3WGGT+YnTVpfltT2OK8jFC9OkbD2pYY3i+VwmzPjScbXW4udiu3LkKclXQgFksbiKDJjvm3tvnX2bQmzPbxGGOq9cU7CFI2a8VUxdooHLGZSZm6scgydXK2R+YDa+y3FztwdHh9XBU94a7zGD/sW1ji8MAwIQ2YVMgZantK9us7Mdc+E9y4ta7EamGHK095afsaqt+z7awyGYmrxs8agay/VKAr7DDfrcmWO9mlwx2ziBwhMdc5ZbDnZ8lRTGxjH3LZT1Zvc3vVyjrjKD2grdjvNLZqKZ1mYjEUx202OeqYBS0sfsOsxfb5ixzXJbSA0vsQiVBU7Ug6gmqbC+pATAt7ZtRZkzdQG4AkH83sJWzrO8m5ZzmzfzPbUPoRF7qOZ1RYLVxub3Lk4OYeEW3pSbicshvzAFryAuF0z5duXtBFbLzGm+J6Onrf3PF+ygt/lR33SiB/yV+z+E0WJwVe5RbA62o1S3/jYYonbAq+1NpE85RGRvsULnDS8cukSEsZVPrSJxQFTWuvE01L5QYp95jBEXD5n10mn9v6fgcXDPZnFzuUPse2bWBu0HCx2VpMZ1q69AgC46kU7b/N3ltvncTPgBPKLU+uzLjFuai5ApNzedf3lywjc5pkOuO2TbY5P2mTs8YfcnlE0EVvPcTKxsvd+4ymLm43C9jAnj/Dggx2Ll1G027VkhuKKPVbbdRZL5DuxlFkcvBKtPbmWWNsz6/Zgsc25sIkwsKK5xZq1F6PTdo3t3G65CwfLfrBAGnH5KX4wKy121mornjqc2bUufZjPs82+KbXHreYVmku/y1frRZ5l3jxuhmxzzje2wHU12GJj4wulrcVNeX4TeW5jnsHA7huMrB9a4rb0jAWil7ktcJQBF3fY5+X23t+7ZnGzsmSTgS9/xK4Vx9zizm2Ba3mD5OoVvlj1UotkL3YsHl6orTRGFiyGZvy4GbmoXl54CYmX3Bj65zK7lo9/Bmx7lnNuKW0iXvoGj+UuLSZOZxZn66m1Wy9/C7esjlnAnmPl0bDG4Jo9dufgGpjMkb8Hy7mNY15t7BCl6z4+5kQhmDyx9co5JCx1ka0wftbtZ3czjokYLwUXJIq6xcWn7XN1xgN4zk6tT7qvsRmgi49aXOVbHGNzcudUUWM42QYA7Bx6zUeVtnqJiIiIiIiIiPTUscv48eO4Mxanw8xmASPTUlufC2MxuXZlCHDGMTDdNXCbWOmFLTmrmHMeMEsStI2tYCTcQpH4saY8Fq5lUTtUnLlk6nWGgMaPX249jZoFfGcFfwcvUCZ3Wx19rvfgEZVlY7PENbMsWgzQRj9n2e7z47oTxljBY5JHTD8dpBE7Ex61zfTpFS/izJRYsLCdr+CvLNn9S23A9ElbCavH23yNR31eul8mrbUJni7qoTRjRl8bLXaqZNRldEUWuWxnXLViPHgKdMGth4OQIJtxaw4LIW4wNfYeplovWVIPlqb2/Ce4XWwYEzChbC+lWhbG1I/c5raclu9byQDyI9VnSezami6LkFtulgo/mp23vtKaRoT24FaNVfZvJ5kdi0v2PCO2TSfYh67EAJTcrqPV94U0Y+zUXlCeBTMnjKHItPcZgMD3sG3Y5kyszclHLPLtq6wsqlokEQW3BHlx53uYgr+S2vOk24xRnsA9YluXl2M05av2zagM1UXjx7XX6cGxb9341hsW1AXQekfGWPLaywOOjzP2URnHO4MkxSozBAtuHTvLa6xO7Xl2n+YhFjsWGwNedFhNUZeX+CrVWS2icWvjmJKHEVQcf+yNe7iFB3u7Gjh8Rl1xHM2M5gFjJveMnzRBwQTBIX/mHsbdafZj7UW7I+GWnSUOakZJDcy8PIbanEUz8b6KaV+RW46ntccNd88EdOONhuOdhuPjbMAUeB7NjqGXTQkoWPB5mffdx7bnYc4BnGAftcLj3k+xjVpBxJAx3RfK+BERERERERER6aljlfETELripYEzy4HFxFBydbP1rB7epklX6C7zo+D5dc1jbT2LKOdfMzZAxQwNz/TJODvt2UGBhcfSyu/3LJ/QTcclLM6acAWl5bHxTatj3O82jxufYYYXyeTKV+3H2PpKalIjtpZG0XK2uGVRQtbKxICJWyOGXhGBglleXoNunZk/65x9HpzjYxkby16rbrfG7MJFew11X3ai9kvFVcvKY4Qx1RWM81iKJUL3j92yhFBbQ5GyFnzwAqv80RCAnKtny1z02GDMnORK2L0XeD+LbK55xkjdYJVHXQbFzMKpuLpVMROn5oppxdvAeKritOskuoxTsO7TGvsu1pzz2wwRiR+ZzKLga5ln9nBVnrXrPG5O8PVkdY22ug4AiKoKtWC8tiCL6HI8UzOzeBq9iLtnbpSIDducmul/M+uvMitRh5S3OUsMpu3esdtMPMVJtj0rbE8Glzh4uW7xmPG0i8nWRTTVNb4Gxc7i4NjWM37YBnhCX42471FAjLNuPBq8r2qsr4pDP1KZhaL5Qy2AZa8X5VkbHAevcmG9foV3+FHLrCFV7+wicFwli+Zg7FTJwWLPezkT1li0KIHIzzLMho5euNsLihWMNz9aO0QMatZ7ChYb9/D2NHdTZJfsGsWM/ZZnFNa7iLVl/KjNWTwVawdWPGSCbzNKhlH0+MIEiNzVUFufFUvLaG2X7H2NhRd3ZhZrbLvPVMssAn6GfddZNm4PXLa4Wd21768zAympxhg19pnKx8dHfZSsjB8RERERERERkZ46Vhk/CKHLzvFsmgw8nYlHKYfGZg4THrWcjGsknPkLnDn2vcpTr+nD2eicM5SxsVo9ADDw1Q5OEfpM27Bb5O+mM+05ADAZqHtzePgOGtZraBpldNxtoXvnfI+7V5r38wJ9dYN1n2INq5yAboUqae2NLJj+NUi9HgdX4WPsTi9Y4uU3eKrBmmdxbNptwYWRgsedzrZqzHZs73vk6oksiuTAbQwHj8Ztu3/pTP1CDbDej6+EBZ6QknC/esIVK1/Mim3EgKscq7zcOh/LhXrcx4XSJZ5COPSFsFmLJbaHqvGzSA7GTcsjlIPXo/NlUNbisQxDroA1nhXKkwR5GpzHj9cFSmKLnJ3RCp9unbdrjI9TvF1mJzbiCZRtWXfxqb5o0RyMnYb1Vpouw8czfnjSTpwhMlsschUVlb23frLOkP1VxnjIqtidLniCTdcGn3WFqc5neCoTdpgFcN3GLq9eu4AYlbmxeHw86se6MzvDM3+6XB+vEVUCHN/ELmuDA1ZmYoRDWUOILZb4vQ3GyUk2H54p5nXEEmYZBobk5vUKqdarF9TBDJ82OVjLsO3Skz3juUKI7D+ij5U5zmFb47c+LknatssuPOVtDB9zsrafPTVjtiEzfvKaNYdmO0g8O60nmRv9wPGN1wjj52v/2uPGs1OBvczm6LtfmoOnRiYeajVrYjYNshHrpTKL7CQ/O51gPc17PEOMt4PKrtnMtrHMbKS+ZMSrBRURERERERER6aljlfETY4O64p7zytbBR8mQ99nG9bq17wdmAhXVNobZOgBgpbHvnWBtnxO5fe0nqAy4EppnEU+xxk/Kmi2Rs4uBs9JXued+NmH9Bu6nj02LJZ6MkfgpG5XNak5x1a6p/al3XQt73waJZYbVkafBsYZGzVX3mpljKSaIPDkg437UIli8bI+5ysEF+zWulq4PgYr1NQY8DWWNGT5nONF8wldfGU+r/Nl0rUCYPcPXpJXUxcKV8sTe/zHbhJwZGzW/v5fxE5BwdSNFxcdazOxy/3zKVbMTjKH1PGKXKyPl1OJuxucJXHndZmsfmJo4GLHmz1qO/MJ1Pnd9e7+q3EH2PqU8WWfqWV/MEm190zpXrEI9Q2CWYQprOIrcYmpWMuvQTwNkgbHRMGLCtgzbnj5mN14/LHpdj8yef4mpYqMkRXvO40Z90mLh6VmpjWumrBHmK6FtZnHhmRxJDAi8M09tLDIsbNyx4/W/2DetMstnbTmiusoV1Ss+vuFJKDNrbJaYubiS2/NvPGKxnDzyEP7rF3iMSk9WUfuBfVVm4xg/OTDhKnzjdcUS1pyLq0jSNfseT3HKRtaebI0sBpZ5kuAqs5mX8xZjnipY8+S4wabddz/H2Cljbplr0ysbds3s9Ah/4ZxqXC6mg7EzY7ZF4Gcs7688mSxpUgTPkPexcm59zBV2RTNmIhfs89aGARnrscZd9o98zEmeApXzGgNmfywt8dTKYgXVM9YvRsXOAmFfxc/T48YzBVnrh1tquk0VsdjLcubYB+xfdmY8MZm17HKe6rW8EnCNGfbTXe6c4YGka5Vda7Wx2Fgrmu5nACDPljE+b+OctifjHGX8iIiIiIiIiIj01LHK+AH2avykXa0DuydyNTXyRIsImw4sZyMkzNjYjczm8FO2ZswK8muyVkuS14hTrrCxJHnup3tx5roe834mfwwaP6ElAGN7gobZSeXMXsustFnHsrnG30ez1ndbE33PsRdpYr0LP7mLj6vDFC1sRbONV3ifzT5nuzazHHg8ymbJk3dqwEsipD6RzQr0y6wLlG5yBYQ1OgoGVmwbzPgatJqxmLqyCV29BHuTvdaFv2sNKrSRpyVxBTRUbB+m9wEAUtZb2a6sCc8GEWM/nanyVVruVc75xNy77PXHCq7Epk1E3lUC0lrA4vA4IdaLaxk3TVcfwffAVwittS1tzVVWroC16Vl7LI/K2G2szSnaiBn7M88iLPKD2YQNV1b9oJURs82KNu9ON5RF47HivCYLVzu7/spruJRAywyc0jKLscu2ZcsyLBKegrJd2/cHZcT4IjN7+EzDwm7XGKvBT2Vim5fzdTVlA7U1i6sbQ3RtTs1br8XCLEEERK6UB7Y5gTXk8twem3PsUrKtKNEgsv8qGFNrXleMmUX1zNqcIePT4wdti7TLjJVFdKPYabwgodcXiw1ab3MaHnXLz0vZ1Gq4ZOynxo3F27htMGNCe+TJzCOOc5aZkdiWzCDhy8j5OpIGSIKnsXbpI7fzq8odcbCviqzdFJkt2nbFwXw8VAP8DBxLr3PIn9mxuPHsw1ll90+qFrV/dhvzMxUzi1a9pi7jyfuojKfshioiY3Z+6F7D0Y4b9bwiIiIiIiIiIj11rJbrAgISrzzP6vHRa/B4xk/j9VEsy6auMkxbr9fDn2VWEJipkfuJGZxFLooSTWOzzm0VuucGsHckBldek8pWScCTWhD3ZsjLqT3P9u5LAIBp9XX70ahTve42P9XLV9u7VXafjfZT2YLXSImI0Wafm9ZmpwOzusbTl+1apWVZJKzDUJcBdWPfy7ky0Qy8kj0zfqYWiwW/9syfsmyQsL7Q3ukKshgOrhKEbrFpL1b2Pw6o0bJGS8N96xXbKZSWudFWfK9Z/6WsGowZRwXbqQljtvSTMVo/6oAZil7wo4lYTx8CACReeEoWAFfC2Mb4Slhg3MRujczf1wptZA0Df68bL5Zwwr5fWYyEmcVPLIEKqwCAEfuvxk9jYduSst0qEj+pknUZmhajhJlE7Vf5WvqxB/7oO5hV2B3WFD3LkKcG+up70iBy9b2ueXoc66xMNu20yGbM7MPc2pxmGNFctfHLCrsplkhEYJ2gnE+TJ6y5wJXYZlKgCOz3ok6hXBweN/x3nDB+/PQ+H+90JzM1aKLFh5fbSHiKUmRfldT2nm8xO3VQ1RgWPE2XbU6eekYYs1LZR3mmzyDl6XRVi/XkfgDAdnORr1nj4MXg/RXbGNaBQrzByY9JgwjLbG7YXyWsRVZOLUs+8PPR5oz1e2Y1WmZBTznOiUPfLcHPVH76smdusD9rG2A5sXMHr7av3NZvKnfSwbhB4vVuvZaXZ6v6h6y220WRMGunLS3Gmt1XAQAla7GOJxY3W0WDJrP+pmFWmZ8OmLaskejbf/h5PnhtoSbidHKP3dc8ebu/7EJQxo+IiIiIiIiISE8dr4yfEJAknkFhq6bBJxG50pBEzlJzL3NZX0Ld8HvdqoetnlYzZvgkPHWF1yqKEj7H3bS+N5DPyxXYurRskMgMj8hVtjYC08YyemYzmw3fnPyefV1/2R7brdbJ3RH2Th/gSlcS/X3lSgXjx1fKQoi2nA7As8f8n9t0ZrPGTVjl5e22bVOk7UkAwFJqqxo+c+2nfSWNPZYHqyDhqnxsGoSgjJ9F1P3b99jhfuHEY8dX4z120CDGCb+3xVuLnVn5PACg4QlxkW1PHSOa1GJnyO9tp1zt4KrqBk+ga3iiYO01h2LASnYvX8vAX4zMnWfyeNx4d80sUn//+Ga1qLradF2GKd/HavaUPYZZFuBJckmTAonFTZna6uqUtTomzPRZYbuSMm4CM4/aBlhObfU91F5nSBbB4TYneCZf9Axk5/1V09Ua62r9sJZPufmCfTtlhipP6QnDiHybY6EhMxN5ck/Fk3U8i6xglphncqDJMUrPAADGtWoWLorQrQUzi8LfL6+J6ScIMrM5oO5W4rvkZ2bT1+V5AMAsWLzsst8bNDUAa3Nq9lEN6/80jIFBao/NWfMn5YCnDQnuz98NAHi5sXFx1EmUC+JQ7LCf8N0VXX/lWfGhQux2L3iL5Bns5wAAkX3PDsc0edMgiXaK3Iz91cyzonmFzGM18VPFGDt1xAnWunuptgzVo16rpR88bg6Oa7wAb+jeI77DoUHw7B9minmyTrP7Ah/DUwn9s3nRAgNrc9rIcTDbntrbrcQz1Vj3zrNYG+DenBnxFe/zTNoj6lhN/MQYuw9XGY8zHQ0tmOpym4/ihEy0AXRbt6iZUjqeWXpgObF0wbL6LQBAntiAaMDbNB92E0wtAy9mfnQzj7Rko1VHS02b8oPerJ1gzIHWlOmI49q2eJXNc3YtdXR3WexS5gcsgjtOrQBmwvet6doue18LzHzoDTCWQmOxNdv9fd7BbTpMP92ZrON9yWMAgDOpDahPTe027FhDtsEzlpeGXtSXl1oKyF7U0cqLyNuclFvxymAxk3Ig3CYHCw627WxvSwb/qTettVPV9Gm7ZZw1bHOqyQBv54fwU/xenlvsbGf29ds27IPWoGCn6dsIRy1KvMRX68MnmT8WG0xsoOv1VGM32XtwQNS2ABqmSbfWn3kcleWLAPYmHVtOHLaTAu+MHjf2PIN8HQCwWdgk86OrVlB8vbBGbomjhrYABuA2j9v+XeVO8v4qY3/V+Af14P++uZCwNzOIhNtJA8dASWITQJPtp3ktjmUSLlQkQ3xs+AcAAPdyG9jJHbv1Lagby/b8y14QnEPOtbXTyFpfEJFF0Y2P2Ye0vq2Tk87e5sToB5BMEbwIK7fpxMo+zFfbNn7N+CGsSmzSuUwSfNfa4wCAB4O1NWupxVTkuHiN89MDFobOuLU+zRucDLas6kOuo/0RrE84zmHfUjfennBLVlfcmXHSlogc13QxVFmbMG1sK0/KUXTN8U4ZUnzf8L0AgAfYDq3uWqxOOBi+d9mCZ8j+MuOMZDpKkfmktiwQJl0kvjDpowkuJmV+oA63dbWzfUkYfmtj6mpicdOwAHzgwnmSZvhY8nYAwEOwcY63Obu53Z5Z5qJ64WVX2OasDFDA4rMv4xxt9RIRERERERER6aljlfGDAHg90yzzQpk2k5d2s9UsJBU8pavaK2TH2WkvQhU4CwimSDdcwWqbojv6tOUKa6x8NtP+5FvMAqmYBlnWtspWxhk2Z18EANStzV7W8Rqf1wtPKz3xbvNVi5Sp7D5j6gXkYpcK7T8QutXUrgB0YKzxsSnXqgKLQCMGNI1lgFXccjhjXA5am40uE5vJHjC9NbDAXVVNui0+XkQ6ai1sIezFjr1nCVOgu4WN9uDXbRu79PmEK6uJr3IypzXxSprMHorJDLNg1x9zo+kWm4khby9NuBWw5co90+vbtkaSHtqSKnPn2y7SnAWZD6++p/HA1wmaLkMsib7aau2IFzBMmFWWgVlnmCGNltGatIyl2lPx7WY2tbangq26VoybWVNh5G1Ob9bC+sFjJ8u54ulblXkb2VF5BhjitNv2kHiRVO+vEo8d8H4vplpiVlmm4KTmqntrWYaj3FZVm8yX3bm9lYdYxLbBQ8U7AQCXprYl6Kinz/cDt1twi5Uno4ZuOwS/Dl4AvurKJSTBt4d52QTfylzx1gsgBExmdsDFmNllk8TGtqPMMoBm3Iqa5uwz4VvBGrxr2eLmn05yfk8Z8IuBbQrHEmnmYwreRo8Pbu9rJ0gazzz0cQ77K14x9ayL7vNZgrK2LNMp42nKTq+M1k9Npxxvtb6lnttPQ4VTuWU9h8rHyMqOnz/GgG/Nw8FiyxyyIGGbFJu4t23Qj2hnIXF/jH+2ShkjKWaoGTcVs4RK35LKJ5gW9rMZs1W7wtGTFicHp+06Y8bSbf/O86WMHxERERERERGRnjpWGT8BQFVzNXzbVjkHPE09sFhv23jGj80GDrICsWn3LrDvdpD5ypcXimbmT0zQ1LZ6WkcvSucrFvY8ZbQVj4pH1s0aHh/fzjBrnuV17HptV3BamT7zY3/73doycpZZg8czffYOVmYtoLRAy2npyIyOBvZ++opIwh8qGD9pHGPSWlG7663VEMpbO555ygKsCTPDRmNfkbOLnK+vYByv85VqFWOx2PvvK1WD5pR9O/F7fX2Lq11FYMGWvRUNLzafdqsiXFXl9yMaTBqLr9h64V5rU2bMKNupLD42mDW2zmK943oXvz/7fT7Wiy3KoqiqywCAtHwAwF5WYZdIyIYkGbYAi8AHr6/QWtbGwFfVeLhAzuzVPNYIrAdUg8fkRvt6hytjz/Gwgau7thI2Sq0vu1Jt4XJr/VirVfeFNKstdoqaxds5hIjMFG2ZOZjlsRvneJahtzFDX8HnOIRDJuRxgp3KDipIGDtFbSujVWUr683U2ry1LWtzVq/ZT2/VO6gSjWcWVVlatnleWra5Z2R0bY4XfM+nCNEzenjLtzVLPG58FZ6ZibHFpRkPLWCh3gCLlwmPTd5mNs8qa/6MWL9l0s5wobbaQcoQWzQc57C/ykobKyeelcxdEF4cPs1nCExzTtnmpPBdDfZ+p8wIKjhGymPA9cr7GtZUZZtTJna7ObWB1ZDbOwqOmbbaHVziLgqNkRcJC3rPLF7SiRXgTpiV3NZ+kg0/PxXoCnsFP4qd4+CcWdBJa4/NPGu6qbFT27jm1dbarqVg4/DAgt9XK7Y5zGJeYvHncTvB1yr7TF73ZJyjjB8RERERERERkZ46Vhk/MUbUPPVke9dWvwetzeplftoJj4eL/HqtOIWEM8eRK6ANV9ALVgX3ugkNV+mrtkbNE5xarrQ2XPWo+fVuw9N5Wt9vOOY1StSs/7NXo0UrY/PmKwSz2laqdoKtZOZhh/cbr7exnpztvuk1oWqusg8yZu8wWyhjZleIM2y1tsJWej2N1uJ0FOxnnuOxzJ5lNuZx3y+3r+Bqc+HAa5XF4O9HzdiZzizbMPPY6fa+W5tTpCfQpRWyVotn/aUDrlLATwCz2GrrGXZai70p6z+VXN26BluB+zozfnK/Ntue7fYqXolfstcIr8Mg8+ZxU1XWJqQ7FjeB9S/2Vt95wl8+2jvipt6LCwDImaGY8q3PfB99M0XNGJtGa3M2vf9hfDw/tef3uJmxvXq5fRUvt7YS1qrNWSiHY2c65eo7TxaM3bHuNi4ZpaMuoFreek3DUc76BnyLM45VQlthkyeeVlyhb9kvbtbW9nylsczVUfB6L3aNi+11vFgrdhYPs5OZLVqNrf5SCMwmbj3t3dqcfDRAd/qOH6ncsK/y8bHXhIp+f4NX2G/tMDt1F9Y+bcBW33ci46arTWU/fCXu4Aut1cBUbZ9Fw7EKY6fk0dpJ8DbHHuX91SBdQeDnrdZrSLHuShrYX3mmqmcoxhaXWmvLxuyvZoydy8Gyxq6153kxu9aU/dWFeBXPt8/wlarNWRyH2pwdyyIO/AzkfRW4W2aAU90Wi5YZzE3Dcc7ATij13RSeEdQ0M1zj2KjiZ/4WFidbtfVhL07sa6/NWnPMfSVu4avRTtZWxo+IiIiIiIiIiCy04KfH3JUnC2HuqSuBJ3QliVd85wlL3M9XZNx/6qcnhQQ+P5bCvhe7mUebVmx83zxnsmsEJFz+qLsVe7tG6csi7V5NHwAoWUehbqhNjncAAAdXSURBVKevcxrT3P9sbyruTcvecYsQN91pF109FsYNTyjIuAqfBZ7eFmfdSlfm++F53yCxU3p81jpjXAEpCv8zxu47fHr72RlsRbXkXuhd1gAat9dQc0/sUYgX982MG2BRYscdPDEr8dOYEp68w1hq42SvnUp5uok/xpcy4O2YtVchZMiqg8VfMq7C+4rbhPWpwFWLmm3QrN1C5SfLHaGVsOPS5vh7HVjcJyR+OpzHhMVA02yi66u8jgL3q2eMrcCaCTnblRQJ1mEr877s6u1REez7m9FqajSsQbXbMusxXkXJ7CC1OXsWI3aMn+KFQ6d4JYf6pDru7J0Exu/t9W08VY616nKv1YKANWzY9ZnVmnFsVAT7mc3oWah2/4x91HZ7HbtdvY2jU6ul/22OO1jQ0vsjz0pN2K607bRrlzyWMt6XckzkNX8Ktis5EpxhZg9arwtk1xgE+5nLbHN8hb3i6bs7cRvbrH94lOqKHac2Zy92/EvP9jt4KmXbVl3sJF1b4zFk9Xs802cQrD/LkeFsuA/AXkaix463bVfZ5pTR2pox6xbuxi3M1F8dsFhx43konhp2qM3xXTmx7DLeE46FfJyTsv3o+ir+TBYS3IcH7ed5eqmnoGVd3FjmT8kMsRnjZxLHXSwdpUyxN4obZfyIiIiIiIiIiPTUscv42eMZHL6i4becneacWJ7sL4OUHLhdYtZQ02X32L0tgJyzlX5f4/sY/e/NFYyG+1D91mYUF+jPdJOOz6z0Qb5K2t2GfXOp3Z7mg4/x1dBu0/O+1TVfFfMTK7z+gdeKiqzBUrPeVI1y3+MX9s90Q8drJewwb4MOxk5EvOF9vupx+BqWi8aMxNgeeoS3RTwlgZkbkW2Oxc7RWclwx6/NObQK7/1R8Eyy5jWP9cekXDH1Nifs6/+W2B5FZm3stUo8gRK2Ytqwz6oOtDmKm8OOUux4Joe994fHRAdjx3+rvT9ewAiMHcZBZPQkXezYSSoN6yXUbHsa1Ecq08cdvzbHHc4A8iMp42ses9dXMcuwixuu0iNgjTUL264P4imW7KumjJv6UNy0aBQ3r2OxY8cdOhb5wDjnYLvUxU4XM/65LMUK6/943cOm64Pssd5feZZYsy92NEY+6GjFzaGMIOzFR5dV1u2eOPgTAQlWg2Wnel+0145wNw64C6eLK4+b/n0mV8aPiIiIiIiIiEhPHeOMn5uT7JsbiziYoeFZPa1Xs9/3t0y7+7gSxvsO53gcXik7qjQrfdj+P8fBdz05VOtl/3vvNRK6uOluD1617cmJb1oJu1kHV8Rer71IWHtj776DK2EeK3v7lI/2n0Ztzs3bi5vD3wdStjmv7YsOxo3anJvTv9g5WNduv6zLUD3YX92ozTlKNRJej9qct+LwuvJetmHBE3v2MprbQ4/0ePEsxKP9p1Gb81a9fk5CwF7twvia2IkHbl/7/aNJbc5b8dqsIHc4bl7bFx2fuNHEz23wNLPX65T27nOvN+R+ve8fTWqc7ozXfrDv96+uAdGddPhP2e9fXW3OnaK4uZMUO/2lNufOCIc+oN14nNwPanPuHI2R75zjFDfqq/Zoq5eIiIiIiIiISE9lb/4QuZE3Sj9989TUfs82yq056qnwMk9qU+RWKG7kVil25K3TOEdulWJHbo36KqeMHxERERERERGRntLEj4iIiIiIiIhIT2niR0RERERERESkpzTxIyIiIiIiIiLSU5r4ERERERERERHpKU38iIiIiIiIiIj0lCZ+RERERERERER6ShM/IiIiIiIiIiI9pYkfEREREREREZGe0sSPiIiIiIiIiEhPaeJHRERERERERKSnNPEjIiIiIiIiItJTmvgREREREREREekpTfyIiIiIiIiIiPSUJn5ERERERERERHpKEz8iIiIiIiIiIj2liR8RERERERERkZ7SxI+IiIiIiIiISE9p4kdEREREREREpKc08SMiIiIiIiIi0lOa+BERERERERER6SlN/IiIiIiIiIiI9JQmfkREREREREREekoTPyIiIiIiIiIiPaWJHxERERERERGRntLEj4iIiIiIiIhIT2niR0RERERERESkpzTxIyIiIiIiIiLSU5r4ERERERERERHpqRBjnPdrEBERERERERGRbwJl/IiIiIiIiIiI9JQmfkREREREREREekoTPyIiIiIiIiIiPaWJHxERERERERGRntLEj4iIiIiIiIhIT2niR0RERERERESkpzTxIyIiIiIiIiLSU5r4ERERERERERHpKU38iIiIiIiIiIj0lCZ+RERERERERER6ShM/IiIiIiIiIiI9pYkfEREREREREZGe0sSPiIiIiIiIiEhPaeJHRERERERERKSnNPEjIiIiIiIiItJTmvgREREREREREekpTfyIiIiIiIiIiPSUJn5ERERERERERHpKEz8iIiIiIiIiIj2liR8RERERERERkZ7SxI+IiIiIiIiISE9p4kdEREREREREpKc08SMiIiIiIiIi0lP/PwG0eoL1AX/NAAAAAElFTkSuQmCC\n",
321 | "text/plain": [
322 | ""
323 | ]
324 | },
325 | "metadata": {
326 | "needs_background": "light"
327 | },
328 | "output_type": "display_data"
329 | }
330 | ],
331 | "source": [
332 | "# Create progressively growing context set\n",
333 | "batch_size, n_views, c, h, w = x_c.shape\n",
334 | "\n",
335 | "f, axarr = plt.subplots(1, num_samples, figsize=(20, 7))\n",
336 | "for i, ax in enumerate(axarr.flat):\n",
337 | " x_ = x_c[scene_id][:i+1].view(-1, c, h, w)\n",
338 | " v_ = v_c[scene_id][:i+1].view(-1, 7)\n",
339 | " \n",
340 | " phi = model.representation(x_, v_)\n",
341 | " \n",
342 | " r = torch.sum(phi, dim=0)\n",
343 | " x_mu = model.generator.sample((h, w), v_q[scene_id].unsqueeze(0), r)\n",
344 | " ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))\n",
345 | " ax.set_title(\"Context points: {}\".format(i))\n",
346 | " ax.axis(\"off\")"
347 | ]
348 | },
349 | {
350 | "cell_type": "markdown",
351 | "metadata": {},
352 | "source": [
353 | "## Mental rotation task\n",
354 | "\n",
355 | "As an extension to the above mentioned sampling procedure, we can perform the mental rotation task by continuously sampling from the prior given a static representation $r$ and then varying the query viewpoint vector $v^q$ between each sample to \"rotate the object\".\n",
356 | "\n",
357 | "In the example below we change the yaw slightly at each frame for 8 frames."
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 198,
363 | "metadata": {},
364 | "outputs": [
365 | {
366 | "data": {
367 | "image/png": "\n",
368 | "text/plain": [
369 | ""
370 | ]
371 | },
372 | "metadata": {
373 | "needs_background": "light"
374 | },
375 | "output_type": "display_data"
376 | }
377 | ],
378 | "source": [
379 | "# Change viewpoint yaw\n",
380 | "batch_size, n_views, c, h, w = context_x.shape\n",
381 | "pi = 3.1415629\n",
382 | "\n",
383 | "x_ = x_c[scene_id].view(-1, c, h, w)\n",
384 | "v_ = v_c[scene_id].view(-1, 7)\n",
385 | "\n",
386 | "phi = model.representation(x_, v_)\n",
387 | "\n",
388 | "r = torch.sum(phi, dim=0)\n",
389 | "\n",
390 | "f, axarr = plt.subplots(2, num_samples, figsize=(20, 7))\n",
391 | "for i, ax in enumerate(axarr[0].flat):\n",
392 | " v = torch.zeros(7).copy_(v_q[scene_id])\n",
393 | " \n",
394 | " yaw = (i+1) * (pi/8) - pi/2\n",
395 | " v[3], v[4] = np.cos(yaw), np.sin(yaw)\n",
396 | "\n",
397 | " x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)\n",
398 | " ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))\n",
399 | " ax.set_title(r\"Yaw:\" + str(i+1) + r\"$\\frac{\\pi}{8} - \\frac{\\pi}{2}$\")\n",
400 | " ax.axis(\"off\")\n",
401 | " \n",
402 | "for i, ax in enumerate(axarr[1].flat):\n",
403 | " v = torch.zeros(7).copy_(v_q[scene_id])\n",
404 | " \n",
405 | " pitch = (i+1) * (pi/8) - pi/2\n",
406 | " v[5], v[6] = np.cos(pitch), np.sin(pitch)\n",
407 | "\n",
408 | " x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)\n",
409 | " ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))\n",
410 | " ax.set_title(r\"Pitch:\" + str(i+1) + r\"$\\frac{\\pi}{8} - \\frac{\\pi}{2}$\")\n",
411 | " ax.axis(\"off\")"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {},
418 | "outputs": [],
419 | "source": []
420 | }
421 | ],
422 | "metadata": {
423 | "kernelspec": {
424 | "display_name": "Python 3",
425 | "language": "python",
426 | "name": "python3"
427 | },
428 | "language_info": {
429 | "codemirror_mode": {
430 | "name": "ipython",
431 | "version": 3
432 | },
433 | "file_extension": ".py",
434 | "mimetype": "text/x-python",
435 | "name": "python",
436 | "nbconvert_exporter": "python",
437 | "pygments_lexer": "ipython3",
438 | "version": "3.5.6"
439 | }
440 | },
441 | "nbformat": 4,
442 | "nbformat_minor": 2
443 | }
444 |
--------------------------------------------------------------------------------