├── lib ├── __init__.py ├── modules │ ├── __init__.py │ ├── Layout.py │ ├── _init_paths.py │ ├── ResReader.py │ ├── ConceptMapper.py │ ├── ResWriter.py │ ├── Transform.py │ ├── DistributionRender.py │ ├── VAE.py │ ├── Combine.py │ └── Describe.py ├── data_loader │ ├── __init__.py │ ├── clevr │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── treeutils.py │ │ └── clevr_tree.py │ ├── _init_paths.py │ └── color_mnist_tree_multi.py ├── BiKLD.py ├── weight_init.py ├── reparameterize.py ├── utils.py ├── tree.py ├── config.py ├── LambdaBiKLD.py └── ResidualModule.py ├── mains ├── __init__.py ├── _init_paths.py └── pnpnet_main.py ├── models ├── __init__.py └── PNPNet │ ├── __init__.py │ └── pnp_net.py ├── data ├── .gitplaceholder └── CLEVR │ └── add_parent.py ├── trainers ├── __init__.py └── pnpnet_trainer.py ├── .gitignore ├── requirements.txt ├── images ├── combine.png ├── layout.png ├── mapping.png ├── samples.png ├── describe.png ├── pipeline.png └── transform.png ├── .gitmodules ├── configs └── pnp_net_configs.yaml ├── README.md └── LICENSE /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mains/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitplaceholder: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/PNPNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/data_loader/clevr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea 3 | results/ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==0.4.0 2 | torchvision==0.2.1 3 | -------------------------------------------------------------------------------- /images/combine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/combine.png -------------------------------------------------------------------------------- /images/layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/layout.png -------------------------------------------------------------------------------- /images/mapping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/mapping.png -------------------------------------------------------------------------------- /images/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/samples.png -------------------------------------------------------------------------------- /images/describe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/describe.png -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/pipeline.png -------------------------------------------------------------------------------- /images/transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas2012/ProbabilisticNeuralProgrammedNetwork/HEAD/images/transform.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data/CLEVR/clevr-dataset-gen"] 2 | path = data/CLEVR/clevr-dataset-gen 3 | url = git@github.com:woodfrog/clevr-dataset-gen.git 4 | [submodule "SemanticCorrectnessScore"] 5 | path = SemanticCorrectnessScore 6 | url = git@github.com:woodfrog/SemanticCorrectnessScore.git 7 | -------------------------------------------------------------------------------- /lib/modules/Layout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | class Layout(nn.Module): 7 | def __init__(self, hiddim=160): 8 | super(Layout, self).__init__() 9 | 10 | def forward(self, x, y): 11 | pass 12 | 13 | -------------------------------------------------------------------------------- /lib/data_loader/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: sys.path.insert(0, path) 6 | 7 | this_dir = osp.dirname(__file__) 8 | 9 | # Add caffe to PYTHONPATH 10 | project_path = osp.abspath( osp.join(this_dir, '..', '..') ) 11 | add_path(project_path) 12 | -------------------------------------------------------------------------------- /lib/data_loader/clevr/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: sys.path.insert(0, path) 6 | 7 | this_dir = osp.dirname(__file__) 8 | 9 | # Add caffe to PYTHONPATH 10 | project_path = osp.abspath(osp.join(this_dir, '..', '..', '..') ) 11 | add_path(project_path) 12 | -------------------------------------------------------------------------------- /mains/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: sys.path.insert(0, path) 6 | 7 | this_dir = osp.dirname(__file__) 8 | 9 | # Add caffe to PYTHONPATH 10 | project_path = osp.abspath( osp.join(this_dir, '..')) 11 | print(project_path) 12 | add_path(project_path) 13 | -------------------------------------------------------------------------------- /lib/modules/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: sys.path.insert(0, path) 6 | 7 | this_dir = osp.dirname(__file__) 8 | 9 | # Add caffe to PYTHONPATH 10 | project_path = osp.abspath( osp.join(this_dir, '..', '..') ) 11 | print(project_path) 12 | add_path(project_path) 13 | -------------------------------------------------------------------------------- /lib/BiKLD.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | 7 | class BiKLD(nn.Module): 8 | def __init__(self): 9 | super(BiKLD, self).__init__() 10 | 11 | def forward(self, q, p): 12 | q_mu, q_var = q[0], t.exp(q[1]) 13 | p_mu, p_var = p[0], t.exp(p[1]) 14 | 15 | kld = q_var / p_var - 1 16 | kld += (p_mu - q_mu).pow(2) / p_var 17 | kld += p[1] - q[1] 18 | kld = kld.sum() / 2 19 | 20 | return kld 21 | -------------------------------------------------------------------------------- /lib/weight_init.py: -------------------------------------------------------------------------------- 1 | def weights_init(m): 2 | classname = m.__class__.__name__ 3 | if classname.find('Conv') != -1: 4 | try: 5 | m.weight.data.normal_(0.0, 0.02) 6 | m.bias.data.fill_(0) 7 | except: 8 | pass 9 | elif classname.find('Linear') != -1: 10 | try: 11 | m.weight.data.normal_(0.0, 0.02) 12 | 13 | except: 14 | pass 15 | elif classname.find('Embedding') != -1: 16 | try: 17 | m.weight.data.normal_(0.0, 0.02) 18 | m.bias.data.fill_(0) 19 | except: 20 | pass -------------------------------------------------------------------------------- /lib/reparameterize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class reparameterize(nn.Module): 7 | def __init__(self): 8 | super(reparameterize, self).__init__() 9 | 10 | def forward(self, mu, logvar, sample_num=1, phase='training'): 11 | if phase == 'training': 12 | std = logvar.mul(0.5).exp_() 13 | eps = Variable(std.data.new(std.size()).normal_()) 14 | return eps.mul(std).add_(mu) 15 | else: 16 | raise ValueError('Wrong phase. Always assume training phase.') 17 | # elif phase == 'test': 18 | # return mu 19 | # elif phase == 'generation': 20 | # eps = Variable(logvar.data.new(logvar.size()).normal_()) 21 | # return eps 22 | -------------------------------------------------------------------------------- /configs/pnp_net_configs.yaml: -------------------------------------------------------------------------------- 1 | # Random seed 2 | seed: 12138 3 | # 4 | mode: train 5 | dataset: CLEVR 6 | # checkpoint: ./results/CLEVR_64_MULTI_LARGE/PNP-Net-5/checkpoints/model_epoch_360.pth 7 | checkpoint: 8 | data_folder: CLEVR_64_MULTI_LARGE 9 | base_dir: ./data/CLEVR/ 10 | exp_dir_name: PNP-Net 11 | # Hyper parameter 12 | hiddim: 160 13 | latentdim: 64 14 | pos_size: [8, 1, 1] 15 | nr_resnet: 5 16 | word_size: 16 17 | ds: 2 18 | combine_op: gPoE 19 | describe_op: CAT_gPoE 20 | maskweight: 2.0 21 | bg_bias: False 22 | normalize: batch_norm 23 | loss: l1 24 | # Training 25 | batch_size: 16 26 | epochs: 500 27 | gpu_id: 0 28 | log_interval: 10 29 | lr: 0.001 30 | kl_beta: 5 31 | alpha_ub: 0.6 32 | pos_beta: 1 33 | warmup_iter: 100 34 | sample_interval: 20 35 | validate_interval: 100 36 | save_interval: 30 37 | 38 | -------------------------------------------------------------------------------- /lib/modules/ResReader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils import weight_norm 5 | import torch.nn.functional as F 6 | from lib.ResidualModule import ResidualModule 7 | 8 | 9 | class Reader(nn.Module): 10 | def __init__(self, indim, hiddim, outdim, ds_times, normalize, nlayers=4): 11 | super(Reader, self).__init__() 12 | 13 | self.ds_times = ds_times 14 | 15 | if normalize == 'gate': 16 | ifgate = True 17 | else: 18 | ifgate = False 19 | 20 | self.encoder = ResidualModule(modeltype='encoder', indim=indim, hiddim=hiddim, outdim=outdim, 21 | nres=self.ds_times, nlayers=nlayers, ifgate=ifgate, normalize=normalize) 22 | 23 | def forward(self, x): 24 | out = self.encoder(x) 25 | 26 | return out 27 | -------------------------------------------------------------------------------- /lib/modules/ConceptMapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | class ConceptMapper(nn.Module): 7 | def __init__(self, CHW, vocab_size): 8 | super(ConceptMapper, self).__init__() 9 | C, H, W = CHW[0], CHW[1], CHW[2] 10 | self.mean_dictionary = nn.Linear(vocab_size, C*H*W, bias=False) 11 | self.std_dictionary = nn.Linear(vocab_size, C*H*W, bias=False) 12 | self.C, self.H, self.W = C, H, W 13 | 14 | def forward(self, x): 15 | word_mean = self.mean_dictionary(x) 16 | word_std = self.std_dictionary(x) 17 | if self.H == 1 and self.W == 1: 18 | return [word_mean.view(-1, self.C, 1, 1), \ 19 | word_std.view(-1, self.C, 1, 1)] 20 | else: 21 | return [word_mean.view(-1, self.C, self.H, self.W), \ 22 | word_std.view(-1, self.C, self.H, self.W)] 23 | 24 | -------------------------------------------------------------------------------- /lib/modules/ResWriter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils import weight_norm 5 | import torch.nn.functional as F 6 | from lib.ResidualModule import ResidualModule 7 | 8 | 9 | class Writer(nn.Module): 10 | def __init__(self, indim, hiddim, outdim, ds_times, normalize, nlayers=4): 11 | super(Writer, self).__init__() 12 | 13 | self.ds_times = ds_times 14 | 15 | if normalize == 'gate': 16 | ifgate = True 17 | else: 18 | ifgate = False 19 | 20 | self.decoder = ResidualModule(modeltype='decoder', indim=indim, hiddim=hiddim, outdim=hiddim, 21 | nres=self.ds_times, nlayers=nlayers, ifgate=ifgate, normalize=normalize) 22 | 23 | self.out_conv = nn.Conv2d(hiddim, outdim, 3, 1, 1) 24 | 25 | def forward(self, x): 26 | out = self.decoder(x) 27 | out = self.out_conv(out) 28 | 29 | return out 30 | -------------------------------------------------------------------------------- /lib/modules/Transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class Transform(nn.Module): 9 | def __init__(self, matrix='default'): 10 | super(Transform, self).__init__() 11 | 12 | def forward(self, x, hw, variance=False): 13 | if variance: 14 | x = torch.exp(x / 2.0) 15 | size = torch.Size([x.size(0), x.size(1), int(hw[0]), int(hw[1])]) 16 | 17 | # grid generation 18 | theta = np.array([[[1, 0, 0], [0, 1, 0]]], dtype=np.float32) 19 | theta = Variable(torch.from_numpy(theta), requires_grad=False).cuda() 20 | theta = theta.expand(x.size(0), theta.size(1), theta.size(2)) 21 | gridout = F.affine_grid(theta, size) 22 | 23 | # bilinear sampling 24 | out = F.grid_sample(x, gridout, mode='bilinear') 25 | if variance: 26 | out = torch.log(out) * 2.0 27 | return out 28 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | 4 | 5 | def color_grid_vis(X, nh, nw, save_path): 6 | h, w = X[0].shape[:2] 7 | img = np.zeros((h * nh, w * nw, 3)) 8 | for n, x in enumerate(X): 9 | j = int(n / nw) 10 | i = n % nw 11 | img[j * h:j * h + h, i * w:i * w + w, :] = x 12 | scipy.misc.imsave(save_path, img) 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.pixel_count = 0 25 | self.batch_count = 0 26 | 27 | def update(self, val, n=1, batch=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.pixel_count += n 31 | self.batch_count += batch 32 | self.pixel_avg = self.sum / self.pixel_count 33 | self.batch_avg = self.sum / self.batch_count 34 | -------------------------------------------------------------------------------- /lib/tree.py: -------------------------------------------------------------------------------- 1 | # tree object from stanfordnlp/treelstm 2 | class Tree(object): 3 | def __init__(self): 4 | self.parent = None 5 | self.num_children = 0 6 | self.children = list() 7 | 8 | def add_child(self, child): 9 | child.parent = self 10 | self.num_children += 1 11 | self.children.append(child) 12 | 13 | def size(self): 14 | if getattr(self, '_size'): 15 | return self._size 16 | count = 1 17 | for i in xrange(self.num_children): 18 | count += self.children[i].size() 19 | self._size = count 20 | return self._size 21 | 22 | def depth(self): 23 | if getattr(self, '_depth'): 24 | return self._depth 25 | count = 0 26 | if self.num_children > 0: 27 | for i in xrange(self.num_children): 28 | child_depth = self.children[i].depth() 29 | if child_depth > count: 30 | count = child_depth 31 | count += 1 32 | self._depth = count 33 | return self._depth 34 | -------------------------------------------------------------------------------- /lib/modules/DistributionRender.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils import weight_norm 5 | import torch.nn.functional as F 6 | 7 | class DistributionRender(nn.Module): 8 | def __init__(self, hiddim): 9 | super(DistributionRender, self).__init__() 10 | self.render_mean = nn.Sequential( 11 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 12 | nn.ELU(inplace=True), 13 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 14 | nn.ELU(inplace=True), 15 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 16 | ) 17 | 18 | self.render_var = nn.Sequential( 19 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 20 | nn.ELU(inplace=True), 21 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 22 | nn.ELU(inplace=True), 23 | nn.Conv2d(hiddim, hiddim, 3, 1, 1), 24 | ) 25 | 26 | 27 | def forward(self, x): 28 | # x = [mean, var] 29 | return self.render_mean(x[0]), self.render_var(x[1]) 30 | -------------------------------------------------------------------------------- /data/CLEVR/add_parent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script adds a parent pointer for all tree nodes. 3 | The original trees generated together with CLEVR images doesn't have this kind of parent pointer, 4 | but the parent pointer can be useful when we implementing PNP-Net 5 | 6 | """ 7 | 8 | import _init_paths 9 | import pickle 10 | from lib.tree import Tree 11 | import os.path as osp 12 | import os 13 | 14 | def add_parent(tree): 15 | tree = _add_parent(tree, None) 16 | 17 | return tree 18 | 19 | def _add_parent(tree, parent): 20 | tree.parent = parent 21 | for i in range(0, tree.num_children): 22 | tree.children[i] = _add_parent(tree.children[i], tree) 23 | 24 | return tree 25 | 26 | path = 'CLEVR_128_NEW/trees_no_parent' 27 | outpath = 'CLEVR_128_NEW/trees' 28 | 29 | split = ['train', 'test'] 30 | 31 | for s in split: 32 | treepath = osp.join(path, s) 33 | files = os.listdir(treepath) 34 | try: 35 | os.makedirs(osp.join(outpath, s)) 36 | except: 37 | pass 38 | 39 | for fi in files: 40 | if fi.endswith('tree'): 41 | with open(osp.join(path, s, fi), 'rb') as f: 42 | treei = pickle.load(f) 43 | treei = add_parent(treei) 44 | pickle.dump(treei, open(osp.join(outpath, s, fi), 'wb')) 45 | -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_config(file_path): 5 | with open(file_path, 'r') as f: 6 | config_dict = yaml.load(f) 7 | return config_dict 8 | 9 | 10 | class Struct: 11 | def __init__(self, **entries): 12 | rec_entries = {} 13 | for k, v in entries.items(): 14 | if isinstance(v, dict): 15 | rv = Struct(**v) 16 | elif isinstance(v, list): 17 | rv = [] 18 | for item in v: 19 | if isinstance(item, dict): 20 | rv.append(Struct(**item)) 21 | else: 22 | rv.append(item) 23 | else: 24 | rv = v 25 | rec_entries[k] = rv 26 | self.__dict__.update(rec_entries) 27 | 28 | def __str_helper(self, depth): 29 | lines = [] 30 | for k, v in self.__dict__.items(): 31 | if isinstance(v, Struct): 32 | v_str = v.__str_helper(depth + 1) 33 | lines.append("%s:\n%s" % (k, v_str)) 34 | else: 35 | lines.append("%s: %r" % (k, v)) 36 | indented_lines = [" " * depth + l for l in lines] 37 | return "\n".join(indented_lines) 38 | 39 | def __str__(self): 40 | return "struct {\n%s\n}" % self.__str_helper(1) 41 | 42 | def __repr__(self): 43 | return "Struct(%r)" % self.__dict__ 44 | 45 | 46 | if __name__ == '__main__': 47 | config_dic = load_config('./configs/pnp_net_configs.yaml') 48 | configs = Struct(**config_dic) 49 | import pdb 50 | pdb.set_trace() 51 | -------------------------------------------------------------------------------- /lib/LambdaBiKLD.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class BiKLD(nn.Module): 9 | def __init__(self, lambda_t, k): 10 | super(BiKLD, self).__init__() 11 | print('Using the modified KL Divergence analytic solution, thresholded by lambda:', lambda_t) 12 | self.lambda_t = lambda_t 13 | self.k = k 14 | 15 | def sample(self, mu, logvar, k): 16 | # input: [B, C, H, W], [B, C, H, W] 17 | # output: [B, K, C, H, W] 18 | std = logvar.mul(0.5).exp_() 19 | eps = Variable(std.data.new(std.size()).normal_()) 20 | return eps.mul(std).add_(mu) 21 | 22 | def gaussian_diag_logps(self, mu, logvar, z): 23 | logps = logvar + ((z - mu) ** 2) / (logvar.exp()) 24 | logps = logps.add(np.log(2 * np.pi)) 25 | logps = logps.mul(-0.5) 26 | return logps 27 | 28 | def expand_dis(self, p, k): 29 | mu, logvar = p[0], p[1] 30 | logvar = logvar.unsqueeze(1).expand(logvar.size(0), k, logvar.size(1), logvar.size(2), logvar.size(3)) 31 | mu = mu.unsqueeze(1).expand(mu.size(0), k, mu.size(1), mu.size(2), mu.size(3)) 32 | return mu, logvar 33 | 34 | def forward(self, q, p): 35 | # input: [B, C, H, W], [B, C, H, W] 36 | # output: [1] 37 | # expand 38 | q_mu, q_var = q[0], torch.exp(q[1]) 39 | p_mu, p_var = p[0], torch.exp(p[1]) 40 | 41 | kld = q_var / p_var - 1 42 | kld += (p_mu - q_mu).pow(2) / p_var 43 | kld += p[1] - q[1] 44 | kld = kld.sum(dim=3).sum(dim=2).mean(0) / 2 45 | 46 | lambda_tensor = Variable(self.lambda_t * torch.ones(kld.size())).cuda() 47 | kld = torch.max(kld, lambda_tensor) 48 | kld = kld.sum() * p[0].size(0) 49 | 50 | return kld 51 | 52 | 53 | if __name__ == '__main__': 54 | # test code 55 | q_mu = Variable(torch.zeros(4, 32, 16, 16)).cuda() 56 | q_logvar = Variable(torch.zeros(4, 32, 16, 16)).cuda() 57 | p_mu = Variable(torch.zeros(4, 32, 16, 16)).cuda() 58 | p_logvar = Variable(torch.zeros(4, 32, 16, 16)).cuda() 59 | 60 | kldloss = BiKLD(lambda_t=0.01, k=10) 61 | 62 | kld = kldloss([q_mu, q_logvar], [p_mu, p_logvar]) 63 | -------------------------------------------------------------------------------- /lib/modules/VAE.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.nn.utils import weight_norm 6 | import torch.nn.functional as F 7 | from lib.BiKLD import BiKLD 8 | from lib.reparameterize import reparameterize 9 | 10 | 11 | class VAE(nn.Module): 12 | def __init__(self, indim, latentdim, half=False): 13 | super(VAE, self).__init__() 14 | 15 | self.half = half 16 | if self.half is False: 17 | self.encoder = nn.Sequential( 18 | nn.Linear(indim, latentdim * 2), 19 | nn.ELU(inplace=True), 20 | nn.Linear(latentdim * 2, latentdim * 2), 21 | nn.ELU(inplace=True), 22 | nn.Linear(latentdim * 2, latentdim * 2) 23 | ) 24 | self.mean = nn.Linear(latentdim * 2, latentdim) 25 | self.logvar = nn.Linear(latentdim * 2, latentdim) 26 | self.bikld = BiKLD() 27 | 28 | dec_out = indim 29 | self.decoder = nn.Sequential( 30 | nn.Linear(latentdim, latentdim * 2), 31 | nn.ELU(inplace=True), 32 | nn.Linear(latentdim * 2, latentdim * 2), 33 | nn.ELU(inplace=True), 34 | nn.Linear(latentdim * 2, dec_out) 35 | ) 36 | 37 | self.sampler = reparameterize() 38 | 39 | def forward(self, x=None, prior=None): 40 | prior = [prior[0].view(1, -1), prior[1].view(1, -1)] 41 | 42 | if self.half is False: 43 | encoding = self.encoder(x) 44 | mean, logvar = self.mean(encoding), self.logvar(encoding) 45 | kld = self.bikld([mean, logvar], prior) 46 | z = self.sampler(mean, logvar) 47 | else: 48 | z = self.sampler(prior[0], prior[1]) 49 | kld = 0 50 | 51 | decoding = self.decoder(z) 52 | 53 | return decoding, kld 54 | 55 | def generate(self, prior): 56 | prior = [prior[0].view(1, -1), prior[1].view(1, -1)] 57 | z = self.sampler(*prior) 58 | 59 | decoding = self.decoder(z) 60 | 61 | return decoding 62 | 63 | 64 | ''' 65 | #Test case 0 66 | model = VAE(6, 4).cuda() 67 | mean = Variable(torch.zeros(16, 4)).cuda() 68 | var = Variable(torch.zeros(16, 4)).cuda() 69 | data = Variable(torch.zeros(16, 6)).cuda() 70 | out, kld = model(data, [mean, var]) 71 | 72 | #Test case 1 73 | model = VAE(6, 4, 10).cuda() 74 | mean = Variable(torch.zeros(16, 4)).cuda() 75 | var = Variable(torch.zeros(16, 4)).cuda() 76 | data = Variable(torch.zeros(16, 6)).cuda() 77 | condition = Variable(torch.zeros(16, 10)).cuda() 78 | out, kld = model(data, [mean, var], condition) 79 | ''' 80 | -------------------------------------------------------------------------------- /lib/modules/Combine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils import weight_norm 5 | import torch.nn.functional as F 6 | 7 | class Combine(nn.Module): 8 | def __init__(self, hiddim_v, hiddim_p=None, op='PROD'): 9 | super(Combine, self).__init__() 10 | self.op = op 11 | self.hiddim_v = hiddim_v 12 | self.hiddim_p = hiddim_p 13 | if self.op == 'DEEP': 14 | self.net_vis = nn.Sequential( 15 | weight_norm(nn.Conv2d(4*hiddim_v, hiddim_v, 3, 1, 1)), 16 | nn.ELU(inplace=True), 17 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)), 18 | nn.ELU(inplace=True), 19 | weight_norm(nn.Conv2d(hiddim_v, 2*hiddim_v, 3, 1, 1)) 20 | ) 21 | self.net_pos = nn.Sequential( 22 | weight_norm(nn.Conv2d(4*hiddim_p, hiddim_p, 1, 1)), 23 | nn.ELU(inplace=True), 24 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)), 25 | nn.ELU(inplace=True), 26 | weight_norm(nn.Conv2d(hiddim_p, 2*hiddim_p, 1, 1)) 27 | ) 28 | 29 | elif self.op == 'CAT': 30 | self.net_mean_vis = nn.Sequential( 31 | weight_norm(nn.Conv2d(hiddim_v*2, hiddim_v, 3, 1, 1)), 32 | nn.ELU(inplace=True), 33 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 34 | ) 35 | self.net_var_vis = nn.Sequential( 36 | weight_norm(nn.Conv2d(hiddim_v*2, hiddim_v, 3, 1, 1)), 37 | nn.ELU(inplace=True), 38 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 39 | ) 40 | self.net_mean_pos = nn.Sequential( 41 | weight_norm(nn.Conv2d(hiddim_p*2, hiddim_p, 1, 1)), 42 | nn.ELU(inplace=True), 43 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 44 | ) 45 | self.net_var_pos = nn.Sequential( 46 | weight_norm(nn.Conv2d(hiddim_p*2, hiddim_p, 1, 1)), 47 | nn.ELU(inplace=True), 48 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 49 | ) 50 | elif self.op == 'gPoE': 51 | self.gates_v = nn.Sequential( 52 | weight_norm(nn.Conv2d(hiddim_v*4, hiddim_v*4, 3, 1, 1)), 53 | nn.ELU(inplace=True), 54 | weight_norm(nn.Conv2d(hiddim_v*4, hiddim_v*4, 3, 1, 1)) 55 | ) 56 | self.gates_p = nn.Sequential( 57 | weight_norm(nn.Conv2d(hiddim_p*4, hiddim_p*4, 3, 1, 1)), 58 | nn.ELU(inplace=True), 59 | weight_norm(nn.Conv2d(hiddim_p*4, hiddim_p*4, 3, 1, 1)) 60 | ) 61 | 62 | def forward(self, x1, x2, mode='vis'): 63 | if self.op == 'PROD': 64 | return [x1[0]*x2[0], x1[1]*x2[1]] 65 | elif self.op == 'PoE': 66 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 67 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 68 | mlogvar1 = -x1[1] 69 | mlogvar2 = -x2[1] 70 | mu1 = x1[0] 71 | mu2 = x2[0] 72 | 73 | logvar = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 74 | mu = torch.exp(logvar)*(torch.exp(mlogvar1)*mu1 + torch.exp(mlogvar2)*mu2) 75 | return [mu, logvar] 76 | elif self.op == 'gPoE': 77 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 78 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 79 | 80 | if mode == 'vis': 81 | gates = torch.sigmoid(self.gates_v(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1))) 82 | x1_mu_g = gates[:,:self.hiddim_v,:,:] 83 | x1_var_g = gates[:,self.hiddim_v:2*self.hiddim_v,:,:] 84 | x2_mu_g = gates[:,2*self.hiddim_v:3*self.hiddim_v,:,:] 85 | x2_var_g = gates[:,3*self.hiddim_v:4*self.hiddim_v,:,:] 86 | elif mode == 'pos': 87 | gates = torch.sigmoid(self.gates_p(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1))) 88 | x1_mu_g = gates[:,:self.hiddim_p,:,:] 89 | x1_var_g = gates[:,self.hiddim_p:2*self.hiddim_p,:,:] 90 | x2_mu_g = gates[:,2*self.hiddim_p:3*self.hiddim_p,:,:] 91 | x2_var_g = gates[:,3*self.hiddim_p:4*self.hiddim_p,:,:] 92 | 93 | x1[0] = x1_mu_g*x1[0] 94 | x1[1] = torch.log(x1_var_g + 1e-5) + x1[1] 95 | x2[0] = x2_mu_g*x2[0] 96 | x2[1] = torch.log(x2_var_g + 1e-5) + x2[1] 97 | 98 | mlogvar1 = -x1[1] 99 | mlogvar2 = -x2[1] 100 | mu1 = x1[0] 101 | mu2 = x2[0] 102 | 103 | logvar = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 104 | mu = torch.exp(logvar)*(torch.exp(mlogvar1)*mu1 + torch.exp(mlogvar2)*mu2) 105 | return [mu, logvar] 106 | elif self.op == 'ADD': 107 | return [x1[0] + x2[0], x1[1], x2[1]] 108 | elif self.op == 'CAT': 109 | if mode == 'vis': 110 | return [self.net_mean_vis(torch.cat([x1[0], x2[0]], dim=1)), \ 111 | self.net_var_vis(torch.cat([x1[1], x2[1]], dim=1))] 112 | elif mode == 'pos': 113 | return [self.net_mean_pos(torch.cat([x1[0], x2[0]], dim=1)), \ 114 | self.net_var_pos(torch.cat([x1[1], x2[1]], dim=1))] 115 | elif self.op == 'DEEP': 116 | if mode == 'vis': 117 | gaussian_out = self.net_vis(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1)) 118 | return [gaussian_out[:,:self.hiddim_v,:,:], gaussian_out[:,self.hiddim_v:,:,:]] 119 | elif mode == 'pos': 120 | gaussian_out = self.net_pos(torch.cat([x1[0], x1[1], x2[0], x2[1]], dim=1)) 121 | return [gaussian_out[:,:self.hiddim_p,:,:], gaussian_out[:,self.hiddim_p:,:,:]] 122 | else: 123 | print('Operator:', self.op) 124 | raise ValueError('Unknown operator for combine module.') 125 | -------------------------------------------------------------------------------- /lib/ResidualModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from torch.nn.utils import weight_norm 6 | 7 | 8 | class ResidualModule(nn.Module): 9 | def __init__(self, modeltype, indim, hiddim, outdim, nlayers, nres, ifgate=False, nonlinear='elu', normalize='instance_norm'): 10 | super(ResidualModule, self).__init__() 11 | if ifgate: 12 | print('Using gated version.') 13 | if modeltype == 'encoder': 14 | self.model = self.encoder(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize) 15 | elif modeltype == 'decoder': 16 | self.model = self.decoder(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize) 17 | elif modeltype == 'plain': 18 | self.model = self.plain(indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize=normalize) 19 | else: 20 | raise ('Uknown model type.') 21 | 22 | def encoder(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize): 23 | layers = [] 24 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize)) 25 | 26 | for i in range(0, nres): 27 | for j in range(0, nlayers): 28 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize)) 29 | layers.append(ResidualBlock('down', nonlinear, ifgate, hiddim, hiddim, normalize)) 30 | 31 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize)) 32 | 33 | return nn.Sequential(*layers) 34 | 35 | def decoder(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize): 36 | layers = [] 37 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize)) 38 | 39 | for i in range(0, nres): 40 | for j in range(0, nlayers): 41 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize)) 42 | layers.append(ResidualBlock('up', nonlinear, ifgate, hiddim, hiddim, normalize)) 43 | 44 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize)) 45 | 46 | return nn.Sequential(*layers) 47 | 48 | def plain(self, indim, hiddim, outdim, nlayers, nres, ifgate, nonlinear, normalize): 49 | layers = [] 50 | layers.append(ResidualBlock(None, nonlinear, ifgate, indim, hiddim, normalize)) 51 | 52 | for i in range(0, nres): 53 | for j in range(0, nlayers): 54 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, hiddim, normalize)) 55 | 56 | layers.append(ResidualBlock(None, nonlinear, ifgate, hiddim, outdim, normalize)) 57 | 58 | return nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | return self.model(x) 62 | 63 | 64 | class ResidualBlock(nn.Module): 65 | def __init__(self, resample, nonlinear, ifgate, indim, outdim, normalize): 66 | super(ResidualBlock, self).__init__() 67 | 68 | self.ifgate = ifgate 69 | self.indim = indim 70 | self.outdim = outdim 71 | self.resample = resample 72 | 73 | if resample == 'down': 74 | convtype = 'sconv_d' 75 | elif resample == 'up': 76 | convtype = 'upconv' 77 | elif resample == None: 78 | convtype = 'sconv' 79 | 80 | self.shortflag = False 81 | if not (indim == outdim and resample == None): 82 | self.shortcut = self.conv(convtype, indim, outdim) 83 | self.shortflag = True 84 | 85 | if ifgate: 86 | self.conv1 = nn.Conv2d(indim, outdim, 3, 1, 1) 87 | self.conv2 = nn.Conv2d(indim, outdim, 3, 1, 1) 88 | self.c = nn.Sigmoid() 89 | self.g = nn.Tanh() 90 | self.conv3 = self.conv(convtype, outdim, outdim) 91 | self.act = self.nonlinear(nonlinear) 92 | elif normalize == 'batch_norm': 93 | self.resblock = nn.Sequential( 94 | self.conv('sconv', indim, outdim), 95 | nn.BatchNorm2d(outdim), 96 | self.nonlinear(nonlinear), 97 | self.conv(convtype, outdim, outdim), 98 | nn.BatchNorm2d(outdim), 99 | self.nonlinear(nonlinear) 100 | ) 101 | elif normalize == 'instance_norm': 102 | self.resblock = nn.Sequential( 103 | self.conv('sconv', indim, outdim), 104 | nn.InstanceNorm2d(outdim), 105 | self.nonlinear(nonlinear), 106 | self.conv(convtype, outdim, outdim), 107 | nn.InstanceNorm2d(outdim), 108 | self.nonlinear(nonlinear) 109 | ) 110 | elif normalize == 'no_norm': 111 | self.resblock = nn.Sequential( 112 | self.conv('sconv', indim, outdim), 113 | self.nonlinear(nonlinear), 114 | self.conv(convtype, outdim, outdim), 115 | self.nonlinear(nonlinear) 116 | ) 117 | elif normalize == 'weight_norm': 118 | self.resblock = nn.Sequential( 119 | self.conv('sconv', indim, outdim, 'weight_norm'), 120 | self.nonlinear(nonlinear), 121 | self.conv(convtype, outdim, outdim, 'weight_norm'), 122 | self.nonlinear(nonlinear) 123 | ) 124 | 125 | def conv(self, name, indim, outdim, normalize=None): 126 | if name == 'sconv_d': 127 | if normalize == 'weight_norm': 128 | return weight_norm(nn.Conv2d(indim, outdim, 4, 2, 1)) 129 | else: 130 | return nn.Conv2d(indim, outdim, 4, 2, 1) 131 | elif name == 'sconv': 132 | if normalize == 'weight_norm': 133 | return weight_norm(nn.Conv2d(indim, outdim, 3, 1, 1)) 134 | else: 135 | return nn.Conv2d(indim, outdim, 3, 1, 1) 136 | elif name == 'upconv': 137 | if normalize == 'weight_norm': 138 | return weight_norm(nn.ConvTranspose2d(indim, outdim, 4, 2, 1)) 139 | else: 140 | return nn.ConvTranspose2d(indim, outdim, 4, 2, 1) 141 | else: 142 | raise ("Unknown convolution type") 143 | 144 | def nonlinear(self, name): 145 | if name == 'elu': 146 | return nn.ELU(1, True) 147 | elif name == 'relu': 148 | return nn.ReLU(True) 149 | 150 | def forward(self, x): 151 | if self.ifgate: 152 | conv1 = self.conv1(x) 153 | conv2 = self.conv2(x) 154 | c = self.c(conv1) 155 | g = self.g(conv2) 156 | gated = c * g 157 | conv3 = self.conv3(gated) 158 | res = self.act(conv3) 159 | if not (self.indim == self.outdim and self.resample == None): 160 | out = self.shortcut(x) + res 161 | else: 162 | out = x + res 163 | else: 164 | if self.shortflag: 165 | out = self.shortcut(x) + self.resblock(x) 166 | else: 167 | out = x + self.resblock(x) 168 | 169 | return out 170 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Neural Programmed Network 2 | 3 | This is the official implementation for paper: 4 | 5 | [Probabilistic Neural Programmed Networks for Scene Generation](http://www2.cs.sfu.ca/~mori/research/papers/deng-nips18.pdf) 6 | 7 | [Zhiwei Deng](http://www.sfu.ca/~zhiweid/), [Jiacheng Chen](http://jcchen.me/), [Yifang Fu](https://yifangfu.wordpress.com/) and [Greg Mori](http://www2.cs.sfu.ca/~mori/) 8 | 9 | 10 | Published on NeurIPS 2018 11 | 12 | [Poster](http://www.sfu.ca/~zhiweid/papers/PNP_Net_Poster.pdf) 13 | 14 | If you find this code helpful in your research, please cite 15 | 16 | ``` 17 | @inproceedings{deng2018probabilistic, 18 | title={Probabilistic Neural Programmed Networks for Scene Generation}, 19 | author={Deng, Zhiwei and Chen, Jiacheng and Fu, Yifang and Mori, Greg}, 20 | booktitle={Advances in Neural Information Processing Systems}, 21 | pages={4032--4042}, 22 | year={2018} 23 | } 24 | ``` 25 | 26 | Other re-implementations: 27 | 28 | [Tensorflow version](https://github.com/mihirp1998/ProbabilisticNeuralProgrammedNetwork_Tensorflow) 29 | 30 | ## Contents 31 | 1. [Overview](#overview) 32 | 2. [Environment Setup](#environment) 33 | 3. [Data and Pre-trained Models](#data-and-models) 34 | - [CLEVR-G](#CLEVR-G) 35 | 4. [Configurations](#configurations) 36 | 5. [Code Guide](#code-guide) 37 | - [Neural Operators](#neural-operators) 38 | 6. [Training Model](#training) 39 | 7. [Evaluation](#evaluation) 40 | 8. [Results](#results) 41 | 42 | 43 | ## Overview 44 | 45 | Generating scenes from rich and complex semantics is an important step towards understanding the visual world. Probabilistic Neural Programmed Network (PNP-Net) brings symbolic methods into generative models, it exploits a set of **reusable neural modules** to compose latent distributions for scenes described by complex semantics in a **programmatic** manner, a decoder can then sample from latent scene distributions and generate realistic images. PNP-Net is naturally formulated as a learnable prior in canonical VAE framework to learn the parameters efficiently. 46 | 47 | 48 | 49 |
50 | 51 |
52 | 53 | 54 | 55 | ## Environment 56 | 57 | All code was tested on Ubuntu 16.04 with Python 2.7 and **PyTorch 0.4.0** (but the code should also work well with Python 3). To install required environment, run: 58 | 59 | ```bash 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | For running our measurement (a semantic correctness score based on detector), check [this submodule](https://github.com/woodfrog/SemanticCorrectnessScore/tree/483c6ef2e0548fcc629059b84c489cd4e0c19f86) for full details (**it's released now**). 64 | 65 | ## Data and Models 66 | 67 | ### CLEVR-G 68 | 69 | We used the released code of [CLEVR (Johnson et al.)](https://arxiv.org/pdf/1612.06890.pdf) to generate a modified CLEVR dataset for the task of scene image generation, and we call it CLEVR-G. The generation code is in the [submodule](https://github.com/woodfrog/clevr-dataset-gen/tree/42a5c4914bbae49a0cd36cf96607c05111394ddc). 70 | 71 | We also provide the [64x64 CLEVR-G](https://drive.google.com/open?id=1QCIINcIOdcIl5U0IZrHpj5XZW2FZBbJv) used in our experiments. Please download and zip it into **./data/CLEVR** if you want to use it with our model. 72 | 73 | **Pre-trained Model** 74 | 75 | Please download pre-trained models from: 76 | 77 | - [PNP-Net CLEVR-G 64x64](https://drive.google.com/open?id=1VusqEqIHZibRqKbXyIxJDxBRTp9AZP0y) 78 | 79 | 80 | ### COLOR-MNIST 81 | 82 | - [PNP-Net COLOR-MNIST 64x64](https://www.dropbox.com/s/jf5u7yosyisf8zd/MULTI_8000_64_100.tar?dl=0) 83 | 84 | 85 | 86 | ## Configurations 87 | 88 | We use [global configuration files](configs/pnp_net_configs.yaml) to set up all configs, including the training settings and model hyper-parameters. Please check the file and corresponding code for more detail. 89 | 90 | 91 | ## Code Guide 92 | 93 | ### Neural Operators 94 | 95 | The core of PNP-Net is a set of **neural modular operators**. We briefly introduce them here and provide the pointers to corresponding code. 96 | 97 | 98 | - **Concept Mapping Operator** 99 | 100 | 101 |
102 | 103 |
104 | 105 | 106 | Convert one-hot representation of word concepts into appearance and scale distribution. 107 | [code](lib/modules/ConceptMapper.py) 108 | 109 | - **Combine Operator** 110 | 111 |
112 | 113 |
114 | 115 | 116 | Combine module combines the latent distributions of two attributes. [code](lib/modules/Combine.py) 117 | 118 | 119 | - **Describe Operator** 120 | 121 |
122 | 123 |
124 | 125 | Attributes describe an object, this module takes the distributions of attributes (merged using combine module) and uses it to render the distributions of an object. [code](lib/modules/Describe.py) 126 | 127 | 128 | - **Transform Operator** 129 | 130 |
131 | 132 |
133 | 134 | This module first samples a size instance from an object's scale distribution and then use bilinear interpolation to re-size the appearance distribution. [code](lib/modules/Transform.py) 135 | 136 | 137 | - **Layout Operator** 138 | 139 |
140 | 141 |
142 | 143 | 144 | Layout module puts latent distributions of two different objects (from its children nodes) on a background latent canvas according to the offsets of the two children objects. [code](models/PNPNet/pnp_net.py#L267) 145 | 146 | 147 | 148 | ## Training 149 | 150 | The default training can be started by: 151 | 152 | ```bash 153 | python mains/pnpnet_main.py --config_path configs/pnp_net_configs.yaml 154 | ``` 155 | 156 | Make sure that you are in the project root directory when typing the above command. 157 | 158 | 159 | 160 | ## Evaluation 161 | 162 | The evaluation has two major steps: 163 | 164 | 1. Generate images according to the semantics in the test set using pre-trained model. 165 | 166 | 2. Run our [detector-based semantic correctness score](https://github.com/woodfrog/SemanticCorrectnessScore) to evaluate the quality of images. Please check that repo for more details about our proposed metric for measuring semantic correctness of scene images. 167 | 168 | 169 | For generating test images using pre-trained model, first set the code mode to be **test**, then set up the checkpoint path properly in the config file, finally run the same command as training: 170 | 171 | ```bash 172 | python mains/pnpnet_main.py --config_path configs/pnp_net_configs.yaml 173 | ``` 174 | 175 | 176 | ## Results 177 | 178 | Detailed results can be checked in our paper. We provide some samples here to show PNP-Net's capability for generating images for different complex scenes. The dataset used here is CLEVR-G 128x128, every scene contains at most 8 objects. 179 | 180 | 181 |
182 | 183 |
184 | 185 | When the scene becomes too complex, PNP-Net can suffer from the following problems: 186 | 187 | 1. It might fail to handle occlusion between objects. When multiple objects overlap, their latents get mixed on the background latent canvas, and the appearance of objects can be distorted. 188 | 189 | 2. It might put some of the objects out of the image boundary, therefore some images do not contain the correct number of objects as described by the semantics. 190 | -------------------------------------------------------------------------------- /trainers/pnpnet_trainer.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | import datetime 6 | import pytz 7 | import pdb 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | from lib.utils import color_grid_vis, AverageMeter 13 | 14 | 15 | class PNPNetTrainer: 16 | def __init__(self, model, train_loader, val_loader, gen_loader, optimizer, configs): 17 | self.model = model 18 | self.train_loader = train_loader 19 | self.val_loader = val_loader 20 | self.gen_loader = gen_loader 21 | self.optimizer = optimizer 22 | self.configs = configs 23 | 24 | def train_epoch(self, epoch_num, timestamp_start): 25 | self.model.train() 26 | train_rec_loss = AverageMeter() 27 | train_kld_loss = AverageMeter() 28 | train_pos_loss = AverageMeter() 29 | batch_idx = 0 30 | epoch_end = False 31 | # annealing for kl penalty 32 | kl_coeff = float(epoch_num) / float(self.configs.warmup_iter + 1) 33 | if kl_coeff >= self.configs.alpha_ub: 34 | kl_coeff = self.configs.alpha_ub 35 | print('kl penalty coefficient: ', kl_coeff, 'alpha upperbound:', self.configs.alpha_ub) 36 | while epoch_end is False: 37 | data, trees, _, epoch_end = self.train_loader.next_batch() 38 | data = Variable(data).cuda() 39 | 40 | self.optimizer.zero_grad() 41 | ifmask = False 42 | if self.configs.maskweight > 0: 43 | ifmask = True 44 | rec_loss, kld_loss, pos_loss, modelout = self.model(data, trees, alpha=kl_coeff, ifmask=ifmask, 45 | maskweight=self.configs.maskweight) 46 | recon = modelout 47 | rec_loss, kld_loss, pos_loss = rec_loss.sum() / self._total(data), kld_loss.sum() / self._total(data), pos_loss.sum() / self._total(data) 48 | loss = rec_loss + self.configs.kl_beta * kld_loss + self.configs.pos_beta * pos_loss 49 | loss.backward() 50 | self.optimizer.step() 51 | train_rec_loss.update(rec_loss.item(), self._total(data), data.size(0)) 52 | train_kld_loss.update(kld_loss.item(), self._total(data), data.size(0)) 53 | train_pos_loss.update(pos_loss.item(), self._total(data), data.size(0)) 54 | 55 | if batch_idx % 30 == 0: 56 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_data.png'), 57 | (data.cpu().data.numpy().transpose(0, 2, 3, 1)[0] + 1) / 2.0) 58 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_reconstruction.png'), \ 59 | (recon.cpu().data.numpy().transpose(0, 2, 3, 1)[0] + 1) / 2.0) 60 | scipy.misc.imsave(osp.join(self.configs.exp_dir, 'samples', 'generativenmn_reconstruction_clip.png'), \ 61 | np.clip(recon.cpu().data.numpy().transpose(0, 2, 3, 1)[0], -1, 1)) 62 | print('Epoch:{0}\tIter:{1}/{2}\tRecon {3:.6f}\t KL {4:.6f}'.format(epoch_num, batch_idx, 63 | len(self.train_loader) // self.configs.batch_size, 64 | train_rec_loss.batch_avg, train_kld_loss.batch_avg)) 65 | 66 | self.model.clean_tree(trees) 67 | batch_idx += 1 68 | 69 | elapsed_time = \ 70 | datetime.datetime.now(pytz.timezone('America/New_York')) - \ 71 | timestamp_start 72 | 73 | print('====> Epoch: {} Average rec loss: {:.6f} Average kld loss: {:.6f} Average pos loss: {:.6f}'.format( 74 | epoch_num, train_rec_loss.batch_avg, train_kld_loss.batch_avg, train_pos_loss.batch_avg)) 75 | print('Elapsed time:', elapsed_time) 76 | 77 | @staticmethod 78 | def _total(tensor): 79 | return tensor.size(0) * tensor.size(1) * tensor.size(2) * tensor.size(3) 80 | 81 | def validate(self, epoch_num, timestamp_start, minloss): 82 | self.model.eval() 83 | test_rec_loss = AverageMeter() 84 | test_kld_loss = AverageMeter() 85 | test_pos_loss = AverageMeter() 86 | epoch_end = False 87 | count = 0. 88 | while epoch_end is False: 89 | data, trees, _, epoch_end = self.val_loader.next_batch() 90 | data = Variable(data, volatile=True).cuda() 91 | rec_loss, kld_loss, pos_loss, modelout = self.model(data, trees) 92 | 93 | rec_loss, kld_loss, pos_loss = rec_loss.sum(), kld_loss.sum(), pos_loss.sum() 94 | loss = rec_loss + kld_loss + pos_loss 95 | 96 | test_rec_loss.update(rec_loss.item(), self._total(data), data.size(0)) 97 | test_kld_loss.update(kld_loss.item(), self._total(data), data.size(0)) 98 | test_pos_loss.update(pos_loss.item(), self._total(data), data.size(0)) 99 | 100 | self.model.clean_tree(trees) 101 | 102 | elapsed_time = \ 103 | datetime.datetime.now(pytz.timezone('America/New_York')) - \ 104 | timestamp_start 105 | 106 | torch.save(self.model.state_dict(), 107 | osp.join(self.configs.exp_dir, 'checkpoints', 'model_epoch_{0}.pth'.format(epoch_num))) 108 | 109 | print('====> Epoch: {} Test rec loss: {:.6f} Test kld loss: {:.6f} Test pos loss: {:.6f}'.format( 110 | epoch_num, test_rec_loss.batch_avg, test_kld_loss.batch_avg, test_pos_loss.batch_avg)) 111 | print('Elapsed time:', elapsed_time) 112 | 113 | return minloss 114 | 115 | def sample(self, epoch_num, sample_num, timestamp_start): 116 | self.model.eval() 117 | 118 | data, trees, _, _ = self.gen_loader.next_batch() 119 | data = Variable(data, volatile=True).cuda() 120 | epoch_result_dir = osp.join(self.configs.exp_dir, 'samples', 'epoch-{}'.format(epoch_num)) 121 | 122 | try: 123 | os.makedirs(epoch_result_dir) 124 | except: 125 | pass 126 | 127 | samples_image_dict = dict() 128 | data_image_dict = dict() 129 | batch_size = None 130 | for j in range(sample_num): 131 | sample = self.model.generate(data, trees) 132 | if not batch_size: 133 | batch_size = sample.size(0) 134 | for i in range(0, sample.size(0)): 135 | samples_image_dict.setdefault(i, list()).append(sample.cpu().data.numpy().transpose(0, 2, 3, 1)[i]) 136 | if j == sample_num - 1: 137 | data_image_dict[i] = data.cpu().data.numpy().transpose(0, 2, 3, 1)[i] 138 | self.model.clean_tree(trees) 139 | print(j) 140 | 141 | for i in range(batch_size): 142 | samples = np.clip(np.stack(samples_image_dict[i], axis=0), -1, 1) 143 | data = data_image_dict[i] 144 | color_grid_vis(samples, nh=2, nw=sample_num // 2, 145 | save_path=osp.join(epoch_result_dir, 'generativenmn_{}_sample.png'.format(i))) 146 | scipy.misc.imsave(osp.join(epoch_result_dir, 'generativenmn_{}_real.png'.format(i)), data) 147 | torch.save(trees[i], osp.join(epoch_result_dir, 'generativenmn_tree_' + str(i) + '.pth')) 148 | print('====> Epoch: {} Generating image number: {:d}'.format(epoch_num, i)) 149 | 150 | elapsed_time = \ 151 | datetime.datetime.now(pytz.timezone('America/New_York')) - \ 152 | timestamp_start 153 | print('Elapsed time:', elapsed_time) 154 | -------------------------------------------------------------------------------- /lib/data_loader/clevr/treeutils.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import os 3 | import sys 4 | import argparse 5 | import os.path as osp 6 | import random 7 | from lib.tree import Tree 8 | 9 | # parser = argparse.ArgumentParser(description='generate trees for CLEVR dataset') 10 | # parser.add_argument('--output_dir', type=str, default='', 11 | # help='output path for trees') 12 | # parser.add_argument('--train_sample', type=int, default=0, 13 | # help='number of samples for training') 14 | # parser.add_argument('--test_sample', type=int, default=0, 15 | # help='number of samples for testing') 16 | # args = parser.parse_args() 17 | 18 | 19 | ######### hyperparameters ########## 20 | # max level of the tree 21 | max_level = 2 22 | 23 | # module list 24 | module_list = ['layout', 'describe', 'combine'] 25 | 26 | # children dict 27 | children_dict = dict() 28 | children_dict['layout'] = 2 29 | children_dict['describe'] = 1 30 | children_dict['combine'] = 1 31 | 32 | # we will have two split dict for modules for designing a zero-shot setting 33 | 34 | module_dict_split1 = dict() 35 | module_dict_split2 = dict() 36 | 37 | # objects list 38 | module_dict_split1['describe'] = ['cylinder', 'sphere'] 39 | module_dict_split2['describe'] = ['cube', 'sphere'] 40 | 41 | # attributes list 42 | attribute_list = ['material', 'color', 'size'] 43 | 44 | module_dict_split1['combine'] = {'material': ['metal'], 45 | 'color': ['green', 'blue', 'yellow', 'red'], 46 | 'size': ['large']} 47 | 48 | module_dict_split2['combine'] = {'material': ['rubber'], 49 | 'color': ['cyan', 'brown', 'gray', 'purple'], 50 | 'size': ['small']} 51 | 52 | # relations list 53 | module_dict_split1['layout'] = ['front', 'left', 'left-front', 'right-front'] 54 | module_dict_split2['layout'] = ['right', 'behind', 'left-behind', 'right-behind'] 55 | 56 | module_dicts = [module_dict_split1, module_dict_split2] 57 | 58 | pattern_map = {'describe': 0, 'material': 1, 'color': 2, 'size': 3, 'layout': 4} 59 | 60 | training_patterns = [(0, 1, 0, 1, 0), (1, 0, 1, 0, 1)] 61 | test_patterns = [(1, 1, 1, 1, 1), (0, 0, 0, 0, 0), (0, 0, 1, 1, 1), (1, 1, 0, 0, 0), (0, 1, 1, 1, 0), (1, 0, 0, 0, 1), 62 | (0, 1, 1, 1, 1), (1, 0, 0, 0, 0)] 63 | 64 | 65 | # degree range: curently randomize this number, \ 66 | # no need for input from the tree 67 | # deg_range = [0, 360] 68 | 69 | # def get_flag(level, maxlevel): 70 | # if level + 1 >= max_level: 71 | # flag = 0 72 | # else: 73 | # # flag = random.randint(0, 1) 74 | # flag = 1 75 | # 76 | # return flag 77 | 78 | 79 | def expand_tree(tree, level, parent, memorylist, child_idx, max_level, metadata_pattern): 80 | if parent is None or parent.function == 'layout': 81 | if level + 1 >= max_level: 82 | valid = [1] 83 | else: 84 | valid = [0, 1] 85 | 86 | # sample module, the module can be either layout or describe here 87 | module_id = random.randint(0, len(valid) - 1) 88 | tree.function = module_list[valid[module_id]] 89 | 90 | # sample content 91 | dict_index = metadata_pattern[pattern_map[tree.function]] 92 | module_dict = module_dicts[dict_index] 93 | 94 | word_id = random.randint(0, len(module_dict[tree.function]) - 1) 95 | tree.word = module_dict[tree.function][word_id] 96 | 97 | if tree.function == 'layout': 98 | tree.function_obj = Layout(tree.word) 99 | print('add layout') 100 | else: 101 | tree.function_obj = Describe(tree.word) 102 | print('add describe') 103 | 104 | # num children 105 | if level + 1 > max_level: 106 | tree.num_children = 0 107 | else: 108 | tree.num_children = children_dict[tree.function] 109 | if parent is not None: # then the parent must be a layout node 110 | if child_idx == 0: 111 | parent.function_obj.left_child = tree.function_obj 112 | else: 113 | parent.function_obj.right_child = tree.function_obj 114 | 115 | for i in range(tree.num_children): 116 | tree.children.append(Tree()) 117 | tree.children[i] = expand_tree(tree.children[i], level + 1, tree, [], i, max_level, metadata_pattern) 118 | 119 | # must contain only one child node, which is a combine node 120 | elif parent.function == 'describe' or parent.function == 'combine': 121 | print('add combine') 122 | valid = [2] 123 | # no need to sample module for now 124 | module_id = 0 125 | tree.function = module_list[valid[module_id]] 126 | 127 | # sample content 128 | # sample which attributes 129 | if len(set(attribute_list) - set(memorylist)) <= 1: 130 | full_attribute = True 131 | else: 132 | full_attribute = False 133 | 134 | attribute = random.sample(set(attribute_list) - set(memorylist), 1)[0] 135 | memorylist += [attribute] 136 | 137 | dict_idx = metadata_pattern[pattern_map[attribute]] 138 | module_dict = module_dicts[dict_idx] 139 | word_id = random.randint(0, len(module_dict[tree.function][attribute]) - 1) 140 | tree.word = module_dict[tree.function][attribute][word_id] 141 | 142 | if isinstance(parent.function_obj, Describe): 143 | carrier = parent.function_obj 144 | else: 145 | carrier = parent.function_obj.get_carrier() 146 | 147 | tree.function_obj = Combine(attribute, tree.word) 148 | tree.function_obj.set_carrier(carrier) 149 | carrier.set_attribute(attribute, tree.function_obj) 150 | 151 | if not full_attribute: 152 | tree.num_children = children_dict[tree.function] 153 | 154 | for i in range(tree.num_children): 155 | tree.children.append(Tree()) 156 | tree.children[i] = expand_tree(tree.children[i], level + 1, tree, memorylist, i, max_level, metadata_pattern) 157 | 158 | else: 159 | raise ValueError('Wrong function.') 160 | return tree 161 | 162 | 163 | def visualize_tree(trees): 164 | for i in range(len(trees)): 165 | print('************** tree **************') 166 | _visualize_tree(trees[i], 0) 167 | print('**********************************') 168 | 169 | 170 | def _visualize_tree(tree, level): 171 | if tree == None: 172 | return 173 | for i in range(tree.num_children - 1, (tree.num_children - 1) // 2, -1): 174 | _visualize_tree(tree.children[i], level + 1) 175 | 176 | print(' ' * level + tree.word) 177 | 178 | # if isinstance(tree.function_obj, Describe): 179 | # print(tree.function_obj.attributes, tree.function_obj) 180 | # if tree.function != 'combine': 181 | # print('position {}'.format(tree.function_obj.position)) 182 | 183 | if hasattr(tree, 'bbox'): 184 | print('Bouding box of {} is {}'.format(tree.word, tree.bbox)) 185 | 186 | for i in range((tree.num_children - 1) // 2, -1, -1): 187 | _visualize_tree(tree.children[i], level + 1) 188 | 189 | return 190 | 191 | 192 | def allign_tree(tree, level): 193 | """ 194 | A pre-order traversal 195 | :param tree: 196 | :return: 197 | """ 198 | if tree is None: 199 | return 200 | 201 | if tree.function == 'describe' and level == 0: 202 | tree.function_obj.set_random_pos() 203 | elif tree.function == 'layout': 204 | tree.function_obj.set_children_pos() 205 | for i in range(tree.num_children): 206 | allign_tree(tree.children[i], level + 1) 207 | else: 208 | pass 209 | 210 | 211 | def extract_objects(tree): 212 | objects = list() 213 | 214 | if tree is None: 215 | return objects 216 | 217 | if tree.function == 'describe': 218 | objects.append(tree.function_obj) 219 | elif tree.function == 'layout': 220 | for i in range(tree.num_children): 221 | objects += extract_objects(tree.children[i]) 222 | else: 223 | pass 224 | 225 | return objects 226 | 227 | 228 | def sample_tree(max_level, train=True): 229 | tree = Tree() 230 | if train: 231 | pattern = random.sample(training_patterns, 1)[0] # sample a pattern for training data 232 | else: 233 | pattern = random.sample(test_patterns, 1)[0] # sample a pattern for test data 234 | tree = expand_tree(tree, 0, None, [], 0, max_level, pattern) 235 | allign_tree(tree, 0) 236 | return tree 237 | 238 | 239 | if __name__ == '__main__': 240 | # random.seed(12113) 241 | # 242 | # # tree = Tree() 243 | # # tree = expand_tree(tree, 0, None, [], 0) 244 | # # allign_tree(tree) 245 | # 246 | # num_sample = 1 247 | # trees = [] 248 | # for i in range(num_sample): 249 | # treei = Tree() 250 | # treei = expand_tree(treei, 0, None, [], 0, max_level=2) 251 | # allign_tree(treei, 0) 252 | # objects = extract_objects(treei) 253 | # trees += [treei] 254 | # print(objects) 255 | # 256 | # visualize_tree(trees) 257 | 258 | tree = sample_tree(max_level=3) 259 | -------------------------------------------------------------------------------- /lib/data_loader/clevr/clevr_tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import _init_paths 4 | import os 5 | import os.path as osp 6 | import pickle 7 | 8 | import numpy as np 9 | import PIL.Image 10 | import torch 11 | import random 12 | 13 | 14 | class CLEVRTREE(): 15 | # set the thresholds for the size of objects here 16 | SMALL_THRESHOLD = 16 17 | # remove the medium size for simplicity and clarity 18 | 19 | NEW_SIZE_WORDS = ['small', 'large'] 20 | OLD_SIZE_WORDS = ['small', 'large'] 21 | 22 | def __init__(self, batch_size=16, 23 | base_dir='/cs/vml4/zhiweid/ECCV18/GenerativeNeuralModuleNetwork/data/CLEVR/CLEVR_128', 24 | random_seed=12138, 25 | phase='train', shuffle=True, file_format='png'): 26 | self.phase = phase 27 | self.base_dir = base_dir 28 | self.batch_size = batch_size 29 | self.fileformat = file_format 30 | 31 | if phase not in ['train', 'test']: 32 | raise ValueError('invalid phase name {}, should be train or test'.format(phase)) 33 | 34 | self.image_dir = osp.join(base_dir, 'images', phase) 35 | self.tree_dir = osp.join(base_dir, 'trees', phase) 36 | 37 | # get the file names for images, and corresponding trees, and the dictionary of the dataset 38 | self.image_files, self.tree_files = self.prepare_file_list() 39 | self.dictionary_path = osp.join(self.base_dir, 'dictionary_tree.pickle') 40 | self.dictionary = self.load_dictionary(self.tree_files, self.dictionary_path) 41 | 42 | # update the size words, since we need to adjust the size for all objects 43 | # according to the actual 2-D bounding box 44 | for word in self.OLD_SIZE_WORDS: 45 | self.dictionary.remove(word) 46 | for word in self.NEW_SIZE_WORDS: 47 | self.dictionary.append(word) 48 | 49 | # iterator part 50 | self.random_generator = random.Random() 51 | self.random_generator.seed(random_seed) 52 | 53 | self.files = dict() 54 | self.files[phase] = { 55 | 'img': self.image_files, 56 | 'tree': self.tree_files, 57 | } 58 | 59 | self.index_ptr = 0 60 | self.index_list = list(range(len(self.image_files))) 61 | 62 | self.shuffle = shuffle 63 | if shuffle: 64 | self.random_generator.shuffle(self.index_list) 65 | 66 | self.im_size = self.read_first() 67 | self.ttdim = len(self.dictionary) + 1 68 | 69 | def __len__(self): 70 | return len(self.files[self.phase]['img']) 71 | 72 | def read_first(self): 73 | images, _, _, _ = self.next_batch() 74 | self.index_ptr = 0 75 | 76 | return images.size() 77 | 78 | def prepare_file_list(self): 79 | """ 80 | Get the filename list for images and correspondign trees 81 | :return: 82 | """ 83 | image_list = [] 84 | tree_list = [] 85 | for image_filename in sorted(os.listdir(self.image_dir)): 86 | image_path = os.path.join(self.image_dir, image_filename) 87 | image_list.append(image_path) 88 | filename, _ = os.path.splitext(image_filename) 89 | tree_filename = filename + '.tree' 90 | tree_path = os.path.join(self.tree_dir, tree_filename) 91 | tree_list.append(tree_path) 92 | return image_list, tree_list 93 | 94 | def load_dictionary(self, tree_files, dictionary_path): 95 | if osp.isfile(dictionary_path): # the dictionary has been created, then just load return 96 | with open(dictionary_path, 'rb') as f: 97 | dictionary = pickle.load(f) 98 | else: 99 | dictionary_set = set() 100 | for idx, tree_file_path in enumerate(tree_files): 101 | with open(tree_file_path, 'rb') as f: 102 | tree = pickle.load(f) 103 | tree_words = self.get_tree_words(tree) 104 | dictionary_set.update(set(tree_words)) 105 | 106 | dictionary = list(dictionary_set) 107 | 108 | # update the size words, since we need to adjust the size for all objects 109 | # according to the actual 2-D bounding box 110 | for word in self.OLD_SIZE_WORDS: 111 | dictionary.remove(word) 112 | for word in self.NEW_SIZE_WORDS: 113 | dictionary.append(word) 114 | with open(dictionary_path, 'wb') as f: 115 | pickle.dump(dictionary, f) 116 | 117 | return dictionary 118 | 119 | def get_tree_words(self, tree): 120 | words = [tree.word] 121 | for child in tree.children: 122 | words += self.get_tree_words(child) 123 | return words 124 | 125 | def next_batch(self): 126 | data_file = self.files[self.phase] 127 | 128 | images, trees, categories = [], [], [] 129 | for i in range(0, min(self.batch_size, len(self) - self.index_ptr)): 130 | index = self.index_list[self.index_ptr] 131 | 132 | # load image 133 | if self.fileformat == 'png': 134 | img_file = data_file['img'][index] 135 | img = PIL.Image.open(img_file) 136 | else: 137 | raise ValueError('wrong file format for images') 138 | 139 | # remove the alpha-channel 140 | img = np.array(img, dtype=np.float32)[:, :, :-1] 141 | img = (img - 127.5) / 127.5 142 | images.append(img) 143 | 144 | # load tree 145 | with open(data_file['tree'][index], 'rb') as f: 146 | tree = pickle.load(f) 147 | tree = self.adapt_tree(tree) 148 | trees.append(tree) 149 | categories.append(self.get_categorical_list(tree)) 150 | 151 | self.index_ptr += 1 152 | 153 | images = np.array(images, dtype=np.float32).transpose(0, 3, 1, 2) 154 | 155 | refetch = False 156 | if self.index_ptr >= len(self): 157 | self.index_ptr = 0 158 | refetch = True 159 | if self.shuffle: 160 | self.random_generator.shuffle(self.index_list) 161 | 162 | return torch.from_numpy(images), trees, categories, refetch 163 | 164 | def get_all(self): 165 | data_file = self.files[self.phase] 166 | 167 | images, trees, categories = [], [], [] 168 | for i in range(len(self.index_list)): 169 | index = self.index_list[self.index_ptr] 170 | 171 | # load image 172 | if self.fileformat == 'png': 173 | img_file = data_file['img'][index] 174 | img = PIL.Image.open(img_file) 175 | else: 176 | raise ValueError('wrong file format for images') 177 | 178 | # remove the alpha-channel 179 | img = np.array(img, dtype=np.float32)[:, :, :-1] 180 | img = (img - 127.5) / 127.5 181 | images.append(img) 182 | 183 | # load tree 184 | with open(data_file['tree'][index], 'rb') as f: 185 | tree = pickle.load(f) 186 | tree = self.adapt_tree(tree) 187 | trees.append(tree) 188 | 189 | self.index_ptr += 1 190 | 191 | images = np.array(images, dtype=np.float32).transpose(0, 3, 1, 2) 192 | 193 | return torch.from_numpy(images), trees 194 | 195 | def adapt_tree(self, tree): 196 | tree = self._adapt_tree(tree, parent_bbox=None) 197 | return tree 198 | 199 | def _adapt_tree(self, tree, parent_bbox): 200 | # adjust tree.word for object size according to the bounding box, since the original size is for 3-D world 201 | if tree.function == 'combine' and tree.word in self.OLD_SIZE_WORDS: 202 | width = parent_bbox[2] 203 | height = parent_bbox[3] 204 | tree.word = self._get_size_word(width, height) 205 | 206 | # set the bbox for passing to children 207 | if tree.function == 'combine': 208 | bbox_xywh = parent_bbox 209 | elif tree.function == 'describe': 210 | bbox_xywh = tree.bbox 211 | else: 212 | bbox_xywh = None 213 | 214 | # then swap the bbox to (y,x,h,w) 215 | if hasattr(tree, 'bbox'): 216 | bbox_yxhw = (tree.bbox[1], tree.bbox[0], tree.bbox[3], tree.bbox[2]) 217 | tree.bbox = np.array(bbox_yxhw) 218 | 219 | # pre-order traversal 220 | for child in tree.children: 221 | self._adapt_tree(child, parent_bbox=bbox_xywh) 222 | return tree 223 | 224 | def get_categorical_list(self, tree): 225 | categorical_list, attr_list = self._get_categorical_list(tree) 226 | return categorical_list 227 | 228 | def _get_categorical_list(self, tree): 229 | # must be post-ordering traversal, parent need info from children 230 | category_list = list() 231 | attr_list = list() 232 | for child in tree.children: 233 | children_category_list, children_attr_list = self._get_categorical_list(child) 234 | category_list += children_category_list 235 | attr_list += children_attr_list 236 | 237 | if tree.function == 'describe': 238 | bbox = tree.bbox 239 | adapted_bbox = (bbox[1], bbox[0], bbox[3], bbox[2]) 240 | attr_list.append(tree.word) 241 | attr_vec = self._get_attr_vec(attr_list) 242 | obj_category = (adapted_bbox, attr_vec) 243 | category_list.append(obj_category) 244 | 245 | if tree.function == 'combine': 246 | attr_list.append(tree.word) # just pass its word to parent 247 | 248 | return category_list, attr_list 249 | 250 | def _get_attr_vec(self, attr_list): 251 | vec = np.zeros(len(self.dictionary), dtype=np.float64) 252 | for attr in attr_list: 253 | attr_idx = self.dictionary.index(attr) 254 | vec[attr_idx] = 1.0 255 | return vec 256 | 257 | @classmethod 258 | def _get_size_word(cls, width, height): 259 | maximum = max(width, height) 260 | if maximum < cls.SMALL_THRESHOLD: 261 | return 'small' 262 | else: 263 | return 'large' 264 | 265 | 266 | if __name__ == '__main__': 267 | loader = CLEVRTREE(phase='test', 268 | base_dir='/zhiweid/work/gnmn/GenerativeNeuralModuleNetwork/data/CLEVR/CLEVR_64_MULTI_LARGE') 269 | for i in range(3): 270 | im, trees, categories, ref = loader.next_batch() 271 | import IPython; 272 | 273 | IPython.embed() 274 | break 275 | print(im[0].shape) 276 | print(categories[0]) 277 | print(loader.dictionary) 278 | 279 | ''' 280 | # For testing the format of PIL.Image and cv2 281 | img2 = PIL.Image.open('/local-scratch/cjc/GenerativeNeuralModuleNetwork/data/CLEVR/clevr-dataset-gen/output/images/train/CLEVR_new_000003.png') 282 | img2 = np.array(img2) 283 | img2 = img2[:,:,:-1] 284 | 285 | out = PIL.Image.fromarray(img2) 286 | out.save('test.png') 287 | # img = cv2.imread('/local-scratch/cjc/GenerativeNeuralModuleNetwork/data/CLEVR/clevr-dataset-gen/output/images/train/CLEVR_new_000003.png') 288 | # cv2.imwrite('test.png',img) 289 | print(img2) 290 | # 291 | img = PIL.Image.open('test.png') 292 | img = np.array(img) 293 | 294 | print(img2 - img) 295 | 296 | ''' 297 | -------------------------------------------------------------------------------- /lib/data_loader/color_mnist_tree_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os.path as osp 3 | 4 | import _init_paths 5 | import numpy as np 6 | import PIL.Image 7 | import scipy.misc 8 | import torch 9 | import random 10 | from lib.tree import Tree 11 | import scipy.io 12 | 13 | 14 | class COLORMNISTTREE(): 15 | def __init__(self, batch_size=16, directory='/cs/vml4/Datasets/COLORMNIST', random_seed=12138, folder='TWO_5000', 16 | phase='train', shuffle=True, fileformat='png'): 17 | # basic 18 | self.folder = folder 19 | self.phase = phase 20 | self.dir = directory 21 | self.batch_size = batch_size 22 | self.fileformat = fileformat 23 | 24 | # static tree structure 25 | self.treefile = osp.join('data', 'COLORMNIST', folder, phase + '_parents.list') 26 | self.functionfile = osp.join('data', 'COLORMNIST', folder, phase + '_functions.list') 27 | self.infofile = osp.join(directory, folder, phase + '_text.txt') 28 | with open(self.infofile, 'r') as f: 29 | self.info = [] 30 | self.number = [] 31 | for line in f: 32 | line = line[:-1].split(' ') 33 | self.number += [int(line[0][0])] 34 | line[0] = line[0][2:] 35 | self.info += [line] 36 | 37 | self.trees = self.read_trees() 38 | self.functions = self.read_functions() 39 | 40 | # parse info 41 | imgdir = osp.join(self.dir, folder, phase) 42 | self.imglist, self.trees, self.dictionary = self.read_info() 43 | 44 | # iterator part 45 | self.random_generator = random.Random() 46 | self.random_generator.seed(random_seed) 47 | self.files = {} 48 | self.files[phase] = { 49 | 'img': self.imglist, 50 | 'tree': self.trees, 51 | } 52 | 53 | self.idx_ptr = 0 54 | self.indexlist = list(xrange(len(self.imglist))) 55 | 56 | self.shuffle = shuffle 57 | if shuffle: 58 | self.random_generator.shuffle(self.indexlist) 59 | 60 | # (H, W, C) 61 | self.im_size = self.test_read() 62 | 63 | def __len__(self): 64 | return len(self.files[self.phase]['img']) 65 | 66 | def test_read(self): 67 | data_file = self.files[self.phase] 68 | 69 | # load image 70 | if self.fileformat == 'mat': 71 | mat_file = data_file['img'][0] + '.mat' 72 | img = scipy.io.loadmat(mat_file)['data'] 73 | elif self.fileformat == 'png': 74 | img_file = data_file['img'][0] + '.png' 75 | img = PIL.Image.open(img_file) 76 | img = np.array(img, dtype=np.float32) 77 | 78 | return img.shape 79 | 80 | def read_functions(self): 81 | functionfile = self.functionfile 82 | f = open(functionfile) 83 | 84 | functions = [] 85 | for line in f: 86 | functions += [line[:-1].split(" ")] 87 | 88 | return functions 89 | 90 | def next_batch(self): 91 | data_file = self.files[self.phase] 92 | 93 | imgs, trees, categories = [], [], [] 94 | for i in range(0, min(self.batch_size, self.__len__() - self.idx_ptr)): 95 | index = self.indexlist[self.idx_ptr] 96 | 97 | # load image 98 | if self.fileformat == 'mat': 99 | mat_file = data_file['img'][index] + '.mat' 100 | img = scipy.io.loadmat(mat_file)['data'] 101 | elif self.fileformat == 'png': 102 | img_file = data_file['img'][index] + '.png' 103 | img = PIL.Image.open(img_file) 104 | img = np.array(img, dtype=np.float32) 105 | img = (img - 127.5) / 127.5 106 | 107 | imgs.append(img) 108 | 109 | # load label 110 | tree = data_file['tree'][index] 111 | trees.append(tree) 112 | categories.append(self.get_categorical_list(tree)) 113 | 114 | self.idx_ptr += 1 115 | 116 | imgs = np.array(imgs, dtype=np.float32).transpose(0, 3, 1, 2) 117 | 118 | refetch = False 119 | if self.idx_ptr >= self.__len__(): 120 | self.idx_ptr = 0 121 | refetch = True 122 | if self.shuffle: 123 | self.random_generator.shuffle(self.indexlist) 124 | 125 | return torch.from_numpy(imgs), trees, categories, refetch 126 | 127 | def next_batch_multigpu(self): 128 | data_file = self.files[self.phase] 129 | 130 | imgs, tree_indices = [], [] 131 | for i in range(0, min(self.batch_size, self.__len__() - self.idx_ptr)): 132 | index = self.indexlist[self.idx_ptr] 133 | 134 | # load image 135 | if self.fileformat == 'mat': 136 | mat_file = data_file['img'][index] + '.mat' 137 | img = scipy.io.loadmat(mat_file)['data'] 138 | elif self.fileformat == 'png': 139 | img_file = data_file['img'][index] + '.png' 140 | img = PIL.Image.open(img_file) 141 | img = np.array(img, dtype=np.float32) 142 | img = (img - 127.5) / 127.5 143 | 144 | imgs += [img] 145 | 146 | # load label 147 | tree_indices.append(index) 148 | 149 | self.idx_ptr += 1 150 | 151 | imgs = np.array(imgs, dtype=np.float32).transpose(0, 3, 1, 2) 152 | 153 | refetch = False 154 | if self.idx_ptr >= self.__len__(): 155 | self.idx_ptr = 0 156 | refetch = True 157 | if self.shuffle: 158 | self.random_generator.shuffle(self.indexlist) 159 | 160 | return torch.from_numpy(imgs), torch.from_numpy(np.array(tree_indices)), refetch 161 | 162 | def get_tree_by_idx(self, index): 163 | return self.files[self.phase]['tree'][index] 164 | 165 | def get_tree_list_current_epoch(self): 166 | tree_list = list() 167 | for idx in self.indexlist: 168 | tree_list.append(self.files[self.phase]['tree'][idx]) 169 | return tree_list 170 | 171 | def read_info(self): 172 | imglist = [] 173 | dictionary = [] 174 | functions = self.functions 175 | 176 | count = 0 177 | for tree in self.trees: 178 | imglist.append(osp.join(self.dir, self.folder, self.phase, 'image{:05d}'.format(count))) 179 | words, numbers, bboxes = self._extract_info(self.info[count], self.number[count]) 180 | dictionary = list(set(dictionary + words)) 181 | self.trees[count] = self._read_info(tree, words, functions[count], numbers, bboxes) 182 | count += 1 183 | 184 | return imglist, self.trees, dictionary 185 | 186 | def _extract_info(self, line, num): 187 | words = [] 188 | numbers = [] 189 | bboxes = [] 190 | count = 0 191 | 192 | numid = 0 193 | for ele in line: 194 | words.append(ele) 195 | if self._is_number(ele): 196 | numbers.append(ele) 197 | numid += 1 198 | if numid == num: 199 | count += 1 200 | break 201 | count += 1 202 | 203 | newid = count 204 | for i in range(0, newid): 205 | if self._is_number(line[i]): 206 | bboxes += [[int(ele) for ele in line[count:count + 4]]] 207 | count += 4 208 | else: 209 | bboxes += [[]] 210 | 211 | return words, numbers, bboxes 212 | 213 | def _is_number(self, n): 214 | try: 215 | int(n) 216 | return True 217 | except: 218 | return False 219 | 220 | def _read_info(self, tree, words, functions, numbers, bboxes): 221 | for i in range(0, tree.num_children): 222 | tree.children[i] = self._read_info(tree.children[i], words, functions, numbers, bboxes) 223 | 224 | tree.word = words[tree.idx] 225 | tree.function = functions[tree.idx] 226 | tree.bbox = np.array(bboxes[tree.idx]) 227 | 228 | return tree 229 | 230 | def read_trees(self): 231 | filename = self.treefile 232 | with open(filename, 'r') as f: 233 | trees = [self.read_tree(line) for line in f.readlines()] 234 | return trees 235 | 236 | def read_tree(self, line): 237 | parents = map(int, line.split()) 238 | trees = dict() 239 | root = None 240 | for i in xrange(1, len(parents) + 1): 241 | # if not trees[i-1] and parents[i-1]!=-1: 242 | if i - 1 not in trees.keys() and parents[i - 1] != -1: 243 | idx = i 244 | prev = None 245 | while True: 246 | parent = parents[idx - 1] 247 | if parent == -1: 248 | break 249 | tree = Tree() 250 | if prev is not None: 251 | tree.add_child(prev) 252 | trees[idx - 1] = tree 253 | tree.idx = idx - 1 254 | # if trees[parent-1] is not None: 255 | if parent - 1 in trees.keys(): 256 | trees[parent - 1].add_child(tree) 257 | break 258 | elif parent == 0: 259 | root = tree 260 | break 261 | else: 262 | prev = tree 263 | idx = parent 264 | return root 265 | 266 | def get_categorical_list(self, tree): 267 | categorical_list, attr_list = self._get_categorical_list(tree) 268 | return categorical_list 269 | 270 | def _get_categorical_list(self, tree): 271 | # must be post-ordering traversal, parent need info from children 272 | category_list = list() 273 | attr_list = list() 274 | for child in tree.children: 275 | children_category_list, children_attr_list = self._get_categorical_list(child) 276 | category_list += children_category_list 277 | attr_list += children_attr_list 278 | 279 | if tree.function == 'describe': 280 | bbox = tree.bbox 281 | adapted_bbox = (bbox[0], bbox[1], bbox[2], bbox[3]) 282 | attr_list.append(tree.word) 283 | attr_vec = self._get_attr_vec(attr_list) 284 | obj_category = (adapted_bbox, attr_vec) 285 | category_list.append(obj_category) 286 | 287 | if tree.function == 'combine': 288 | attr_list.append(tree.word) # just pass its word to parent 289 | 290 | return category_list, attr_list 291 | 292 | def _get_attr_vec(self, attr_list): 293 | vec = np.zeros(len(self.dictionary), dtype=np.float64) 294 | for attr in attr_list: 295 | attr_idx = self.dictionary.index(attr) 296 | vec[attr_idx] = 1.0 297 | return vec 298 | ''' 299 | # Single card 300 | loader = COLORMNISTTREE(directory='data/COLORMNIST', folder='TWO_5000_64_modified') 301 | trees = loader.trees 302 | print(trees[0].word) 303 | print(loader.dictionary) 304 | for i in range(0, 1): 305 | im, tr, cats, ref = loader.next_batch() 306 | print(tr[0].function) 307 | print(tr[0].word) 308 | print(tr[0].children[0].word) 309 | print(cats[0]) 310 | ''' 311 | ''' 312 | # Multi-gpu 313 | loader = COLORMNISTTREE(directory='data/COLORMNIST', folder='ONE') 314 | trees = loader.trees 315 | print(trees[0].word) 316 | print(loader.dictionary) 317 | for i in range(0, 1000): 318 | im, tr_indices, ref = loader.next_batch_multigpu() 319 | tree = loader.get_tree_by_idx(tr_indices[0]) 320 | print(tree.function) 321 | print(tree.word) 322 | print(tree.children[0].word) 323 | ''' 324 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mains/pnpnet_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import _init_paths 3 | import datetime 4 | import argparse 5 | import torch 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | import pytz 9 | import scipy.misc 10 | import os.path as osp 11 | import os 12 | import numpy as np 13 | import random 14 | import pdb 15 | 16 | from lib.data_loader.color_mnist_tree_multi import COLORMNISTTREE 17 | from lib.data_loader.clevr.clevr_tree import CLEVRTREE 18 | from lib.config import load_config, Struct 19 | from models.PNPNet.pnp_net import PNPNet 20 | from trainers.pnpnet_trainer import PNPNetTrainer 21 | from lib.weight_init import weights_init 22 | 23 | parser = argparse.ArgumentParser(description='PNPNet - main model experiment') 24 | parser.add_argument('--config_path', type=str, default='./configs/pnp_net_configs.yaml', metavar='C', 25 | help='path to the configuration file') 26 | 27 | 28 | def main(): 29 | args = parser.parse_args() 30 | 31 | config_dic = load_config(args.config_path) 32 | configs = Struct(**config_dic) 33 | 34 | assert (torch.cuda.is_available()) # assume CUDA is always available 35 | 36 | print('configurations:', configs) 37 | 38 | torch.cuda.set_device(configs.gpu_id) 39 | torch.manual_seed(configs.seed) 40 | torch.cuda.manual_seed(configs.seed) 41 | np.random.seed(configs.seed) 42 | random.seed(configs.seed) 43 | torch.backends.cudnn.benchmark = False 44 | torch.backends.cudnn.deterministic = True 45 | 46 | configs.exp_dir = 'results/' + configs.data_folder + '/' + configs.exp_dir_name 47 | exp_dir = configs.exp_dir 48 | 49 | try: 50 | os.makedirs(configs.exp_dir) 51 | except: 52 | pass 53 | try: 54 | os.makedirs(osp.join(configs.exp_dir, 'samples')) 55 | except: 56 | pass 57 | try: 58 | os.makedirs(osp.join(configs.exp_dir, 'checkpoints')) 59 | except: 60 | pass 61 | 62 | # loaders 63 | if 'CLEVR' in configs.data_folder: 64 | # we need the module's label->index dictionary from train loader 65 | train_loader = CLEVRTREE(phase='train', base_dir=osp.join(configs.base_dir, configs.data_folder), 66 | batch_size=configs.batch_size, 67 | random_seed=configs.seed, shuffle=True) 68 | test_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder), 69 | batch_size=configs.batch_size, 70 | random_seed=configs.seed, shuffle=False) 71 | gen_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder), 72 | batch_size=configs.batch_size, 73 | random_seed=configs.seed, shuffle=False) 74 | elif 'COLORMNIST' in configs.data_folder: 75 | train_loader = COLORMNISTTREE(phase='train', directory=configs.base_dir, folder=configs.data_folder, 76 | batch_size=configs.batch_size, 77 | random_seed=configs.seed, shuffle=True) 78 | test_loader = CLEVRTREE(phase='test', base_dir=osp.join(configs.base_dir, configs.data_folder), 79 | batch_size=configs.batch_size, 80 | random_seed=configs.seed, shuffle=False) 81 | gen_loader = COLORMNISTTREE(phase='test', directory=configs.base_dir, folder=configs.data_folder, 82 | batch_size=configs.batch_size, 83 | random_seed=configs.seed, shuffle=False) 84 | else: 85 | raise ValueError('invalid dataset folder name {}'.format(configs.data_folder)) 86 | 87 | # hack, parameter 88 | im_size = gen_loader.im_size[2] 89 | 90 | # model 91 | model = PNPNet(hiddim=configs.hiddim, latentdim=configs.latentdim, 92 | word_size=[configs.latentdim, configs.word_size, configs.word_size], pos_size=[8, 1, 1], 93 | nres=configs.nr_resnet, nlayers=4, 94 | nonlinear='elu', dictionary=train_loader.dictionary, 95 | op=[configs.combine_op, configs.describe_op], 96 | lmap_size=im_size // 2 ** configs.ds, 97 | downsample=configs.ds, lambdakl=-1, bg_bias=configs.bg_bias, 98 | normalize=configs.normalize, 99 | loss=configs.loss, debug_mode=False) 100 | 101 | if configs.checkpoint is not None and len(configs.checkpoint) > 0: 102 | model.load_state_dict(torch.load(configs.checkpoint)) 103 | print('load model from {}'.format(configs.checkpoint)) 104 | else: 105 | model.apply(weights_init) 106 | 107 | if configs.mode == 'train': 108 | train(model, train_loader, test_loader, gen_loader, configs=configs) 109 | elif configs.mode == 'test': 110 | print(exp_dir) 111 | print('Start generating...') 112 | generate(model, gen_loader=gen_loader, num_sample=2, target_dir=exp_dir) 113 | elif configs.mode == 'visualize': 114 | print(exp_dir) 115 | print('Start visualizing...') 116 | visualize(model, num_sample=50, target_dir=exp_dir) 117 | elif configs.mode == 'sample': 118 | print('Sampling') 119 | sample_tree(model, test_loader=test_loader, tree_idx=4, base_dir='samples', num_sample=500) 120 | else: 121 | raise ValueError('Wrong mode given:{}'.format(configs.mode)) 122 | 123 | 124 | def train(model, train_loader, test_loader, gen_loader, configs): 125 | model.train() 126 | # optimizer, it's better to set up lr for some modules separately so that the whole training become more stable 127 | optimizer = optim.Adamax([ 128 | {'params': model.reader.parameters(), 'lr': 0.2 * configs.lr}, 129 | {'params': model.h_mean.parameters(), 'lr': 0.1 * configs.lr}, 130 | {'params': model.h_var.parameters(), 'lr': 0.1 * configs.lr}, 131 | {'params': model.writer.parameters()}, 132 | {'params': model.vis_dist.parameters()}, 133 | {'params': model.pos_dist.parameters()}, 134 | {'params': model.combine.parameters()}, 135 | {'params': model.describe.parameters()}, 136 | {'params': model.box_vae.parameters(), 'lr': 10 * configs.lr}, 137 | {'params': model.offset_vae.parameters(), 'lr': 10 * configs.lr}, 138 | {'params': model.renderer.parameters()}, 139 | {'params': model.bias_mean.parameters()}, 140 | {'params': model.bias_var.parameters()} 141 | ], lr=configs.lr) 142 | 143 | model.cuda() 144 | 145 | trainer = PNPNetTrainer(model=model, train_loader=train_loader, val_loader=test_loader, gen_loader=gen_loader, 146 | optimizer=optimizer, 147 | configs=configs) 148 | 149 | minloss = 1000 150 | for epoch_num in range(0, configs.epochs + 1): 151 | timestamp_start = datetime.datetime.now(pytz.timezone('America/New_York')) 152 | trainer.train_epoch(epoch_num, timestamp_start) 153 | if epoch_num % configs.validate_interval == 0 and epoch_num > 0: 154 | minloss = trainer.validate(epoch_num, timestamp_start, minloss) 155 | if epoch_num % configs.sample_interval == 0 and epoch_num > 0: 156 | trainer.sample(epoch_num, sample_num=8, timestamp_start=timestamp_start) 157 | if epoch_num % configs.save_interval == 0 and epoch_num > 0: 158 | torch.save(model.state_dict(), 159 | osp.join(configs.exp_dir, 'checkpoints', 'model_epoch_{0}.pth'.format(epoch_num))) 160 | 161 | 162 | def generate(model, gen_loader, num_sample, target_dir): 163 | model.eval() 164 | model.cuda() 165 | 166 | epoch_end = False 167 | sample_dirs = [] 168 | for i in range(num_sample): 169 | sample_dir = osp.join(target_dir, 'test-data-{}'.format(i)) 170 | if not osp.isdir(sample_dir): 171 | os.mkdir(sample_dir) 172 | sample_dirs.append(sample_dir) 173 | 174 | image_idx = 0 175 | while epoch_end is False: 176 | data, trees, _, epoch_end = gen_loader.next_batch() 177 | 178 | with torch.no_grad(): 179 | data = Variable(data).cuda() 180 | 181 | samples_image_dict = dict() 182 | batch_size = None 183 | for j in range(num_sample): 184 | sample = model.generate(data, trees) 185 | if not batch_size: 186 | batch_size = sample.size(0) 187 | for i in range(0, sample.size(0)): 188 | samples_image_dict.setdefault(i, list()).append(sample.cpu().data.numpy().transpose(0, 2, 3, 1)[i]) 189 | model.clean_tree(trees) 190 | 191 | for i in range(batch_size): 192 | samples = np.clip(np.stack(samples_image_dict[i], axis=0), -1, 1) 193 | for j in range(num_sample): 194 | sample = samples[j] 195 | scipy.misc.imsave(osp.join(sample_dirs[j], 'image{:05d}.png'.format(image_idx)), sample) 196 | image_idx += 1 197 | 198 | 199 | def sample_tree(model, test_loader, tree_idx, base_dir, num_sample): 200 | """ 201 | Sample multiple image instances for a specified tree in the test dataset 202 | :param model: model 203 | :param test_loader: test loader 204 | :param tree_idx: the tree's index 205 | :param base_dir: base directory for saving results 206 | :param num_sample: number of samples to sample for the specified tree structure 207 | """ 208 | model.eval() 209 | 210 | target_dir = osp.join(base_dir, 'tree-{}'.format(tree_idx)) 211 | if not osp.isdir(target_dir): 212 | os.makedirs(target_dir) 213 | 214 | all_data, all_trees = test_loader.get_all() 215 | 216 | data = all_data[tree_idx] 217 | tree = [all_trees[tree_idx]] 218 | 219 | for i in range(num_sample): 220 | sample = model.generate(data, tree) 221 | sample = np.clip(sample.cpu().data.numpy().transpose(0, 2, 3, 1), -1, 1) 222 | sample = sample[0] 223 | scipy.misc.imsave(osp.join(target_dir, 'sample-{:05d}.png'.format(i)), sample) 224 | 225 | 226 | def visualize(model, num_sample, target_dir): 227 | """ 228 | Visualize intermediate results of PNPNet (e.g. visual concepts or partially composed image) 229 | :param model: model 230 | :param num_sample: number of samples to generate 231 | :param target_dir: target directory for saving results 232 | """ 233 | model.eval() 234 | 235 | sample_dir = osp.join(target_dir, 'visualize-data') 236 | if not osp.isdir(sample_dir): 237 | os.mkdir(sample_dir) 238 | 239 | mode = 'transform' 240 | 241 | for i in range(0, len(model.dictionary)): 242 | if i not in [1, 2, 5]: 243 | continue 244 | print('index ', i, 'current word:', model.dictionary[i]) 245 | 246 | data = Variable(torch.zeros(num_sample, len(model.dictionary))).cuda() 247 | data[:, i] = 1 248 | vis_dist = model.vis_dist(data) 249 | 250 | if mode == 'full': 251 | prior_mean, prior_var = model.renderer(vis_dist) 252 | elif mode == 'transform': 253 | prior_mean, prior_var = Variable(torch.zeros(vis_dist[0].size())).cuda(), Variable( 254 | torch.zeros(vis_dist[1].size())).cuda() 255 | sz = vis_dist[0].size(2) 256 | tsz = 5 257 | vis_dist[0] = model.transform(vis_dist[0], [tsz, tsz]) 258 | vis_dist[1] = model.transform(vis_dist[1], [tsz, tsz]) 259 | prior_mean[:, :, 5:5 + tsz, 5:5 + tsz] = vis_dist[0] 260 | prior_var[:, :, 5:5 + tsz, 5:5 + tsz] = vis_dist[1] 261 | prior_mean, prior_var = model.renderer([prior_mean, prior_var]) 262 | else: 263 | raise ValueError('Invalid mode name {}'.format(mode)) 264 | 265 | z_map = model.sampler(prior_mean, prior_var) 266 | 267 | rec = model.writer(z_map) 268 | 269 | for im in range(0, rec.size(0)): 270 | scipy.misc.imsave(osp.join(sample_dir, 'image-{}-{:05d}.png'.format(model.dictionary[i], im)), 271 | rec[im].cpu().data.numpy().transpose(1, 2, 0)) 272 | 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /lib/modules/Describe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils import weight_norm 5 | import torch.nn.functional as F 6 | 7 | 8 | # one more advanced plan: predict an attention map a, then 9 | # render the vw by a*x*y + (1-a)*y 10 | class Describe(nn.Module): 11 | def __init__(self, hiddim_v=None, hiddim_p=None, op='CAT'): 12 | super(Describe, self).__init__() 13 | self.op = op 14 | self.hiddim_v = hiddim_v 15 | self.hiddim_p = hiddim_p 16 | if op == 'CAT' or op == 'CAT_PoE': 17 | self.net1_mean_vis = nn.Sequential( 18 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 19 | nn.ELU(inplace=True), 20 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 21 | ) 22 | 23 | self.net1_var_vis = nn.Sequential( 24 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 25 | nn.ELU(inplace=True), 26 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 27 | ) 28 | 29 | self.net1_mean_pos = nn.Sequential( 30 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 31 | nn.ELU(inplace=True), 32 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 33 | ) 34 | 35 | self.net1_var_pos = nn.Sequential( 36 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 37 | nn.ELU(inplace=True), 38 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 39 | ) 40 | elif op == 'DEEP': 41 | self.net_vis = nn.Sequential( 42 | weight_norm(nn.Conv2d(4 * hiddim_v, hiddim_v, 3, 1, 1)), 43 | nn.ELU(inplace=True), 44 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)), 45 | nn.ELU(inplace=True), 46 | weight_norm(nn.Conv2d(hiddim_v, 2 * hiddim_v, 3, 1, 1)) 47 | ) 48 | 49 | self.net_pos = nn.Sequential( 50 | weight_norm(nn.Conv2d(4 * hiddim_p, hiddim_p, 1, 1)), 51 | nn.ELU(inplace=True), 52 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)), 53 | nn.ELU(inplace=True), 54 | weight_norm(nn.Conv2d(hiddim_p, 2 * hiddim_p, 1, 1)) 55 | ) 56 | 57 | elif op == 'CAT_PROD': 58 | self.net1_mean_vis = nn.Sequential( 59 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 60 | nn.ELU(inplace=True), 61 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 62 | ) 63 | 64 | self.net1_var_vis = nn.Sequential( 65 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 66 | nn.ELU(inplace=True), 67 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 68 | ) 69 | 70 | self.net2_mean_vis = nn.Sequential( 71 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)), 72 | nn.ELU(inplace=True), 73 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 74 | ) 75 | 76 | self.net2_var_vis = nn.Sequential( 77 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)), 78 | nn.ELU(inplace=True), 79 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 80 | ) 81 | 82 | self.net1_mean_pos = nn.Sequential( 83 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 84 | nn.ELU(inplace=True), 85 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 86 | ) 87 | 88 | self.net1_var_pos = nn.Sequential( 89 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 90 | nn.ELU(inplace=True), 91 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 92 | ) 93 | 94 | self.net2_mean_pos = nn.Sequential( 95 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)), 96 | nn.ELU(inplace=True), 97 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 98 | ) 99 | 100 | self.net2_var_pos = nn.Sequential( 101 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)), 102 | nn.ELU(inplace=True), 103 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 104 | ) 105 | elif op == 'CAT_gPoE': 106 | self.net1_mean_vis = nn.Sequential( 107 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 108 | nn.ELU(inplace=True), 109 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 110 | ) 111 | 112 | self.net1_var_vis = nn.Sequential( 113 | weight_norm(nn.Conv2d(hiddim_v * 2, hiddim_v, 3, 1, 1)), 114 | nn.ELU(inplace=True), 115 | weight_norm(nn.Conv2d(hiddim_v, hiddim_v, 3, 1, 1)) 116 | ) 117 | 118 | self.net1_mean_pos = nn.Sequential( 119 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 120 | nn.ELU(inplace=True), 121 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 122 | ) 123 | 124 | self.net1_var_pos = nn.Sequential( 125 | weight_norm(nn.Conv2d(hiddim_p * 2, hiddim_p, 1, 1)), 126 | nn.ELU(inplace=True), 127 | weight_norm(nn.Conv2d(hiddim_p, hiddim_p, 1, 1)) 128 | ) 129 | self.gates_v = nn.Sequential( 130 | weight_norm(nn.Conv2d(hiddim_v * 4, hiddim_v * 4, 3, 1, 1)), 131 | nn.ELU(inplace=True), 132 | weight_norm(nn.Conv2d(hiddim_v * 4, hiddim_v * 4, 3, 1, 1)) 133 | ) 134 | self.gates_p = nn.Sequential( 135 | weight_norm(nn.Conv2d(hiddim_p * 4, hiddim_p * 4, 3, 1, 1)), 136 | nn.ELU(inplace=True), 137 | weight_norm(nn.Conv2d(hiddim_p * 4, hiddim_p * 4, 3, 1, 1)) 138 | ) 139 | 140 | def forward(self, x, y, mode, lognormal=False): # -> x describe y 141 | if mode == 'vis': 142 | if self.op == 'CAT_PROD': 143 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1)) 144 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1)) 145 | 146 | if lognormal == True: 147 | x_mean = torch.exp(x_mean) 148 | 149 | y_mean = self.net2_mean_vis(x_mean * y[0]) 150 | y_var = self.net2_var_vis(x_var * y[1]) 151 | elif self.op == 'CAT_PoE': 152 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 153 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 154 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1)) 155 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1)) 156 | mlogvar1 = -x_var 157 | mlogvar2 = -y[1] 158 | mu1 = x_mean 159 | mu2 = y[0] 160 | 161 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 162 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2) 163 | elif self.op == 'CAT_gPoE': 164 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 165 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 166 | x_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1)) 167 | x_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1)) 168 | 169 | # gates 170 | gates = torch.sigmoid(self.gates_v(torch.cat([x_mean, x_var, y[0], y[1]], dim=1))) 171 | x1_mu_g = gates[:, :self.hiddim_v, :, :] 172 | x1_var_g = gates[:, self.hiddim_v:2 * self.hiddim_v, :, :] 173 | x2_mu_g = gates[:, 2 * self.hiddim_v:3 * self.hiddim_v, :, :] 174 | x2_var_g = gates[:, 3 * self.hiddim_v:4 * self.hiddim_v, :, :] 175 | 176 | x_mean = x1_mu_g * x_mean 177 | x_var = torch.log(x1_var_g + 1e-5) + x_var 178 | y[0] = x2_mu_g * y[0] 179 | y[1] = torch.log(x2_var_g + 1e-5) + y[1] 180 | 181 | mlogvar1 = -x_var 182 | mlogvar2 = -y[1] 183 | mu1 = x_mean 184 | mu2 = y[0] 185 | 186 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 187 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2) 188 | elif self.op == 'CAT': 189 | y_mean = self.net1_mean_vis(torch.cat([x[0], y[0]], dim=1)) 190 | y_var = self.net1_var_vis(torch.cat([x[1], y[1]], dim=1)) 191 | elif self.op == 'PROD': 192 | y_mean = x[0] * y[0] 193 | y_var = x[1] * y[1] 194 | elif self.op == 'DEEP': 195 | gaussian_out = self.net_vis(torch.cat([x[0], x[1], y[0], y[1]], dim=1)) 196 | y_mean = gaussian_out[:, :self.hiddim_v, :, :] 197 | y_var = gaussian_out[:, self.hiddim_v:, :, :] 198 | else: 199 | raise ValueError('invalid operator name {} for Describe module'.format(self.op)) 200 | 201 | elif mode == 'pos': 202 | if self.op == 'CAT_PROD': 203 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1)) 204 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1)) 205 | 206 | y_mean = self.net2_mean_pos(x_mean * y[0]) 207 | y_var = self.net2_var_pos(x_var * y[1]) 208 | elif self.op == 'CAT_PoE': 209 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 210 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 211 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1)) 212 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1)) 213 | 214 | mlogvar1 = -x_var 215 | mlogvar2 = -y[1] 216 | mu1 = x_mean 217 | mu2 = y[0] 218 | 219 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 220 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2) 221 | elif self.op == 'CAT_gPoE': 222 | # logvar = -log(exp(-logvar1) + exp(-logvar2)) 223 | # mu = exp(logvar) * (exp(-logvar1) * mu1 + exp(-logvar2) * mu2) 224 | x_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1)) 225 | x_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1)) 226 | 227 | # gates 228 | gates = torch.sigmoid(self.gates_p(torch.cat([x_mean, x_var, y[0], y[1]], dim=1))) 229 | x1_mu_g = gates[:, :self.hiddim_p, :, :] 230 | x1_var_g = gates[:, self.hiddim_p:2 * self.hiddim_p, :, :] 231 | x2_mu_g = gates[:, 2 * self.hiddim_p:3 * self.hiddim_p, :, :] 232 | x2_var_g = gates[:, 3 * self.hiddim_p:4 * self.hiddim_p, :, :] 233 | 234 | x_mean = x1_mu_g * x_mean 235 | x_var = torch.log(x1_var_g + 1e-5) + x_var 236 | y[0] = x2_mu_g * y[0] 237 | y[1] = torch.log(x2_var_g + 1e-5) + y[1] 238 | 239 | mlogvar1 = -x_var 240 | mlogvar2 = -y[1] 241 | mu1 = x_mean 242 | mu2 = y[0] 243 | 244 | y_var = -torch.log(torch.exp(mlogvar1) + torch.exp(mlogvar2)) 245 | y_mean = torch.exp(y_var) * (torch.exp(mlogvar1) * mu1 + torch.exp(mlogvar2) * mu2) 246 | elif self.op == 'CAT': 247 | y_mean = self.net1_mean_pos(torch.cat([x[0], y[0]], dim=1)) 248 | y_var = self.net1_var_pos(torch.cat([x[1], y[1]], dim=1)) 249 | elif self.op == 'PROD': 250 | y_mean = x[0] * y[0] 251 | y_var = x[1] * y[1] 252 | elif self.op == 'DEEP': 253 | gaussian_out = self.net_pos(torch.cat([x[0], x[1], y[0], y[1]], dim=1)) 254 | y_mean = gaussian_out[:, :self.hiddim_p, :, :] 255 | y_var = gaussian_out[:, self.hiddim_p:, :, :] 256 | else: 257 | raise ValueError('invalid operator name {} for Describe module'.format(self.op)) 258 | 259 | else: 260 | raise ValueError('invalid mode {}'.format(mode)) 261 | 262 | return [y_mean, y_var] 263 | -------------------------------------------------------------------------------- /models/PNPNet/pnp_net.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * 3 | PNP-Net: flexibly takes a tree-structure program and 4 | assemble modules for image generative modeling 5 | * 6 | 7 | It contains: 8 | 9 | -- primitive visual elements 10 | 11 | -- unit modules 12 | 13 | -- tree recursive function 14 | 15 | -- forward(), generate() 16 | 17 | -- utility functions: clean_tree(), etc. 18 | ''' 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torch.autograd import Variable 23 | import torch.nn.functional as F 24 | import numpy as np 25 | import math 26 | 27 | from lib.reparameterize import reparameterize 28 | from lib.modules.VAE import VAE 29 | from lib.modules.ResReader import Reader 30 | from lib.modules.ResWriter import Writer 31 | from lib.modules.ConceptMapper import ConceptMapper 32 | from lib.modules.Combine import Combine 33 | from lib.modules.Describe import Describe 34 | from lib.modules.Transform import Transform 35 | from lib.modules.DistributionRender import DistributionRender 36 | 37 | 38 | class PNPNet(nn.Module): 39 | def __init__(self, hiddim=160, latentdim=12, 40 | word_size=[-1, 16, 16], pos_size=[4, 1, 1], nres=4, nlayers=1, 41 | nonlinear='elu', dictionary=None, op=['PROD', 'CAT'], 42 | lmap_size=0, downsample=2, gpu_ids=None, 43 | multigpu_full=False, lambdakl=-1, bg_bias=False, normalize='instance_norm', 44 | loss=None, debug_mode=True): 45 | super(PNPNet, self).__init__() 46 | ## basic settings 47 | word_size[0] = latentdim 48 | self.word_size = word_size 49 | self.latentdim = latentdim 50 | self.hiddim = hiddim 51 | self.downsample = downsample # -> downsample times 52 | self.ds = 2 ** self.downsample # -> downsample ratio 53 | self.nres = nres 54 | self.nlayers = nlayers 55 | self.lmap_size = lmap_size 56 | self.im_size = lmap_size * self.ds 57 | self.multigpu_full = multigpu_full 58 | self.bg_bias = bg_bias 59 | self.normalize = normalize 60 | self.debug_mode = debug_mode 61 | 62 | # dictionary 63 | self.dictionary = dictionary 64 | 65 | ########## modules ########## 66 | # proposal networks 67 | self.reader = Reader(indim=3, hiddim=hiddim, outdim=hiddim, ds_times=self.downsample, normalize=normalize, 68 | nlayers=nlayers) 69 | self.h_mean = nn.Conv2d(hiddim, latentdim, 3, 1, 1) 70 | self.h_var = nn.Conv2d(hiddim, latentdim, 3, 1, 1) 71 | 72 | # pixel writer 73 | self.writer = Writer(indim=latentdim, hiddim=hiddim, outdim=3, ds_times=self.downsample, normalize=normalize, 74 | nlayers=nlayers) 75 | 76 | # visual words 77 | self.vis_dist = ConceptMapper(word_size, len(dictionary)) 78 | self.pos_dist = ConceptMapper(pos_size, len(dictionary)) 79 | 80 | self.renderer = DistributionRender(hiddim=latentdim) 81 | 82 | # neural modules 83 | self.combine = Combine(hiddim_v=latentdim, hiddim_p=pos_size[0], op=op[0]) 84 | self.describe = Describe(hiddim_v=latentdim, hiddim_p=pos_size[0], op=op[1]) 85 | self.transform = Transform(matrix='default') 86 | 87 | # small vaes for bounding boxes and offsets learning 88 | # input: H, W 89 | self.box_vae = VAE(indim=2, latentdim=pos_size[0]) 90 | # input: [x0, y0, x1, y1] + condition: [H0, W0, H1, W1, im_H, im_W, (im_H-H0), (im_W-W0), (im_H-H1), (im_W-W1)] 91 | self.offset_vae = VAE(indim=4, latentdim=pos_size[0]) 92 | 93 | ## loss functions&sampler 94 | self.sampler = reparameterize() 95 | if lambdakl > 0: 96 | from lib.LambdaBiKLD import BiKLD 97 | self.bikld = BiKLD(lambda_t=lambdakl, k=None) 98 | else: 99 | from lib.BiKLD import BiKLD 100 | self.bikld = BiKLD() 101 | 102 | if loss == 'l1': 103 | self.pixelrecon_criterion = nn.L1Loss() 104 | elif loss == 'l2': 105 | self.pixelrecon_criterion = nn.MSELoss() 106 | self.pixelrecon_criterion.size_average = False 107 | 108 | self.pos_criterion = nn.MSELoss() 109 | self.pos_criterion.size_average = False 110 | 111 | ## biases 112 | self.bias_mean = nn.Linear(1, self.latentdim * self.lmap_size * self.lmap_size, bias=False) 113 | self.bias_var = nn.Linear(1, self.latentdim * self.lmap_size * self.lmap_size, bias=False) 114 | self.latent_canvas_size = torch.Size([1, self.latentdim, self.lmap_size, self.lmap_size]) 115 | 116 | def get_mask_from_tree(self, tree, size): 117 | mask = Variable(torch.zeros(size), requires_grad=False).cuda() 118 | 119 | return self._get_mask_from_tree(tree, mask) 120 | 121 | def _get_mask_from_tree(self, tree, mask): 122 | for i in range(0, tree.num_children): 123 | mask = self._get_mask_from_tree(tree.children[i], mask) 124 | 125 | if tree.function == 'describe': 126 | bbx = tree.bbox 127 | mask[:, :, bbx[0]:bbx[0] + bbx[2], bbx[1]:bbx[1] + bbx[3]] = 1.0 128 | 129 | return mask 130 | 131 | def forward(self, x, treex, treeindex=None, alpha=1.0, ifmask=False, maskweight=1.0): 132 | ################################ 133 | ## input: images, trees ## 134 | ################################ 135 | 136 | # if multigpu_full, pick the trees by treeindex 137 | if self.multigpu_full: 138 | treex_pick = [treex[ele[0]] for ele in treeindex.data.cpu().numpy().astype(int)] 139 | treex = treex_pick 140 | 141 | if ifmask == True: 142 | mask = [] 143 | for i in range(0, len(treex)): 144 | mask += [self.get_mask_from_tree(treex[i], x[0:1, :, :, :].size())] 145 | mask = torch.cat(mask, dim=0) 146 | 147 | # encoding the images 148 | h = self.reader(x) 149 | # proposal distribution 150 | latent_mean = self.h_mean(h) 151 | latent_var = self.h_var(h) 152 | 153 | # losses 154 | kld_loss, rec_loss, pos_loss = 0, 0, 0 155 | 156 | # forward GNMN 157 | prior_mean_all = [] 158 | prior_var_all = [] 159 | trees = [] 160 | for i in range(0, len(treex)): # iterate through every tree of the batch 161 | trees.append(self.compose_tree(treex[i], self.latent_canvas_size)) 162 | prior_mean_all += [trees[i].vis_dist[0]] 163 | prior_var_all += [trees[i].vis_dist[1]] 164 | pos_loss += trees[i].pos_loss 165 | if np.isnan(trees[i].pos_loss.data.cpu().numpy()): 166 | print('found nan pos loss') 167 | import IPython; 168 | IPython.embed() 169 | 170 | prior_mean = torch.cat(prior_mean_all, dim=0) 171 | prior_var = torch.cat(prior_var_all, dim=0) 172 | 173 | prior_mean, prior_var = self.renderer([prior_mean, prior_var]) 174 | 175 | # sample z map 176 | z_map = self.sampler(latent_mean, latent_var) 177 | 178 | # kld loss 179 | kld_loss = alpha * self.bikld([latent_mean, latent_var], [prior_mean, prior_var]) + \ 180 | (1 - alpha) * self.bikld([latent_mean.detach(), latent_var.detach()], [prior_mean, prior_var]) 181 | 182 | rec = self.writer(z_map) 183 | 184 | if ifmask is True: 185 | mask = (mask + maskweight) / (maskweight + 1.0) 186 | rec_loss = self.pixelrecon_criterion(mask * rec, mask * x) 187 | else: 188 | rec_loss = self.pixelrecon_criterion(rec, x) 189 | rec_loss = rec_loss.sum() 190 | 191 | return rec_loss, kld_loss, pos_loss, rec 192 | 193 | def get_code(self, dictionary, word): 194 | code = Variable(torch.zeros(1, len(dictionary))).cuda() 195 | code[0, dictionary.index(word)] = 1 196 | 197 | return code 198 | 199 | def compose_tree(self, treex, latent_canvas_size): 200 | for i in range(0, treex.num_children): 201 | treex.children[i] = self.compose_tree(treex.children[i], latent_canvas_size) 202 | 203 | # one hot embedding of a word 204 | ohe = self.get_code(self.dictionary, treex.word) 205 | 206 | if treex.function == 'combine': 207 | vis_dist = self.vis_dist(ohe) 208 | pos_dist = self.pos_dist(ohe) 209 | if treex.num_children > 0: 210 | # visual content 211 | vis_dist_child = treex.children[0].vis_dist 212 | vis_dist = self.combine(vis_dist, vis_dist_child, 'vis') 213 | # visual position 214 | pos_dist_child = treex.children[0].pos_dist 215 | pos_dist = self.combine(pos_dist, pos_dist_child, 'pos') 216 | 217 | treex.vis_dist = vis_dist 218 | treex.pos_dist = pos_dist 219 | 220 | elif treex.function == 'describe': 221 | # blend visual words 222 | vis_dist = self.vis_dist(ohe) 223 | pos_dist = self.pos_dist(ohe) 224 | if treex.num_children > 0: 225 | # visual content 226 | vis_dist_child = treex.children[0].vis_dist 227 | vis_dist = self.describe(vis_dist_child, vis_dist, 'vis') 228 | # visual position 229 | pos_dist_child = treex.children[0].pos_dist 230 | pos_dist = self.describe(pos_dist_child, pos_dist, 'pos') 231 | 232 | treex.pos_dist = pos_dist 233 | 234 | # regress bbox 235 | treex.pos = np.maximum(treex.bbox[2:] // self.ds, [1, 1]) 236 | target_box = Variable(torch.from_numpy(np.array(treex.bbox[2:])[np.newaxis, ...].astype(np.float32))).cuda() 237 | regress_box, kl_box = self.box_vae(target_box, prior=treex.pos_dist) 238 | treex.pos_loss = self.pos_criterion(regress_box, target_box) + kl_box 239 | 240 | if treex.parent == None: 241 | ones = self.get_ones(torch.Size([1, 1])) 242 | if not self.bg_bias: 243 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \ 244 | Variable(torch.zeros(latent_canvas_size)).cuda()] 245 | else: 246 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \ 247 | self.bias_var(ones).view(*latent_canvas_size)] 248 | b = np.maximum(treex.bbox // self.ds, [0, 0, 1, 1]) 249 | 250 | bg_vis_dist = [self.assign_util(bg_vis_dist[0], b, self.transform(vis_dist[0], treex.pos), 251 | 'assign'), \ 252 | self.assign_util(bg_vis_dist[1], b, 253 | self.transform(vis_dist[1], treex.pos, variance=True), 254 | 'assign')] 255 | vis_dist = bg_vis_dist 256 | else: 257 | try: 258 | # resize vis_dist 259 | vis_dist = [self.transform(vis_dist[0], treex.pos), \ 260 | self.transform(vis_dist[1], treex.pos, variance=True)] 261 | except: 262 | import IPython; 263 | IPython.embed() 264 | 265 | treex.vis_dist = vis_dist 266 | 267 | elif treex.function == 'layout': 268 | # get pos word as position prior 269 | treex.pos_dist = self.pos_dist(ohe) 270 | assert (treex.num_children > 0) 271 | 272 | # get offsets: use gt for training 273 | l_pos = treex.children[0].pos 274 | l_offset = np.maximum(treex.children[0].bbox[:2] // self.ds, [1, 1]) 275 | 276 | r_pos = treex.children[1].pos 277 | r_offset = np.maximum(treex.children[1].bbox[:2] // self.ds, [1, 1]) 278 | 279 | # regress offsets 280 | target_offset = np.append(l_offset * self.ds, r_offset * self.ds).astype(np.float32) 281 | target_offset = Variable(torch.from_numpy(target_offset[np.newaxis, ...])).cuda() 282 | regress_offset, kl_offset = self.offset_vae(target_offset, prior=treex.pos_dist) 283 | treex.pos_loss = self.pos_criterion(regress_offset, target_offset) + kl_offset + treex.children[ 284 | 0].pos_loss + \ 285 | treex.children[1].pos_loss 286 | 287 | ######################### constructing latent map ############################### 288 | # bias filled mean&var 289 | ones = self.get_ones(torch.Size([1, 1])) 290 | if not self.bg_bias: 291 | vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \ 292 | Variable(torch.zeros(latent_canvas_size)).cuda()] 293 | else: 294 | vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \ 295 | self.bias_var(ones).view(*latent_canvas_size)] 296 | 297 | # arrange the layout of two children 298 | vis_dist[0] = self.assign_util(vis_dist[0], list(l_offset) + list(l_pos), treex.children[0].vis_dist[0], 299 | 'assign') 300 | vis_dist[1] = self.assign_util(vis_dist[1], list(l_offset) + list(l_pos), treex.children[0].vis_dist[1], 301 | 'assign') 302 | 303 | vis_dist[0] = self.assign_util(vis_dist[0], list(r_offset) + list(r_pos), treex.children[1].vis_dist[0], 304 | 'assign') 305 | vis_dist[1] = self.assign_util(vis_dist[1], list(r_offset) + list(r_pos), treex.children[1].vis_dist[1], 306 | 'assign') 307 | 308 | # continue layout 309 | if treex.parent != None: 310 | p = [min(l_offset[0], r_offset[0]), min(l_offset[1], r_offset[1]), \ 311 | max(l_offset[0] + l_pos[0], r_offset[0] + r_pos[0]), 312 | max(l_offset[1] + l_pos[1], r_offset[1] + r_pos[1])] 313 | treex.pos = [p[2] - p[0], p[3] - p[1]] 314 | treex.vis_dist = [vis_dist[0][:, :, p[0]:p[2], p[1]:p[3]], \ 315 | vis_dist[1][:, :, p[0]:p[2], p[1]:p[3]]] 316 | else: 317 | treex.vis_dist = vis_dist 318 | 319 | return treex 320 | 321 | def assign_util(self, a, bx, b, mode): 322 | if mode == 'assign': 323 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] = b 324 | elif mode == 'add': 325 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] = \ 326 | a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]] + b 327 | elif mode == 'slice': 328 | a = a[:, :, bx[0]:bx[0] + bx[2], bx[1]:bx[1] + bx[3]].clone() 329 | else: 330 | raise ValueError('Please specify the correct mode.') 331 | return a 332 | 333 | def overlap_box(self, box_left, box_right): 334 | x1, y1, h1, w1 = box_left[0], box_left[1], box_left[2], box_left[3] 335 | x2, y2, h2, w2 = box_right[0], box_right[1], box_right[2], box_right[3] 336 | 337 | ox1 = max(x1, x2) 338 | oy1 = max(y1, y2) 339 | ox2 = min(x1 + h1, x2 + h2) 340 | oy2 = min(y1 + w1, y2 + w2) 341 | 342 | if ox2 > ox1 and oy2 > oy1: 343 | return [ox1, oy1, ox2 - ox1, oy2 - oy1] 344 | else: 345 | return [] 346 | 347 | def generate(self, x, treex, treeindex=None): 348 | ################################ 349 | ## input: images, trees ## 350 | ################################ 351 | if self.multigpu_full: 352 | treex_pick = [treex[ele[0]] for ele in treeindex.data.cpu().numpy().astype(int)] 353 | treex = treex_pick 354 | 355 | # tranverse trees to compose visual words 356 | prior_mean = [] 357 | prior_var = [] 358 | 359 | for i in range(0, len(treex)): 360 | treex[i] = self.generate_compose_tree(treex[i], self.latent_canvas_size) 361 | prior_mean += [treex[i].vis_dist[0]] 362 | prior_var += [treex[i].vis_dist[1]] 363 | prior_mean = torch.cat(prior_mean, dim=0) 364 | prior_var = torch.cat(prior_var, dim=0) 365 | 366 | # sample z map 367 | prior_mean, prior_var = self.renderer([prior_mean, prior_var]) 368 | 369 | z_map = self.sampler(prior_mean, prior_var) 370 | 371 | rec = self.writer(z_map) 372 | 373 | return rec 374 | 375 | def check_valid(self, offsets, l_pos, r_pos, im_size): 376 | flag = True 377 | if offsets[0] + l_pos[0] > im_size: 378 | flag = False 379 | return flag 380 | if offsets[1] + l_pos[1] > im_size: 381 | flag = False 382 | return flag 383 | if offsets[2] + r_pos[0] > im_size: 384 | flag = False 385 | return flag 386 | if offsets[3] + r_pos[1] > im_size: 387 | flag = False 388 | return flag 389 | 390 | return flag 391 | 392 | def generate_compose_tree(self, treex, latent_canvas_size): 393 | for i in range(0, treex.num_children): 394 | treex.children[i] = self.generate_compose_tree(treex.children[i], latent_canvas_size) 395 | 396 | # one hot embedding of a word 397 | ohe = self.get_code(self.dictionary, treex.word) 398 | if treex.function == 'combine': 399 | vis_dist = self.vis_dist(ohe) 400 | pos_dist = self.pos_dist(ohe) 401 | if treex.num_children > 0: 402 | # visual content 403 | vis_dist_child = treex.children[0].vis_dist 404 | vis_dist = self.combine(vis_dist, vis_dist_child, 'vis') 405 | # visual position 406 | pos_dist_child = treex.children[0].pos_dist 407 | pos_dist = self.combine(pos_dist, pos_dist_child, 'pos') 408 | 409 | treex.vis_dist = vis_dist 410 | treex.pos_dist = pos_dist 411 | 412 | elif treex.function == 'describe': 413 | # blend visual words 414 | vis_dist = self.vis_dist(ohe) 415 | pos_dist = self.pos_dist(ohe) 416 | if treex.num_children > 0: 417 | # visual content 418 | vis_dist_child = treex.children[0].vis_dist 419 | vis_dist = self.describe(vis_dist_child, vis_dist, 'vis') 420 | # visual position 421 | pos_dist_child = treex.children[0].pos_dist 422 | pos_dist = self.describe(pos_dist_child, pos_dist, 'pos') 423 | 424 | treex.pos_dist = pos_dist 425 | 426 | # regress bbox 427 | treex.pos = np.clip(self.box_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int), 428 | int(self.ds), 429 | self.im_size).flatten() // self.ds 430 | 431 | if treex.parent == None: 432 | ones = self.get_ones(torch.Size([1, 1])) 433 | if not self.bg_bias: 434 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \ 435 | Variable(torch.zeros(latent_canvas_size)).cuda()] 436 | else: 437 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \ 438 | self.bias_var(ones).view(*latent_canvas_size)] 439 | b = [int(latent_canvas_size[2]) // 2 - treex.pos[0] // 2, 440 | int(latent_canvas_size[3]) // 2 - treex.pos[1] // 2, treex.pos[0], treex.pos[1]] 441 | 442 | bg_vis_dist = [self.assign_util(bg_vis_dist[0], b, self.transform(vis_dist[0], treex.pos), 443 | 'assign'), \ 444 | self.assign_util(bg_vis_dist[1], b, 445 | self.transform(vis_dist[1], treex.pos, variance=True), 446 | 'assign')] 447 | 448 | vis_dist = bg_vis_dist 449 | treex.offsets = b 450 | else: 451 | # resize vis_dist 452 | vis_dist = [self.transform(vis_dist[0], treex.pos), \ 453 | self.transform(vis_dist[1], treex.pos, variance=True)] 454 | 455 | treex.vis_dist = vis_dist 456 | 457 | elif treex.function == 'layout': 458 | # get pos word as position prior 459 | treex.pos_dist = self.pos_dist(ohe) 460 | assert (treex.num_children > 0) 461 | 462 | # get offsets: use gt for training 463 | l_pos = treex.children[0].pos 464 | r_pos = treex.children[1].pos 465 | 466 | offsets = np.clip(self.offset_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int), 0, 467 | self.im_size).flatten() // self.ds 468 | countdown = 0 469 | while self.check_valid(offsets, l_pos, r_pos, self.im_size // self.ds) == False: 470 | offsets = np.clip(self.offset_vae.generate(prior=treex.pos_dist).data.cpu().numpy().astype(int), 0, 471 | self.im_size).flatten() // self.ds 472 | if countdown >= 100: 473 | print('Tried proposing more than 100 times.') 474 | if self.debug_mode: 475 | import IPython; 476 | IPython.embed() 477 | print('Warning! Manually adapt offsets') 478 | lat_size = self.im_size // self.ds 479 | if offsets[0] + l_pos[0] > lat_size: 480 | offsets[0] = lat_size - l_pos[0] 481 | if offsets[1] + l_pos[1] > lat_size: 482 | offsets[1] = lat_size - l_pos[1] 483 | if offsets[2] + r_pos[0] > lat_size: 484 | offsets[2] = lat_size - r_pos[0] 485 | if offsets[3] + r_pos[1] > lat_size: 486 | offsets[3] = lat_size - r_pos[1] 487 | 488 | countdown += 1 489 | treex.offsets = offsets 490 | l_offset = offsets[:2] 491 | r_offset = offsets[2:] 492 | 493 | ######################### constructing latent map ############################### 494 | # bias filled mean&var 495 | ones = self.get_ones(torch.Size([1, 1])) 496 | if not self.bg_bias: 497 | bg_vis_dist = [Variable(torch.zeros(latent_canvas_size)).cuda(), \ 498 | Variable(torch.zeros(latent_canvas_size)).cuda()] 499 | else: 500 | bg_vis_dist = [self.bias_mean(ones).view(*latent_canvas_size), \ 501 | self.bias_var(ones).view(*latent_canvas_size)] 502 | 503 | vis_dist = bg_vis_dist 504 | try: 505 | # arrange the layout of two children 506 | vis_dist[0] = self.assign_util(vis_dist[0], list(l_offset) + list(l_pos), treex.children[0].vis_dist[0], 507 | 'assign') 508 | vis_dist[1] = self.assign_util(vis_dist[1], list(l_offset) + list(l_pos), treex.children[0].vis_dist[1], 509 | 'assign') 510 | 511 | vis_dist[0] = self.assign_util(vis_dist[0], list(r_offset) + list(r_pos), treex.children[1].vis_dist[0], 512 | 'assign') 513 | vis_dist[1] = self.assign_util(vis_dist[1], list(r_offset) + list(r_pos), treex.children[1].vis_dist[1], 514 | 'assign') 515 | except: 516 | print('latent distribution doesnt fit size.') 517 | import IPython; 518 | IPython.embed() 519 | 520 | if treex.parent != None: 521 | p = [min(l_offset[0], r_offset[0]), min(l_offset[1], r_offset[1]), \ 522 | max(l_offset[0] + l_pos[0], r_offset[0] + r_pos[0]), 523 | max(l_offset[1] + l_pos[1], r_offset[1] + r_pos[1])] 524 | treex.pos = [p[2] - p[0], p[3] - p[1]] 525 | treex.vis_dist = [vis_dist[0][:, :, p[0]:p[2], p[1]:p[3]], \ 526 | vis_dist[1][:, :, p[0]:p[2], p[1]:p[3]]] 527 | else: 528 | treex.vis_dist = vis_dist 529 | 530 | return treex 531 | 532 | def get_ones(self, size): 533 | return Variable(torch.ones(size), requires_grad=False).cuda() 534 | 535 | def clean_tree(self, treex): 536 | for i in range(0, len(treex)): 537 | self._clean_tree(treex[i]) 538 | 539 | def _clean_tree(self, treex): 540 | for i in range(0, treex.num_children): 541 | self._clean_tree(treex.children[i]) 542 | 543 | if treex.function == 'combine': 544 | treex.vis_dist = None 545 | treex.pos_dist = None 546 | elif treex.function == 'describe': 547 | treex.vis_dist = None 548 | treex.pos_dist = None 549 | treex.pos = None 550 | elif treex.function == 'layout': 551 | treex.vis_dist = None 552 | treex.pos_dist = None 553 | treex.pos = None 554 | --------------------------------------------------------------------------------