├── lib
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── Layout.py
│ ├── _init_paths.py
│ ├── ResReader.py
│ ├── ConceptMapper.py
│ ├── ResWriter.py
│ ├── Transform.py
│ ├── DistributionRender.py
│ ├── VAE.py
│ ├── Combine.py
│ └── Describe.py
├── data_loader
│ ├── __init__.py
│ ├── clevr
│ │ ├── __init__.py
│ │ ├── _init_paths.py
│ │ ├── treeutils.py
│ │ └── clevr_tree.py
│ ├── _init_paths.py
│ └── color_mnist_tree_multi.py
├── BiKLD.py
├── weight_init.py
├── reparameterize.py
├── utils.py
├── tree.py
├── config.py
├── LambdaBiKLD.py
└── ResidualModule.py
├── mains
├── __init__.py
├── _init_paths.py
└── pnpnet_main.py
├── models
├── __init__.py
└── PNPNet
│ ├── __init__.py
│ └── pnp_net.py
├── data
├── .gitplaceholder
└── CLEVR
│ └── add_parent.py
├── trainers
├── __init__.py
└── pnpnet_trainer.py
├── .gitignore
├── requirements.txt
├── images
├── combine.png
├── layout.png
├── mapping.png
├── samples.png
├── describe.png
├── pipeline.png
└── transform.png
├── .gitmodules
├── configs
└── pnp_net_configs.yaml
├── README.md
└── LICENSE
/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mains/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/.gitplaceholder:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/PNPNet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/data_loader/clevr/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .idea
3 | results/
4 | .DS_Store
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==0.4.0
2 | torchvision==0.2.1
3 |
--------------------------------------------------------------------------------
/images/combine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/combine.png
--------------------------------------------------------------------------------
/images/layout.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/layout.png
--------------------------------------------------------------------------------
/images/mapping.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/mapping.png
--------------------------------------------------------------------------------
/images/samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/samples.png
--------------------------------------------------------------------------------
/images/describe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/describe.png
--------------------------------------------------------------------------------
/images/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/pipeline.png
--------------------------------------------------------------------------------
/images/transform.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/transform.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data/CLEVR/clevr-dataset-gen"]
2 | path = data/CLEVR/clevr-dataset-gen
3 | url = git@github.com:woodfrog/clevr-dataset-gen.git
4 | [submodule "SemanticCorrectnessScore"]
5 | path = SemanticCorrectnessScore
6 | url = git@github.com:woodfrog/SemanticCorrectnessScore.git
7 |
--------------------------------------------------------------------------------
/lib/modules/Layout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | class Layout(nn.Module):
7 | def __init__(self, hiddim=160):
8 | super(Layout, self).__init__()
9 |
10 | def forward(self, x, y):
11 | pass
12 |
13 |
--------------------------------------------------------------------------------
/lib/data_loader/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | def add_path(path):
5 | if path not in sys.path: sys.path.insert(0, path)
6 |
7 | this_dir = osp.dirname(__file__)
8 |
9 | # Add caffe to PYTHONPATH
10 | project_path = osp.abspath( osp.join(this_dir, '..', '..') )
11 | add_path(project_path)
12 |
--------------------------------------------------------------------------------
/lib/data_loader/clevr/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | def add_path(path):
5 | if path not in sys.path: sys.path.insert(0, path)
6 |
7 | this_dir = osp.dirname(__file__)
8 |
9 | # Add caffe to PYTHONPATH
10 | project_path = osp.abspath(osp.join(this_dir, '..', '..', '..') )
11 | add_path(project_path)
12 |
--------------------------------------------------------------------------------
/mains/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | def add_path(path):
5 | if path not in sys.path: sys.path.insert(0, path)
6 |
7 | this_dir = osp.dirname(__file__)
8 |
9 | # Add caffe to PYTHONPATH
10 | project_path = osp.abspath( osp.join(this_dir, '..'))
11 | print(project_path)
12 | add_path(project_path)
13 |
--------------------------------------------------------------------------------
/lib/modules/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | def add_path(path):
5 | if path not in sys.path: sys.path.insert(0, path)
6 |
7 | this_dir = osp.dirname(__file__)
8 |
9 | # Add caffe to PYTHONPATH
10 | project_path = osp.abspath( osp.join(this_dir, '..', '..') )
11 | print(project_path)
12 | add_path(project_path)
13 |
--------------------------------------------------------------------------------
/lib/BiKLD.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 |
7 | class BiKLD(nn.Module):
8 | def __init__(self):
9 | super(BiKLD, self).__init__()
10 |
11 | def forward(self, q, p):
12 | q_mu, q_var = q[0], t.exp(q[1])
13 | p_mu, p_var = p[0], t.exp(p[1])
14 |
15 | kld = q_var / p_var - 1
16 | kld += (p_mu - q_mu).pow(2) / p_var
17 | kld += p[1] - q[1]
18 | kld = kld.sum() / 2
19 |
20 | return kld
21 |
--------------------------------------------------------------------------------
/lib/weight_init.py:
--------------------------------------------------------------------------------
1 | def weights_init(m):
2 | classname = m.__class__.__name__
3 | if classname.find('Conv') != -1:
4 | try:
5 | m.weight.data.normal_(0.0, 0.02)
6 | m.bias.data.fill_(0)
7 | except:
8 | pass
9 | elif classname.find('Linear') != -1:
10 | try:
11 | m.weight.data.normal_(0.0, 0.02)
12 |
13 | except:
14 | pass
15 | elif classname.find('Embedding') != -1:
16 | try:
17 | m.weight.data.normal_(0.0, 0.02)
18 | m.bias.data.fill_(0)
19 | except:
20 | pass
--------------------------------------------------------------------------------
/lib/reparameterize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 |
6 | class reparameterize(nn.Module):
7 | def __init__(self):
8 | super(reparameterize, self).__init__()
9 |
10 | def forward(self, mu, logvar, sample_num=1, phase='training'):
11 | if phase == 'training':
12 | std = logvar.mul(0.5).exp_()
13 | eps = Variable(std.data.new(std.size()).normal_())
14 | return eps.mul(std).add_(mu)
15 | else:
16 | raise ValueError('Wrong phase. Always assume training phase.')
17 | # elif phase == 'test':
18 | # return mu
19 | # elif phase == 'generation':
20 | # eps = Variable(logvar.data.new(logvar.size()).normal_())
21 | # return eps
22 |
--------------------------------------------------------------------------------
/configs/pnp_net_configs.yaml:
--------------------------------------------------------------------------------
1 | # Random seed
2 | seed: 12138
3 | #
4 | mode: train
5 | dataset: CLEVR
6 | # checkpoint: ./results/CLEVR_64_MULTI_LARGE/PNP-Net-5/checkpoints/model_epoch_360.pth
7 | checkpoint:
8 | data_folder: CLEVR_64_MULTI_LARGE
9 | base_dir: ./data/CLEVR/
10 | exp_dir_name: PNP-Net
11 | # Hyper parameter
12 | hiddim: 160
13 | latentdim: 64
14 | pos_size: [8, 1, 1]
15 | nr_resnet: 5
16 | word_size: 16
17 | ds: 2
18 | combine_op: gPoE
19 | describe_op: CAT_gPoE
20 | maskweight: 2.0
21 | bg_bias: False
22 | normalize: batch_norm
23 | loss: l1
24 | # Training
25 | batch_size: 16
26 | epochs: 500
27 | gpu_id: 0
28 | log_interval: 10
29 | lr: 0.001
30 | kl_beta: 5
31 | alpha_ub: 0.6
32 | pos_beta: 1
33 | warmup_iter: 100
34 | sample_interval: 20
35 | validate_interval: 100
36 | save_interval: 30
37 |
38 |
--------------------------------------------------------------------------------
/lib/modules/ResReader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from torch.nn.utils import weight_norm
5 | import torch.nn.functional as F
6 | from lib.ResidualModule import ResidualModule
7 |
8 |
9 | class Reader(nn.Module):
10 | def __init__(self, indim, hiddim, outdim, ds_times, normalize, nlayers=4):
11 | super(Reader, self).__init__()
12 |
13 | self.ds_times = ds_times
14 |
15 | if normalize == 'gate':
16 | ifgate = True
17 | else:
18 | ifgate = False
19 |
20 | self.encoder = ResidualModule(modeltype='encoder', indim=indim, hiddim=hiddim, outdim=outdim,
21 | nres=self.ds_times, nlayers=nlayers, ifgate=ifgate, normalize=normalize)
22 |
23 | def forward(self, x):
24 | out = self.encoder(x)
25 |
26 | return out
27 |
--------------------------------------------------------------------------------
/lib/modules/ConceptMapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | class ConceptMapper(nn.Module):
7 | def __init__(self, CHW, vocab_size):
8 | super(ConceptMapper, self).__init__()
9 | C, H, W = CHW[0], CHW[1], CHW[2]
10 | self.mean_dictionary = nn.Linear(vocab_size, C*H*W, bias=False)
11 | self.std_dictionary = nn.Linear(vocab_size, C*H*W, bias=False)
12 | self.C, self.H, self.W = C, H, W
13 |
14 | def forward(self, x):
15 | word_mean = self.mean_dictionary(x)
16 | word_std = self.std_dictionary(x)
17 | if self.H == 1 and self.W == 1:
18 | return [word_mean.view(-1, self.C, 1, 1), \
19 | word_std.view(-1, self.C, 1, 1)]
20 | else:
21 | return [word_mean.view(-1, self.C, self.H, self.W), \
22 | word_std.view(-1, self.C, self.H, self.W)]
23 |
24 |
--------------------------------------------------------------------------------
/lib/modules/ResWriter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from torch.nn.utils import weight_norm
5 | import torch.nn.functional as F
6 | from lib.ResidualModule import ResidualModule
7 |
8 |
9 | class Writer(nn.Module):
10 | def __init__(self, indim, hiddim, outdim, ds_times, normalize, nlayers=4):
11 | super(Writer, self).__init__()
12 |
13 | self.ds_times = ds_times
14 |
15 | if normalize == 'gate':
16 | ifgate = True
17 | else:
18 | ifgate = False
19 |
20 | self.decoder = ResidualModule(modeltype='decoder', indim=indim, hiddim=hiddim, outdim=hiddim,
21 | nres=self.ds_times, nlayers=nlayers, ifgate=ifgate, normalize=normalize)
22 |
23 | self.out_conv = nn.Conv2d(hiddim, outdim, 3, 1, 1)
24 |
25 | def forward(self, x):
26 | out = self.decoder(x)
27 | out = self.out_conv(out)
28 |
29 | return out
30 |
--------------------------------------------------------------------------------
/lib/modules/Transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | class Transform(nn.Module):
9 | def __init__(self, matrix='default'):
10 | super(Transform, self).__init__()
11 |
12 | def forward(self, x, hw, variance=False):
13 | if variance:
14 | x = torch.exp(x / 2.0)
15 | size = torch.Size([x.size(0), x.size(1), int(hw[0]), int(hw[1])])
16 |
17 | # grid generation
18 | theta = np.array([[[1, 0, 0], [0, 1, 0]]], dtype=np.float32)
19 | theta = Variable(torch.from_numpy(theta), requires_grad=False).cuda()
20 | theta = theta.expand(x.size(0), theta.size(1), theta.size(2))
21 | gridout = F.affine_grid(theta, size)
22 |
23 | # bilinear sampling
24 | out = F.grid_sample(x, gridout, mode='bilinear')
25 | if variance:
26 | out = torch.log(out) * 2.0
27 | return out
28 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import numpy as np
3 |
4 |
5 | def color_grid_vis(X, nh, nw, save_path):
6 | h, w = X[0].shape[:2]
7 | img = np.zeros((h * nh, w * nw, 3))
8 | for n, x in enumerate(X):
9 | j = int(n / nw)
10 | i = n % nw
11 | img[j * h:j * h + h, i * w:i * w + w, :] = x
12 | scipy.misc.imsave(save_path, img)
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 |
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.pixel_count = 0
25 | self.batch_count = 0
26 |
27 | def update(self, val, n=1, batch=1):
28 | self.val = val
29 | self.sum += val * n
30 | self.pixel_count += n
31 | self.batch_count += batch
32 | self.pixel_avg = self.sum / self.pixel_count
33 | self.batch_avg = self.sum / self.batch_count
34 |
--------------------------------------------------------------------------------
/lib/tree.py:
--------------------------------------------------------------------------------
1 | # tree object from stanfordnlp/treelstm
2 | class Tree(object):
3 | def __init__(self):
4 | self.parent = None
5 | self.num_children = 0
6 | self.children = list()
7 |
8 | def add_child(self, child):
9 | child.parent = self
10 | self.num_children += 1
11 | self.children.append(child)
12 |
13 | def size(self):
14 | if getattr(self, '_size'):
15 | return self._size
16 | count = 1
17 | for i in xrange(self.num_children):
18 | count += self.children[i].size()
19 | self._size = count
20 | return self._size
21 |
22 | def depth(self):
23 | if getattr(self, '_depth'):
24 | return self._depth
25 | count = 0
26 | if self.num_children > 0:
27 | for i in xrange(self.num_children):
28 | child_depth = self.children[i].depth()
29 | if child_depth > count:
30 | count = child_depth
31 | count += 1
32 | self._depth = count
33 | return self._depth
34 |
--------------------------------------------------------------------------------
/lib/modules/DistributionRender.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from torch.nn.utils import weight_norm
5 | import torch.nn.functional as F
6 |
7 | class DistributionRender(nn.Module):
8 | def __init__(self, hiddim):
9 | super(DistributionRender, self).__init__()
10 | self.render_mean = nn.Sequential(
11 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
12 | nn.ELU(inplace=True),
13 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
14 | nn.ELU(inplace=True),
15 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
16 | )
17 |
18 | self.render_var = nn.Sequential(
19 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
20 | nn.ELU(inplace=True),
21 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
22 | nn.ELU(inplace=True),
23 | nn.Conv2d(hiddim, hiddim, 3, 1, 1),
24 | )
25 |
26 |
27 | def forward(self, x):
28 | # x = [mean, var]
29 | return self.render_mean(x[0]), self.render_var(x[1])
30 |
--------------------------------------------------------------------------------
/data/CLEVR/add_parent.py:
--------------------------------------------------------------------------------
1 | """
2 | This script adds a parent pointer for all tree nodes.
3 | The original trees generated together with CLEVR images doesn't have this kind of parent pointer,
4 | but the parent pointer can be useful when we implementing PNP-Net
5 |
6 | """
7 |
8 | import _init_paths
9 | import pickle
10 | from lib.tree import Tree
11 | import os.path as osp
12 | import os
13 |
14 | def add_parent(tree):
15 | tree = _add_parent(tree, None)
16 |
17 | return tree
18 |
19 | def _add_parent(tree, parent):
20 | tree.parent = parent
21 | for i in range(0, tree.num_children):
22 | tree.children[i] = _add_parent(tree.children[i], tree)
23 |
24 | return tree
25 |
26 | path = 'CLEVR_128_NEW/trees_no_parent'
27 | outpath = 'CLEVR_128_NEW/trees'
28 |
29 | split = ['train', 'test']
30 |
31 | for s in split:
32 | treepath = osp.join(path, s)
33 | files = os.listdir(treepath)
34 | try:
35 | os.makedirs(osp.join(outpath, s))
36 | except:
37 | pass
38 |
39 | for fi in files:
40 | if fi.endswith('tree'):
41 | with open(osp.join(path, s, fi), 'rb') as f:
42 | treei = pickle.load(f)
43 | treei = add_parent(treei)
44 | pickle.dump(treei, open(osp.join(outpath, s, fi), 'wb'))
45 |
--------------------------------------------------------------------------------
/lib/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 |
4 | def load_config(file_path):
5 | with open(file_path, 'r') as f:
6 | config_dict = yaml.load(f)
7 | return config_dict
8 |
9 |
10 | class Struct:
11 | def __init__(self, **entries):
12 | rec_entries = {}
13 | for k, v in entries.items():
14 | if isinstance(v, dict):
15 | rv = Struct(**v)
16 | elif isinstance(v, list):
17 | rv = []
18 | for item in v:
19 | if isinstance(item, dict):
20 | rv.append(Struct(**item))
21 | else:
22 | rv.append(item)
23 | else:
24 | rv = v
25 | rec_entries[k] = rv
26 | self.__dict__.update(rec_entries)
27 |
28 | def __str_helper(self, depth):
29 | lines = []
30 | for k, v in self.__dict__.items():
31 | if isinstance(v, Struct):
32 | v_str = v.__str_helper(depth + 1)
33 | lines.append("%s:\n%s" % (k, v_str))
34 | else:
35 | lines.append("%s: %r" % (k, v))
36 | indented_lines = [" " * depth + l for l in lines]
37 | return "\n".join(indented_lines)
38 |
39 | def __str__(self):
40 | return "struct {\n%s\n}" % self.__str_helper(1)
41 |
42 | def __repr__(self):
43 | return "Struct(%r)" % self.__dict__
44 |
45 |
46 | if __name__ == '__main__':
47 | config_dic = load_config('./configs/pnp_net_configs.yaml')
48 | configs = Struct(**config_dic)
49 | import pdb
50 | pdb.set_trace()
51 |
--------------------------------------------------------------------------------
/lib/LambdaBiKLD.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.autograd import Variable
3 | import torch.nn.functional as F
4 | import torch
5 | import numpy as np
6 |
7 |
8 | class BiKLD(nn.Module):
9 | def __init__(self, lambda_t, k):
10 | super(BiKLD, self).__init__()
11 | print('Using the modified KL Divergence analytic solution, thresholded by lambda:', lambda_t)
12 | self.lambda_t = lambda_t
13 | self.k = k
14 |
15 | def sample(self, mu, logvar, k):
16 | # input: [B, C, H, W], [B, C, H, W]
17 | # output: [B, K, C, H, W]
18 | std = logvar.mul(0.5).exp_()
19 | eps = Variable(std.data.new(std.size()).normal_())
20 | return eps.mul(std).add_(mu)
21 |
22 | def gaussian_diag_logps(self, mu, logvar, z):
23 | logps = logvar + ((z - mu) ** 2) / (logvar.exp())
24 | logps = logps.add(np.log(2 * np.pi))
25 | logps = logps.mul(-0.5)
26 | return logps
27 |
28 | def expand_dis(self, p, k):
29 | mu, logvar = p[0], p[1]
30 | logvar = logvar.unsqueeze(1).expand(logvar.size(0), k, logvar.size(1), logvar.size(2), logvar.size(3))
31 | mu = mu.unsqueeze(1).expand(mu.size(0), k, mu.size(1), mu.size(2), mu.size(3))
32 | return mu, logvar
33 |
34 | def forward(self, q, p):
35 | # input: [B, C, H, W], [B, C, H, W]
36 | # output: [1]
37 | # expand
38 | q_mu, q_var = q[0], torch.exp(q[1])
39 | p_mu, p_var = p[0], torch.exp(p[1])
40 |
41 | kld = q_var / p_var - 1
42 | kld += (p_mu - q_mu).pow(2) / p_var
43 | kld += p[1] - q[1]
44 | kld = kld.sum(dim=3).sum(dim=2).mean(0) / 2
45 |
46 | lambda_tensor = Variable(self.lambda_t * torch.ones(kld.size())).cuda()
47 | kld = torch.max(kld, lambda_tensor)
48 | kld = kld.sum() * p[0].size(0)
49 |
50 | return kld
51 |
52 |
53 | if __name__ == '__main__':
54 | # test code
55 | q_mu = Variable(torch.zeros(4, 32, 16, 16)).cuda()
56 | q_logvar = Variable(torch.zeros(4, 32, 16, 16)).cuda()
57 | p_mu = Variable(torch.zeros(4, 32, 16, 16)).cuda()
58 | p_logvar = Variable(torch.zeros(4, 32, 16, 16)).cuda()
59 |
60 | kldloss = BiKLD(lambda_t=0.01, k=10)
61 |
62 | kld = kldloss([q_mu, q_logvar], [p_mu, p_logvar])
63 |
--------------------------------------------------------------------------------
/lib/modules/VAE.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.nn.utils import weight_norm
6 | import torch.nn.functional as F
7 | from lib.BiKLD import BiKLD
8 | from lib.reparameterize import reparameterize
9 |
10 |
11 | class VAE(nn.Module):
12 | def __init__(self, indim, latentdim, half=False):
13 | super(VAE, self).__init__()
14 |
15 | self.half = half
16 | if self.half is False:
17 | self.encoder = nn.Sequential(
18 | nn.Linear(indim, latentdim * 2),
19 | nn.ELU(inplace=True),
20 | nn.Linear(latentdim * 2, latentdim * 2),
21 | nn.ELU(inplace=True),
22 | nn.Linear(latentdim * 2, latentdim * 2)
23 | )
24 | self.mean = nn.Linear(latentdim * 2, latentdim)
25 | self.logvar = nn.Linear(latentdim * 2, latentdim)
26 | self.bikld = BiKLD()
27 |
28 | dec_out = indim
29 | self.decoder = nn.Sequential(
30 | nn.Linear(latentdim, latentdim * 2),
31 | nn.ELU(inplace=True),
32 | nn.Linear(latentdim * 2, latentdim * 2),
33 | nn.ELU(inplace=True),
34 | nn.Linear(latentdim * 2, dec_out)
35 | )
36 |
37 | self.sampler = reparameterize()
38 |
39 | def forward(self, x=None, prior=None):
40 | prior = [prior[0].view(1, -1), prior[1].view(1, -1)]
41 |
42 | if self.half is False:
43 | encoding = self.encoder(x)
44 | mean, logvar = self.mean(encoding), self.logvar(encoding)
45 | kld = self.bikld([mean, logvar], prior)
46 | z = self.sampler(mean, logvar)
47 | else:
48 | z = self.sampler(prior[0], prior[1])
49 | kld = 0
50 |
51 | decoding = self.decoder(z)
52 |
53 | return decoding, kld
54 |
55 | def generate(self, prior):
56 | prior = [prior[0].view(1, -1), prior[1].view(1, -1)]
57 | z = self.sampler(*prior)
58 |
59 | decoding = self.decoder(z)
60 |
61 | return decoding
62 |
63 |
64 | '''
65 | #Test case 0
66 | model = VAE(6, 4).cuda()
67 | mean = Variable(torch.zeros(16, 4)).cuda()
68 | var = Variable(torch.zeros(16, 4)).cuda()
69 | data = Variable(torch.zeros(16, 6)).cuda()
70 | out, kld = model(data, [mean, var])
71 |
72 | #Test case 1
73 | model = VAE(6, 4, 10).cuda()
74 | mean = Variable(torch.zeros(16, 4)).cuda()
75 | var = Variable(torch.zeros(16, 4)).cuda()
76 | data = Variable(torch.zeros(16, 6)).cuda()
77 | condition = Variable(torch.zeros(16, 10)).cuda()
78 | out, kld = model(data, [mean, var], condition)
79 | '''
80 |
--------------------------------------------------------------------------------
/lib/modules/Combine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from torch.nn.utils import weight_norm
5 | import torch.nn.functional as F
6 |
7 | class Combine(nn.Module):
8 | def __init__(self, hiddim_v, hiddim_p=None, op='PROD'):
9 | super(Combine, self).__init__()
10 | self.op = op
11 | self.hiddim_v = hiddim_v
12 | self.hiddim_p = hiddim_p
13 | if self.op == 'DEEP':
14 | self.net_vis = nn.Sequential(
15 | weight_norm(nn.Conv2d(4*hiddim_v, hiddim_v, 3, 1, 1)),
16 | nn.ELU(inplace=True),
17 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)),
18 | nn.ELU(inplace=True),
19 | weight_norm(nn.Conv2d(hiddim_v, 2*hiddim_v, 3, 1, 1))
20 | )
21 | self.net_pos = nn.Sequential(
22 | weight_norm(nn.Conv2d(4*hiddim_p, hiddim_p, 1, 1)),
23 | nn.ELU(inplace=True),
24 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)),
25 | nn.ELU(inplace=True),
26 | weight_norm(nn.Conv2d(hiddim_p, 2*hiddim_p, 1, 1))
27 | )
28 |
29 | elif self.op == 'CAT':
30 | self.net_mean_vis = nn.Sequential(
31 | weight_norm(nn.Conv2d(hiddim_v*2, hiddim_v, 3, 1, 1)),
32 | nn.ELU(inplace=True),
33 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
34 | )
35 | self.net_var_vis = nn.Sequential(
36 | weight_norm(nn.Conv2d(hiddim_v*2, hiddim_v, 3, 1, 1)),
37 | nn.ELU(inplace=True),
38 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
39 | )
40 | self.net_mean_pos = nn.Sequential(
41 | weight_norm(nn.Conv2d(hiddim_p*2, hiddim_p, 1, 1)),
42 | nn.ELU(inplace=True),
43 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
44 | )
45 | self.net_var_pos = nn.Sequential(
46 | weight_norm(nn.Conv2d(hiddim_p*2, hiddim_p, 1, 1)),
47 | nn.ELU(inplace=True),
48 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
49 | )
50 | elif self.op == 'gPoE':
51 | self.gates_v = nn.Sequential(
52 | weight_norm(nn.Conv2d(hiddim_v*4, hiddim_v*4, 3, 1, 1)),
53 | nn.ELU(inplace=True),
54 | weight_norm(nn.Conv2d(hiddim_v*4, hiddim_v*4, 3, 1, 1))
55 | )
56 | self.gates_p = nn.Sequential(
57 | weight_norm(nn.Conv2d(hiddim_p*4, hiddim_p*4, 3, 1, 1)),
58 | nn.ELU(inplace=True),
59 | weight_norm(nn.Conv2d(hiddim_p*4, hiddim_p*4, 3, 1, 1))
60 | )
61 |
62 | def forward(self, x1, x2, mode='vis'):
63 | if self.op == 'PROD':
64 | return [x1[0]*x2[0], x1[1]*x2[1]]
65 | elif self.op == 'PoE':
66 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
67 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
68 | mlogvar1 = -x1[1]
69 | mlogvar2 = -x2[1]
70 | mu1 = x1[0]
71 | mu2 = x2[0]
72 |
73 | logvar = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
74 | mu = torch.exp(logvar)*(torch.exp(mlogvar1)*mu1 + torch.exp(mlogvar2)*mu2)
75 | return [mu, logvar]
76 | elif self.op == 'gPoE':
77 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
78 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
79 |
80 | if mode == 'vis':
81 | gates = torch.sigmoid(self.gates_v(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1)))
82 | x1_mu_g = gates[:,:self.hiddim_v,:,:]
83 | x1_var_g = gates[:,self.hiddim_v:2*self.hiddim_v,:,:]
84 | x2_mu_g = gates[:,2*self.hiddim_v:3*self.hiddim_v,:,:]
85 | x2_var_g = gates[:,3*self.hiddim_v:4*self.hiddim_v,:,:]
86 | elif mode == 'pos':
87 | gates = torch.sigmoid(self.gates_p(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1)))
88 | x1_mu_g = gates[:,:self.hiddim_p,:,:]
89 | x1_var_g = gates[:,self.hiddim_p:2*self.hiddim_p,:,:]
90 | x2_mu_g = gates[:,2*self.hiddim_p:3*self.hiddim_p,:,:]
91 | x2_var_g = gates[:,3*self.hiddim_p:4*self.hiddim_p,:,:]
92 |
93 | x1[0] = x1_mu_g*x1[0]
94 | x1[1] = torch.log(x1_var_g + 1e-5) + x1[1]
95 | x2[0] = x2_mu_g*x2[0]
96 | x2[1] = torch.log(x2_var_g + 1e-5) + x2[1]
97 |
98 | mlogvar1 = -x1[1]
99 | mlogvar2 = -x2[1]
100 | mu1 = x1[0]
101 | mu2 = x2[0]
102 |
103 | logvar = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
104 | mu = torch.exp(logvar)*(torch.exp(mlogvar1)*mu1 + torch.exp(mlogvar2)*mu2)
105 | return [mu, logvar]
106 | elif self.op == 'ADD':
107 | return [x1[0] + x2[0], x1[1], x2[1]]
108 | elif self.op == 'CAT':
109 | if mode == 'vis':
110 | return [self.net_mean_vis(torch.cat([x1[0], x2[0]], dim=1)), \
111 | self.net_var_vis(torch.cat([x1[1], x2[1]], dim=1))]
112 | elif mode == 'pos':
113 | return [self.net_mean_pos(torch.cat([x1[0], x2[0]], dim=1)), \
114 | self.net_var_pos(torch.cat([x1[1], x2[1]], dim=1))]
115 | elif self.op == 'DEEP':
116 | if mode == 'vis':
117 | gaussian_out = self.net_vis(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1))
118 | return [gaussian_out[:,:self.hiddim_v,:,:], gaussian_out[:,self.hiddim_v:,:,:]]
119 | elif mode == 'pos':
120 | gaussian_out = self.net_pos(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1))
121 | return [gaussian_out[:,:self.hiddim_p,:,:], gaussian_out[:,self.hiddim_p:,:,:]]
122 | else:
123 | print('Operator:', self.op)
124 | raise ValueError('Unknown operator for combine module.')
125 |
--------------------------------------------------------------------------------
/lib/ResidualModule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.autograd import Variable
5 | from torch.nn.utils import weight_norm
6 |
7 |
8 | class ResidualModule(nn.Module):
9 | def __init__(self, modeltype, indim, hiddim, outdim, nlayers, nres, ifgate=False, nonlinear='elu', normalize='instance_norm'):
10 | super(ResidualModule, self).__init__()
11 | if ifgate:
12 | print('Using gated version.')
13 | if modeltype == 'encoder':
14 | self.model = self.encoder(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize)
15 | elif modeltype == 'decoder':
16 | self.model = self.decoder(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize)
17 | elif modeltype == 'plain':
18 | self.model = self.plain(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize)
19 | else:
20 | raise ('Uknown model type.')
21 |
22 | def encoder(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize):
23 | layers = []
24 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize))
25 |
26 | for i in range(0, nres):
27 | for j in range(0, nlayers):
28 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize))
29 | layers.append(ResidualBlock('down', nonlinear, ifgate, hiddim, hiddim, normalize))
30 |
31 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize))
32 |
33 | return nn.Sequential(*layers)
34 |
35 | def decoder(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize):
36 | layers = []
37 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize))
38 |
39 | for i in range(0, nres):
40 | for j in range(0, nlayers):
41 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize))
42 | layers.append(ResidualBlock('up', nonlinear, ifgate, hiddim, hiddim, normalize))
43 |
44 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize))
45 |
46 | return nn.Sequential(*layers)
47 |
48 | def plain(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize):
49 | layers = []
50 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize))
51 |
52 | for i in range(0, nres):
53 | for j in range(0, nlayers):
54 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize))
55 |
56 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize))
57 |
58 | return nn.Sequential(*layers)
59 |
60 | def forward(self, x):
61 | return self.model(x)
62 |
63 |
64 | class ResidualBlock(nn.Module):
65 | def __init__(self, resample, nonlinear, ifgate, indim, outdim, normalize):
66 | super(ResidualBlock, self).__init__()
67 |
68 | self.ifgate = ifgate
69 | self.indim = indim
70 | self.outdim = outdim
71 | self.resample = resample
72 |
73 | if resample == 'down':
74 | convtype = 'sconv_d'
75 | elif resample == 'up':
76 | convtype = 'upconv'
77 | elif resample == None:
78 | convtype = 'sconv'
79 |
80 | self.shortflag = False
81 | if not (indim == outdim and resample == None):
82 | self.shortcut = self.conv(convtype, indim, outdim)
83 | self.shortflag = True
84 |
85 | if ifgate:
86 | self.conv1 = nn.Conv2d(indim, outdim, 3, 1, 1)
87 | self.conv2 = nn.Conv2d(indim, outdim, 3, 1, 1)
88 | self.c = nn.Sigmoid()
89 | self.g = nn.Tanh()
90 | self.conv3 = self.conv(convtype, outdim, outdim)
91 | self.act = self.nonlinear(nonlinear)
92 | elif normalize == 'batch_norm':
93 | self.resblock = nn.Sequential(
94 | self.conv('sconv', indim, outdim),
95 | nn.BatchNorm2d(outdim),
96 | self.nonlinear(nonlinear),
97 | self.conv(convtype, outdim, outdim),
98 | nn.BatchNorm2d(outdim),
99 | self.nonlinear(nonlinear)
100 | )
101 | elif normalize == 'instance_norm':
102 | self.resblock = nn.Sequential(
103 | self.conv('sconv', indim, outdim),
104 | nn.InstanceNorm2d(outdim),
105 | self.nonlinear(nonlinear),
106 | self.conv(convtype, outdim, outdim),
107 | nn.InstanceNorm2d(outdim),
108 | self.nonlinear(nonlinear)
109 | )
110 | elif normalize == 'no_norm':
111 | self.resblock = nn.Sequential(
112 | self.conv('sconv', indim, outdim),
113 | self.nonlinear(nonlinear),
114 | self.conv(convtype, outdim, outdim),
115 | self.nonlinear(nonlinear)
116 | )
117 | elif normalize == 'weight_norm':
118 | self.resblock = nn.Sequential(
119 | self.conv('sconv', indim, outdim, 'weight_norm'),
120 | self.nonlinear(nonlinear),
121 | self.conv(convtype, outdim, outdim, 'weight_norm'),
122 | self.nonlinear(nonlinear)
123 | )
124 |
125 | def conv(self, name, indim, outdim, normalize=None):
126 | if name == 'sconv_d':
127 | if normalize == 'weight_norm':
128 | return weight_norm(nn.Conv2d(indim, outdim, 4, 2, 1))
129 | else:
130 | return nn.Conv2d(indim, outdim, 4, 2, 1)
131 | elif name == 'sconv':
132 | if normalize == 'weight_norm':
133 | return weight_norm(nn.Conv2d(indim, outdim, 3, 1, 1))
134 | else:
135 | return nn.Conv2d(indim, outdim, 3, 1, 1)
136 | elif name == 'upconv':
137 | if normalize == 'weight_norm':
138 | return weight_norm(nn.ConvTranspose2d(indim, outdim, 4, 2, 1))
139 | else:
140 | return nn.ConvTranspose2d(indim, outdim, 4, 2, 1)
141 | else:
142 | raise ("Unknown convolution type")
143 |
144 | def nonlinear(self, name):
145 | if name == 'elu':
146 | return nn.ELU(1, True)
147 | elif name == 'relu':
148 | return nn.ReLU(True)
149 |
150 | def forward(self, x):
151 | if self.ifgate:
152 | conv1 = self.conv1(x)
153 | conv2 = self.conv2(x)
154 | c = self.c(conv1)
155 | g = self.g(conv2)
156 | gated = c * g
157 | conv3 = self.conv3(gated)
158 | res = self.act(conv3)
159 | if not (self.indim == self.outdim and self.resample == None):
160 | out = self.shortcut(x) + res
161 | else:
162 | out = x + res
163 | else:
164 | if self.shortflag:
165 | out = self.shortcut(x) + self.resblock(x)
166 | else:
167 | out = x + self.resblock(x)
168 |
169 | return out
170 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Probabilistic Neural Programmed Network
2 |
3 | This is the official implementation for paper:
4 |
5 | [Probabilistic Neural Programmed Networks for Scene Generation](http://www2.cs.sfu.ca/~mori/research/papers/deng-nips18.pdf)
6 |
7 | [Zhiwei Deng](http://www.sfu.ca/~zhiweid/), [Jiacheng Chen](http://jcchen.me/), [Yifang Fu](https://yifangfu.wordpress.com/) and [Greg Mori](http://www2.cs.sfu.ca/~mori/)
8 |
9 |
10 | Published on NeurIPS 2018
11 |
12 | [Poster](http://www.sfu.ca/~zhiweid/papers/PNP_Net_Poster.pdf)
13 |
14 | If you find this code helpful in your research, please cite
15 |
16 | ```
17 | @inproceedings{deng2018probabilistic,
18 | title={Probabilistic Neural Programmed Networks for Scene Generation},
19 | author={Deng, Zhiwei and Chen, Jiacheng and Fu, Yifang and Mori, Greg},
20 | booktitle={Advances in Neural Information Processing Systems},
21 | pages={4032--4042},
22 | year={2018}
23 | }
24 | ```
25 |
26 | Other re-implementations:
27 |
28 | [Tensorflow version](https://github.com/mihirp1998/ProbabilisticNeuralProgrammedNetwork_Tensorflow)
29 |
30 | ## Contents
31 | 1. [Overview](#overview)
32 | 2. [Environment Setup](#environment)
33 | 3. [Data and Pre-trained Models](#data-and-models)
34 | - [CLEVR-G](#CLEVR-G)
35 | 4. [Configurations](#configurations)
36 | 5. [Code Guide](#code-guide)
37 | - [Neural Operators](#neural-operators)
38 | 6. [Training Model](#training)
39 | 7. [Evaluation](#evaluation)
40 | 8. [Results](#results)
41 |
42 |
43 | ## Overview
44 |
45 | Generating scenes from rich and complex semantics is an important step towards understanding the visual world. Probabilistic Neural Programmed Network (PNP-Net) brings symbolic methods into generative models, it exploits a set of **reusable neural modules** to compose latent distributions for scenes described by complex semantics in a **programmatic** manner, a decoder can then sample from latent scene distributions and generate realistic images. PNP-Net is naturally formulated as a learnable prior in canonical VAE framework to learn the parameters efficiently.
46 |
47 |
48 |
49 |
50 |

51 |
52 |
53 |
54 |
55 | ## Environment
56 |
57 | All code was tested on Ubuntu 16.04 with Python 2.7 and **PyTorch 0.4.0** (but the code should also work well with Python 3). To install required environment, run:
58 |
59 | ```bash
60 | pip install -r requirements.txt
61 | ```
62 |
63 | For running our measurement (a semantic correctness score based on detector), check [this submodule](https://github.com/woodfrog/SemanticCorrectnessScore/tree/483c6ef2e0548fcc629059b84c489cd4e0c19f86) for full details (**it's released now**).
64 |
65 | ## Data and Models
66 |
67 | ### CLEVR-G
68 |
69 | We used the released code of [CLEVR (Johnson et al.)](https://arxiv.org/pdf/1612.06890.pdf) to generate a modified CLEVR dataset for the task of scene image generation, and we call it CLEVR-G. The generation code is in the [submodule](https://github.com/woodfrog/clevr-dataset-gen/tree/42a5c4914bbae49a0cd36cf96607c05111394ddc).
70 |
71 | We also provide the [64x64 CLEVR-G](https://drive.google.com/open?id=1QCIINcIOdcIl5U0IZrHpj5XZW2FZBbJv) used in our experiments. Please download and zip it into **./data/CLEVR** if you want to use it with our model.
72 |
73 | **Pre-trained Model**
74 |
75 | Please download pre-trained models from:
76 |
77 | - [PNP-Net CLEVR-G 64x64](https://drive.google.com/open?id=1VusqEqIHZibRqKbXyIxJDxBRTp9AZP0y)
78 |
79 |
80 | ### COLOR-MNIST
81 |
82 | - [PNP-Net COLOR-MNIST 64x64](https://www.dropbox.com/s/jf5u7yosyisf8zd/MULTI_8000_64_100.tar?dl=0)
83 |
84 |
85 |
86 | ## Configurations
87 |
88 | We use [global configuration files](configs/pnp_net_configs.yaml) to set up all configs, including the training settings and model hyper-parameters. Please check the file and corresponding code for more detail.
89 |
90 |
91 | ## Code Guide
92 |
93 | ### Neural Operators
94 |
95 | The core of PNP-Net is a set of **neural modular operators**. We briefly introduce them here and provide the pointers to corresponding code.
96 |
97 |
98 | - **Concept Mapping Operator**
99 |
100 |
101 |
102 |

103 |
104 |
105 |
106 | Convert one-hot representation of word concepts into appearance and scale distribution.
107 | [code](lib/modules/ConceptMapper.py)
108 |
109 | - **Combine Operator**
110 |
111 |
112 |

113 |
114 |
115 |
116 | Combine module combines the latent distributions of two attributes. [code](lib/modules/Combine.py)
117 |
118 |
119 | - **Describe Operator**
120 |
121 |
122 |

123 |
124 |
125 | Attributes describe an object, this module takes the distributions of attributes (merged using combine module) and uses it to render the distributions of an object. [code](lib/modules/Describe.py)
126 |
127 |
128 | - **Transform Operator**
129 |
130 |
131 |

132 |
133 |
134 | This module first samples a size instance from an object's scale distribution and then use bilinear interpolation to re-size the appearance distribution. [code](lib/modules/Transform.py)
135 |
136 |
137 | - **Layout Operator**
138 |
139 |
140 |

141 |
142 |
143 |
144 | Layout module puts latent distributions of two different objects (from its children nodes) on a background latent canvas according to the offsets of the two children objects. [code](models/PNPNet/pnp_net.py#L267)
145 |
146 |
147 |
148 | ## Training
149 |
150 | The default training can be started by:
151 |
152 | ```bash
153 | python mains/pnpnet_main.py --config_path configs/pnp_net_configs.yaml
154 | ```
155 |
156 | Make sure that you are in the project root directory when typing the above command.
157 |
158 |
159 |
160 | ## Evaluation
161 |
162 | The evaluation has two major steps:
163 |
164 | 1. Generate images according to the semantics in the test set using pre-trained model.
165 |
166 | 2. Run our [detector-based semantic correctness score](https://github.com/woodfrog/SemanticCorrectnessScore) to evaluate the quality of images. Please check that repo for more details about our proposed metric for measuring semantic correctness of scene images.
167 |
168 |
169 | For generating test images using pre-trained model, first set the code mode to be **test**, then set up the checkpoint path properly in the config file, finally run the same command as training:
170 |
171 | ```bash
172 | python mains/pnpnet_main.py --config_path configs/pnp_net_configs.yaml
173 | ```
174 |
175 |
176 | ## Results
177 |
178 | Detailed results can be checked in our paper. We provide some samples here to show PNP-Net's capability for generating images for different complex scenes. The dataset used here is CLEVR-G 128x128, every scene contains at most 8 objects.
179 |
180 |
181 |
182 |

183 |
184 |
185 | When the scene becomes too complex, PNP-Net can suffer from the following problems:
186 |
187 | 1. It might fail to handle occlusion between objects. When multiple objects overlap, their latents get mixed on the background latent canvas, and the appearance of objects can be distorted.
188 |
189 | 2. It might put some of the objects out of the image boundary, therefore some images do not contain the correct number of objects as described by the semantics.
190 |
--------------------------------------------------------------------------------
/trainers/pnpnet_trainer.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import numpy as np
3 | import os
4 | import os.path as osp
5 | import datetime
6 | import pytz
7 | import pdb
8 |
9 | import torch
10 | from torch.autograd import Variable
11 |
12 | from lib.utils import color_grid_vis, AverageMeter
13 |
14 |
15 | class PNPNetTrainer:
16 | def __init__(self, model, train_loader, val_loader, gen_loader, optimizer, configs):
17 | self.model = model
18 | self.train_loader = train_loader
19 | self.val_loader = val_loader
20 | self.gen_loader = gen_loader
21 | self.optimizer = optimizer
22 | self.configs = configs
23 |
24 | def train_epoch(self, epoch_num, timestamp_start):
25 | self.model.train()
26 | train_rec_loss = AverageMeter()
27 | train_kld_loss = AverageMeter()
28 | train_pos_loss = AverageMeter()
29 | batch_idx = 0
30 | epoch_end = False
31 | # annealing for kl penalty
32 | kl_coeff = float(epoch_num) / float(self.configs.warmup_iter + 1)
33 | if kl_coeff >= self.configs.alpha_ub:
34 | kl_coeff = self.configs.alpha_ub
35 | print('kl penalty coefficient: ', kl_coeff, 'alpha upperbound:', self.configs.alpha_ub)
36 | while epoch_end is False:
37 | data, trees, _, epoch_end = self.train_loader.next_batch()
38 | data = Variable(data).cuda()
39 |
40 | self.optimizer.zero_grad()
41 | ifmask = False
42 | if self.configs.maskweight > 0:
43 | ifmask = True
44 | rec_loss, kld_loss, pos_loss, modelout = self.model(data, trees, alpha=kl_coeff, ifmask=ifmask,
45 | maskweight=self.configs.maskweight)
46 | recon = modelout
47 | rec_loss, kld_loss, pos_loss = rec_loss.sum() / self._total(data), kld_loss.sum() / self._total(data), pos_loss.sum() / self._total(data)
48 | loss = rec_loss + self.configs.kl_beta * kld_loss + self.configs.pos_beta * pos_loss
49 | loss.backward()
50 | self.optimizer.step()
51 | train_rec_loss.update(rec_loss.item(), self._total(data), data.size(0))
52 | train_kld_loss.update(kld_loss.item(), self._total(data), data.size(0))
53 | train_pos_loss.update(pos_loss.item(), self._total(data), data.size(0))
54 |
55 | if batch_idx % 30 == 0:
56 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_data.png'),
57 | (data.cpu().data.numpy().transpose(0, 2, 3, 1)[0] + 1) / 2.0)
58 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_reconstruction.png'), \
59 | (recon.cpu().data.numpy().transpose(0, 2, 3, 1)[0] + 1) / 2.0)
60 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_reconstruction_clip.png'), \
61 | np.clip(recon.cpu().data.numpy().transpose(0, 2, 3, 1)[0], -1, 1))
62 | print('Epoch:{0}\tIter:{1}/{2}\tRecon {3:.6f}\t KL {4:.6f}'.format(epoch_num, batch_idx,
63 | len(self.train_loader) // self.configs.batch_size,
64 | train_rec_loss.batch_avg, train_kld_loss.batch_avg))
65 |
66 | self.model.clean_tree(trees)
67 | batch_idx += 1
68 |
69 | elapsed_time = \
70 | datetime.datetime.now(pytz.timezone('America/New_York')) - \
71 | timestamp_start
72 |
73 | print('====> Epoch: {} Average rec loss: {:.6f} Average kld loss: {:.6f} Average pos loss: {:.6f}'.format(
74 | epoch_num, train_rec_loss.batch_avg, train_kld_loss.batch_avg, train_pos_loss.batch_avg))
75 | print('Elapsed time:', elapsed_time)
76 |
77 | @staticmethod
78 | def _total(tensor):
79 | return tensor.size(0) * tensor.size(1) * tensor.size(2) * tensor.size(3)
80 |
81 | def validate(self, epoch_num, timestamp_start, minloss):
82 | self.model.eval()
83 | test_rec_loss = AverageMeter()
84 | test_kld_loss = AverageMeter()
85 | test_pos_loss = AverageMeter()
86 | epoch_end = False
87 | count = 0.
88 | while epoch_end is False:
89 | data, trees, _, epoch_end = self.val_loader.next_batch()
90 | data = Variable(data, volatile=True).cuda()
91 | rec_loss, kld_loss, pos_loss, modelout = self.model(data, trees)
92 |
93 | rec_loss, kld_loss, pos_loss = rec_loss.sum(), kld_loss.sum(), pos_loss.sum()
94 | loss = rec_loss + kld_loss + pos_loss
95 |
96 | test_rec_loss.update(rec_loss.item(), self._total(data), data.size(0))
97 | test_kld_loss.update(kld_loss.item(), self._total(data), data.size(0))
98 | test_pos_loss.update(pos_loss.item(), self._total(data), data.size(0))
99 |
100 | self.model.clean_tree(trees)
101 |
102 | elapsed_time = \
103 | datetime.datetime.now(pytz.timezone('America/New_York')) - \
104 | timestamp_start
105 |
106 | torch.save(self.model.state_dict(),
107 | osp.join(self.configs.exp_dir, 'checkpoints', 'model_epoch_{0}.pth'.format(epoch_num)))
108 |
109 | print('====> Epoch: {} Test rec loss: {:.6f} Test kld loss: {:.6f} Test pos loss: {:.6f}'.format(
110 | epoch_num, test_rec_loss.batch_avg, test_kld_loss.batch_avg, test_pos_loss.batch_avg))
111 | print('Elapsed time:', elapsed_time)
112 |
113 | return minloss
114 |
115 | def sample(self, epoch_num, sample_num, timestamp_start):
116 | self.model.eval()
117 |
118 | data, trees, _, _ = self.gen_loader.next_batch()
119 | data = Variable(data, volatile=True).cuda()
120 | epoch_result_dir = osp.join(self.configs.exp_dir, 'samples', 'epoch-{}'.format(epoch_num))
121 |
122 | try:
123 | os.makedirs(epoch_result_dir)
124 | except:
125 | pass
126 |
127 | samples_image_dict = dict()
128 | data_image_dict = dict()
129 | batch_size = None
130 | for j in range(sample_num):
131 | sample = self.model.generate(data, trees)
132 | if not batch_size:
133 | batch_size = sample.size(0)
134 | for i in range(0, sample.size(0)):
135 | samples_image_dict.setdefault(i, list()).append(sample.cpu().data.numpy().transpose(0, 2, 3, 1)[i])
136 | if j == sample_num - 1:
137 | data_image_dict[i] = data.cpu().data.numpy().transpose(0, 2, 3, 1)[i]
138 | self.model.clean_tree(trees)
139 | print(j)
140 |
141 | for i in range(batch_size):
142 | samples = np.clip(np.stack(samples_image_dict[i], axis=0), -1, 1)
143 | data = data_image_dict[i]
144 | color_grid_vis(samples, nh=2, nw=sample_num // 2,
145 | save_path=osp.join(epoch_result_dir, 'generativenmn_{}_sample.png'.format(i)))
146 | scipy.misc.imsave(osp.join(epoch_result_dir, 'generativenmn_{}_real.png'.format(i)), data)
147 | torch.save(trees[i], osp.join(epoch_result_dir, 'generativenmn_tree_' + str(i) + '.pth'))
148 | print('====> Epoch: {} Generating image number: {:d}'.format(epoch_num, i))
149 |
150 | elapsed_time = \
151 | datetime.datetime.now(pytz.timezone('America/New_York')) - \
152 | timestamp_start
153 | print('Elapsed time:', elapsed_time)
154 |
--------------------------------------------------------------------------------
/lib/data_loader/clevr/treeutils.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import os
3 | import sys
4 | import argparse
5 | import os.path as osp
6 | import random
7 | from lib.tree import Tree
8 |
9 | # parser = argparse.ArgumentParser(description='generate trees for CLEVR dataset')
10 | # parser.add_argument('--output_dir', type=str, default='',
11 | # help='output path for trees')
12 | # parser.add_argument('--train_sample', type=int, default=0,
13 | # help='number of samples for training')
14 | # parser.add_argument('--test_sample', type=int, default=0,
15 | # help='number of samples for testing')
16 | # args = parser.parse_args()
17 |
18 |
19 | ######### hyperparameters ##########
20 | # max level of the tree
21 | max_level = 2
22 |
23 | # module list
24 | module_list = ['layout', 'describe', 'combine']
25 |
26 | # children dict
27 | children_dict = dict()
28 | children_dict['layout'] = 2
29 | children_dict['describe'] = 1
30 | children_dict['combine'] = 1
31 |
32 | # we will have two split dict for modules for designing a zero-shot setting
33 |
34 | module_dict_split1 = dict()
35 | module_dict_split2 = dict()
36 |
37 | # objects list
38 | module_dict_split1['describe'] = ['cylinder', 'sphere']
39 | module_dict_split2['describe'] = ['cube', 'sphere']
40 |
41 | # attributes list
42 | attribute_list = ['material', 'color', 'size']
43 |
44 | module_dict_split1['combine'] = {'material': ['metal'],
45 | 'color': ['green', 'blue', 'yellow', 'red'],
46 | 'size': ['large']}
47 |
48 | module_dict_split2['combine'] = {'material': ['rubber'],
49 | 'color': ['cyan', 'brown', 'gray', 'purple'],
50 | 'size': ['small']}
51 |
52 | # relations list
53 | module_dict_split1['layout'] = ['front', 'left', 'left-front', 'right-front']
54 | module_dict_split2['layout'] = ['right', 'behind', 'left-behind', 'right-behind']
55 |
56 | module_dicts = [module_dict_split1, module_dict_split2]
57 |
58 | pattern_map = {'describe': 0, 'material': 1, 'color': 2, 'size': 3, 'layout': 4}
59 |
60 | training_patterns = [(0, 1, 0, 1, 0), (1, 0, 1, 0, 1)]
61 | test_patterns = [(1, 1, 1, 1, 1), (0, 0, 0, 0, 0), (0, 0, 1, 1, 1), (1, 1, 0, 0, 0), (0, 1, 1, 1, 0), (1, 0, 0, 0, 1),
62 | (0, 1, 1, 1, 1), (1, 0, 0, 0, 0)]
63 |
64 |
65 | # degree range: curently randomize this number, \
66 | # no need for input from the tree
67 | # deg_range = [0, 360]
68 |
69 | # def get_flag(level, maxlevel):
70 | # if level + 1 >= max_level:
71 | # flag = 0
72 | # else:
73 | # # flag = random.randint(0, 1)
74 | # flag = 1
75 | #
76 | # return flag
77 |
78 |
79 | def expand_tree(tree, level, parent, memorylist, child_idx, max_level, metadata_pattern):
80 | if parent is None or parent.function == 'layout':
81 | if level + 1 >= max_level:
82 | valid = [1]
83 | else:
84 | valid = [0, 1]
85 |
86 | # sample module, the module can be either layout or describe here
87 | module_id = random.randint(0, len(valid) - 1)
88 | tree.function = module_list[valid[module_id]]
89 |
90 | # sample content
91 | dict_index = metadata_pattern[pattern_map[tree.function]]
92 | module_dict = module_dicts[dict_index]
93 |
94 | word_id = random.randint(0, len(module_dict[tree.function]) - 1)
95 | tree.word = module_dict[tree.function][word_id]
96 |
97 | if tree.function == 'layout':
98 | tree.function_obj = Layout(tree.word)
99 | print('add layout')
100 | else:
101 | tree.function_obj = Describe(tree.word)
102 | print('add describe')
103 |
104 | # num children
105 | if level + 1 > max_level:
106 | tree.num_children = 0
107 | else:
108 | tree.num_children = children_dict[tree.function]
109 | if parent is not None: # then the parent must be a layout node
110 | if child_idx == 0:
111 | parent.function_obj.left_child = tree.function_obj
112 | else:
113 | parent.function_obj.right_child = tree.function_obj
114 |
115 | for i in range(tree.num_children):
116 | tree.children.append(Tree())
117 | tree.children[i] = expand_tree(tree.children[i], level + 1, tree, [], i, max_level, metadata_pattern)
118 |
119 | # must contain only one child node, which is a combine node
120 | elif parent.function == 'describe' or parent.function == 'combine':
121 | print('add combine')
122 | valid = [2]
123 | # no need to sample module for now
124 | module_id = 0
125 | tree.function = module_list[valid[module_id]]
126 |
127 | # sample content
128 | # sample which attributes
129 | if len(set(attribute_list) - set(memorylist)) <= 1:
130 | full_attribute = True
131 | else:
132 | full_attribute = False
133 |
134 | attribute = random.sample(set(attribute_list) - set(memorylist), 1)[0]
135 | memorylist += [attribute]
136 |
137 | dict_idx = metadata_pattern[pattern_map[attribute]]
138 | module_dict = module_dicts[dict_idx]
139 | word_id = random.randint(0, len(module_dict[tree.function][attribute]) - 1)
140 | tree.word = module_dict[tree.function][attribute][word_id]
141 |
142 | if isinstance(parent.function_obj, Describe):
143 | carrier = parent.function_obj
144 | else:
145 | carrier = parent.function_obj.get_carrier()
146 |
147 | tree.function_obj = Combine(attribute, tree.word)
148 | tree.function_obj.set_carrier(carrier)
149 | carrier.set_attribute(attribute, tree.function_obj)
150 |
151 | if not full_attribute:
152 | tree.num_children = children_dict[tree.function]
153 |
154 | for i in range(tree.num_children):
155 | tree.children.append(Tree())
156 | tree.children[i] = expand_tree(tree.children[i], level + 1, tree, memorylist, i, max_level, metadata_pattern)
157 |
158 | else:
159 | raise ValueError('Wrong function.')
160 | return tree
161 |
162 |
163 | def visualize_tree(trees):
164 | for i in range(len(trees)):
165 | print('************** tree **************')
166 | _visualize_tree(trees[i], 0)
167 | print('**********************************')
168 |
169 |
170 | def _visualize_tree(tree, level):
171 | if tree == None:
172 | return
173 | for i in range(tree.num_children - 1, (tree.num_children - 1) // 2, -1):
174 | _visualize_tree(tree.children[i], level + 1)
175 |
176 | print(' ' * level + tree.word)
177 |
178 | # if isinstance(tree.function_obj, Describe):
179 | # print(tree.function_obj.attributes, tree.function_obj)
180 | # if tree.function != 'combine':
181 | # print('position {}'.format(tree.function_obj.position))
182 |
183 | if hasattr(tree, 'bbox'):
184 | print('Bouding box of {} is {}'.format(tree.word, tree.bbox))
185 |
186 | for i in range((tree.num_children - 1) // 2, -1, -1):
187 | _visualize_tree(tree.children[i], level + 1)
188 |
189 | return
190 |
191 |
192 | def allign_tree(tree, level):
193 | """
194 | A pre-order traversal
195 | :param tree:
196 | :return:
197 | """
198 | if tree is None:
199 | return
200 |
201 | if tree.function == 'describe' and level == 0:
202 | tree.function_obj.set_random_pos()
203 | elif tree.function == 'layout':
204 | tree.function_obj.set_children_pos()
205 | for i in range(tree.num_children):
206 | allign_tree(tree.children[i], level + 1)
207 | else:
208 | pass
209 |
210 |
211 | def extract_objects(tree):
212 | objects = list()
213 |
214 | if tree is None:
215 | return objects
216 |
217 | if tree.function == 'describe':
218 | objects.append(tree.function_obj)
219 | elif tree.function == 'layout':
220 | for i in range(tree.num_children):
221 | objects += extract_objects(tree.children[i])
222 | else:
223 | pass
224 |
225 | return objects
226 |
227 |
228 | def sample_tree(max_level, train=True):
229 | tree = Tree()
230 | if train:
231 | pattern = random.sample(training_patterns, 1)[0] # sample a pattern for training data
232 | else:
233 | pattern = random.sample(test_patterns, 1)[0] # sample a pattern for test data
234 | tree = expand_tree(tree, 0, None, [], 0, max_level, pattern)
235 | allign_tree(tree, 0)
236 | return tree
237 |
238 |
239 | if __name__ == '__main__':
240 | # random.seed(12113)
241 | #
242 | # # tree = Tree()
243 | # # tree = expand_tree(tree, 0, None, [], 0)
244 | # # allign_tree(tree)
245 | #
246 | # num_sample = 1
247 | # trees = []
248 | # for i in range(num_sample):
249 | # treei = Tree()
250 | # treei = expand_tree(treei, 0, None, [], 0, max_level=2)
251 | # allign_tree(treei, 0)
252 | # objects = extract_objects(treei)
253 | # trees += [treei]
254 | # print(objects)
255 | #
256 | # visualize_tree(trees)
257 |
258 | tree = sample_tree(max_level=3)
259 |
--------------------------------------------------------------------------------
/lib/data_loader/clevr/clevr_tree.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import _init_paths
4 | import os
5 | import os.path as osp
6 | import pickle
7 |
8 | import numpy as np
9 | import PIL.Image
10 | import torch
11 | import random
12 |
13 |
14 | class CLEVRTREE():
15 | # set the thresholds for the size of objects here
16 | SMALL_THRESHOLD = 16
17 | # remove the medium size for simplicity and clarity
18 |
19 | NEW_SIZE_WORDS = ['small', 'large']
20 | OLD_SIZE_WORDS = ['small', 'large']
21 |
22 | def __init__(self, batch_size=16,
23 | base_dir='/cs/vml4/zhiweid/ECCV18/GenerativeNeuralModuleNetwork/data/CLEVR/CLEVR_128',
24 | random_seed=12138,
25 | phase='train', shuffle=True, file_format='png'):
26 | self.phase = phase
27 | self.base_dir = base_dir
28 | self.batch_size = batch_size
29 | self.fileformat = file_format
30 |
31 | if phase not in ['train', 'test']:
32 | raise ValueError('invalid phase name {}, should be train or test'.format(phase))
33 |
34 | self.image_dir = osp.join(base_dir, 'images', phase)
35 | self.tree_dir = osp.join(base_dir, 'trees', phase)
36 |
37 | # get the file names for images, and corresponding trees, and the dictionary of the dataset
38 | self.image_files, self.tree_files = self.prepare_file_list()
39 | self.dictionary_path = osp.join(self.base_dir, 'dictionary_tree.pickle')
40 | self.dictionary = self.load_dictionary(self.tree_files, self.dictionary_path)
41 |
42 | # update the size words, since we need to adjust the size for all objects
43 | # according to the actual 2-D bounding box
44 | for word in self.OLD_SIZE_WORDS:
45 | self.dictionary.remove(word)
46 | for word in self.NEW_SIZE_WORDS:
47 | self.dictionary.append(word)
48 |
49 | # iterator part
50 | self.random_generator = random.Random()
51 | self.random_generator.seed(random_seed)
52 |
53 | self.files = dict()
54 | self.files[phase] = {
55 | 'img': self.image_files,
56 | 'tree': self.tree_files,
57 | }
58 |
59 | self.index_ptr = 0
60 | self.index_list = list(range(len(self.image_files)))
61 |
62 | self.shuffle = shuffle
63 | if shuffle:
64 | self.random_generator.shuffle(self.index_list)
65 |
66 | self.im_size = self.read_first()
67 | self.ttdim = len(self.dictionary) + 1
68 |
69 | def __len__(self):
70 | return len(self.files[self.phase]['img'])
71 |
72 | def read_first(self):
73 | images, _, _, _ = self.next_batch()
74 | self.index_ptr = 0
75 |
76 | return images.size()
77 |
78 | def prepare_file_list(self):
79 | """
80 | Get the filename list for images and correspondign trees
81 | :return:
82 | """
83 | image_list = []
84 | tree_list = []
85 | for image_filename in sorted(os.listdir(self.image_dir)):
86 | image_path = os.path.join(self.image_dir, image_filename)
87 | image_list.append(image_path)
88 | filename, _ = os.path.splitext(image_filename)
89 | tree_filename = filename + '.tree'
90 | tree_path = os.path.join(self.tree_dir, tree_filename)
91 | tree_list.append(tree_path)
92 | return image_list, tree_list
93 |
94 | def load_dictionary(self, tree_files, dictionary_path):
95 | if osp.isfile(dictionary_path): # the dictionary has been created, then just load return
96 | with open(dictionary_path, 'rb') as f:
97 | dictionary = pickle.load(f)
98 | else:
99 | dictionary_set = set()
100 | for idx, tree_file_path in enumerate(tree_files):
101 | with open(tree_file_path, 'rb') as f:
102 | tree = pickle.load(f)
103 | tree_words = self.get_tree_words(tree)
104 | dictionary_set.update(set(tree_words))
105 |
106 | dictionary = list(dictionary_set)
107 |
108 | # update the size words, since we need to adjust the size for all objects
109 | # according to the actual 2-D bounding box
110 | for word in self.OLD_SIZE_WORDS:
111 | dictionary.remove(word)
112 | for word in self.NEW_SIZE_WORDS:
113 | dictionary.append(word)
114 | with open(dictionary_path, 'wb') as f:
115 | pickle.dump(dictionary, f)
116 |
117 | return dictionary
118 |
119 | def get_tree_words(self, tree):
120 | words = [tree.word]
121 | for child in tree.children:
122 | words += self.get_tree_words(child)
123 | return words
124 |
125 | def next_batch(self):
126 | data_file = self.files[self.phase]
127 |
128 | images, trees, categories = [], [], []
129 | for i in range(0, min(self.batch_size, len(self) - self.index_ptr)):
130 | index = self.index_list[self.index_ptr]
131 |
132 | # load image
133 | if self.fileformat == 'png':
134 | img_file = data_file['img'][index]
135 | img = PIL.Image.open(img_file)
136 | else:
137 | raise ValueError('wrong file format for images')
138 |
139 | # remove the alpha-channel
140 | img = np.array(img, dtype=np.float32)[:, :, :-1]
141 | img = (img - 127.5) / 127.5
142 | images.append(img)
143 |
144 | # load tree
145 | with open(data_file['tree'][index], 'rb') as f:
146 | tree = pickle.load(f)
147 | tree = self.adapt_tree(tree)
148 | trees.append(tree)
149 | categories.append(self.get_categorical_list(tree))
150 |
151 | self.index_ptr += 1
152 |
153 | images = np.array(images, dtype=np.float32).transpose(0, 3, 1, 2)
154 |
155 | refetch = False
156 | if self.index_ptr >= len(self):
157 | self.index_ptr = 0
158 | refetch = True
159 | if self.shuffle:
160 | self.random_generator.shuffle(self.index_list)
161 |
162 | return torch.from_numpy(images), trees, categories, refetch
163 |
164 | def get_all(self):
165 | data_file = self.files[self.phase]
166 |
167 | images, trees, categories = [], [], []
168 | for i in range(len(self.index_list)):
169 | index = self.index_list[self.index_ptr]
170 |
171 | # load image
172 | if self.fileformat == 'png':
173 | img_file = data_file['img'][index]
174 | img = PIL.Image.open(img_file)
175 | else:
176 | raise ValueError('wrong file format for images')
177 |
178 | # remove the alpha-channel
179 | img = np.array(img, dtype=np.float32)[:, :, :-1]
180 | img = (img - 127.5) / 127.5
181 | images.append(img)
182 |
183 | # load tree
184 | with open(data_file['tree'][index], 'rb') as f:
185 | tree = pickle.load(f)
186 | tree = self.adapt_tree(tree)
187 | trees.append(tree)
188 |
189 | self.index_ptr += 1
190 |
191 | images = np.array(images, dtype=np.float32).transpose(0, 3, 1, 2)
192 |
193 | return torch.from_numpy(images), trees
194 |
195 | def adapt_tree(self, tree):
196 | tree = self._adapt_tree(tree, parent_bbox=None)
197 | return tree
198 |
199 | def _adapt_tree(self, tree, parent_bbox):
200 | # adjust tree.word for object size according to the bounding box, since the original size is for 3-D world
201 | if tree.function == 'combine' and tree.word in self.OLD_SIZE_WORDS:
202 | width = parent_bbox[2]
203 | height = parent_bbox[3]
204 | tree.word = self._get_size_word(width, height)
205 |
206 | # set the bbox for passing to children
207 | if tree.function == 'combine':
208 | bbox_xywh = parent_bbox
209 | elif tree.function == 'describe':
210 | bbox_xywh = tree.bbox
211 | else:
212 | bbox_xywh = None
213 |
214 | # then swap the bbox to (y,x,h,w)
215 | if hasattr(tree, 'bbox'):
216 | bbox_yxhw = (tree.bbox[1], tree.bbox[0], tree.bbox[3], tree.bbox[2])
217 | tree.bbox = np.array(bbox_yxhw)
218 |
219 | # pre-order traversal
220 | for child in tree.children:
221 | self._adapt_tree(child, parent_bbox=bbox_xywh)
222 | return tree
223 |
224 | def get_categorical_list(self, tree):
225 | categorical_list, attr_list = self._get_categorical_list(tree)
226 | return categorical_list
227 |
228 | def _get_categorical_list(self, tree):
229 | # must be post-ordering traversal, parent need info from children
230 | category_list = list()
231 | attr_list = list()
232 | for child in tree.children:
233 | children_category_list, children_attr_list = self._get_categorical_list(child)
234 | category_list += children_category_list
235 | attr_list += children_attr_list
236 |
237 | if tree.function == 'describe':
238 | bbox = tree.bbox
239 | adapted_bbox = (bbox[1], bbox[0], bbox[3], bbox[2])
240 | attr_list.append(tree.word)
241 | attr_vec = self._get_attr_vec(attr_list)
242 | obj_category = (adapted_bbox, attr_vec)
243 | category_list.append(obj_category)
244 |
245 | if tree.function == 'combine':
246 | attr_list.append(tree.word) # just pass its word to parent
247 |
248 | return category_list, attr_list
249 |
250 | def _get_attr_vec(self, attr_list):
251 | vec = np.zeros(len(self.dictionary), dtype=np.float64)
252 | for attr in attr_list:
253 | attr_idx = self.dictionary.index(attr)
254 | vec[attr_idx] = 1.0
255 | return vec
256 |
257 | @classmethod
258 | def _get_size_word(cls, width, height):
259 | maximum = max(width, height)
260 | if maximum < cls.SMALL_THRESHOLD:
261 | return 'small'
262 | else:
263 | return 'large'
264 |
265 |
266 | if __name__ == '__main__':
267 | loader = CLEVRTREE(phase='test',
268 | base_dir='/zhiweid/work/gnmn/GenerativeNeuralModuleNetwork/data/CLEVR/CLEVR_64_MULTI_LARGE')
269 | for i in range(3):
270 | im, trees, categories, ref = loader.next_batch()
271 | import IPython;
272 |
273 | IPython.embed()
274 | break
275 | print(im[0].shape)
276 | print(categories[0])
277 | print(loader.dictionary)
278 |
279 | '''
280 | # For testing the format of PIL.Image and cv2
281 | img2 = PIL.Image.open('/local-scratch/cjc/GenerativeNeuralModuleNetwork/data/CLEVR/clevr-dataset-gen/output/images/train/CLEVR_new_000003.png')
282 | img2 = np.array(img2)
283 | img2 = img2[:,:,:-1]
284 |
285 | out = PIL.Image.fromarray(img2)
286 | out.save('test.png')
287 | # img = cv2.imread('/local-scratch/cjc/GenerativeNeuralModuleNetwork/data/CLEVR/clevr-dataset-gen/output/images/train/CLEVR_new_000003.png')
288 | # cv2.imwrite('test.png',img)
289 | print(img2)
290 | #
291 | img = PIL.Image.open('test.png')
292 | img = np.array(img)
293 |
294 | print(img2 - img)
295 |
296 | '''
297 |
--------------------------------------------------------------------------------
/lib/data_loader/color_mnist_tree_multi.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os.path as osp
3 |
4 | import _init_paths
5 | import numpy as np
6 | import PIL.Image
7 | import scipy.misc
8 | import torch
9 | import random
10 | from lib.tree import Tree
11 | import scipy.io
12 |
13 |
14 | class COLORMNISTTREE():
15 | def __init__(self, batch_size=16, directory='/cs/vml4/Datasets/COLORMNIST', random_seed=12138, folder='TWO_5000',
16 | phase='train', shuffle=True, fileformat='png'):
17 | # basic
18 | self.folder = folder
19 | self.phase = phase
20 | self.dir = directory
21 | self.batch_size = batch_size
22 | self.fileformat = fileformat
23 |
24 | # static tree structure
25 | self.treefile = osp.join('data', 'COLORMNIST', folder, phase + '_parents.list')
26 | self.functionfile = osp.join('data', 'COLORMNIST', folder, phase + '_functions.list')
27 | self.infofile = osp.join(directory, folder, phase + '_text.txt')
28 | with open(self.infofile, 'r') as f:
29 | self.info = []
30 | self.number = []
31 | for line in f:
32 | line = line[:-1].split(' ')
33 | self.number += [int(line[0][0])]
34 | line[0] = line[0][2:]
35 | self.info += [line]
36 |
37 | self.trees = self.read_trees()
38 | self.functions = self.read_functions()
39 |
40 | # parse info
41 | imgdir = osp.join(self.dir, folder, phase)
42 | self.imglist, self.trees, self.dictionary = self.read_info()
43 |
44 | # iterator part
45 | self.random_generator = random.Random()
46 | self.random_generator.seed(random_seed)
47 | self.files = {}
48 | self.files[phase] = {
49 | 'img': self.imglist,
50 | 'tree': self.trees,
51 | }
52 |
53 | self.idx_ptr = 0
54 | self.indexlist = list(xrange(len(self.imglist)))
55 |
56 | self.shuffle = shuffle
57 | if shuffle:
58 | self.random_generator.shuffle(self.indexlist)
59 |
60 | # (H, W, C)
61 | self.im_size = self.test_read()
62 |
63 | def __len__(self):
64 | return len(self.files[self.phase]['img'])
65 |
66 | def test_read(self):
67 | data_file = self.files[self.phase]
68 |
69 | # load image
70 | if self.fileformat == 'mat':
71 | mat_file = data_file['img'][0] + '.mat'
72 | img = scipy.io.loadmat(mat_file)['data']
73 | elif self.fileformat == 'png':
74 | img_file = data_file['img'][0] + '.png'
75 | img = PIL.Image.open(img_file)
76 | img = np.array(img, dtype=np.float32)
77 |
78 | return img.shape
79 |
80 | def read_functions(self):
81 | functionfile = self.functionfile
82 | f = open(functionfile)
83 |
84 | functions = []
85 | for line in f:
86 | functions += [line[:-1].split(" ")]
87 |
88 | return functions
89 |
90 | def next_batch(self):
91 | data_file = self.files[self.phase]
92 |
93 | imgs, trees, categories = [], [], []
94 | for i in range(0, min(self.batch_size, self.__len__() - self.idx_ptr)):
95 | index = self.indexlist[self.idx_ptr]
96 |
97 | # load image
98 | if self.fileformat == 'mat':
99 | mat_file = data_file['img'][index] + '.mat'
100 | img = scipy.io.loadmat(mat_file)['data']
101 | elif self.fileformat == 'png':
102 | img_file = data_file['img'][index] + '.png'
103 | img = PIL.Image.open(img_file)
104 | img = np.array(img, dtype=np.float32)
105 | img = (img - 127.5) / 127.5
106 |
107 | imgs.append(img)
108 |
109 | # load label
110 | tree = data_file['tree'][index]
111 | trees.append(tree)
112 | categories.append(self.get_categorical_list(tree))
113 |
114 | self.idx_ptr += 1
115 |
116 | imgs = np.array(imgs, dtype=np.float32).transpose(0, 3, 1, 2)
117 |
118 | refetch = False
119 | if self.idx_ptr >= self.__len__():
120 | self.idx_ptr = 0
121 | refetch = True
122 | if self.shuffle:
123 | self.random_generator.shuffle(self.indexlist)
124 |
125 | return torch.from_numpy(imgs), trees, categories, refetch
126 |
127 | def next_batch_multigpu(self):
128 | data_file = self.files[self.phase]
129 |
130 | imgs, tree_indices = [], []
131 | for i in range(0, min(self.batch_size, self.__len__() - self.idx_ptr)):
132 | index = self.indexlist[self.idx_ptr]
133 |
134 | # load image
135 | if self.fileformat == 'mat':
136 | mat_file = data_file['img'][index] + '.mat'
137 | img = scipy.io.loadmat(mat_file)['data']
138 | elif self.fileformat == 'png':
139 | img_file = data_file['img'][index] + '.png'
140 | img = PIL.Image.open(img_file)
141 | img = np.array(img, dtype=np.float32)
142 | img = (img - 127.5) / 127.5
143 |
144 | imgs += [img]
145 |
146 | # load label
147 | tree_indices.append(index)
148 |
149 | self.idx_ptr += 1
150 |
151 | imgs = np.array(imgs, dtype=np.float32).transpose(0, 3, 1, 2)
152 |
153 | refetch = False
154 | if self.idx_ptr >= self.__len__():
155 | self.idx_ptr = 0
156 | refetch = True
157 | if self.shuffle:
158 | self.random_generator.shuffle(self.indexlist)
159 |
160 | return torch.from_numpy(imgs), torch.from_numpy(np.array(tree_indices)), refetch
161 |
162 | def get_tree_by_idx(self, index):
163 | return self.files[self.phase]['tree'][index]
164 |
165 | def get_tree_list_current_epoch(self):
166 | tree_list = list()
167 | for idx in self.indexlist:
168 | tree_list.append(self.files[self.phase]['tree'][idx])
169 | return tree_list
170 |
171 | def read_info(self):
172 | imglist = []
173 | dictionary = []
174 | functions = self.functions
175 |
176 | count = 0
177 | for tree in self.trees:
178 | imglist.append(osp.join(self.dir, self.folder, self.phase, 'image{:05d}'.format(count)))
179 | words, numbers, bboxes = self._extract_info(self.info[count], self.number[count])
180 | dictionary = list(set(dictionary + words))
181 | self.trees[count] = self._read_info(tree, words, functions[count], numbers, bboxes)
182 | count += 1
183 |
184 | return imglist, self.trees, dictionary
185 |
186 | def _extract_info(self, line, num):
187 | words = []
188 | numbers = []
189 | bboxes = []
190 | count = 0
191 |
192 | numid = 0
193 | for ele in line:
194 | words.append(ele)
195 | if self._is_number(ele):
196 | numbers.append(ele)
197 | numid += 1
198 | if numid == num:
199 | count += 1
200 | break
201 | count += 1
202 |
203 | newid = count
204 | for i in range(0, newid):
205 | if self._is_number(line[i]):
206 | bboxes += [[int(ele) for ele in line[count:count + 4]]]
207 | count += 4
208 | else:
209 | bboxes += [[]]
210 |
211 | return words, numbers, bboxes
212 |
213 | def _is_number(self, n):
214 | try:
215 | int(n)
216 | return True
217 | except:
218 | return False
219 |
220 | def _read_info(self, tree, words, functions, numbers, bboxes):
221 | for i in range(0, tree.num_children):
222 | tree.children[i] = self._read_info(tree.children[i], words, functions, numbers, bboxes)
223 |
224 | tree.word = words[tree.idx]
225 | tree.function = functions[tree.idx]
226 | tree.bbox = np.array(bboxes[tree.idx])
227 |
228 | return tree
229 |
230 | def read_trees(self):
231 | filename = self.treefile
232 | with open(filename, 'r') as f:
233 | trees = [self.read_tree(line) for line in f.readlines()]
234 | return trees
235 |
236 | def read_tree(self, line):
237 | parents = map(int, line.split())
238 | trees = dict()
239 | root = None
240 | for i in xrange(1, len(parents) + 1):
241 | # if not trees[i-1] and parents[i-1]!=-1:
242 | if i - 1 not in trees.keys() and parents[i - 1] != -1:
243 | idx = i
244 | prev = None
245 | while True:
246 | parent = parents[idx - 1]
247 | if parent == -1:
248 | break
249 | tree = Tree()
250 | if prev is not None:
251 | tree.add_child(prev)
252 | trees[idx - 1] = tree
253 | tree.idx = idx - 1
254 | # if trees[parent-1] is not None:
255 | if parent - 1 in trees.keys():
256 | trees[parent - 1].add_child(tree)
257 | break
258 | elif parent == 0:
259 | root = tree
260 | break
261 | else:
262 | prev = tree
263 | idx = parent
264 | return root
265 |
266 | def get_categorical_list(self, tree):
267 | categorical_list, attr_list = self._get_categorical_list(tree)
268 | return categorical_list
269 |
270 | def _get_categorical_list(self, tree):
271 | # must be post-ordering traversal, parent need info from children
272 | category_list = list()
273 | attr_list = list()
274 | for child in tree.children:
275 | children_category_list, children_attr_list = self._get_categorical_list(child)
276 | category_list += children_category_list
277 | attr_list += children_attr_list
278 |
279 | if tree.function == 'describe':
280 | bbox = tree.bbox
281 | adapted_bbox = (bbox[0], bbox[1], bbox[2], bbox[3])
282 | attr_list.append(tree.word)
283 | attr_vec = self._get_attr_vec(attr_list)
284 | obj_category = (adapted_bbox, attr_vec)
285 | category_list.append(obj_category)
286 |
287 | if tree.function == 'combine':
288 | attr_list.append(tree.word) # just pass its word to parent
289 |
290 | return category_list, attr_list
291 |
292 | def _get_attr_vec(self, attr_list):
293 | vec = np.zeros(len(self.dictionary), dtype=np.float64)
294 | for attr in attr_list:
295 | attr_idx = self.dictionary.index(attr)
296 | vec[attr_idx] = 1.0
297 | return vec
298 | '''
299 | # Single card
300 | loader = COLORMNISTTREE(directory='data/COLORMNIST', folder='TWO_5000_64_modified')
301 | trees = loader.trees
302 | print(trees[0].word)
303 | print(loader.dictionary)
304 | for i in range(0, 1):
305 | im, tr, cats, ref = loader.next_batch()
306 | print(tr[0].function)
307 | print(tr[0].word)
308 | print(tr[0].children[0].word)
309 | print(cats[0])
310 | '''
311 | '''
312 | # Multi-gpu
313 | loader = COLORMNISTTREE(directory='data/COLORMNIST', folder='ONE')
314 | trees = loader.trees
315 | print(trees[0].word)
316 | print(loader.dictionary)
317 | for i in range(0, 1000):
318 | im, tr_indices, ref = loader.next_batch_multigpu()
319 | tree = loader.get_tree_by_idx(tr_indices[0])
320 | print(tree.function)
321 | print(tree.word)
322 | print(tree.children[0].word)
323 | '''
324 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/mains/pnpnet_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import _init_paths
3 | import datetime
4 | import argparse
5 | import torch
6 | import torch.optim as optim
7 | from torch.autograd import Variable
8 | import pytz
9 | import scipy.misc
10 | import os.path as osp
11 | import os
12 | import numpy as np
13 | import random
14 | import pdb
15 |
16 | from lib.data_loader.color_mnist_tree_multi import COLORMNISTTREE
17 | from lib.data_loader.clevr.clevr_tree import CLEVRTREE
18 | from lib.config import load_config, Struct
19 | from models.PNPNet.pnp_net import PNPNet
20 | from trainers.pnpnet_trainer import PNPNetTrainer
21 | from lib.weight_init import weights_init
22 |
23 | parser = argparse.ArgumentParser(description='PNPNet - main model experiment')
24 | parser.add_argument('--config_path', type=str, default='./configs/pnp_net_configs.yaml', metavar='C',
25 | help='path to the configuration file')
26 |
27 |
28 | def main():
29 | args = parser.parse_args()
30 |
31 | config_dic = load_config(args.config_path)
32 | configs = Struct(**config_dic)
33 |
34 | assert (torch.cuda.is_available()) # assume CUDA is always available
35 |
36 | print('configurations:', configs)
37 |
38 | torch.cuda.set_device(configs.gpu_id)
39 | torch.manual_seed(configs.seed)
40 | torch.cuda.manual_seed(configs.seed)
41 | np.random.seed(configs.seed)
42 | random.seed(configs.seed)
43 | torch.backends.cudnn.benchmark = False
44 | torch.backends.cudnn.deterministic = True
45 |
46 | configs.exp_dir = 'results/' + configs.data_folder + '/' + configs.exp_dir_name
47 | exp_dir = configs.exp_dir
48 |
49 | try:
50 | os.makedirs(configs.exp_dir)
51 | except:
52 | pass
53 | try:
54 | os.makedirs(osp.join(configs.exp_dir, 'samples'))
55 | except:
56 | pass
57 | try:
58 | os.makedirs(osp.join(configs.exp_dir, 'checkpoints'))
59 | except:
60 | pass
61 |
62 | # loaders
63 | if 'CLEVR' in configs.data_folder:
64 | # we need the module's label->index dictionary from train loader
65 | train_loader = CLEVRTREE(phase='train', base_dir=osp.join(configs.base_dir, configs.data_folder),
66 | batch_size=configs.batch_size,
67 | random_seed=configs.seed, shuffle=True)
68 | test_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder),
69 | batch_size=configs.batch_size,
70 | random_seed=configs.seed, shuffle=False)
71 | gen_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder),
72 | batch_size=configs.batch_size,
73 | random_seed=configs.seed, shuffle=False)
74 | elif 'COLORMNIST' in configs.data_folder:
75 | train_loader = COLORMNISTTREE(phase='train', directory=configs.base_dir, folder=configs.data_folder,
76 | batch_size=configs.batch_size,
77 | random_seed=configs.seed, shuffle=True)
78 | test_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder),
79 | batch_size=configs.batch_size,
80 | random_seed=configs.seed, shuffle=False)
81 | gen_loader = COLORMNISTTREE(phase='test', directory=configs.base_dir, folder=configs.data_folder,
82 | batch_size=configs.batch_size,
83 | random_seed=configs.seed, shuffle=False)
84 | else:
85 | raise ValueError('invalid dataset folder name {}'.format(configs.data_folder))
86 |
87 | # hack, parameter
88 | im_size = gen_loader.im_size[2]
89 |
90 | # model
91 | model = PNPNet(hiddim=configs.hiddim, latentdim=configs.latentdim,
92 | word_size=[configs.latentdim, configs.word_size, configs.word_size], pos_size=[8, 1, 1],
93 | nres=configs.nr_resnet, nlayers=4,
94 | nonlinear='elu', dictionary=train_loader.dictionary,
95 | op=[configs.combine_op, configs.describe_op],
96 | lmap_size=im_size // 2 ** configs.ds,
97 | downsample=configs.ds, lambdakl=-1, bg_bias=configs.bg_bias,
98 | normalize=configs.normalize,
99 | loss=configs.loss, debug_mode=False)
100 |
101 | if configs.checkpoint is not None and len(configs.checkpoint) > 0:
102 | model.load_state_dict(torch.load(configs.checkpoint))
103 | print('load model from {}'.format(configs.checkpoint))
104 | else:
105 | model.apply(weights_init)
106 |
107 | if configs.mode == 'train':
108 | train(model, train_loader, test_loader, gen_loader, configs=configs)
109 | elif configs.mode == 'test':
110 | print(exp_dir)
111 | print('Start generating...')
112 | generate(model, gen_loader=gen_loader, num_sample=2, target_dir=exp_dir)
113 | elif configs.mode == 'visualize':
114 | print(exp_dir)
115 | print('Start visualizing...')
116 | visualize(model, num_sample=50, target_dir=exp_dir)
117 | elif configs.mode == 'sample':
118 | print('Sampling')
119 | sample_tree(model, test_loader=test_loader, tree_idx=4, base_dir='samples', num_sample=500)
120 | else:
121 | raise ValueError('Wrong mode given:{}'.format(configs.mode))
122 |
123 |
124 | def train(model, train_loader, test_loader, gen_loader, configs):
125 | model.train()
126 | # optimizer, it's better to set up lr for some modules separately so that the whole training become more stable
127 | optimizer = optim.Adamax([
128 | {'params': model.reader.parameters(), 'lr': 0.2 * configs.lr},
129 | {'params': model.h_mean.parameters(), 'lr': 0.1 * configs.lr},
130 | {'params': model.h_var.parameters(), 'lr': 0.1 * configs.lr},
131 | {'params': model.writer.parameters()},
132 | {'params': model.vis_dist.parameters()},
133 | {'params': model.pos_dist.parameters()},
134 | {'params': model.combine.parameters()},
135 | {'params': model.describe.parameters()},
136 | {'params': model.box_vae.parameters(), 'lr': 10 * configs.lr},
137 | {'params': model.offset_vae.parameters(), 'lr': 10 * configs.lr},
138 | {'params': model.renderer.parameters()},
139 | {'params': model.bias_mean.parameters()},
140 | {'params': model.bias_var.parameters()}
141 | ], lr=configs.lr)
142 |
143 | model.cuda()
144 |
145 | trainer = PNPNetTrainer(model=model, train_loader=train_loader, val_loader=test_loader, gen_loader=gen_loader,
146 | optimizer=optimizer,
147 | configs=configs)
148 |
149 | minloss = 1000
150 | for epoch_num in range(0, configs.epochs + 1):
151 | timestamp_start = datetime.datetime.now(pytz.timezone('America/New_York'))
152 | trainer.train_epoch(epoch_num, timestamp_start)
153 | if epoch_num % configs.validate_interval == 0 and epoch_num > 0:
154 | minloss = trainer.validate(epoch_num, timestamp_start, minloss)
155 | if epoch_num % configs.sample_interval == 0 and epoch_num > 0:
156 | trainer.sample(epoch_num, sample_num=8, timestamp_start=timestamp_start)
157 | if epoch_num % configs.save_interval == 0 and epoch_num > 0:
158 | torch.save(model.state_dict(),
159 | osp.join(configs.exp_dir, 'checkpoints', 'model_epoch_{0}.pth'.format(epoch_num)))
160 |
161 |
162 | def generate(model, gen_loader, num_sample, target_dir):
163 | model.eval()
164 | model.cuda()
165 |
166 | epoch_end = False
167 | sample_dirs = []
168 | for i in range(num_sample):
169 | sample_dir = osp.join(target_dir, 'test-data-{}'.format(i))
170 | if not osp.isdir(sample_dir):
171 | os.mkdir(sample_dir)
172 | sample_dirs.append(sample_dir)
173 |
174 | image_idx = 0
175 | while epoch_end is False:
176 | data, trees, _, epoch_end = gen_loader.next_batch()
177 |
178 | with torch.no_grad():
179 | data = Variable(data).cuda()
180 |
181 | samples_image_dict = dict()
182 | batch_size = None
183 | for j in range(num_sample):
184 | sample = model.generate(data, trees)
185 | if not batch_size:
186 | batch_size = sample.size(0)
187 | for i in range(0, sample.size(0)):
188 | samples_image_dict.setdefault(i, list()).append(sample.cpu().data.numpy().transpose(0, 2, 3, 1)[i])
189 | model.clean_tree(trees)
190 |
191 | for i in range(batch_size):
192 | samples = np.clip(np.stack(samples_image_dict[i], axis=0), -1, 1)
193 | for j in range(num_sample):
194 | sample = samples[j]
195 | scipy.misc.imsave(osp.join(sample_dirs[j], 'image{:05d}.png'.format(image_idx)), sample)
196 | image_idx += 1
197 |
198 |
199 | def sample_tree(model, test_loader, tree_idx, base_dir, num_sample):
200 | """
201 | Sample multiple image instances for a specified tree in the test dataset
202 | :param model: model
203 | :param test_loader: test loader
204 | :param tree_idx: the tree's index
205 | :param base_dir: base directory for saving results
206 | :param num_sample: number of samples to sample for the specified tree structure
207 | """
208 | model.eval()
209 |
210 | target_dir = osp.join(base_dir, 'tree-{}'.format(tree_idx))
211 | if not osp.isdir(target_dir):
212 | os.makedirs(target_dir)
213 |
214 | all_data, all_trees = test_loader.get_all()
215 |
216 | data = all_data[tree_idx]
217 | tree = [all_trees[tree_idx]]
218 |
219 | for i in range(num_sample):
220 | sample = model.generate(data, tree)
221 | sample = np.clip(sample.cpu().data.numpy().transpose(0, 2, 3, 1), -1, 1)
222 | sample = sample[0]
223 | scipy.misc.imsave(osp.join(target_dir, 'sample-{:05d}.png'.format(i)), sample)
224 |
225 |
226 | def visualize(model, num_sample, target_dir):
227 | """
228 | Visualize intermediate results of PNPNet (e.g. visual concepts or partially composed image)
229 | :param model: model
230 | :param num_sample: number of samples to generate
231 | :param target_dir: target directory for saving results
232 | """
233 | model.eval()
234 |
235 | sample_dir = osp.join(target_dir, 'visualize-data')
236 | if not osp.isdir(sample_dir):
237 | os.mkdir(sample_dir)
238 |
239 | mode = 'transform'
240 |
241 | for i in range(0, len(model.dictionary)):
242 | if i not in [1, 2, 5]:
243 | continue
244 | print('index ', i, 'current word:', model.dictionary[i])
245 |
246 | data = Variable(torch.zeros(num_sample, len(model.dictionary))).cuda()
247 | data[:, i] = 1
248 | vis_dist = model.vis_dist(data)
249 |
250 | if mode == 'full':
251 | prior_mean, prior_var = model.renderer(vis_dist)
252 | elif mode == 'transform':
253 | prior_mean, prior_var = Variable(torch.zeros(vis_dist[0].size())).cuda(), Variable(
254 | torch.zeros(vis_dist[1].size())).cuda()
255 | sz = vis_dist[0].size(2)
256 | tsz = 5
257 | vis_dist[0] = model.transform(vis_dist[0], [tsz, tsz])
258 | vis_dist[1] = model.transform(vis_dist[1], [tsz, tsz])
259 | prior_mean[:, :, 5:5 + tsz, 5:5 + tsz] = vis_dist[0]
260 | prior_var[:, :, 5:5 + tsz, 5:5 + tsz] = vis_dist[1]
261 | prior_mean, prior_var = model.renderer([prior_mean, prior_var])
262 | else:
263 | raise ValueError('Invalid mode name {}'.format(mode))
264 |
265 | z_map = model.sampler(prior_mean, prior_var)
266 |
267 | rec = model.writer(z_map)
268 |
269 | for im in range(0, rec.size(0)):
270 | scipy.misc.imsave(osp.join(sample_dir, 'image-{}-{:05d}.png'.format(model.dictionary[i], im)),
271 | rec[im].cpu().data.numpy().transpose(1, 2, 0))
272 |
273 |
274 | if __name__ == '__main__':
275 | main()
276 |
--------------------------------------------------------------------------------
/lib/modules/Describe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from torch.nn.utils import weight_norm
5 | import torch.nn.functional as F
6 |
7 |
8 | # one more advanced plan: predict an attention map a, then
9 | # render the vw by a*x*y + (1-a)*y
10 | class Describe(nn.Module):
11 | def __init__(self, hiddim_v=None, hiddim_p=None, op='CAT'):
12 | super(Describe, self).__init__()
13 | self.op = op
14 | self.hiddim_v = hiddim_v
15 | self.hiddim_p = hiddim_p
16 | if op == 'CAT' or op == 'CAT_PoE':
17 | self.net1_mean_vis = nn.Sequential(
18 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
19 | nn.ELU(inplace=True),
20 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
21 | )
22 |
23 | self.net1_var_vis = nn.Sequential(
24 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
25 | nn.ELU(inplace=True),
26 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
27 | )
28 |
29 | self.net1_mean_pos = nn.Sequential(
30 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
31 | nn.ELU(inplace=True),
32 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
33 | )
34 |
35 | self.net1_var_pos = nn.Sequential(
36 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
37 | nn.ELU(inplace=True),
38 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
39 | )
40 | elif op == 'DEEP':
41 | self.net_vis = nn.Sequential(
42 | weight_norm(nn.Conv2d(4 * hiddim_v, hiddim_v, 3, 1, 1)),
43 | nn.ELU(inplace=True),
44 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)),
45 | nn.ELU(inplace=True),
46 | weight_norm(nn.Conv2d(hiddim_v, 2 * hiddim_v, 3, 1, 1))
47 | )
48 |
49 | self.net_pos = nn.Sequential(
50 | weight_norm(nn.Conv2d(4 * hiddim_p, hiddim_p, 1, 1)),
51 | nn.ELU(inplace=True),
52 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)),
53 | nn.ELU(inplace=True),
54 | weight_norm(nn.Conv2d(hiddim_p, 2 * hiddim_p, 1, 1))
55 | )
56 |
57 | elif op == 'CAT_PROD':
58 | self.net1_mean_vis = nn.Sequential(
59 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
60 | nn.ELU(inplace=True),
61 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
62 | )
63 |
64 | self.net1_var_vis = nn.Sequential(
65 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
66 | nn.ELU(inplace=True),
67 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
68 | )
69 |
70 | self.net2_mean_vis = nn.Sequential(
71 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)),
72 | nn.ELU(inplace=True),
73 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
74 | )
75 |
76 | self.net2_var_vis = nn.Sequential(
77 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)),
78 | nn.ELU(inplace=True),
79 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
80 | )
81 |
82 | self.net1_mean_pos = nn.Sequential(
83 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
84 | nn.ELU(inplace=True),
85 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
86 | )
87 |
88 | self.net1_var_pos = nn.Sequential(
89 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
90 | nn.ELU(inplace=True),
91 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
92 | )
93 |
94 | self.net2_mean_pos = nn.Sequential(
95 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)),
96 | nn.ELU(inplace=True),
97 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
98 | )
99 |
100 | self.net2_var_pos = nn.Sequential(
101 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)),
102 | nn.ELU(inplace=True),
103 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
104 | )
105 | elif op == 'CAT_gPoE':
106 | self.net1_mean_vis = nn.Sequential(
107 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
108 | nn.ELU(inplace=True),
109 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
110 | )
111 |
112 | self.net1_var_vis = nn.Sequential(
113 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)),
114 | nn.ELU(inplace=True),
115 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1))
116 | )
117 |
118 | self.net1_mean_pos = nn.Sequential(
119 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
120 | nn.ELU(inplace=True),
121 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
122 | )
123 |
124 | self.net1_var_pos = nn.Sequential(
125 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)),
126 | nn.ELU(inplace=True),
127 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1))
128 | )
129 | self.gates_v = nn.Sequential(
130 | weight_norm(nn.Conv2d(hiddim_v * 4, hiddim_v * 4, 3, 1, 1)),
131 | nn.ELU(inplace=True),
132 | weight_norm(nn.Conv2d(hiddim_v * 4, hiddim_v * 4, 3, 1, 1))
133 | )
134 | self.gates_p = nn.Sequential(
135 | weight_norm(nn.Conv2d(hiddim_p * 4, hiddim_p * 4, 3, 1, 1)),
136 | nn.ELU(inplace=True),
137 | weight_norm(nn.Conv2d(hiddim_p * 4, hiddim_p * 4, 3, 1, 1))
138 | )
139 |
140 | def forward(self, x, y, mode, lognormal=False): # -> x describe y
141 | if mode == 'vis':
142 | if self.op == 'CAT_PROD':
143 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1))
144 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1))
145 |
146 | if lognormal == True:
147 | x_mean = torch.exp(x_mean)
148 |
149 | y_mean = self.net2_mean_vis(x_mean * y[0])
150 | y_var = self.net2_var_vis(x_var * y[1])
151 | elif self.op == 'CAT_PoE':
152 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
153 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
154 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1))
155 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1))
156 | mlogvar1 = -x_var
157 | mlogvar2 = -y[1]
158 | mu1 = x_mean
159 | mu2 = y[0]
160 |
161 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
162 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2)
163 | elif self.op == 'CAT_gPoE':
164 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
165 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
166 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1))
167 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1))
168 |
169 | # gates
170 | gates = torch.sigmoid(self.gates_v(torch.cat([x_mean, x_var, y[0], y[1]], dim=1)))
171 | x1_mu_g = gates[:, :self.hiddim_v, :, :]
172 | x1_var_g = gates[:, self.hiddim_v:2 * self.hiddim_v, :, :]
173 | x2_mu_g = gates[:, 2 * self.hiddim_v:3 * self.hiddim_v, :, :]
174 | x2_var_g = gates[:, 3 * self.hiddim_v:4 * self.hiddim_v, :, :]
175 |
176 | x_mean = x1_mu_g * x_mean
177 | x_var = torch.log(x1_var_g + 1e-5) + x_var
178 | y[0] = x2_mu_g * y[0]
179 | y[1] = torch.log(x2_var_g + 1e-5) + y[1]
180 |
181 | mlogvar1 = -x_var
182 | mlogvar2 = -y[1]
183 | mu1 = x_mean
184 | mu2 = y[0]
185 |
186 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
187 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2)
188 | elif self.op == 'CAT':
189 | y_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1))
190 | y_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1))
191 | elif self.op == 'PROD':
192 | y_mean = x[0] * y[0]
193 | y_var = x[1] * y[1]
194 | elif self.op == 'DEEP':
195 | gaussian_out = self.net_vis(torch.cat([x[0], x[1], y[0], y[1]], dim=1))
196 | y_mean = gaussian_out[:, :self.hiddim_v, :, :]
197 | y_var = gaussian_out[:, self.hiddim_v:, :, :]
198 | else:
199 | raise ValueError('invalid operator name {} for Describe module'.format(self.op))
200 |
201 | elif mode == 'pos':
202 | if self.op == 'CAT_PROD':
203 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1))
204 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1))
205 |
206 | y_mean = self.net2_mean_pos(x_mean * y[0])
207 | y_var = self.net2_var_pos(x_var * y[1])
208 | elif self.op == 'CAT_PoE':
209 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
210 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
211 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1))
212 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1))
213 |
214 | mlogvar1 = -x_var
215 | mlogvar2 = -y[1]
216 | mu1 = x_mean
217 | mu2 = y[0]
218 |
219 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
220 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2)
221 | elif self.op == 'CAT_gPoE':
222 | # logvar = -log(exp(-logvar1) + exp(-logvar2))
223 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2)
224 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1))
225 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1))
226 |
227 | # gates
228 | gates = torch.sigmoid(self.gates_p(torch.cat([x_mean, x_var, y[0], y[1]], dim=1)))
229 | x1_mu_g = gates[:, :self.hiddim_p, :, :]
230 | x1_var_g = gates[:, self.hiddim_p:2 * self.hiddim_p, :, :]
231 | x2_mu_g = gates[:, 2 * self.hiddim_p:3 * self.hiddim_p, :, :]
232 | x2_var_g = gates[:, 3 * self.hiddim_p:4 * self.hiddim_p, :, :]
233 |
234 | x_mean = x1_mu_g * x_mean
235 | x_var = torch.log(x1_var_g + 1e-5) + x_var
236 | y[0] = x2_mu_g * y[0]
237 | y[1] = torch.log(x2_var_g + 1e-5) + y[1]
238 |
239 | mlogvar1 = -x_var
240 | mlogvar2 = -y[1]
241 | mu1 = x_mean
242 | mu2 = y[0]
243 |
244 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2))
245 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2)
246 | elif self.op == 'CAT':
247 | y_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1))
248 | y_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1))
249 | elif self.op == 'PROD':
250 | y_mean = x[0] * y[0]
251 | y_var = x[1] * y[1]
252 | elif self.op == 'DEEP':
253 | gaussian_out = self.net_pos(torch.cat([x[0], x[1], y[0], y[1]], dim=1))
254 | y_mean = gaussian_out[:, :self.hiddim_p, :, :]
255 | y_var = gaussian_out[:, self.hiddim_p:, :, :]
256 | else:
257 | raise ValueError('invalid operator name {} for Describe module'.format(self.op))
258 |
259 | else:
260 | raise ValueError('invalid mode {}'.format(mode))
261 |
262 | return [y_mean, y_var]
263 |
--------------------------------------------------------------------------------
/models/PNPNet/pnp_net.py:
--------------------------------------------------------------------------------
1 | '''
2 | *
3 | PNP-Net: flexibly takes a tree-structure program and
4 | assemble modules for image generative modeling
5 | *
6 |
7 | It contains:
8 |
9 | -- primitive visual elements
10 |
11 | -- unit modules
12 |
13 | -- tree recursive function
14 |
15 | -- forward(), generate()
16 |
17 | -- utility functions: clean_tree(), etc.
18 | '''
19 |
20 | import torch
21 | import torch.nn as nn
22 | from torch.autograd import Variable
23 | import torch.nn.functional as F
24 | import numpy as np
25 | import math
26 |
27 | from lib.reparameterize import reparameterize
28 | from lib.modules.VAE import VAE
29 | from lib.modules.ResReader import Reader
30 | from lib.modules.ResWriter import Writer
31 | from lib.modules.ConceptMapper import ConceptMapper
32 | from lib.modules.Combine import Combine
33 | from lib.modules.Describe import Describe
34 | from lib.modules.Transform import Transform
35 | from lib.modules.DistributionRender import DistributionRender
36 |
37 |
38 | class PNPNet(nn.Module):
39 | def __init__(self, hiddim=160, latentdim=12,
40 | word_size=[-1, 16, 16], pos_size=[4, 1, 1], nres=4, nlayers=1,
41 | nonlinear='elu', dictionary=None, op=['PROD', 'CAT'],
42 | lmap_size=0, downsample=2, gpu_ids=None,
43 | multigpu_full=False, lambdakl=-1, bg_bias=False, normalize='instance_norm',
44 | loss=None, debug_mode=True):
45 | super(PNPNet, self).__init__()
46 | ## basic settings
47 | word_size[0] = latentdim
48 | self.word_size = word_size
49 | self.latentdim = latentdim
50 | self.hiddim = hiddim
51 | self.downsample = downsample # -> downsample times
52 | self.ds = 2 ** self.downsample # -> downsample ratio
53 | self.nres = nres
54 | self.nlayers = nlayers
55 | self.lmap_size = lmap_size
56 | self.im_size = lmap_size * self.ds
57 | self.multigpu_full = multigpu_full
58 | self.bg_bias = bg_bias
59 | self.normalize = normalize
60 | self.debug_mode = debug_mode
61 |
62 | # dictionary
63 | self.dictionary = dictionary
64 |
65 | ########## modules ##########
66 | # proposal networks
67 | self.reader = Reader(indim=3, hiddim=hiddim, outdim=hiddim, ds_times=self.downsample, normalize=normalize,
68 | nlayers=nlayers)
69 | self.h_mean = nn.Conv2d(hiddim, latentdim, 3, 1, 1)
70 | self.h_var = nn.Conv2d(hiddim, latentdim, 3, 1, 1)
71 |
72 | # pixel writer
73 | self.writer = Writer(indim=latentdim, hiddim=hiddim, outdim=3, ds_times=self.downsample, normalize=normalize,
74 | nlayers=nlayers)
75 |
76 | # visual words
77 | self.vis_dist = ConceptMapper(word_size, len(dictionary))
78 | self.pos_dist = ConceptMapper(pos_size, len(dictionary))
79 |
80 | self.renderer = DistributionRender(hiddim=latentdim)
81 |
82 | # neural modules
83 | self.combine = Combine(hiddim_v=latentdim, hiddim_p=pos_size[0], op=op[0])
84 | self.describe = Describe(hiddim_v=latentdim, hiddim_p=pos_size[0], op=op[1])
85 | self.transform = Transform(matrix='default')
86 |
87 | # small vaes for bounding boxes and offsets learning
88 | # input: H, W
89 | self.box_vae = VAE(indim=2, latentdim=pos_size[0])
90 | # input: [x0, y0, x1, y1] + condition: [H0, W0, H1, W1, im_H, im_W, (im_H-H0), (im_W-W0), (im_H-H1), (im_W-W1)]
91 | self.offset_vae = VAE(indim=4, latentdim=pos_size[0])
92 |
93 | ## loss functions&sampler
94 | self.sampler = reparameterize()
95 | if lambdakl > 0:
96 | from lib.LambdaBiKLD import BiKLD
97 | self.bikld = BiKLD(lambda_t=lambdakl, k=None)
98 | else:
99 | from lib.BiKLD import BiKLD
100 | self.bikld = BiKLD()
101 |
102 | if loss == 'l1':
103 | self.pixelrecon_criterion = nn.L1Loss()
104 | elif loss == 'l2':
105 | self.pixelrecon_criterion = nn.MSELoss()
106 | self.pixelrecon_criterion.size_average = False
107 |
108 | self.pos_criterion = nn.MSELoss()
109 | self.pos_criterion.size_average = False
110 |
111 | ## biases
112 | self.bias_mean = nn.Linear(1, self.latentdim * self.lmap_size * self.lmap_size, bias=False)
113 | self.bias_var = nn.Linear(1, self.latentdim * self.lmap_size * self.lmap_size, bias=False)
114 | self.latent_canvas_size = torch.Size([1, self.latentdim, self.lmap_size, self.lmap_size])
115 |
116 | def get_mask_from_tree(self, tree, size):
117 | mask = Variable(torch.zeros(size), requires_grad=False).cuda()
118 |
119 | return self._get_mask_from_tree(tree, mask)
120 |
121 | def _get_mask_from_tree(self, tree, mask):
122 | for i in range(0, tree.num_children):
123 | mask = self._get_mask_from_tree(tree.children[i], mask)
124 |
125 | if tree.function == 'describe':
126 | bbx = tree.bbox
127 | mask[:, :, bbx[0]:bbx[0] + bbx[2], bbx[1]:bbx[1] + bbx[3]] = 1.0
128 |
129 | return mask
130 |
131 | def forward(self, x, treex, treeindex=None, alpha=1.0, ifmask=False, maskweight=1.0):
132 | ################################
133 | ## input: images, trees ##
134 | ################################
135 |
136 | # if multigpu_full, pick the trees by treeindex
137 | if self.multigpu_full:
138 | treex_pick = [treex[ele[0]] for ele in treeindex.data.cpu().numpy().astype(int)]
139 | treex = treex_pick
140 |
141 | if ifmask == True:
142 | mask = []
143 | for i in range(0, len(treex)):
144 | mask += [self.get_mask_from_tree(treex[i], x[0:1, :, :, :].size())]
145 | mask = torch.cat(mask, dim=0)
146 |
147 | # encoding the images
148 | h = self.reader(x)
149 | # proposal distribution
150 | latent_mean = self.h_mean(h)
151 | latent_var = self.h_var(h)
152 |
153 | # losses
154 | kld_loss, rec_loss, pos_loss = 0, 0, 0
155 |
156 | # forward GNMN
157 | prior_mean_all = []
158 | prior_var_all = []
159 | trees = []
160 | for i in range(0, len(treex)): # iterate through every tree of the batch
161 | trees.append(self.compose_tree(treex[i], self.latent_canvas_size))
162 | prior_mean_all += [trees[i].vis_dist[0]]
163 | prior_var_all += [trees[i].vis_dist[1]]
164 | pos_loss += trees[i].pos_loss
165 | if np.isnan(trees[i].pos_loss.data.cpu().numpy()):
166 | print('found nan pos loss')
167 | import IPython;
168 | IPython.embed()
169 |
170 | prior_mean = torch.cat(prior_mean_all, dim=0)
171 | prior_var = torch.cat(prior_var_all, dim=0)
172 |
173 | prior_mean, prior_var = self.renderer([prior_mean, prior_var])
174 |
175 | # sample z map
176 | z_map = self.sampler(latent_mean, latent_var)
177 |
178 | # kld loss
179 | kld_loss = alpha * self.bikld([latent_mean, latent_var], [prior_mean, prior_var]) + \
180 | (1 - alpha) * self.bikld([latent_mean.detach(), latent_var.detach()], [prior_mean, prior_var])
181 |
182 | rec = self.writer(z_map)
183 |
184 | if ifmask is True:
185 | mask = (mask + maskweight) / (maskweight + 1.0)
186 | rec_loss = self.pixelrecon_criterion(mask * rec, mask * x)
187 | else:
188 | rec_loss = self.pixelrecon_criterion(rec, x)
189 | rec_loss = rec_loss.sum()
190 |
191 | return rec_loss, kld_loss, pos_loss, rec
192 |
193 | def get_code(self, dictionary, word):
194 | code = Variable(torch.zeros(1, len(dictionary))).cuda()
195 | code[0, dictionary.index(word)] = 1
196 |
197 | return code
198 |
199 | def compose_tree(self, treex, latent_canvas_size):
200 | for i in range(0, treex.num_children):
201 | treex.children[i] = self.compose_tree(treex.children[i], latent_canvas_size)
202 |
203 | # one hot embedding of a word
204 | ohe = self.get_code(self.dictionary, treex.word)
205 |
206 | if treex.function == 'combine':
207 | vis_dist = self.vis_dist(ohe)
208 | pos_dist = self.pos_dist(ohe)
209 | if treex.num_children > 0:
210 | # visual content
211 | vis_dist_child = treex.children[0].vis_dist
212 | vis_dist = self.combine(vis_dist, vis_dist_child, 'vis')
213 | # visual position
214 | pos_dist_child = treex.children[0].pos_dist
215 | pos_dist = self.combine(pos_dist, pos_dist_child, 'pos')
216 |
217 | treex.vis_dist = vis_dist
218 | treex.pos_dist = pos_dist
219 |
220 | elif treex.function == 'describe':
221 | # blend visual words
222 | vis_dist = self.vis_dist(ohe)
223 | pos_dist = self.pos_dist(ohe)
224 | if treex.num_children > 0:
225 | # visual content
226 | vis_dist_child = treex.children[0].vis_dist
227 | vis_dist = self.describe(vis_dist_child, vis_dist, 'vis')
228 | # visual position
229 | pos_dist_child = treex.children[0].pos_dist
230 | pos_dist = self.describe(pos_dist_child, pos_dist, 'pos')
231 |
232 | treex.pos_dist = pos_dist
233 |
234 | # regress bbox
235 | treex.pos = np.maximum(treex.bbox[2:] // self.ds, [1, 1])
236 | target_box = Variable(torch.from_numpy(np.array(treex.bbox[2:])[np.newaxis, ...].astype(np.float32))).cuda()
237 | regress_box, kl_box = self.box_vae(target_box, prior=treex.pos_dist)
238 | treex.pos_loss = self.pos_criterion(regress_box, target_box) + kl_box
239 |
240 | if treex.parent == None:
241 | ones = self.get_ones(torch.Size([1, 1]))
242 | if not self.bg_bias:
243 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \
244 | Variable(torch.zeros(latent_canvas_size)).cuda()]
245 | else:
246 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \
247 | self.bias_var(ones).view(*latent_canvas_size)]
248 | b = np.maximum(treex.bbox // self.ds, [0, 0, 1, 1])
249 |
250 | bg_vis_dist = [self.assign_util(bg_vis_dist[0], b, self.transform(vis_dist[0], treex.pos),
251 | 'assign'), \
252 | self.assign_util(bg_vis_dist[1], b,
253 | self.transform(vis_dist[1], treex.pos, variance=True),
254 | 'assign')]
255 | vis_dist = bg_vis_dist
256 | else:
257 | try:
258 | # resize vis_dist
259 | vis_dist = [self.transform(vis_dist[0], treex.pos), \
260 | self.transform(vis_dist[1], treex.pos, variance=True)]
261 | except:
262 | import IPython;
263 | IPython.embed()
264 |
265 | treex.vis_dist = vis_dist
266 |
267 | elif treex.function == 'layout':
268 | # get pos word as position prior
269 | treex.pos_dist = self.pos_dist(ohe)
270 | assert (treex.num_children > 0)
271 |
272 | # get offsets: use gt for training
273 | l_pos = treex.children[0].pos
274 | l_offset = np.maximum(treex.children[0].bbox[:2] // self.ds, [1, 1])
275 |
276 | r_pos = treex.children[1].pos
277 | r_offset = np.maximum(treex.children[1].bbox[:2] // self.ds, [1, 1])
278 |
279 | # regress offsets
280 | target_offset = np.append(l_offset * self.ds, r_offset * self.ds).astype(np.float32)
281 | target_offset = Variable(torch.from_numpy(target_offset[np.newaxis, ...])).cuda()
282 | regress_offset, kl_offset = self.offset_vae(target_offset, prior=treex.pos_dist)
283 | treex.pos_loss = self.pos_criterion(regress_offset, target_offset) + kl_offset + treex.children[
284 | 0].pos_loss + \
285 | treex.children[1].pos_loss
286 |
287 | ######################### constructing latent map ###############################
288 | # bias filled mean&var
289 | ones = self.get_ones(torch.Size([1, 1]))
290 | if not self.bg_bias:
291 | vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \
292 | Variable(torch.zeros(latent_canvas_size)).cuda()]
293 | else:
294 | vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \
295 | self.bias_var(ones).view(*latent_canvas_size)]
296 |
297 | # arrange the layout of two children
298 | vis_dist[0] = self.assign_util(vis_dist[0], list(l_offset) + list(l_pos), treex.children[0].vis_dist[0],
299 | 'assign')
300 | vis_dist[1] = self.assign_util(vis_dist[1], list(l_offset) + list(l_pos), treex.children[0].vis_dist[1],
301 | 'assign')
302 |
303 | vis_dist[0] = self.assign_util(vis_dist[0], list(r_offset) + list(r_pos), treex.children[1].vis_dist[0],
304 | 'assign')
305 | vis_dist[1] = self.assign_util(vis_dist[1], list(r_offset) + list(r_pos), treex.children[1].vis_dist[1],
306 | 'assign')
307 |
308 | # continue layout
309 | if treex.parent != None:
310 | p = [min(l_offset[0], r_offset[0]), min(l_offset[1], r_offset[1]), \
311 | max(l_offset[0] + l_pos[0], r_offset[0] + r_pos[0]),
312 | max(l_offset[1] + l_pos[1], r_offset[1] + r_pos[1])]
313 | treex.pos = [p[2] - p[0], p[3] - p[1]]
314 | treex.vis_dist = [vis_dist[0][:, :, p[0]:p[2], p[1]:p[3]], \
315 | vis_dist[1][:, :, p[0]:p[2], p[1]:p[3]]]
316 | else:
317 | treex.vis_dist = vis_dist
318 |
319 | return treex
320 |
321 | def assign_util(self, a, bx, b, mode):
322 | if mode == 'assign':
323 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] = b
324 | elif mode == 'add':
325 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] = \
326 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] + b
327 | elif mode == 'slice':
328 | a = a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]].clone()
329 | else:
330 | raise ValueError('Please specify the correct mode.')
331 | return a
332 |
333 | def overlap_box(self, box_left, box_right):
334 | x1, y1, h1, w1 = box_left[0], box_left[1], box_left[2], box_left[3]
335 | x2, y2, h2, w2 = box_right[0], box_right[1], box_right[2], box_right[3]
336 |
337 | ox1 = max(x1, x2)
338 | oy1 = max(y1, y2)
339 | ox2 = min(x1 + h1, x2 + h2)
340 | oy2 = min(y1 + w1, y2 + w2)
341 |
342 | if ox2 > ox1 and oy2 > oy1:
343 | return [ox1, oy1, ox2 - ox1, oy2 - oy1]
344 | else:
345 | return []
346 |
347 | def generate(self, x, treex, treeindex=None):
348 | ################################
349 | ## input: images, trees ##
350 | ################################
351 | if self.multigpu_full:
352 | treex_pick = [treex[ele[0]] for ele in treeindex.data.cpu().numpy().astype(int)]
353 | treex = treex_pick
354 |
355 | # tranverse trees to compose visual words
356 | prior_mean = []
357 | prior_var = []
358 |
359 | for i in range(0, len(treex)):
360 | treex[i] = self.generate_compose_tree(treex[i], self.latent_canvas_size)
361 | prior_mean += [treex[i].vis_dist[0]]
362 | prior_var += [treex[i].vis_dist[1]]
363 | prior_mean = torch.cat(prior_mean, dim=0)
364 | prior_var = torch.cat(prior_var, dim=0)
365 |
366 | # sample z map
367 | prior_mean, prior_var = self.renderer([prior_mean, prior_var])
368 |
369 | z_map = self.sampler(prior_mean, prior_var)
370 |
371 | rec = self.writer(z_map)
372 |
373 | return rec
374 |
375 | def check_valid(self, offsets, l_pos, r_pos, im_size):
376 | flag = True
377 | if offsets[0] + l_pos[0] > im_size:
378 | flag = False
379 | return flag
380 | if offsets[1] + l_pos[1] > im_size:
381 | flag = False
382 | return flag
383 | if offsets[2] + r_pos[0] > im_size:
384 | flag = False
385 | return flag
386 | if offsets[3] + r_pos[1] > im_size:
387 | flag = False
388 | return flag
389 |
390 | return flag
391 |
392 | def generate_compose_tree(self, treex, latent_canvas_size):
393 | for i in range(0, treex.num_children):
394 | treex.children[i] = self.generate_compose_tree(treex.children[i], latent_canvas_size)
395 |
396 | # one hot embedding of a word
397 | ohe = self.get_code(self.dictionary, treex.word)
398 | if treex.function == 'combine':
399 | vis_dist = self.vis_dist(ohe)
400 | pos_dist = self.pos_dist(ohe)
401 | if treex.num_children > 0:
402 | # visual content
403 | vis_dist_child = treex.children[0].vis_dist
404 | vis_dist = self.combine(vis_dist, vis_dist_child, 'vis')
405 | # visual position
406 | pos_dist_child = treex.children[0].pos_dist
407 | pos_dist = self.combine(pos_dist, pos_dist_child, 'pos')
408 |
409 | treex.vis_dist = vis_dist
410 | treex.pos_dist = pos_dist
411 |
412 | elif treex.function == 'describe':
413 | # blend visual words
414 | vis_dist = self.vis_dist(ohe)
415 | pos_dist = self.pos_dist(ohe)
416 | if treex.num_children > 0:
417 | # visual content
418 | vis_dist_child = treex.children[0].vis_dist
419 | vis_dist = self.describe(vis_dist_child, vis_dist, 'vis')
420 | # visual position
421 | pos_dist_child = treex.children[0].pos_dist
422 | pos_dist = self.describe(pos_dist_child, pos_dist, 'pos')
423 |
424 | treex.pos_dist = pos_dist
425 |
426 | # regress bbox
427 | treex.pos = np.clip(self.box_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int),
428 | int(self.ds),
429 | self.im_size).flatten() // self.ds
430 |
431 | if treex.parent == None:
432 | ones = self.get_ones(torch.Size([1, 1]))
433 | if not self.bg_bias:
434 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \
435 | Variable(torch.zeros(latent_canvas_size)).cuda()]
436 | else:
437 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \
438 | self.bias_var(ones).view(*latent_canvas_size)]
439 | b = [int(latent_canvas_size[2]) // 2 - treex.pos[0] // 2,
440 | int(latent_canvas_size[3]) // 2 - treex.pos[1] // 2, treex.pos[0], treex.pos[1]]
441 |
442 | bg_vis_dist = [self.assign_util(bg_vis_dist[0], b, self.transform(vis_dist[0], treex.pos),
443 | 'assign'), \
444 | self.assign_util(bg_vis_dist[1], b,
445 | self.transform(vis_dist[1], treex.pos, variance=True),
446 | 'assign')]
447 |
448 | vis_dist = bg_vis_dist
449 | treex.offsets = b
450 | else:
451 | # resize vis_dist
452 | vis_dist = [self.transform(vis_dist[0], treex.pos), \
453 | self.transform(vis_dist[1], treex.pos, variance=True)]
454 |
455 | treex.vis_dist = vis_dist
456 |
457 | elif treex.function == 'layout':
458 | # get pos word as position prior
459 | treex.pos_dist = self.pos_dist(ohe)
460 | assert (treex.num_children > 0)
461 |
462 | # get offsets: use gt for training
463 | l_pos = treex.children[0].pos
464 | r_pos = treex.children[1].pos
465 |
466 | offsets = np.clip(self.offset_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int), 0,
467 | self.im_size).flatten() // self.ds
468 | countdown = 0
469 | while self.check_valid(offsets, l_pos, r_pos, self.im_size // self.ds) == False:
470 | offsets = np.clip(self.offset_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int), 0,
471 | self.im_size).flatten() // self.ds
472 | if countdown >= 100:
473 | print('Tried proposing more than 100 times.')
474 | if self.debug_mode:
475 | import IPython;
476 | IPython.embed()
477 | print('Warning! Manually adapt offsets')
478 | lat_size = self.im_size // self.ds
479 | if offsets[0] + l_pos[0] > lat_size:
480 | offsets[0] = lat_size - l_pos[0]
481 | if offsets[1] + l_pos[1] > lat_size:
482 | offsets[1] = lat_size - l_pos[1]
483 | if offsets[2] + r_pos[0] > lat_size:
484 | offsets[2] = lat_size - r_pos[0]
485 | if offsets[3] + r_pos[1] > lat_size:
486 | offsets[3] = lat_size - r_pos[1]
487 |
488 | countdown += 1
489 | treex.offsets = offsets
490 | l_offset = offsets[:2]
491 | r_offset = offsets[2:]
492 |
493 | ######################### constructing latent map ###############################
494 | # bias filled mean&var
495 | ones = self.get_ones(torch.Size([1, 1]))
496 | if not self.bg_bias:
497 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \
498 | Variable(torch.zeros(latent_canvas_size)).cuda()]
499 | else:
500 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \
501 | self.bias_var(ones).view(*latent_canvas_size)]
502 |
503 | vis_dist = bg_vis_dist
504 | try:
505 | # arrange the layout of two children
506 | vis_dist[0] = self.assign_util(vis_dist[0], list(l_offset) + list(l_pos), treex.children[0].vis_dist[0],
507 | 'assign')
508 | vis_dist[1] = self.assign_util(vis_dist[1], list(l_offset) + list(l_pos), treex.children[0].vis_dist[1],
509 | 'assign')
510 |
511 | vis_dist[0] = self.assign_util(vis_dist[0], list(r_offset) + list(r_pos), treex.children[1].vis_dist[0],
512 | 'assign')
513 | vis_dist[1] = self.assign_util(vis_dist[1], list(r_offset) + list(r_pos), treex.children[1].vis_dist[1],
514 | 'assign')
515 | except:
516 | print('latent distribution doesnt fit size.')
517 | import IPython;
518 | IPython.embed()
519 |
520 | if treex.parent != None:
521 | p = [min(l_offset[0], r_offset[0]), min(l_offset[1], r_offset[1]), \
522 | max(l_offset[0] + l_pos[0], r_offset[0] + r_pos[0]),
523 | max(l_offset[1] + l_pos[1], r_offset[1] + r_pos[1])]
524 | treex.pos = [p[2] - p[0], p[3] - p[1]]
525 | treex.vis_dist = [vis_dist[0][:, :, p[0]:p[2], p[1]:p[3]], \
526 | vis_dist[1][:, :, p[0]:p[2], p[1]:p[3]]]
527 | else:
528 | treex.vis_dist = vis_dist
529 |
530 | return treex
531 |
532 | def get_ones(self, size):
533 | return Variable(torch.ones(size), requires_grad=False).cuda()
534 |
535 | def clean_tree(self, treex):
536 | for i in range(0, len(treex)):
537 | self._clean_tree(treex[i])
538 |
539 | def _clean_tree(self, treex):
540 | for i in range(0, treex.num_children):
541 | self._clean_tree(treex.children[i])
542 |
543 | if treex.function == 'combine':
544 | treex.vis_dist = None
545 | treex.pos_dist = None
546 | elif treex.function == 'describe':
547 | treex.vis_dist = None
548 | treex.pos_dist = None
549 | treex.pos = None
550 | elif treex.function == 'layout':
551 | treex.vis_dist = None
552 | treex.pos_dist = None
553 | treex.pos = None
554 |
--------------------------------------------------------------------------------