├── scene_generation ├── __init__.py ├── data │ ├── __init__.py │ └── utils.py ├── metrics.py ├── utils.py ├── generators.py ├── args.py ├── graph.py ├── losses.py ├── vis.py ├── layout.py ├── discriminators.py ├── bilinear.py ├── layers.py └── model.py ├── images ├── arch.png └── scene_generation.png ├── requirements.txt ├── scripts ├── download_coco.sh ├── gui │ ├── simple-server.py │ ├── model.py │ ├── index.css │ ├── index_panoptic.html │ ├── index.js │ └── index.html ├── inception_score.py ├── create_attributes_file.py ├── encode_features.py ├── train_accuracy_net.py └── sample_images.py ├── README.md ├── train.py └── LICENSE /scene_generation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /images/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashual/scene_generation/HEAD/images/arch.png -------------------------------------------------------------------------------- /images/scene_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashual/scene_generation/HEAD/images/scene_generation.png -------------------------------------------------------------------------------- /scene_generation/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import imagenet_preprocess, imagenet_deprocess 2 | from .utils import imagenet_deprocess_batch 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle==1.2.2 2 | imageio==2.5.0 3 | matplotlib==3.1.1 4 | numpy==1.17.2 5 | Pillow==6.2.0 6 | scikit-image==0.15.0 7 | scipy==1.3.1 8 | pytorch==1.0.0 9 | tensorboardX==1.8 10 | torchvision==0.2.2 11 | graphviz==2.40.1 -------------------------------------------------------------------------------- /scripts/download_coco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | COCO_DIR=datasets/coco 4 | mkdir -p $COCO_DIR 5 | 6 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O $COCO_DIR/annotations_trainval2017.zip 7 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip -O $COCO_DIR/stuff_annotations_trainval2017.zip 8 | wget http://images.cocodataset.org/zips/train2017.zip -O $COCO_DIR/train2017.zip 9 | wget http://images.cocodataset.org/zips/val2017.zip -O $COCO_DIR/val2017.zip 10 | 11 | unzip $COCO_DIR/annotations_trainval2017.zip -d $COCO_DIR 12 | unzip $COCO_DIR/stuff_annotations_trainval2017.zip -d $COCO_DIR 13 | unzip $COCO_DIR/train2017.zip -d $COCO_DIR/images 14 | unzip $COCO_DIR/val2017.zip -d $COCO_DIR/images 15 | -------------------------------------------------------------------------------- /scene_generation/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | 19 | 20 | def intersection(bbox_pred, bbox_gt): 21 | max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:]) 22 | min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2]) 23 | inter = torch.clamp((max_xy - min_xy), min=0) 24 | return inter[:, 0] * inter[:, 1] 25 | 26 | 27 | def jaccard(bbox_pred, bbox_gt): 28 | inter = intersection(bbox_pred, bbox_gt) 29 | area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] - 30 | bbox_pred[:, 1]) 31 | area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] - 32 | bbox_gt[:, 1]) 33 | union = area_pred + area_gt - inter 34 | iou = torch.div(inter, union) 35 | return torch.sum(iou), (iou > 0.5).sum().item(), (iou > 0.3).sum().item() 36 | -------------------------------------------------------------------------------- /scripts/gui/simple-server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | import os 4 | from http.server import HTTPServer, BaseHTTPRequestHandler 5 | from urllib.parse import unquote 6 | 7 | from scripts.gui.model import get_model, json_to_img 8 | 9 | model = get_model() 10 | 11 | 12 | class StaticServer(BaseHTTPRequestHandler): 13 | 14 | def do_GET(self): 15 | root = os.path.dirname(os.path.realpath(__file__)) 16 | path = unquote(self.path) 17 | if path == '/': 18 | filename = root + '/index.html' 19 | elif path.startswith('/get_data?'): 20 | img_path, layout_path = json_to_img(path.split('/get_data?data=')[1], model) 21 | self.send_response(200) 22 | self.send_header('Content-type', 'application/json') 23 | self.end_headers() 24 | self.wfile.write(str.encode(json.dumps({'img_pred': img_path, 'layout_pred': layout_path}))) 25 | return 26 | else: 27 | filename = root + self.path 28 | 29 | self.send_response(200) 30 | if filename[-4:] == '.css': 31 | self.send_header('Content-type', 'text/css') 32 | elif filename[-5:] == '.json': 33 | self.send_header('Content-type', 'application/javascript') 34 | elif filename[-3:] == '.js': 35 | self.send_header('Content-type', 'application/javascript') 36 | elif filename[-4:] == '.ico': 37 | return 38 | # self.send_header('Content-type', 'image/x-icon') 39 | else: 40 | self.send_header('Content-type', 'text/html') 41 | self.end_headers() 42 | with open(filename, 'rb') as fh: 43 | html = fh.read() 44 | # html = bytes(html, 'utf8') 45 | self.wfile.write(html) 46 | 47 | 48 | def run(server_class=HTTPServer, handler_class=StaticServer, port=8000): 49 | print('Loading model') 50 | server_address = ('', port) 51 | httpd = server_class(server_address, handler_class) 52 | print('Starting httpd on port {}'.format(port)) 53 | httpd.serve_forever() 54 | 55 | 56 | run() 57 | -------------------------------------------------------------------------------- /scene_generation/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import random 18 | 19 | import torch 20 | 21 | 22 | def int_tuple(s): 23 | return tuple(int(i) for i in s.split(',')) 24 | 25 | 26 | def float_tuple(s): 27 | return tuple(float(i) for i in s.split(',')) 28 | 29 | 30 | def str_tuple(s): 31 | return tuple(s.split(',')) 32 | 33 | 34 | def bool_flag(s): 35 | if s == '1': 36 | return True 37 | elif s == '0': 38 | return False 39 | msg = 'Invalid value "%s" for bool flag (should be 0 or 1)' 40 | raise ValueError(msg % s) 41 | 42 | 43 | class LossManager(object): 44 | def __init__(self): 45 | self.total_loss = None 46 | self.all_losses = {} 47 | 48 | def add_loss(self, loss, name, weight=1.0, use_loss=True): 49 | cur_loss = loss * weight 50 | if use_loss: 51 | if self.total_loss is not None: 52 | self.total_loss += cur_loss 53 | else: 54 | self.total_loss = cur_loss 55 | 56 | self.all_losses[name] = cur_loss.data.cpu().item() 57 | 58 | def items(self): 59 | return self.all_losses.items() 60 | 61 | 62 | class VectorPool: 63 | def __init__(self, pool_size): 64 | self.pool_size = pool_size 65 | self.vectors = {} 66 | 67 | def query(self, objs, vectors): 68 | if self.pool_size == 0: 69 | return vectors 70 | return_vectors = [] 71 | for obj, vector in zip(objs, vectors): 72 | obj = obj.item() 73 | vector = vector.cpu().clone().detach() 74 | if obj not in self.vectors: 75 | self.vectors[obj] = [] 76 | obj_pool_size = len(self.vectors[obj]) 77 | if obj_pool_size == 0: 78 | return_vectors.append(vector) 79 | self.vectors[obj].append(vector) 80 | elif obj_pool_size < self.pool_size: 81 | random_id = random.randint(0, obj_pool_size - 1) 82 | self.vectors[obj].append(vector) 83 | return_vectors.append(self.vectors[obj][random_id]) 84 | else: 85 | random_id = random.randint(0, obj_pool_size - 1) 86 | tmp = self.vectors[obj][random_id] 87 | self.vectors[obj][random_id] = vector 88 | return_vectors.append(tmp) 89 | return_vectors = torch.stack(return_vectors).to(vectors.device) 90 | return return_vectors 91 | -------------------------------------------------------------------------------- /scene_generation/data/utils.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import torch 3 | import numpy as np 4 | import torchvision.transforms as T 5 | 6 | MEAN = [0.5, 0.5, 0.5] 7 | STD = [0.5, 0.5, 0.5] 8 | 9 | INV_MEAN = [-m for m in MEAN] 10 | INV_STD = [1.0 / s for s in STD] 11 | 12 | 13 | def imagenet_preprocess(): 14 | return T.Normalize(mean=MEAN, std=STD) 15 | 16 | 17 | def rescale(x): 18 | lo, hi = x.min(), x.max() 19 | return x.sub(lo).div(hi - lo) 20 | 21 | 22 | def imagenet_deprocess(rescale_image=True): 23 | transforms = [ 24 | T.Normalize(mean=[0, 0, 0], std=INV_STD), 25 | T.Normalize(mean=INV_MEAN, std=[1.0, 1.0, 1.0]), 26 | ] 27 | if rescale_image: 28 | transforms.append(rescale) 29 | return T.Compose(transforms) 30 | 31 | 32 | def imagenet_deprocess_batch(imgs, rescale=True): 33 | """ 34 | Input: 35 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 36 | 37 | Output: 38 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 39 | in the range [0, 255] 40 | """ 41 | if isinstance(imgs, torch.autograd.Variable): 42 | imgs = imgs.data 43 | imgs = imgs.cpu().clone() 44 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 45 | imgs_de = [] 46 | for i in range(imgs.size(0)): 47 | img_de = deprocess_fn(imgs[i])[None] 48 | img_de = img_de.mul(255).clamp(0, 255) 49 | imgs_de.append(img_de) 50 | imgs_de = torch.cat(imgs_de, dim=0) 51 | return imgs_de 52 | 53 | 54 | class Resize(object): 55 | def __init__(self, size, interp=PIL.Image.BILINEAR): 56 | if isinstance(size, tuple): 57 | H, W = size 58 | self.size = (W, H) 59 | else: 60 | self.size = (size, size) 61 | self.interp = interp 62 | 63 | def __call__(self, img): 64 | return img.resize(self.size, self.interp) 65 | 66 | 67 | def unpack_var(v): 68 | if isinstance(v, torch.autograd.Variable): 69 | return v.data 70 | return v 71 | 72 | 73 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 74 | triples = unpack_var(triples) 75 | obj_data = [unpack_var(o) for o in obj_data] 76 | obj_to_img = unpack_var(obj_to_img) 77 | triple_to_img = unpack_var(triple_to_img) 78 | 79 | triples_out = [] 80 | obj_data_out = [[] for _ in obj_data] 81 | obj_offset = 0 82 | N = obj_to_img.max() + 1 83 | for i in range(N): 84 | o_idxs = (obj_to_img == i).nonzero().view(-1) 85 | t_idxs = (triple_to_img == i).nonzero().view(-1) 86 | 87 | cur_triples = triples[t_idxs].clone() 88 | cur_triples[:, 0] -= obj_offset 89 | cur_triples[:, 2] -= obj_offset 90 | triples_out.append(cur_triples) 91 | 92 | for j, o_data in enumerate(obj_data): 93 | cur_o_data = None 94 | if o_data is not None: 95 | cur_o_data = o_data[o_idxs] 96 | obj_data_out[j].append(cur_o_data) 97 | 98 | obj_offset += o_idxs.size(0) 99 | 100 | return triples_out, obj_data_out 101 | 102 | 103 | def rgb2id(color): 104 | if isinstance(color, np.ndarray) and len(color.shape) == 3: 105 | if color.dtype == np.uint8: 106 | color = color.astype(np.uint32) 107 | return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] 108 | return color[0] + 256 * color[1] + 256 * 256 * color[2] 109 | -------------------------------------------------------------------------------- /scene_generation/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from scene_generation.layers import GlobalAvgPool, build_cnn, ResnetBlock, get_norm_layer, Interpolate 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | m.weight.data.normal_(0.0, 0.02) 11 | elif classname.find('BatchNorm2d') != -1: 12 | m.weight.data.normal_(1.0, 0.02) 13 | m.bias.data.fill_(0) 14 | 15 | 16 | def mask_net(dim, mask_size): 17 | output_dim = 1 18 | layers, cur_size = [], 1 19 | while cur_size < mask_size: 20 | layers.append(Interpolate(scale_factor=2, mode='nearest')) 21 | layers.append(nn.Conv2d(dim, dim, kernel_size=3, padding=1)) 22 | layers.append(nn.BatchNorm2d(dim)) 23 | layers.append(nn.ReLU()) 24 | cur_size *= 2 25 | if cur_size != mask_size: 26 | raise ValueError('Mask size must be a power of 2') 27 | layers.append(nn.Conv2d(dim, output_dim, kernel_size=1)) 28 | return nn.Sequential(*layers) 29 | 30 | 31 | class AppearanceEncoder(nn.Module): 32 | def __init__(self, vocab, arch, normalization='none', activation='relu', 33 | padding='same', vecs_size=1024, pooling='avg'): 34 | super(AppearanceEncoder, self).__init__() 35 | self.vocab = vocab 36 | 37 | cnn_kwargs = { 38 | 'arch': arch, 39 | 'normalization': normalization, 40 | 'activation': activation, 41 | 'pooling': pooling, 42 | 'padding': padding, 43 | } 44 | cnn, channels = build_cnn(**cnn_kwargs) 45 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(channels, vecs_size)) 46 | 47 | def forward(self, crops): 48 | return self.cnn(crops) 49 | 50 | 51 | def define_G(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=9, norm='instance'): 52 | norm_layer = get_norm_layer(norm_type=norm) 53 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) 54 | assert (torch.cuda.is_available()) 55 | netG.cuda() 56 | netG.apply(weights_init) 57 | return netG 58 | 59 | ############################################################################## 60 | # Generator 61 | ############################################################################## 62 | class GlobalGenerator(nn.Module): 63 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 64 | padding_type='reflect'): 65 | assert (n_blocks >= 0) 66 | super(GlobalGenerator, self).__init__() 67 | activation = nn.ReLU(True) 68 | 69 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] 70 | ### downsample 71 | for i in range(n_downsampling): 72 | mult = 2 ** i 73 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 74 | norm_layer(ngf * mult * 2), activation] 75 | 76 | ### resnet blocks 77 | mult = 2 ** n_downsampling 78 | for i in range(n_blocks): 79 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] 80 | 81 | ### upsample 82 | for i in range(n_downsampling): 83 | mult = 2 ** (n_downsampling - i) 84 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, 85 | output_padding=1), 86 | norm_layer(int(ngf * mult / 2)), activation] 87 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 88 | self.model = nn.Sequential(*model) 89 | 90 | def forward(self, input): 91 | return self.model(input) -------------------------------------------------------------------------------- /scripts/inception_score.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | import torchvision.datasets as dset 7 | import torchvision.transforms as transforms 8 | from scipy.stats import entropy 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torchvision.models.inception import inception_v3 12 | from scene_generation.layers import Interpolate 13 | 14 | 15 | class InceptionScore(nn.Module): 16 | def __init__(self, cuda=True, batch_size=32, resize=False): 17 | super(InceptionScore, self).__init__() 18 | assert batch_size > 0 19 | self.resize = resize 20 | self.batch_size = batch_size 21 | self.cuda = cuda 22 | # Set up dtype 23 | self.device = 'cuda' if cuda else 'cpu' 24 | if not cuda and torch.cuda.is_available(): 25 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 26 | 27 | # Load inception model 28 | self.inception_model = inception_v3(pretrained=True, transform_input=False).to(self.device) 29 | self.inception_model.eval() 30 | self.up = Interpolate(size=(299, 299), mode='bilinear').to(self.device) 31 | self.clean() 32 | 33 | def clean(self): 34 | self.preds = np.zeros((0, 1000)) 35 | 36 | def get_pred(self, x): 37 | if self.resize: 38 | x = self.up(x) 39 | x = self.inception_model(x) 40 | return F.softmax(x, dim=1).data.cpu().numpy() 41 | 42 | def forward(self, imgs): 43 | # Get predictions 44 | preds_imgs = self.get_pred(imgs.to(self.device)) 45 | self.preds = np.append(self.preds, preds_imgs, axis=0) 46 | 47 | def compute_score(self, splits=1): 48 | # Now compute the mean kl-div 49 | split_scores = [] 50 | preds = self.preds 51 | N = self.preds.shape[0] 52 | for k in range(splits): 53 | part = preds[k * (N // splits): (k + 1) * (N // splits), :] 54 | py = np.mean(part, axis=0) 55 | scores = [] 56 | for i in range(part.shape[0]): 57 | pyx = part[i, :] 58 | scores.append(entropy(pyx, py)) 59 | split_scores.append(np.exp(np.mean(scores))) 60 | 61 | return np.mean(split_scores), np.std(split_scores) 62 | 63 | 64 | # def inception_score(): 65 | # """Computes the inception score of the generated images imgs 66 | # 67 | # imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 68 | # cuda -- whether or not to run on GPU 69 | # batch_size -- batch size for feeding into Inception v3 70 | # splits -- number of splits 71 | # """ 72 | 73 | 74 | if __name__ == '__main__': 75 | class IgnoreLabelDataset(torch.utils.data.Dataset): 76 | def __init__(self, orig): 77 | self.orig = orig 78 | 79 | def __getitem__(self, index): 80 | return self.orig[index][0] 81 | 82 | def __len__(self): 83 | return len(self.orig) 84 | 85 | 86 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 87 | parser.add_argument('--dir', default='', type=str) 88 | parser.add_argument('--splits', default=1, type=int) 89 | args = parser.parse_args() 90 | imagenet_data = dset.ImageFolder(args.dir, transform=transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 93 | ])) 94 | 95 | # data_loader = torch.utils.data.DataLoader(imagenet_data, 96 | # batch_size=4, 97 | # shuffle=False, 98 | # num_workers=4) 99 | 100 | print("Calculating Inception Score...") 101 | # print(inception_score(imagenet_data, cuda=True, batch_size=32, resize=True, splits=args.splits)) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Specifying Object Attributes and Relations in Interactive Scene Generation 2 | A PyTorch implementation of the paper [Specifying Object Attributes and Relations in Interactive Scene Generation](https://arxiv.org/abs/1909.05379) 3 |

4 | 5 | ## Paper 6 | [Specifying Object Attributes and Relations in Interactive Scene Generation](https://arxiv.org/abs/1909.05379) 7 |
8 | [Oron Ashual](https://www.linkedin.com/in/oronashual/)1, [Lior Wolf](https://www.cs.tau.ac.il/~wolf/)1,2
9 | 1 Tel-Aviv University, 2 Facebook AI Research
10 | The IEEE International Conference on Computer Vision ([ICCV](http://iccv2019.thecvf.com/)), 2019, (Oral) 11 | 12 | ## Network Architechture 13 |

14 | 15 | ## Youtube 16 |
17 | paper_video 18 |
19 | 20 | ## Usage 21 | 22 | ### 1. Create a virtual environment (optional) 23 | All code was developed and tested on Ubuntu 18.04 with Python 3.6 (Anaconda) and PyTorch 1.0. 24 | 25 | ```bash 26 | conda create -n scene_generation python=3.7 27 | conda activate scene_generation 28 | ``` 29 | ### 2. Clone the repository 30 | ```bash 31 | cd ~ 32 | git clone https://github.com/ashual/scene_generation.git 33 | cd scene_generation 34 | ``` 35 | 36 | ### 3. Install dependencies 37 | ```bash 38 | conda install --file requirements.txt -c conda-forge -c pytorch 39 | ``` 40 | * install a PyTorch version which will fit your CUDA TOOLKIT 41 | 42 | ### 4. Install COCO API 43 | Note: we didn't train our models with COCO panoptic dataset, the coco_panoptic.py code is for the sake of the community only. 44 | ```bash 45 | cd ~ 46 | git clone https://github.com/cocodataset/cocoapi.git 47 | cd cocoapi/PythonAPI/ 48 | python setup.py install 49 | cd ~/scene_generation 50 | ``` 51 | 52 | ### 5. Train 53 | ```bash 54 | $ python train.py 55 | ``` 56 | 57 | ### 6. Encode the Appearance attributes 58 | ```bash 59 | python scripts/encode_features --checkpoint TRAINED_MODEL_CHECKPOINT 60 | ``` 61 | 62 | ### 7. Sample Images 63 | ```bash 64 | python scripts/sample_images.py --checkpoint TRAINED_MODEL_CHECKPOINT --batch_size 32 --output_dir OUTPUT_DIR 65 | ``` 66 | 67 | ### 8. or Download trained models 68 | Download [these](https://drive.google.com/drive/folders/1_E56YskDXdmq06FRsIiPAedpBovYOO8X?usp=sharing) files into models/ 69 | 70 | 71 | ### 9. Play with the GUI 72 | The GUI was built as POC. Use it at your own risk: 73 | ```bash 74 | python scripts/gui/simple-server.py --checkpoint YOUR_MODEL_CHECKPOINT --output_dir [DIR_NAME] --draw_scene_graphs 0 75 | ``` 76 | 77 | ### 10. Results 78 | Results were measured by sample images from the validation set and then running these 3 official scripts: 79 | 1. FID - https://github.com/bioinf-jku/TTUR (Tensorflow implementation) 80 | 2. Inception - https://github.com/openai/improved-gan/blob/master/inception_score/model.py (Tensorflow implementation) 81 | 3. Diversity - https://github.com/richzhang/PerceptualSimilarity (Pytorch implementation) 82 | 4. Accuracy - Training code is attached `train_accuracy_net.py`. A trained model is provided. Adding the argument `--accuracy_model_path MODEL_PATH` will output the accuracy of the objects. 83 | 84 | #### Reproduce the comparison figure (Figure 3.) 85 | Run this command 86 | ```bash 87 | $ python scripts/sample_images.py --checkpoint TRAINED_MODEL_CHECKPOINT --output_dir OUTPUT_DIR 88 | ``` 89 | 90 | with these arguments: 91 | * (c) - Ground truth layout: --use_gt_boxes 1 --use_gt_masks 1 92 | * (d) - Ground truth location attributes: --use_gt_attr 1 93 | * (e) - Ground truth appearance attributes: --use_gt_textures 1 94 | * (f) - Scene Graph only - No extra attributes needed 95 | 96 | ## Citation 97 | 98 | If you find this code useful in your research then please cite 99 | ``` 100 | @InProceedings{Ashual_2019_ICCV, 101 | author = {Ashual, Oron and Wolf, Lior}, 102 | title = {Specifying Object Attributes and Relations in Interactive Scene Generation}, 103 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 104 | month = {October}, 105 | year = {2019} 106 | } 107 | ``` 108 | 109 | ## Acknowledgement 110 | Our project borrows some source files from [sg2im](https://github.com/google/sg2im). We thank the authors. -------------------------------------------------------------------------------- /scripts/create_attributes_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | from torch.utils.data import DataLoader 5 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn 6 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn 7 | from scene_generation.utils import int_tuple, bool_flag 8 | from scene_generation.utils import str_tuple 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | # Shared dataset options 14 | parser.add_argument('--dataset', default='coco') 15 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple) 16 | parser.add_argument('--min_object_size', default=0.02, type=float) 17 | parser.add_argument('--min_objects_per_image', default=3, type=int) 18 | parser.add_argument('--max_objects_per_image', default=8, type=int) 19 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple) 20 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple) 21 | parser.add_argument('--coco_include_other', default=False, type=bool_flag) 22 | parser.add_argument('--is_panoptic', default=False, type=bool_flag) 23 | parser.add_argument('--batch_size', default=24, type=int) 24 | parser.add_argument('--shuffle', default=False, type=bool_flag) 25 | parser.add_argument('--loader_num_workers', default=4, type=int) 26 | parser.add_argument('--num_samples', default=None, type=int) 27 | parser.add_argument('--object_size', default=64, type=int) 28 | parser.add_argument('--grid_size', default=25, type=int) 29 | parser.add_argument('--size_attribute_len', default=10, type=int) 30 | 31 | parser.add_argument('--output_dir', default='models') 32 | 33 | 34 | COCO_DIR = os.path.expanduser('datasets/coco') 35 | parser.add_argument('--coco_image_dir', 36 | default=os.path.join(COCO_DIR, 'images/train2017')) 37 | parser.add_argument('--instances_json', 38 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json')) 39 | parser.add_argument('--stuff_json', 40 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json')) 41 | parser.add_argument('--coco_panoptic_train', default=os.path.join(COCO_DIR, 'annotations/panoptic_train2017.json')) 42 | parser.add_argument('--coco_panoptic_segmentation_train', 43 | default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_train2017')) 44 | 45 | 46 | 47 | def build_coco_dset(args): 48 | dset_kwargs = { 49 | 'image_dir': args.coco_image_dir, 50 | 'instances_json': args.instances_json, 51 | 'stuff_json': args.stuff_json, 52 | 'image_size': args.image_size, 53 | 'mask_size': 32, 54 | 'max_samples': args.num_samples, 55 | 'min_object_size': args.min_object_size, 56 | 'min_objects_per_image': args.min_objects_per_image, 57 | 'max_objects_per_image': args.max_objects_per_image, 58 | 'instance_whitelist': args.instance_whitelist, 59 | 'stuff_whitelist': args.stuff_whitelist, 60 | 'include_other': args.coco_include_other, 61 | 'test_part': False, 62 | 'sample_attributes': False, 63 | 'grid_size': args.grid_size 64 | } 65 | dset = CocoSceneGraphDataset(**dset_kwargs) 66 | return dset 67 | 68 | 69 | def build_coco_panoptic_dset(args): 70 | dset_kwargs = { 71 | 'image_dir': args.coco_image_dir, 72 | 'instances_json': args.instances_json, 73 | 'panoptic': args.coco_panoptic_train, 74 | 'panoptic_segmentation': args.coco_panoptic_segmentation_train, 75 | 'stuff_json': args.stuff_json, 76 | 'image_size': args.image_size, 77 | 'mask_size': 32, 78 | 'max_samples': args.num_samples, 79 | 'min_object_size': args.min_object_size, 80 | 'min_objects_per_image': args.min_objects_per_image, 81 | 'max_objects_per_image': args.max_objects_per_image, 82 | 'instance_whitelist': args.instance_whitelist, 83 | 'stuff_whitelist': args.stuff_whitelist, 84 | 'include_other': args.coco_include_other, 85 | 'test_part': False, 86 | 'sample_attributes': args.sample_attributes, 87 | 'grid_size': args.grid_size 88 | } 89 | dset = CocoPanopticSceneGraphDataset(**dset_kwargs) 90 | return dset 91 | 92 | 93 | def build_loader(args, is_panoptic): 94 | if is_panoptic: 95 | dset = build_coco_panoptic_dset(args) 96 | collate_fn = coco_panoptic_collate_fn 97 | else: 98 | dset = build_coco_dset(args) 99 | collate_fn = coco_collate_fn 100 | 101 | loader_kwargs = { 102 | 'batch_size': args.batch_size, 103 | 'num_workers': args.loader_num_workers, 104 | 'shuffle': args.shuffle, 105 | 'collate_fn': collate_fn, 106 | } 107 | loader = DataLoader(dset, **loader_kwargs) 108 | return loader 109 | 110 | 111 | if __name__ == '__main__': 112 | args = parser.parse_args() 113 | 114 | print('Loading dataset') 115 | loader = build_loader(args, args.is_panoptic) 116 | vocab = loader.dataset.vocab 117 | idx_to_name = vocab['my_idx_to_obj'] 118 | sample_attributes = {'location': {}, 'size': {}} 119 | for obj_name in idx_to_name: 120 | sample_attributes['location'][obj_name] = [0] * args.grid_size 121 | sample_attributes['size'][obj_name] = [0] * args.size_attribute_len 122 | 123 | print('Iterating objects') 124 | for _, objs, _, _, _, _, _, attributes in loader: 125 | for obj, attribute in zip(objs, attributes): 126 | obj = obj.item() 127 | if obj == 0: 128 | continue 129 | obj_name = idx_to_name[obj - 1] 130 | size_index = attribute.int().tolist()[:args.size_attribute_len].index(1) 131 | location_index = attribute.int().tolist()[args.size_attribute_len:].index(1) 132 | sample_attributes['size'][obj_name][size_index] += 1 133 | sample_attributes['location'][obj_name][location_index] += 1 134 | 135 | attributes_file = './models/attributes_{}_{}.pickle'.format(args.size_attribute_len, args.grid_size) 136 | print('Saving attributes file to {}'.format(attributes_file)) 137 | pickle.dump(sample_attributes, open(attributes_file, 'wb')) 138 | -------------------------------------------------------------------------------- /scene_generation/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import socket 4 | from datetime import datetime 5 | 6 | from scene_generation.utils import int_tuple, str_tuple, bool_flag 7 | 8 | COCO_DIR = os.path.expanduser('datasets/coco') 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | # Optimization hyperparameters 13 | parser.add_argument('--batch_size', default=12, type=int) 14 | parser.add_argument('--num_iterations', default=1000000, type=int) 15 | parser.add_argument('--learning_rate', default=1e-4, type=float) 16 | parser.add_argument('--mask_learning_rate', default=1e-5, type=float) 17 | 18 | # Dataset options 19 | parser.add_argument('--image_size', default='128,128', type=int_tuple) 20 | parser.add_argument('--num_train_samples', default=None, type=int) 21 | parser.add_argument('--num_val_samples', default=1024, type=int) 22 | parser.add_argument('--shuffle_val', default=True, type=bool_flag) 23 | parser.add_argument('--loader_num_workers', default=4, type=int) 24 | parser.add_argument('--coco_train_image_dir', 25 | default=os.path.join(COCO_DIR, 'images/train2017')) 26 | parser.add_argument('--coco_val_image_dir', 27 | default=os.path.join(COCO_DIR, 'images/val2017')) 28 | parser.add_argument('--coco_train_instances_json', 29 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json')) 30 | parser.add_argument('--coco_train_stuff_json', 31 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json')) 32 | parser.add_argument('--coco_val_instances_json', 33 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json')) 34 | parser.add_argument('--coco_val_stuff_json', 35 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json')) 36 | parser.add_argument('--coco_panoptic_train', default=os.path.join(COCO_DIR, 'annotations/panoptic_train2017.json')) 37 | parser.add_argument('--coco_panoptic_val', default=os.path.join(COCO_DIR, 'annotations/panoptic_val2017.json')) 38 | parser.add_argument('--coco_panoptic_segmentation_train', default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_train2017')) 39 | parser.add_argument('--coco_panoptic_segmentation_val', default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_val2017')) 40 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple) 41 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple) 42 | parser.add_argument('--coco_include_other', default=False, type=bool_flag) 43 | parser.add_argument('--min_object_size', default=0.02, type=float) 44 | parser.add_argument('--min_objects_per_image', default=3, type=int) 45 | parser.add_argument('--max_objects_per_image', default=8, type=int) 46 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag) # Train over images that have at least one stuff 47 | parser.add_argument('--is_panoptic', default=False, type=bool_flag) 48 | 49 | # Generator options 50 | parser.add_argument('--mask_size', default=32, type=int) 51 | parser.add_argument('--embedding_dim', default=128, type=int) 52 | parser.add_argument('--gconv_dim', default=128, type=int) 53 | parser.add_argument('--gconv_hidden_dim', default=512, type=int) 54 | parser.add_argument('--gconv_num_layers', default=5, type=int) 55 | parser.add_argument('--mlp_normalization', default='none', type=str) 56 | parser.add_argument('--activation', default='leakyrelu-0.2') 57 | parser.add_argument('--pool_size', default=100, type=int) 58 | parser.add_argument('--output_nc', default=3, type=int) 59 | parser.add_argument('--n_downsample_global', default=4, type=int) 60 | parser.add_argument('--box_dim', default=128, type=int) 61 | parser.add_argument('--use_attributes', default=True, type=bool_flag) 62 | parser.add_argument('--beta1', default=0.5, type=float) 63 | parser.add_argument('--box_noise_dim', default=64, type=int) 64 | parser.add_argument('--mask_noise_dim', default=64, type=int) 65 | 66 | # Appearance Generator options 67 | parser.add_argument('--rep_size', default=32, type=int) 68 | parser.add_argument('--appearance_normalization', default='batch') 69 | 70 | # Generator losses 71 | parser.add_argument('--l1_pixel_loss_weight', default=.0, type=float) 72 | parser.add_argument('--bbox_pred_loss_weight', default=10, type=float) 73 | parser.add_argument('--vgg_features_weight', default=10.0, type=float) 74 | parser.add_argument('--d_img_weight', default=1.0, type=float) 75 | parser.add_argument('--d_img_features_weight', default=10.0, type=float) 76 | parser.add_argument('--d_mask_weight', default=1.0, type=float) 77 | parser.add_argument('--d_mask_features_weight', default=10.0, type=float) 78 | parser.add_argument('--d_obj_weight', default=0.1, type=float) 79 | parser.add_argument('--ac_loss_weight', default=0.1, type=float) 80 | 81 | # Image discriminator 82 | parser.add_argument('--ndf', default=64, type=int) 83 | parser.add_argument('--num_D', default=2, type=int) 84 | parser.add_argument('--norm_D', default='instance', type=str) 85 | parser.add_argument('--n_layers_D', default=3, type=int) 86 | parser.add_argument('--no_lsgan', default=False, type=bool_flag) # Default is LSGAN (no_lsgan == False) 87 | 88 | # Mask Discriminator 89 | parser.add_argument('--ndf_mask', default=64, type=int) 90 | parser.add_argument('--num_D_mask', default=1, type=int) 91 | parser.add_argument('--norm_D_mask', default='instance', type=str) 92 | parser.add_argument('--n_layers_D_mask', default=2, type=int) 93 | 94 | # Object discriminator 95 | parser.add_argument('--gan_loss_type', default='gan') 96 | parser.add_argument('--d_normalization', default='batch') 97 | parser.add_argument('--d_padding', default='valid') 98 | parser.add_argument('--d_activation', default='leakyrelu-0.2') 99 | parser.add_argument('--d_obj_arch', default='C4-64-2,C4-128-2,C4-256-2') 100 | parser.add_argument('--crop_size', default=32, type=int) 101 | 102 | # Output options 103 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 104 | log_dir = os.path.join(os.getcwd(), 'output', current_time + '_' + socket.gethostname()) 105 | parser.add_argument('--print_every', default=100, type=int) 106 | parser.add_argument('--checkpoint_every', default=10000, type=int) 107 | parser.add_argument('--output_dir', default=log_dir) 108 | parser.add_argument('--checkpoint_name', default='checkpoint') 109 | parser.add_argument('--restore_from_checkpoint', default=False, type=bool_flag) 110 | 111 | 112 | def get_args(): 113 | return parser.parse_args() 114 | -------------------------------------------------------------------------------- /scripts/encode_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.cluster import KMeans 7 | from sklearn.manifold import TSNE 8 | from torch.utils.data import DataLoader 9 | 10 | from scene_generation.bilinear import crop_bbox_batch 11 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn 12 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn 13 | from scene_generation.model import Model 14 | from scene_generation.utils import int_tuple, bool_flag 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--checkpoint', required=True) 18 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval']) 19 | 20 | # Shared dataset options 21 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple) 22 | parser.add_argument('--batch_size', default=64, type=int) 23 | parser.add_argument('--shuffle', default=False, type=bool_flag) 24 | parser.add_argument('--loader_num_workers', default=4, type=int) 25 | parser.add_argument('--num_samples', default=None, type=int) 26 | parser.add_argument('--object_size', default=64, type=int) 27 | 28 | # For COCO 29 | COCO_DIR = os.path.expanduser('~/data3/data/coco') 30 | parser.add_argument('--coco_image_dir', default=os.path.join(COCO_DIR, 'images/train2017')) 31 | parser.add_argument('--instances_json', default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json')) 32 | parser.add_argument('--stuff_json', default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json')) 33 | 34 | 35 | def build_coco_dset(args, checkpoint): 36 | checkpoint_args = checkpoint['args'] 37 | print('include other: ', checkpoint_args.get('coco_include_other')) 38 | dset_kwargs = { 39 | 'image_dir': args.coco_image_dir, 40 | 'instances_json': args.instances_json, 41 | 'stuff_json': args.stuff_json, 42 | 'image_size': args.image_size, 43 | 'mask_size': checkpoint_args['mask_size'], 44 | 'max_samples': args.num_samples, 45 | 'min_object_size': checkpoint_args['min_object_size'], 46 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'], 47 | 'instance_whitelist': checkpoint_args['instance_whitelist'], 48 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'], 49 | 'include_other': checkpoint_args.get('coco_include_other', True), 50 | } 51 | dset = CocoSceneGraphDataset(**dset_kwargs) 52 | return dset 53 | 54 | 55 | def build_model(args, checkpoint): 56 | kwargs = checkpoint['model_kwargs'] 57 | model = Model(**kwargs) 58 | model.load_state_dict(checkpoint['model_state']) 59 | if args.model_mode == 'eval': 60 | model.eval() 61 | elif args.model_mode == 'train': 62 | model.train() 63 | model.image_size = args.image_size 64 | model.cuda() 65 | return model 66 | 67 | 68 | def build_loader(args, checkpoint): 69 | dset = build_coco_dset(args, checkpoint) 70 | collate_fn = coco_collate_fn 71 | 72 | loader_kwargs = { 73 | 'batch_size': args.batch_size, 74 | 'num_workers': args.loader_num_workers, 75 | 'shuffle': args.shuffle, 76 | 'collate_fn': collate_fn, 77 | } 78 | loader = DataLoader(dset, **loader_kwargs) 79 | return loader 80 | 81 | 82 | def cluster(features, num_objs, n_clusters, save_path): 83 | name = 'features' 84 | centers = {} 85 | for label in range(num_objs): 86 | feat = features[label] 87 | if feat.shape[0]: 88 | n_feat_clusters = min(feat.shape[0], n_clusters) 89 | if n_feat_clusters < n_clusters: 90 | print(label) 91 | kmeans = KMeans(n_clusters=n_feat_clusters, random_state=0).fit(feat) 92 | if n_feat_clusters == 1: 93 | centers[label] = kmeans.cluster_centers_ 94 | else: 95 | one_dimension_centers = TSNE(n_components=1).fit_transform(kmeans.cluster_centers_) 96 | args = np.argsort(one_dimension_centers.reshape(-1)) 97 | centers[label] = kmeans.cluster_centers_[args] 98 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % n_clusters) 99 | np.save(save_name, centers) 100 | print('saving to %s' % save_name) 101 | 102 | 103 | def main(opt): 104 | name = 'features' 105 | checkpoint = torch.load(opt.checkpoint) 106 | rep_size = checkpoint['model_kwargs']['rep_size'] 107 | vocab = checkpoint['model_kwargs']['vocab'] 108 | num_objs = len(vocab['object_to_idx']) 109 | model = build_model(opt, checkpoint) 110 | loader = build_loader(opt, checkpoint) 111 | 112 | save_path = os.path.dirname(opt.checkpoint) 113 | 114 | ########### Encode features ########### 115 | counter = 0 116 | max_counter = 1000000000 117 | print('begin') 118 | with torch.no_grad(): 119 | features = {} 120 | for label in range(num_objs): 121 | features[label] = np.zeros((0, rep_size)) 122 | for i, data in enumerate(loader): 123 | if counter >= max_counter: 124 | break 125 | imgs = data[0].cuda() 126 | objs = data[1] 127 | objs = [j.item() for j in objs] 128 | boxes = data[2].cuda() 129 | obj_to_img = data[5].cuda() 130 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, opt.object_size) 131 | feat = model.repr_net(model.image_encoder(crops)).cpu() 132 | for ind, label in enumerate(objs): 133 | features[label] = np.append(features[label], feat[ind].view(1, -1), axis=0) 134 | counter += len(objs) 135 | 136 | # print('%d / %d images' % (i + 1, dataset_size)) 137 | save_name = os.path.join(save_path, name + '.npy') 138 | np.save(save_name, features) 139 | 140 | ############## Clustering ########### 141 | print('begin clustering') 142 | load_name = os.path.join(save_path, name + '.npy') 143 | features = np.load(load_name).item() 144 | cluster(features, num_objs, 100, save_path) 145 | cluster(features, num_objs, 10, save_path) 146 | cluster(features, num_objs, 1, save_path) 147 | 148 | 149 | if __name__ == '__main__': 150 | opt = parser.parse_args() 151 | opt.object_size = 64 152 | main(opt) 153 | -------------------------------------------------------------------------------- /scene_generation/graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from scene_generation.layers import build_mlp 21 | 22 | """ 23 | PyTorch modules for dealing with graphs. 24 | """ 25 | 26 | 27 | def _init_weights(module): 28 | if hasattr(module, 'weight'): 29 | if isinstance(module, nn.Linear): 30 | nn.init.kaiming_normal_(module.weight) 31 | 32 | 33 | class GraphTripleConv(nn.Module): 34 | """ 35 | A single layer of scene graph convolution. 36 | """ 37 | 38 | def __init__(self, input_dim, attributes_dim=0, output_dim=None, hidden_dim=512, 39 | pooling='avg', mlp_normalization='none'): 40 | super(GraphTripleConv, self).__init__() 41 | if output_dim is None: 42 | output_dim = input_dim 43 | self.input_dim = input_dim 44 | self.output_dim = output_dim 45 | self.hidden_dim = hidden_dim 46 | 47 | assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling 48 | self.pooling = pooling 49 | net1_layers = [3 * input_dim + 2 * attributes_dim, hidden_dim, 2 * hidden_dim + output_dim] 50 | net1_layers = [l for l in net1_layers if l is not None] 51 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization) 52 | self.net1.apply(_init_weights) 53 | 54 | net2_layers = [hidden_dim, hidden_dim, output_dim] 55 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization) 56 | self.net2.apply(_init_weights) 57 | 58 | def forward(self, obj_vecs, pred_vecs, edges): 59 | """ 60 | Inputs: 61 | - obj_vecs: FloatTensor of shape (O, D) giving vectors for all objects 62 | - pred_vecs: FloatTensor of shape (T, D) giving vectors for all predicates 63 | - edges: LongTensor of shape (T, 2) where edges[k] = [i, j] indicates the 64 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]] 65 | 66 | Outputs: 67 | - new_obj_vecs: FloatTensor of shape (O, D) giving new vectors for objects 68 | - new_pred_vecs: FloatTensor of shape (T, D) giving new vectors for predicates 69 | """ 70 | dtype, device = obj_vecs.dtype, obj_vecs.device 71 | O, T = obj_vecs.size(0), pred_vecs.size(0) 72 | Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim 73 | 74 | # Break apart indices for subjects and objects; these have shape (T,) 75 | s_idx = edges[:, 0].contiguous() 76 | o_idx = edges[:, 1].contiguous() 77 | 78 | # Get current vectors for subjects and objects; these have shape (T, Din) 79 | cur_s_vecs = obj_vecs[s_idx] 80 | cur_o_vecs = obj_vecs[o_idx] 81 | 82 | # Get current vectors for triples; shape is (T, 3 * Din) 83 | # Pass through net1 to get new triple vecs; shape is (T, 2 * H + Dout) 84 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1) 85 | new_t_vecs = self.net1(cur_t_vecs) 86 | 87 | # Break apart into new s, p, and o vecs; s and o vecs have shape (T, H) and 88 | # p vecs have shape (T, Dout) 89 | new_s_vecs = new_t_vecs[:, :H] 90 | new_p_vecs = new_t_vecs[:, H:(H + Dout)] 91 | new_o_vecs = new_t_vecs[:, (H + Dout):(2 * H + Dout)] 92 | 93 | # Allocate space for pooled object vectors of shape (O, H) 94 | pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device) 95 | 96 | # Use scatter_add to sum vectors for objects that appear in multiple triples; 97 | # we first need to expand the indices to have shape (T, D) 98 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs) 99 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs) 100 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs) 101 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs) 102 | 103 | if self.pooling == 'avg': 104 | # Figure out how many times each object has appeared, again using 105 | # some scatter_add trickery. 106 | obj_counts = torch.zeros(O, dtype=dtype, device=device) 107 | ones = torch.ones(T, dtype=dtype, device=device) 108 | obj_counts = obj_counts.scatter_add(0, s_idx, ones) 109 | obj_counts = obj_counts.scatter_add(0, o_idx, ones) 110 | 111 | # Divide the new object vectors by the number of times they 112 | # appeared, but first clamp at 1 to avoid dividing by zero; 113 | # objects that appear in no triples will have output vector 0 114 | # so this will not affect them. 115 | obj_counts = obj_counts.clamp(min=1) 116 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1) 117 | 118 | # Send pooled object vectors through net2 to get output object vectors, 119 | # of shape (O, Dout) 120 | new_obj_vecs = self.net2(pooled_obj_vecs) 121 | 122 | return new_obj_vecs, new_p_vecs 123 | 124 | 125 | class GraphTripleConvNet(nn.Module): 126 | """ A sequence of scene graph convolution layers """ 127 | 128 | def __init__(self, input_dim, num_layers=5, hidden_dim=512, pooling='avg', 129 | mlp_normalization='none'): 130 | super(GraphTripleConvNet, self).__init__() 131 | 132 | self.num_layers = num_layers 133 | self.gconvs = nn.ModuleList() 134 | gconv_kwargs = { 135 | 'input_dim': input_dim, 136 | 'hidden_dim': hidden_dim, 137 | 'pooling': pooling, 138 | 'mlp_normalization': mlp_normalization, 139 | } 140 | for _ in range(self.num_layers): 141 | self.gconvs.append(GraphTripleConv(**gconv_kwargs)) 142 | 143 | def forward(self, obj_vecs, pred_vecs, edges): 144 | for i in range(self.num_layers): 145 | gconv = self.gconvs[i] 146 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges) 147 | return obj_vecs, pred_vecs 148 | -------------------------------------------------------------------------------- /scripts/gui/model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | from datetime import datetime 6 | 7 | import numpy as np 8 | import torch 9 | from imageio import imwrite 10 | 11 | import scene_generation.vis as vis 12 | from scene_generation.data.utils import imagenet_deprocess_batch 13 | from scene_generation.model import Model 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--checkpoint', required=True) 17 | parser.add_argument('--output_dir', default='outputs') 18 | parser.add_argument('--draw_scene_graphs', type=int, default=0) 19 | parser.add_argument('--device', default='gpu', choices=['cpu', 'gpu']) 20 | args = parser.parse_args() 21 | 22 | 23 | def get_model(): 24 | if not os.path.isfile(args.checkpoint): 25 | print('ERROR: Checkpoint file "%s" not found' % args.checkpoint) 26 | print('Maybe you forgot to download pretraind models? Try running:') 27 | print('bash scripts/download_models.sh') 28 | return 29 | 30 | output_dir = os.path.join('scripts', 'gui', 'images', args.output_dir) 31 | if not os.path.isdir(output_dir): 32 | print('Output directory "%s" does not exist; creating it' % args.output_dir) 33 | os.makedirs(output_dir) 34 | 35 | if args.device == 'cpu': 36 | device = torch.device('cpu') 37 | elif args.device == 'gpu': 38 | device = torch.device('cuda:0') 39 | if not torch.cuda.is_available(): 40 | print('WARNING: CUDA not available; falling back to CPU') 41 | device = torch.device('cpu') 42 | 43 | # Load the model, with a bit of care in case there are no GPUs 44 | map_location = 'cpu' if device == torch.device('cpu') else None 45 | checkpoint = torch.load(args.checkpoint, map_location=map_location) 46 | dirname = os.path.dirname(args.checkpoint) 47 | features_path = os.path.join(dirname, 'features_clustered_100.npy') 48 | features_path_one = os.path.join(dirname, 'features_clustered_001.npy') 49 | features = np.load(features_path, allow_pickle=True).item() 50 | features_one = np.load(features_path_one, allow_pickle=True).item() 51 | model = Model(**checkpoint['model_kwargs']) 52 | model_state = checkpoint['model_state'] 53 | model.load_state_dict(model_state) 54 | model.features = features 55 | model.features_one = features_one 56 | model.colors = torch.randint(0, 256, [172, 3]).float() 57 | model.colors[0, :] = 256 58 | model.eval() 59 | model.to(device) 60 | return model 61 | 62 | 63 | def json_to_img(scene_graph, model): 64 | output_dir = args.output_dir 65 | scene_graphs = json_to_scene_graph(scene_graph) 66 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 67 | 68 | # Run the model forward 69 | with torch.no_grad(): 70 | (imgs, boxes_pred, masks_pred, layout, layout_pred, _), objs = model.forward_json(scene_graphs) 71 | imgs = imagenet_deprocess_batch(imgs) 72 | 73 | # Save the generated image 74 | for i in range(imgs.shape[0]): 75 | img_np = imgs[i].numpy().transpose(1, 2, 0).astype('uint8') 76 | img_path = os.path.join('scripts', 'gui', 'images', output_dir, 'img{}.png'.format(current_time)) 77 | imwrite(img_path, img_np) 78 | return_img_path = os.path.join('images', output_dir, 'img{}.png'.format(current_time)) 79 | 80 | # Save the generated layout image 81 | for i in range(imgs.shape[0]): 82 | img_layout_np = one_hot_to_rgb(layout_pred[:, :172, :, :], model.colors)[0].numpy().transpose(1, 2, 0).astype( 83 | 'uint8') 84 | obj_colors = [] 85 | for obj in objs[:-1]: 86 | new_color = torch.cat([model.colors[obj] / 256, torch.ones(1)]) 87 | obj_colors.append(new_color) 88 | 89 | img_layout_path = os.path.join('scripts', 'gui', 'images', output_dir, 'img_layout{}.png'.format(current_time)) 90 | vis.add_boxes_to_layout(img_layout_np, scene_graphs[i]['objects'], boxes_pred, img_layout_path, 91 | colors=obj_colors) 92 | return_img_layout_path = os.path.join('images', output_dir, 'img_layout{}.png'.format(current_time)) 93 | 94 | # Draw and save the scene graph 95 | if args.draw_scene_graphs: 96 | for i, sg in enumerate(scene_graphs): 97 | sg_img = vis.draw_scene_graph(sg['objects'], sg['relationships']) 98 | sg_img_path = os.path.join('scripts', 'gui', 'images', output_dir, 'sg{}.png'.format(current_time)) 99 | imwrite(sg_img_path, sg_img) 100 | sg_img_path = os.path.join('images', output_dir, 'sg{}.png'.format(current_time)) 101 | 102 | return return_img_path, return_img_layout_path 103 | 104 | 105 | def one_hot_to_rgb(one_hot, colors): 106 | one_hot_3d = torch.einsum('abcd,be->aecd', (one_hot.cpu(), colors.cpu())) 107 | one_hot_3d *= (255.0 / one_hot_3d.max()) 108 | return one_hot_3d 109 | 110 | 111 | def json_to_scene_graph(json_text): 112 | scene = json.loads(json_text) 113 | if len(scene) == 0: 114 | return [] 115 | image_id = scene['image_id'] 116 | scene = scene['objects'] 117 | objects = [i['text'] for i in scene] 118 | relationships = [] 119 | size = [] 120 | location = [] 121 | features = [] 122 | for i in range(0, len(objects)): 123 | obj_s = scene[i] 124 | # Check for inside / surrounding 125 | 126 | sx0 = obj_s['left'] 127 | sy0 = obj_s['top'] 128 | sx1 = obj_s['width'] + sx0 129 | sy1 = obj_s['height'] + sy0 130 | 131 | margin = (obj_s['size'] + 1) / 10 / 2 132 | mean_x_s = 0.5 * (sx0 + sx1) 133 | mean_y_s = 0.5 * (sy0 + sy1) 134 | 135 | sx0 = max(0, mean_x_s - margin) 136 | sx1 = min(1, mean_x_s + margin) 137 | sy0 = max(0, mean_y_s - margin) 138 | sy1 = min(1, mean_y_s + margin) 139 | 140 | size.append(obj_s['size']) 141 | location.append(obj_s['location']) 142 | 143 | features.append(obj_s['feature']) 144 | if i == len(objects) - 1: 145 | continue 146 | 147 | obj_o = scene[i + 1] 148 | ox0 = obj_o['left'] 149 | oy0 = obj_o['top'] 150 | ox1 = obj_o['width'] + ox0 151 | oy1 = obj_o['height'] + oy0 152 | 153 | mean_x_o = 0.5 * (ox0 + ox1) 154 | mean_y_o = 0.5 * (oy0 + oy1) 155 | d_x = mean_x_s - mean_x_o 156 | d_y = mean_y_s - mean_y_o 157 | theta = math.atan2(d_y, d_x) 158 | 159 | margin = (obj_o['size'] + 1) / 10 / 2 160 | ox0 = max(0, mean_x_o - margin) 161 | ox1 = min(1, mean_x_o + margin) 162 | oy0 = max(0, mean_y_o - margin) 163 | oy1 = min(1, mean_y_o + margin) 164 | 165 | if sx0 < ox0 and sx1 > ox1 and sy0 < oy0 and sy1 > oy1: 166 | p = 'surrounding' 167 | elif sx0 > ox0 and sx1 < ox1 and sy0 > oy0 and sy1 < oy1: 168 | p = 'inside' 169 | elif theta >= 3 * math.pi / 4 or theta <= -3 * math.pi / 4: 170 | p = 'left of' 171 | elif -3 * math.pi / 4 <= theta < -math.pi / 4: 172 | p = 'above' 173 | elif -math.pi / 4 <= theta < math.pi / 4: 174 | p = 'right of' 175 | elif math.pi / 4 <= theta < 3 * math.pi / 4: 176 | p = 'below' 177 | relationships.append([i, p, i + 1]) 178 | 179 | return [{'objects': objects, 'relationships': relationships, 'attributes': {'size': size, 'location': location}, 180 | 'features': features, 'image_id': image_id}] 181 | -------------------------------------------------------------------------------- /scripts/gui/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | background: linear-gradient(120deg, #252525 0%, #39404c 100%); 3 | padding: 70px 70px 0px 70px; 4 | } 5 | 6 | h1, h2 { 7 | font-family: 'Open Sans', sans-serif; 8 | color: #dbab21; 9 | text-align: center; 10 | font-weight: 100; 11 | letter-spacing: 4px; 12 | } 13 | 14 | .main { 15 | display: flex; 16 | flex-direction: row; 17 | justify-content: space-between; 18 | } 19 | 20 | .buttons { 21 | display: flex; 22 | flex-direction: row; 23 | justify-content: left; 24 | } 25 | 26 | .button { 27 | margin: 20px; 28 | } 29 | 30 | .main-movable { 31 | height: 512px; 32 | width: 512px; 33 | background-color: #FEFEFE; 34 | background-image: -webkit-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -webkit-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB); 35 | background-image: -moz-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -moz-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB); 36 | background-image: -o-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -o-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB); 37 | background-image: -ms-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -ms-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB); 38 | background-image: linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB); 39 | -webkit-background-size: 30px 30px; 40 | -moz-background-size: 30px 30px; 41 | background-size: 30px 30px; 42 | background-position: 0 0, 15px 15px; 43 | } 44 | 45 | 46 | .resize-drag { 47 | color: black; 48 | font-size: 30px; 49 | font-family: 'Open Sans', sans-serif; 50 | border-radius: 8px; 51 | padding: 20px; 52 | margin: 0; 53 | touch-action: none; 54 | text-align: center; 55 | 56 | width: 120px; 57 | /* This makes things *much* easier */ 58 | box-sizing: border-box; 59 | } 60 | 61 | #resize-container { 62 | display: inline-block; 63 | width: 100%; 64 | height: 100%; 65 | position: relative; 66 | } 67 | 68 | .line1 { 69 | left: 20%; 70 | } 71 | 72 | .line2 { 73 | left: 40%; 74 | } 75 | 76 | .line3 { 77 | left: 60%; 78 | } 79 | 80 | .line4 { 81 | left: 80%; 82 | } 83 | 84 | .line5 { 85 | bottom: 20%; 86 | } 87 | 88 | .line6 { 89 | bottom: 40%; 90 | } 91 | 92 | .line7 { 93 | bottom: 60%; 94 | } 95 | 96 | .line8 { 97 | bottom: 80%; 98 | } 99 | 100 | 101 | .line-horizontal { 102 | content: ""; 103 | position: absolute; 104 | z-index: 1; 105 | border-left: 2px dotted #000000; 106 | top: 0; 107 | bottom: 0; 108 | } 109 | 110 | .line-vertical { 111 | content: ""; 112 | position: absolute; 113 | z-index: 1; 114 | border-top: 1px dotted #000000; 115 | left: 0; 116 | right: 0; 117 | } 118 | 119 | #layout_pred, #img_pred { 120 | width: 512px; 121 | height: 512px; 122 | } 123 | 124 | .selected { 125 | color: red; 126 | } 127 | 128 | /* Range Slider */ 129 | 130 | *, *:before, *:after { 131 | box-sizing: border-box; 132 | } 133 | 134 | 135 | .range-slider { 136 | margin: 15px 0 0 0; 137 | } 138 | 139 | .range-slider { 140 | width: 50%; 141 | } 142 | 143 | .range-slider__range { 144 | -webkit-appearance: none; 145 | width: 512px; 146 | height: 10px; 147 | border-radius: 5px; 148 | background: #d7dcdf; 149 | outline: none; 150 | } 151 | 152 | .range-slider__range::-webkit-slider-thumb { 153 | -webkit-appearance: none; 154 | appearance: none; 155 | width: 20px; 156 | height: 20px; 157 | border-radius: 50%; 158 | background: #dbab21; 159 | cursor: pointer; 160 | transition: background .15s ease-in-out; 161 | } 162 | 163 | .range-slider__range::-webkit-slider-thumb:hover { 164 | background: #47370d; 165 | } 166 | 167 | .range-slider__range:active::-webkit-slider-thumb { 168 | background: #47370d; 169 | } 170 | 171 | .range-slider__range::-moz-range-thumb { 172 | width: 20px; 173 | height: 20px; 174 | border: 0; 175 | border-radius: 50%; 176 | background: #47370d; 177 | cursor: pointer; 178 | transition: background .15s ease-in-out; 179 | } 180 | 181 | .range-slider__range::-moz-range-thumb:hover { 182 | background: #47370d; 183 | } 184 | 185 | .range-slider__range:active::-moz-range-thumb { 186 | background: #47370d; 187 | } 188 | 189 | .range-slider__range:focus::-webkit-slider-thumb { 190 | box-shadow: 0 0 0 3px #fff, 0 0 0 6px #47370d; 191 | } 192 | 193 | .range-slider__value { 194 | display: inline-block; 195 | position: relative; 196 | width: 60px; 197 | color: rgba(0, 0, 0, 0.7); 198 | line-height: 20px; 199 | text-align: center; 200 | border-radius: 3px; 201 | background: #dbab21; 202 | padding: 5px 10px; 203 | margin-left: 8px; 204 | font-family: 'Open Sans', sans-serif; 205 | } 206 | 207 | .range-slider__value:after { 208 | position: absolute; 209 | top: 8px; 210 | left: -7px; 211 | width: 0; 212 | height: 0; 213 | border-top: 7px solid transparent; 214 | border-right: 7px solid #2c3e50; 215 | border-bottom: 7px solid transparent; 216 | content: ''; 217 | } 218 | 219 | ::-moz-range-track { 220 | background: #df2e46; 221 | border: 0; 222 | } 223 | 224 | input::-moz-focus-inner, 225 | input::-moz-focus-outer { 226 | border: 0; 227 | } 228 | 229 | #table { 230 | color: white; 231 | } 232 | 233 | /* navigation bar */ 234 | ul { 235 | list-style: none; 236 | font-family: 'Open Sans', sans-serif; 237 | width: 100%; 238 | display: flex; 239 | justify-content: space-around; 240 | } 241 | 242 | .containerMain { 243 | display: flex; 244 | flex-direction: column; 245 | } 246 | 247 | .container { 248 | position: relative; 249 | z-index: 10; 250 | display: flex; 251 | flex-direction: column; 252 | justify-content: center; 253 | align-items: start; 254 | } 255 | 256 | .nav { 257 | display: inline-block; 258 | text-align: center; 259 | margin: 0 0; 260 | padding-bottom: 20px; 261 | width: 950px; 262 | } 263 | 264 | .nav li:hover { 265 | background: #FFCB25; 266 | } 267 | 268 | .nav > ul { 269 | list-style: none; 270 | padding: 0; 271 | margin: 0; 272 | background: #dbab21; 273 | border-radius: 5px; 274 | color: rgba(0, 0, 0, 0.7); 275 | z-index: 11; 276 | display: flex; 277 | justify-content: space-evenly; 278 | } 279 | 280 | .nav > ul > li { 281 | float: left; 282 | width: 90px; 283 | height: 30px; 284 | line-height: 30px; 285 | position: relative; 286 | font-size: 20px; 287 | cursor: pointer; 288 | } 289 | 290 | ul.drop-menu { 291 | position: absolute; 292 | top: 100%; 293 | left: 0%; 294 | padding: 0; 295 | display: flex; 296 | flex-direction: column; 297 | flex-wrap: wrap; 298 | max-height: 430px; 299 | width: 150px; 300 | text-align: start; 301 | } 302 | 303 | ul.drop-menu li { 304 | background: #a7a8a2; 305 | z-index: 12; 306 | padding: 3px; 307 | } 308 | 309 | ul.drop-menu li:hover { 310 | background: #ffcb25; 311 | } 312 | 313 | ul.drop-menu li { 314 | display: none; 315 | } 316 | 317 | li:hover > ul.drop-menu li { 318 | display: block; 319 | } 320 | 321 | .item { 322 | position: relative; 323 | padding-bottom: 20px; 324 | } 325 | 326 | .adjust { 327 | padding-bottom: 0; 328 | padding-top: 20px; 329 | } 330 | 331 | .nav-arrow { 332 | border: 1px solid transparent; 333 | background: transparent; 334 | border-radius: 3px; 335 | color: white; 336 | transition: background 0.5s ease; 337 | font-size: 24px; 338 | font-family: 'Open Sans', sans-serif; 339 | } -------------------------------------------------------------------------------- /scripts/gui/index_panoptic.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 178 |
179 | 180 | 181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
ObjectSizeLocationfeature
208 |
209 |
210 | 211 |
212 |
213 |
214 | 215 | 0 216 |
217 |
218 |
219 | 220 | -------------------------------------------------------------------------------- /scene_generation/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | 7 | 8 | def get_gan_losses(gan_type): 9 | """ 10 | Returns the generator and discriminator loss for a particular GAN type. 11 | 12 | The returned functions have the following API: 13 | loss_g = g_loss(scores_fake) 14 | loss_d = d_loss(scores_real, scores_fake) 15 | """ 16 | if gan_type == 'gan': 17 | return gan_g_loss, gan_d_loss 18 | elif gan_type == 'wgan': 19 | return wgan_g_loss, wgan_d_loss 20 | elif gan_type == 'lsgan': 21 | return lsgan_g_loss, lsgan_d_loss 22 | else: 23 | raise ValueError('Unrecognized GAN type "%s"' % gan_type) 24 | 25 | 26 | def bce_loss(input, target): 27 | """ 28 | Numerically stable version of the binary cross-entropy loss function. 29 | 30 | As per https://github.com/pytorch/pytorch/issues/751 31 | See the TensorFlow docs for a derivation of this formula: 32 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 33 | 34 | Inputs: 35 | - input: PyTorch Tensor of shape (N, ) giving scores. 36 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. 37 | 38 | Returns: 39 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of 40 | input data. 41 | """ 42 | neg_abs = -input.abs() 43 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 44 | return loss.mean() 45 | 46 | 47 | def _make_targets(x, y): 48 | """ 49 | Inputs: 50 | - x: PyTorch Tensor 51 | - y: Python scalar 52 | 53 | Outputs: 54 | - out: PyTorch Variable with same shape and dtype as x, but filled with y 55 | """ 56 | return torch.full_like(x, y) 57 | 58 | 59 | def gan_g_loss(scores_fake): 60 | """ 61 | Input: 62 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 63 | 64 | Output: 65 | - loss: Variable of shape (,) giving GAN generator loss 66 | """ 67 | if scores_fake.dim() > 1: 68 | scores_fake = scores_fake.view(-1) 69 | y_fake = _make_targets(scores_fake, 1) 70 | return bce_loss(scores_fake, y_fake) 71 | 72 | 73 | def gan_d_loss(scores_real, scores_fake): 74 | """ 75 | Input: 76 | - scores_real: Tensor of shape (N,) giving scores for real samples 77 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 78 | 79 | Output: 80 | - loss: Tensor of shape (,) giving GAN discriminator loss 81 | """ 82 | assert scores_real.size() == scores_fake.size() 83 | if scores_real.dim() > 1: 84 | scores_real = scores_real.view(-1) 85 | scores_fake = scores_fake.view(-1) 86 | y_real = _make_targets(scores_real, 1) 87 | y_fake = _make_targets(scores_fake, 0) 88 | loss_real = bce_loss(scores_real, y_real) 89 | loss_fake = bce_loss(scores_fake, y_fake) 90 | return loss_real + loss_fake 91 | 92 | 93 | def wgan_g_loss(scores_fake): 94 | """ 95 | Input: 96 | - scores_fake: Tensor of shape (N,) containing scores for fake samples 97 | 98 | Output: 99 | - loss: Tensor of shape (,) giving WGAN generator loss 100 | """ 101 | return -scores_fake.mean() 102 | 103 | 104 | def wgan_d_loss(scores_real, scores_fake): 105 | """ 106 | Input: 107 | - scores_real: Tensor of shape (N,) giving scores for real samples 108 | - scores_fake: Tensor of shape (N,) giving scores for fake samples 109 | 110 | Output: 111 | - loss: Tensor of shape (,) giving WGAN discriminator loss 112 | """ 113 | return scores_fake.mean() - scores_real.mean() 114 | 115 | 116 | def lsgan_g_loss(scores_fake): 117 | if scores_fake.dim() > 1: 118 | scores_fake = scores_fake.view(-1) 119 | y_fake = _make_targets(scores_fake, 1) 120 | return F.mse_loss(scores_fake.sigmoid(), y_fake) 121 | 122 | 123 | def lsgan_d_loss(scores_real, scores_fake): 124 | assert scores_real.size() == scores_fake.size() 125 | if scores_real.dim() > 1: 126 | scores_real = scores_real.view(-1) 127 | scores_fake = scores_fake.view(-1) 128 | y_real = _make_targets(scores_real, 1) 129 | y_fake = _make_targets(scores_fake, 0) 130 | loss_real = F.mse_loss(scores_real.sigmoid(), y_real) 131 | loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake) 132 | return loss_real + loss_fake 133 | 134 | 135 | class GANLoss(nn.Module): 136 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): 137 | super(GANLoss, self).__init__() 138 | self.real_label = target_real_label 139 | self.fake_label = target_fake_label 140 | self.real_label_var = None 141 | self.fake_label_var = None 142 | self.Tensor = tensor 143 | if use_lsgan: 144 | self.loss = nn.MSELoss() 145 | else: 146 | self.loss = nn.BCELoss() 147 | 148 | def get_target_tensor(self, input, target_is_real): 149 | if target_is_real: 150 | create_label = ((self.real_label_var is None) or 151 | (self.real_label_var.numel() != input.numel())) 152 | if create_label: 153 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 154 | self.real_label_var = Variable(real_tensor, requires_grad=False) 155 | target_tensor = self.real_label_var 156 | else: 157 | create_label = ((self.fake_label_var is None) or 158 | (self.fake_label_var.numel() != input.numel())) 159 | if create_label: 160 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 161 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 162 | target_tensor = self.fake_label_var 163 | return target_tensor 164 | 165 | def __call__(self, input, target_is_real): 166 | if isinstance(input[0], list): 167 | loss = 0 168 | for input_i in input: 169 | pred = input_i[-1] 170 | target_tensor = self.get_target_tensor(pred, target_is_real) 171 | loss += self.loss(pred, target_tensor) 172 | return loss 173 | else: 174 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 175 | return self.loss(input[-1], target_tensor) 176 | 177 | 178 | # VGG Features matching 179 | class Vgg19(torch.nn.Module): 180 | def __init__(self, requires_grad=False): 181 | super(Vgg19, self).__init__() 182 | vgg_pretrained_features = models.vgg19(pretrained=True).features 183 | self.slice1 = torch.nn.Sequential() 184 | self.slice2 = torch.nn.Sequential() 185 | self.slice3 = torch.nn.Sequential() 186 | self.slice4 = torch.nn.Sequential() 187 | self.slice5 = torch.nn.Sequential() 188 | for x in range(2): 189 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 190 | for x in range(2, 7): 191 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 192 | for x in range(7, 12): 193 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 194 | for x in range(12, 21): 195 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 196 | for x in range(21, 30): 197 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 198 | if not requires_grad: 199 | for param in self.parameters(): 200 | param.requires_grad = False 201 | 202 | def forward(self, X): 203 | h_relu1 = self.slice1(X) 204 | h_relu2 = self.slice2(h_relu1) 205 | h_relu3 = self.slice3(h_relu2) 206 | h_relu4 = self.slice4(h_relu3) 207 | h_relu5 = self.slice5(h_relu4) 208 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 209 | return out 210 | 211 | 212 | class VGGLoss(nn.Module): 213 | def __init__(self): 214 | super(VGGLoss, self).__init__() 215 | self.vgg = Vgg19().cuda() 216 | self.criterion = nn.L1Loss() 217 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 218 | 219 | def forward(self, x, y): 220 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 221 | loss = 0 222 | for i in range(len(x_vgg)): 223 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 224 | return loss -------------------------------------------------------------------------------- /scene_generation/vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import tempfile 19 | 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import torch 23 | from imageio import imread 24 | from matplotlib.patches import Rectangle 25 | 26 | """ 27 | Utilities for making visualizations. 28 | """ 29 | 30 | 31 | def draw_layout(vocab, objs, boxes, masks=None, size=256, 32 | show_boxes=False, bgcolor=(0, 0, 0)): 33 | if bgcolor == 'white': 34 | bgcolor = (255, 255, 255) 35 | 36 | cmap = plt.get_cmap('rainbow') 37 | colors = cmap(np.linspace(0, 1, len(objs))) 38 | 39 | with torch.no_grad(): 40 | objs = objs.cpu().clone() 41 | boxes = boxes.cpu().clone() 42 | boxes *= size 43 | 44 | if masks is not None: 45 | masks = masks.cpu().clone() 46 | 47 | bgcolor = np.asarray(bgcolor) 48 | bg = np.ones((size, size, 1)) * bgcolor 49 | plt.imshow(bg.astype(np.uint8)) 50 | 51 | plt.gca().set_xlim(0, size) 52 | plt.gca().set_ylim(size, 0) 53 | plt.gca().set_aspect(1.0, adjustable='box') 54 | 55 | for i, obj in enumerate(objs): 56 | name = vocab['object_idx_to_name'][obj] 57 | if name == '__image__': 58 | continue 59 | box = boxes[i] 60 | 61 | if masks is None: 62 | continue 63 | mask = masks[i].numpy() 64 | mask /= mask.max() 65 | 66 | r, g, b, a = colors[i] 67 | colored_mask = mask[:, :, None] * np.asarray(colors[i]) 68 | 69 | x0, y0, x1, y1 = box 70 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0), 71 | interpolation='bicubic', alpha=1.0) 72 | 73 | if show_boxes: 74 | for i, obj in enumerate(objs): 75 | name = vocab['object_idx_to_name'][obj] 76 | if name == '__image__': 77 | continue 78 | box = boxes[i] 79 | 80 | draw_box(box, colors[i], name) 81 | 82 | 83 | def add_boxes_to_layout(img, objs, boxes, image_path, size=256, colors=None): 84 | if colors is None: 85 | cmap = plt.get_cmap('rainbow') 86 | colors = cmap(np.linspace(0, 1, len(objs))) 87 | plt.clf() 88 | with torch.no_grad(): 89 | boxes = boxes.cpu().clone() 90 | boxes *= size 91 | 92 | plt.imshow(img) 93 | 94 | plt.gca().set_xlim(0, size) 95 | plt.gca().set_ylim(size, 0) 96 | plt.gca().set_aspect(1.0, adjustable='box') 97 | 98 | for i, obj in enumerate(objs): 99 | if obj == '__image__': 100 | continue 101 | draw_box(boxes[i], colors[i], obj, alpha=0.8) 102 | plt.axis('off') 103 | plt.savefig(image_path, bbox_inches='tight', pad_inches=0) 104 | 105 | 106 | def draw_box(box, color, text=None, alpha=1.0): 107 | """ 108 | Draw a bounding box using pyplot, optionally with a text box label. 109 | 110 | Inputs: 111 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H] 112 | coordinate system. 113 | - color: pyplot color to use for the box. 114 | - text: (Optional) String; if provided then draw a label for this box. 115 | """ 116 | TEXT_BOX_HEIGHT = 10 117 | if torch.is_tensor(box) and box.dim() == 2: 118 | box = box.view(-1) 119 | assert box.size(0) == 4 120 | x0, y0, x1, y1 = box 121 | assert y1 > y0, box 122 | assert x1 > x0, box 123 | w, h = x1 - x0, y1 - y0 124 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color, alpha=alpha) 125 | plt.gca().add_patch(rect) 126 | if text is not None: 127 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5) 128 | plt.gca().add_patch(text_rect) 129 | tx = 0.5 * (x0 + x1) 130 | ty = y0 + TEXT_BOX_HEIGHT / 2.0 131 | plt.text(tx, ty, text, va='center', ha='center') 132 | 133 | 134 | def draw_scene_graph(objs, triples, vocab=None, **kwargs): 135 | """ 136 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume 137 | that objs and triples are python lists containing strings for object and 138 | relationship names. 139 | 140 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy: 141 | sudo apt-get install graphviz 142 | """ 143 | output_filename = kwargs.pop('output_filename', 'graph.png') 144 | orientation = kwargs.pop('orientation', 'V') 145 | edge_width = kwargs.pop('edge_width', 6) 146 | arrow_size = kwargs.pop('arrow_size', 1.5) 147 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2) 148 | ignore_dummies = kwargs.pop('ignore_dummies', True) 149 | 150 | if orientation not in ['V', 'H']: 151 | raise ValueError('Invalid orientation "%s"' % orientation) 152 | rankdir = {'H': 'LR', 'V': 'TD'}[orientation] 153 | 154 | if vocab is not None: 155 | # Decode object and relationship names 156 | assert torch.is_tensor(objs) 157 | assert torch.is_tensor(triples) 158 | objs_list, triples_list = [], [] 159 | idx_to_obj = ['__image__'] + vocab['my_idx_to_obj'] 160 | for i in range(objs.size(0)): 161 | objs_list.append(idx_to_obj[objs[i].item()]) 162 | for i in range(triples.size(0)): 163 | s = triples[i, 0].item() 164 | p = vocab['pred_idx_to_name'][triples[i, 1].item()] 165 | o = triples[i, 2].item() 166 | triples_list.append([s, p, o]) 167 | objs, triples = objs_list, triples_list 168 | 169 | # General setup, and style for object nodes 170 | lines = [ 171 | 'digraph{', 172 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="transparent"]', 173 | 'rankdir=%s' % rankdir, 174 | 'nodesep="0.5"', 175 | 'ranksep="0.5"', 176 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]', 177 | 'node [fillcolor="lightpink1"]', 178 | ] 179 | # Output nodes for objects 180 | for i, obj in enumerate(objs): 181 | if ignore_dummies and obj == '__image__': 182 | continue 183 | lines.append('%d [label="%s"]' % (i, obj)) 184 | 185 | # Output relationships 186 | next_node_id = len(objs) 187 | lines.append('node [fillcolor="lightblue1"]') 188 | for s, p, o in triples: 189 | if ignore_dummies and p == '__in_image__': 190 | continue 191 | lines += [ 192 | '%d [label="%s"]' % (next_node_id, p), 193 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 194 | s, next_node_id, edge_width, arrow_size, binary_edge_weight), 195 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 196 | next_node_id, o, edge_width, arrow_size, binary_edge_weight) 197 | ] 198 | next_node_id += 1 199 | lines.append('}') 200 | 201 | # Now it gets slightly hacky. Write the graphviz spec to a temporary 202 | # text file 203 | ff, dot_filename = tempfile.mkstemp() 204 | with open(dot_filename, 'w') as f: 205 | for line in lines: 206 | f.write('%s\n' % line) 207 | os.close(ff) 208 | 209 | # Shell out to invoke graphviz; this will save the resulting image to disk, 210 | # so we read it, delete it, then return it. 211 | output_format = os.path.splitext(output_filename)[1][1:] 212 | os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename)) 213 | os.remove(dot_filename) 214 | img = imread(output_filename) 215 | os.remove(output_filename) 216 | 217 | return img 218 | 219 | 220 | if __name__ == '__main__': 221 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard'] 222 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above'] 223 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)} 224 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)} 225 | vocab = { 226 | 'object_idx_to_name': o_idx_to_name, 227 | 'object_name_to_idx': o_name_to_idx, 228 | 'pred_idx_to_name': p_idx_to_name, 229 | 'pred_name_to_idx': p_name_to_idx, 230 | } 231 | 232 | objs = [ 233 | 'cat', 234 | 'cat', 235 | 'skateboard', 236 | 'hat', 237 | ] 238 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs]) 239 | triples = [ 240 | [0, 'next to', 1], 241 | [0, 'riding', 2], 242 | [1, 'wearing', 3], 243 | [3, 'above', 2], 244 | ] 245 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples] 246 | triples = torch.LongTensor(triples) 247 | 248 | draw_scene_graph(objs, triples, vocab, orientation='V') 249 | -------------------------------------------------------------------------------- /scripts/gui/index.js: -------------------------------------------------------------------------------- 1 | window.onload = function () { 2 | 3 | var rangeSlider = function () { 4 | var slider = $('.range-slider'), 5 | range = $('.range-slider__range'), 6 | value = $('.range-slider__value'); 7 | 8 | slider.each(function () { 9 | 10 | value.each(function () { 11 | var value = $(this).prev().attr('value'); 12 | $(this).html(value); 13 | }); 14 | 15 | range.on('input', function () { 16 | var id = this.getAttribute('data-id'); 17 | if (id) { 18 | document.getElementById(id).setAttribute('data-style', this.value); 19 | } else { 20 | this.setAttribute('image-id', this.value); 21 | } 22 | $(this).next(value).html(this.value); 23 | render_button(); 24 | }); 25 | }); 26 | }; 27 | 28 | rangeSlider(); 29 | 30 | function dragMoveListener(event) { 31 | console.log('dragMobve'); 32 | var target = event.target, 33 | // keep the dragged position in the data-x/data-y attributes 34 | x = (parseFloat(target.getAttribute('data-x')) || 0) + event.dx, 35 | y = (parseFloat(target.getAttribute('data-y')) || 0) + event.dy; 36 | 37 | // translate the element 38 | target.style.webkitTransform = 39 | target.style.transform = 40 | 'translate(' + x + 'px, ' + y + 'px)'; 41 | 42 | // update the posiion attributes 43 | target.setAttribute('data-x', x); 44 | target.setAttribute('data-y', y); 45 | selectItem(event, event.target); 46 | // render_button(); 47 | } 48 | 49 | // this is used later in the resizing and gesture demos 50 | window.dragMoveListener = dragMoveListener; 51 | 52 | interact('.resize-drag') 53 | .draggable({ 54 | onmove: window.dragMoveListener, 55 | onend: render_button, 56 | restrict: { 57 | restriction: 'parent', 58 | elementRect: {top: 0, left: 0, bottom: 1, right: 1} 59 | }, 60 | }) 61 | 62 | .on('tap', function (event) { 63 | console.log('tap'); 64 | var target = event.target; 65 | var size = parseInt(target.getAttribute('data-size')); 66 | var new_size = (size + 1) % 10; 67 | target.setAttribute('data-size', new_size); 68 | target.style.fontSize = sizeToFont(new_size); 69 | // $(event.currentTarget).remove(); 70 | selectItem(event, event.target); 71 | render_button(); 72 | event.preventDefault(); 73 | }) 74 | .on('hold', function (event) { 75 | console.log('hold'); 76 | $(event.currentTarget).remove(); 77 | render_button(); 78 | event.preventDefault(); 79 | }); 80 | 81 | function selectItem(event, target, should_deselect) { 82 | event.stopPropagation(); 83 | var hasClass = $(target).hasClass('selected'); 84 | $(".resize-drag").removeClass("selected"); 85 | $('#range-slider').attr('data-id', ''); 86 | if (should_deselect && hasClass) { 87 | } else { 88 | $(target).addClass("selected"); 89 | $('#range-slider').attr('data-id', target.id); 90 | var style = target.getAttribute('data-style'); 91 | style = style ? style : -1; 92 | $('#range-slider').val(style); 93 | $('.range-slider__value').text(style.toString()); 94 | } 95 | } 96 | 97 | $(".resize-drag").click(function (e) { 98 | $(".resize-drag").removeClass("selected"); 99 | $(this).addClass("selected"); 100 | e.stopPropagation(); 101 | }); 102 | 103 | function guidGenerator() { 104 | var S4 = function () { 105 | return (((1 + Math.random()) * 0x10000) | 0).toString(16).substring(1); 106 | }; 107 | return (S4() + S4() + "-" + S4() + "-" + S4() + "-" + S4() + "-" + S4() + S4() + S4()); 108 | } 109 | 110 | function stuff_add(evt) { 111 | evt.stopPropagation(); 112 | var newContent = document.createTextNode(evt.currentTarget.textContent); 113 | var node = document.createElement("DIV"); 114 | node.className = "resize-drag"; 115 | node.id = guidGenerator(); 116 | node.appendChild(newContent); 117 | var init_size = 0; 118 | node.setAttribute('data-size', init_size); 119 | node.style.fontSize = sizeToFont(init_size); 120 | document.getElementById("resize-container").appendChild(node); 121 | render_button(); 122 | } 123 | 124 | function sizeToFont(size) { 125 | return size * 8 + 20; 126 | } 127 | 128 | function refresh_image(response) { 129 | response = JSON.parse(response); 130 | document.getElementById("img_pred").src = response.img_pred; 131 | document.getElementById("layout_pred").src = response.layout_pred; 132 | } 133 | 134 | function addRow(obj, size, location, feature) { 135 | return; 136 | // Get a reference to the table 137 | let tableRef = document.getElementById('table').getElementsByTagName('tbody')[0]; 138 | 139 | // Insert a row at the end of the table 140 | let newRow = tableRef.insertRow(-1); 141 | 142 | // Insert a cell in the row at index 0 143 | newRow.insertCell(0).appendChild(document.createTextNode(obj)); 144 | newRow.insertCell(1).appendChild(document.createTextNode(size + '')); 145 | newRow.insertCell(2).appendChild(document.createTextNode(location + '')); 146 | newRow.insertCell(3).appendChild(document.createTextNode(feature + '')); 147 | } 148 | 149 | function render_button() { 150 | console.log('render'); 151 | var allObjects = []; 152 | $("tbody").children().remove(); 153 | var container = document.getElementById("resize-container"); 154 | var container_rect = interact.getElementRect(container); 155 | var containerOffsetLeft = container_rect.left; 156 | var containerOffsetTop = container_rect.top; 157 | var containerWidth = container_rect.width; 158 | var containerHeight = container_rect.height; 159 | var children = document.getElementsByClassName('resize-drag'); 160 | if (children.length < 3) { 161 | return; 162 | } 163 | for (var i = 0; i < children.length; i++) { 164 | var rect = interact.getElementRect(children[i]); 165 | var height = rect.height / containerHeight; 166 | var width = rect.width / containerWidth; 167 | var left = (rect.left - containerOffsetLeft) / containerWidth; 168 | var top = (rect.top - containerOffsetTop) / containerHeight; 169 | var sx0 = left; 170 | var sy0 = top; 171 | var sx1 = width + left; 172 | var sy1 = height + sy0; 173 | var mean_x_s = 0.5 * (sx0 + sx1); 174 | var mean_y_s = 0.5 * (sy0 + sy1); 175 | var grid = 25 / 5; 176 | var location = Math.round(mean_x_s * (grid - 1)) + grid * Math.round(mean_y_s * (grid - 1)); 177 | var size = parseInt(children[i].getAttribute('data-size')); 178 | var text = children[i].innerText; 179 | var style = children[i].getAttribute('data-style') ? 180 | parseInt(children[i].getAttribute('data-style')) : 181 | -1; 182 | allObjects.push({ 183 | 'height': height, 184 | 'width': width, 185 | 'left': left, 186 | 'top': top, 187 | 'text': text, 188 | 'feature': style, 189 | 'size': size, 190 | 'location': location, 191 | }); 192 | console.log(size, location, text); 193 | addRow(text, size, location, style); 194 | } 195 | console.log(allObjects); 196 | var image_id = document.getElementById('range-slider').getAttribute('image-id'); 197 | var image_feature = image_id ? parseInt(image_id) : -1; 198 | addRow('background', '-', '-', image_feature); 199 | var results = {'image_id': image_feature, 'objects': allObjects}; 200 | var url = 'get_data?data=' + JSON.stringify(results); 201 | var xmlHttp = new XMLHttpRequest(); 202 | xmlHttp.onreadystatechange = function () { 203 | if (xmlHttp.readyState == 4 && xmlHttp.status == 200) 204 | refresh_image(xmlHttp.responseText); 205 | }; 206 | xmlHttp.open("GET", url, true); // true for asynchronous 207 | xmlHttp.send(null); 208 | } 209 | 210 | document.querySelectorAll("ul.drop-menu > li").forEach(function (e) { 211 | e.addEventListener("click", stuff_add) 212 | }); 213 | $(window).click(function (devt) { 214 | if (!devt.target.getAttribute('data-size') && !devt.target.getAttribute('max')) { 215 | $(".resize-drag").removeClass("selected"); 216 | var image_style = $('#range-slider').attr('image-id'); 217 | image_style = image_style ? parseInt(image_style) : -1; 218 | $('#range-slider').val(image_style); 219 | $('.range-slider__value').text(image_style.toString()); 220 | $('#range-slider').attr('data-id', ''); 221 | } 222 | }); 223 | }; 224 | 225 | -------------------------------------------------------------------------------- /scene_generation/layout.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | """ 22 | Functions for computing image layouts from object vectors, bounding boxes, 23 | and segmentation masks. These are used to compute course scene layouts which 24 | are then fed as input to the cascaded refinement network. 25 | """ 26 | 27 | 28 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'): 29 | """ 30 | Inputs: 31 | - vecs: Tensor of shape (O, D) giving vectors 32 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 33 | [x0, y0, x1, y1] in the [0, 1] coordinate space 34 | - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to 35 | an image, where each element is in the range [0, N). If obj_to_img[i] = j 36 | then vecs[i] belongs to image j. 37 | - H, W: Size of the output 38 | 39 | Returns: 40 | - out: Tensor of shape (N, D, H, W) 41 | """ 42 | O, D = vecs.size() 43 | if W is None: 44 | W = H 45 | 46 | grid = _boxes_to_grid(boxes, H, W) 47 | 48 | # If we don't add extra spatial dimensions here then out-of-bounds 49 | # elements won't be automatically set to 0 50 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8) 51 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W) 52 | 53 | # Explicitly masking makes everything quite a bit slower. 54 | # If we rely on implicit masking the interpolated boxes end up 55 | # blurred around the edges, but it should be fine. 56 | # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1) 57 | # sampled[mask[:, None]] = 0 58 | 59 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 60 | 61 | return out 62 | 63 | 64 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum', test_mode=False): 65 | """ 66 | Inputs: 67 | - vecs: Tensor of shape (O, D) giving vectors 68 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format 69 | [x0, y0, x1, y1] in the [0, 1] coordinate space 70 | - masks: Tensor of shape (O, M, M) giving binary masks for each object 71 | - obj_to_img: LongTensor of shape (O,) mapping objects to images 72 | - H, W: Size of the output image. 73 | 74 | Returns: 75 | - out: Tensor of shape (N, D, H, W) 76 | """ 77 | O, D = vecs.size() 78 | M = masks.size(1) 79 | assert masks.size() == (O, M, M) 80 | if W is None: 81 | W = H 82 | 83 | grid = _boxes_to_grid(boxes, H, W) 84 | 85 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M) 86 | sampled = F.grid_sample(img_in, grid) 87 | if test_mode: 88 | clean_mask_sampled = F.grid_sample(masks.float().view(O, 1, M, M), grid) 89 | else: 90 | clean_mask_sampled = None 91 | 92 | out = _pool_samples(sampled, clean_mask_sampled, obj_to_img, pooling=pooling) 93 | return out 94 | 95 | 96 | def _boxes_to_grid(boxes, H, W): 97 | """ 98 | Input: 99 | - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1] 100 | format in the [0, 1] coordinate space 101 | - H, W: Scalars giving size of output 102 | 103 | Returns: 104 | - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample 105 | """ 106 | O = boxes.size(0) 107 | 108 | boxes = boxes.view(O, 4, 1, 1) 109 | 110 | # All these are (O, 1, 1) 111 | x0, y0 = boxes[:, 0], boxes[:, 1] 112 | ww, hh = boxes[:, 2] - x0, boxes[:, 3] - y0 113 | 114 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) 115 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) 116 | 117 | X = (X - x0) / ww # (O, 1, W) 118 | Y = (Y - y0) / hh # (O, H, 1) 119 | 120 | # Stack does not broadcast its arguments so we need to expand explicitly 121 | X = X.expand(O, H, W) 122 | Y = Y.expand(O, H, W) 123 | grid = torch.stack([X, Y], dim=3) # (O, H, W, 2) 124 | 125 | # Right now grid is in [0, 1] space; transform to [-1, 1] 126 | grid = grid.mul(2).sub(1) 127 | 128 | return grid 129 | 130 | 131 | def _pool_samples(samples, clean_mask_sampled, obj_to_img, pooling='sum'): 132 | """ 133 | Input: 134 | - samples: FloatTensor of shape (O, D, H, W) 135 | - obj_to_img: LongTensor of shape (O,) with each element in the range 136 | [0, N) mapping elements of samples to output images 137 | 138 | Output: 139 | - pooled: FloatTensor of shape (N, D, H, W) 140 | """ 141 | dtype, device = samples.dtype, samples.device 142 | O, D, H, W = samples.size() 143 | N = obj_to_img.data.max().item() + 1 144 | 145 | # Use scatter_add to sum the sampled outputs for each image 146 | # out = torch.zeros(N, D, H, W, dtype=dtype, device=device) 147 | # idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) 148 | # out = out.scatter_add(0, idx, samples) 149 | obj_to_img_list = [i.item() for i in list(obj_to_img)] 150 | all_out = [] 151 | if clean_mask_sampled is None: 152 | for i in range(N): 153 | start = obj_to_img_list.index(i) 154 | end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i) 155 | all_out.append(torch.sum(samples[start:end, :, :, :], dim=0)) 156 | else: 157 | _, d, h, w = samples.shape 158 | for i in range(N): 159 | start = obj_to_img_list.index(i) 160 | end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i) 161 | mass = [torch.sum(samples[j, :, :, :]).item() for j in range(start, end)] 162 | argsort = np.argsort(mass) 163 | result = torch.zeros((d, h, w), device=samples.device, dtype=samples.dtype) 164 | result_clean = torch.zeros((h, w), device=samples.device, dtype=samples.dtype) 165 | for j in argsort: 166 | masked_mask = (result_clean == 0).float() * (clean_mask_sampled[start + j, 0] > 0.5).float() 167 | result_clean += masked_mask 168 | result += samples[start + j] * masked_mask 169 | all_out.append(result) 170 | out = torch.stack(all_out) 171 | 172 | if pooling == 'avg': 173 | # Divide each output mask by the number of objects; use scatter_add again 174 | # to count the number of objects per image. 175 | ones = torch.ones(O, dtype=dtype, device=device) 176 | obj_counts = torch.zeros(N, dtype=dtype, device=device) 177 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) 178 | # print(obj_counts) 179 | obj_counts = obj_counts.clamp(min=1) 180 | out = out / obj_counts.view(N, 1, 1, 1) 181 | elif pooling != 'sum': 182 | raise ValueError('Invalid pooling "%s"' % pooling) 183 | 184 | return out 185 | 186 | 187 | if __name__ == '__main__': 188 | vecs = torch.FloatTensor([ 189 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 190 | [1, 0, 0], [0, 1, 0], [0, 0, 1], 191 | ]) 192 | boxes = torch.FloatTensor([ 193 | [0.25, 0.125, 0.5, 0.875], 194 | [0, 0, 1, 0.25], 195 | [0.6125, 0, 0.875, 1], 196 | [0, 0.8, 1, 1.0], 197 | [0.25, 0.125, 0.5, 0.875], 198 | [0.6125, 0, 0.875, 1], 199 | ]) 200 | obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1]) 201 | # vecs = torch.FloatTensor([[[1]]]) 202 | # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]]) 203 | vecs, boxes = vecs.cuda(), boxes.cuda() 204 | obj_to_img = obj_to_img.cuda() 205 | out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum') 206 | 207 | from torchvision.utils import save_image 208 | 209 | save_image(out.data, 'out.png') 210 | 211 | masks = torch.FloatTensor([ 212 | [ 213 | [0, 0, 1, 0, 0], 214 | [0, 1, 1, 1, 0], 215 | [1, 1, 1, 1, 1], 216 | [0, 1, 1, 1, 0], 217 | [0, 0, 1, 0, 0], 218 | ], 219 | [ 220 | [0, 0, 1, 0, 0], 221 | [0, 1, 0, 1, 0], 222 | [1, 0, 0, 0, 1], 223 | [0, 1, 0, 1, 0], 224 | [0, 0, 1, 0, 0], 225 | ], 226 | [ 227 | [0, 0, 1, 0, 0], 228 | [0, 1, 1, 1, 0], 229 | [1, 1, 1, 1, 1], 230 | [0, 1, 1, 1, 0], 231 | [0, 0, 1, 0, 0], 232 | ], 233 | [ 234 | [0, 0, 1, 0, 0], 235 | [0, 1, 1, 1, 0], 236 | [1, 1, 1, 1, 1], 237 | [0, 1, 1, 1, 0], 238 | [0, 0, 1, 0, 0], 239 | ], 240 | [ 241 | [0, 0, 1, 0, 0], 242 | [0, 1, 1, 1, 0], 243 | [1, 1, 1, 1, 1], 244 | [0, 1, 1, 1, 0], 245 | [0, 0, 1, 0, 0], 246 | ], 247 | [ 248 | [0, 0, 1, 0, 0], 249 | [0, 1, 1, 1, 0], 250 | [1, 1, 1, 1, 1], 251 | [0, 1, 1, 1, 0], 252 | [0, 0, 1, 0, 0], 253 | ] 254 | ]) 255 | masks = masks.cuda() 256 | out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256) 257 | save_image(out.data, 'out_masks.png') 258 | -------------------------------------------------------------------------------- /scene_generation/discriminators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from scene_generation.bilinear import crop_bbox_batch 7 | from scene_generation.layers import GlobalAvgPool, build_cnn, get_norm_layer 8 | 9 | 10 | class AcDiscriminator(nn.Module): 11 | def __init__(self, vocab, arch, normalization='none', activation='relu', padding='same', pooling='avg'): 12 | super(AcDiscriminator, self).__init__() 13 | self.vocab = vocab 14 | 15 | cnn_kwargs = { 16 | 'arch': arch, 17 | 'normalization': normalization, 18 | 'activation': activation, 19 | 'pooling': pooling, 20 | 'padding': padding, 21 | } 22 | cnn, D = build_cnn(**cnn_kwargs) 23 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(D, 1024)) 24 | num_objects = len(vocab['object_to_idx']) 25 | 26 | self.real_classifier = nn.Linear(1024, 1) 27 | self.obj_classifier = nn.Linear(1024, num_objects) 28 | 29 | def forward(self, x, y): 30 | if x.dim() == 3: 31 | x = x[:, None] 32 | vecs = self.cnn(x) 33 | real_scores = self.real_classifier(vecs) 34 | obj_scores = self.obj_classifier(vecs) 35 | ac_loss = F.cross_entropy(obj_scores, y) 36 | return real_scores, ac_loss 37 | 38 | 39 | class AcCropDiscriminator(nn.Module): 40 | def __init__(self, vocab, arch, normalization='none', activation='relu', 41 | object_size=64, padding='same', pooling='avg'): 42 | super(AcCropDiscriminator, self).__init__() 43 | self.vocab = vocab 44 | self.discriminator = AcDiscriminator(vocab, arch, normalization, 45 | activation, padding, pooling) 46 | self.object_size = object_size 47 | 48 | def forward(self, imgs, objs, boxes, obj_to_img): 49 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) 50 | real_scores, ac_loss = self.discriminator(crops, objs) 51 | return real_scores, ac_loss, crops 52 | 53 | 54 | ############################################################################### 55 | # Functions 56 | ############################################################################### 57 | def weights_init(m): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | m.weight.data.normal_(0.0, 0.02) 61 | elif classname.find('BatchNorm2d') != -1: 62 | m.weight.data.normal_(1.0, 0.02) 63 | m.bias.data.fill_(0) 64 | 65 | 66 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1): 67 | norm_layer = get_norm_layer(norm_type=norm) 68 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D) 69 | # print(netD) 70 | assert (torch.cuda.is_available()) 71 | netD.cuda() 72 | netD.apply(weights_init) 73 | return netD 74 | 75 | 76 | def define_mask_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, 77 | num_objects=None): 78 | norm_layer = get_norm_layer(norm_type=norm) 79 | netD = MultiscaleMaskDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, 80 | num_objects) 81 | assert (torch.cuda.is_available()) 82 | netD.cuda() 83 | netD.apply(weights_init) 84 | return netD 85 | 86 | 87 | class MultiscaleMaskDiscriminator(nn.Module): 88 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 89 | use_sigmoid=False, num_D=3, num_objects=None): 90 | super(MultiscaleMaskDiscriminator, self).__init__() 91 | self.num_D = num_D 92 | self.n_layers = n_layers 93 | 94 | for i in range(num_D): 95 | netD = NLayerMaskDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, num_objects) 96 | for j in range(n_layers + 2): 97 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) 98 | 99 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 100 | 101 | def singleD_forward(self, model, input, cond): 102 | result = [input] 103 | for i in range(len(model) - 2): 104 | # print(result[-1].shape) 105 | result.append(model[i](result[-1])) 106 | 107 | a, b, c, d = result[-1].shape 108 | cond = cond.view(a, -1, 1, 1).expand(-1, -1, c, d) 109 | concat = torch.cat([result[-1], cond], dim=1) 110 | result.append(model[len(model) - 2](concat)) 111 | result.append(model[len(model) - 1](result[-1])) 112 | return result[1:] 113 | 114 | def forward(self, input, cond): 115 | num_D = self.num_D 116 | result = [] 117 | input_downsampled = input 118 | for i in range(num_D): 119 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in 120 | range(self.n_layers + 2)] 121 | result.append(self.singleD_forward(model, input_downsampled, cond)) 122 | if i != (num_D - 1): 123 | input_downsampled = self.downsample(input_downsampled) 124 | return result 125 | 126 | 127 | # Defines the PatchGAN discriminator with the specified arguments. 128 | class NLayerMaskDiscriminator(nn.Module): 129 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, 130 | num_objects=None): 131 | super(NLayerMaskDiscriminator, self).__init__() 132 | self.n_layers = n_layers 133 | 134 | kw = 3 135 | padw = int(np.ceil((kw - 1.0) / 2)) 136 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 137 | 138 | nf = ndf 139 | for n in range(1, n_layers): 140 | nf_prev = nf 141 | nf = min(nf * 2, 512) 142 | sequence += [[ 143 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 144 | norm_layer(nf), nn.LeakyReLU(0.2, True) 145 | ]] 146 | 147 | nf_prev = nf 148 | nf = min(nf * 2, 512) 149 | nf_prev += num_objects 150 | sequence += [[ 151 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 152 | norm_layer(nf), 153 | nn.LeakyReLU(0.2, True) 154 | ]] 155 | 156 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 157 | 158 | if use_sigmoid: 159 | sequence += [[nn.Sigmoid()]] 160 | 161 | for n in range(len(sequence)): 162 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 163 | 164 | def forward(self, input): 165 | res = [input] 166 | for n in range(self.n_layers + 2): 167 | model = getattr(self, 'model' + str(n)) 168 | res.append(model(res[-1])) 169 | return res[1:] 170 | 171 | 172 | class MultiscaleDiscriminator(nn.Module): 173 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 174 | use_sigmoid=False, num_D=3): 175 | super(MultiscaleDiscriminator, self).__init__() 176 | self.num_D = num_D 177 | self.n_layers = n_layers 178 | 179 | for i in range(num_D): 180 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid) 181 | for j in range(n_layers + 2): 182 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) 183 | 184 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 185 | 186 | def singleD_forward(self, model, input): 187 | result = [input] 188 | for i in range(len(model)): 189 | result.append(model[i](result[-1])) 190 | return result[1:] 191 | 192 | def forward(self, input): 193 | num_D = self.num_D 194 | result = [] 195 | input_downsampled = input 196 | for i in range(num_D): 197 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in 198 | range(self.n_layers + 2)] 199 | result.append(self.singleD_forward(model, input_downsampled)) 200 | if i != (num_D - 1): 201 | input_downsampled = self.downsample(input_downsampled) 202 | return result 203 | 204 | 205 | # Defines the PatchGAN discriminator with the specified arguments. 206 | class NLayerDiscriminator(nn.Module): 207 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 208 | super(NLayerDiscriminator, self).__init__() 209 | self.n_layers = n_layers 210 | 211 | kw = 4 212 | padw = int(np.ceil((kw - 1.0) / 2)) 213 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 214 | 215 | nf = ndf 216 | for n in range(1, n_layers): 217 | nf_prev = nf 218 | nf = min(nf * 2, 512) 219 | sequence += [[ 220 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 221 | norm_layer(nf), nn.LeakyReLU(0.2, True) 222 | ]] 223 | 224 | nf_prev = nf 225 | nf = min(nf * 2, 512) 226 | sequence += [[ 227 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 228 | norm_layer(nf), 229 | nn.LeakyReLU(0.2, True) 230 | ]] 231 | 232 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 233 | 234 | if use_sigmoid: 235 | sequence += [[nn.Sigmoid()]] 236 | 237 | for n in range(len(sequence)): 238 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) 239 | 240 | def forward(self, input): 241 | res = [input] 242 | for n in range(self.n_layers + 2): 243 | model = getattr(self, 'model' + str(n)) 244 | res.append(model(res[-1])) 245 | return res[1:] 246 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import os 3 | import json 4 | from collections import defaultdict 5 | import random 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from scene_generation.args import get_args 10 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn 11 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn 12 | from scene_generation.metrics import jaccard 13 | from scene_generation.trainer import Trainer 14 | 15 | from scripts.inception_score import InceptionScore 16 | 17 | 18 | def build_coco_dsets(args): 19 | dset_kwargs = { 20 | 'image_dir': args.coco_train_image_dir, 21 | 'instances_json': args.coco_train_instances_json, 22 | 'stuff_json': args.coco_train_stuff_json, 23 | 'image_size': args.image_size, 24 | 'mask_size': args.mask_size, 25 | 'max_samples': args.num_train_samples, 26 | 'min_object_size': args.min_object_size, 27 | 'min_objects_per_image': args.min_objects_per_image, 28 | 'instance_whitelist': args.instance_whitelist, 29 | 'stuff_whitelist': args.stuff_whitelist, 30 | 'include_other': args.coco_include_other, 31 | } 32 | if args.is_panoptic: 33 | dset_kwargs['panoptic'] = args.coco_panoptic_train 34 | dset_kwargs['panoptic_segmentation'] = args.coco_panoptic_segmentation_train 35 | train_dset = CocoPanopticSceneGraphDataset(**dset_kwargs) 36 | else: 37 | train_dset = CocoSceneGraphDataset(**dset_kwargs) 38 | num_objs = train_dset.total_objects() 39 | num_imgs = len(train_dset) 40 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs)) 41 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs)) 42 | 43 | dset_kwargs['image_dir'] = args.coco_val_image_dir 44 | dset_kwargs['instances_json'] = args.coco_val_instances_json 45 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json 46 | dset_kwargs['max_samples'] = args.num_val_samples 47 | if args.is_panoptic: 48 | dset_kwargs['panoptic'] = args.coco_panoptic_val 49 | dset_kwargs['panoptic_segmentation'] = args.coco_panoptic_segmentation_val 50 | val_dset = CocoPanopticSceneGraphDataset(**dset_kwargs) 51 | else: 52 | val_dset = CocoSceneGraphDataset(**dset_kwargs) 53 | 54 | assert train_dset.vocab == val_dset.vocab 55 | vocab = json.loads(json.dumps(train_dset.vocab)) 56 | 57 | return vocab, train_dset, val_dset 58 | 59 | 60 | def build_loaders(args): 61 | vocab, train_dset, val_dset = build_coco_dsets(args) 62 | if args.is_panoptic: 63 | collate_fn = coco_panoptic_collate_fn 64 | else: 65 | collate_fn = coco_collate_fn 66 | 67 | loader_kwargs = { 68 | 'batch_size': args.batch_size, 69 | 'num_workers': args.loader_num_workers, 70 | 'shuffle': True, 71 | 'collate_fn': collate_fn, 72 | } 73 | train_loader = DataLoader(train_dset, **loader_kwargs) 74 | 75 | loader_kwargs['shuffle'] = args.shuffle_val 76 | val_loader = DataLoader(val_dset, **loader_kwargs) 77 | return vocab, train_loader, val_loader 78 | 79 | 80 | def check_model(args, loader, model, inception_score, use_gt): 81 | fid = None 82 | num_samples = 0 83 | total_iou = 0 84 | total_boxes = 0 85 | inception_score.clean() 86 | with torch.no_grad(): 87 | for batch in loader: 88 | batch = [tensor.cuda() for tensor in batch] 89 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = batch 90 | 91 | # Run the model as it has been run during training 92 | if use_gt: 93 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, attributes=attributes, 94 | test_mode=True, use_gt_box=True) 95 | else: 96 | attributes = torch.zeros_like(attributes) 97 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=None, attributes=attributes, 98 | test_mode=True, use_gt_box=False) 99 | imgs_pred, boxes_pred, masks_pred, _, pred_layout, _ = model_out 100 | 101 | iou, _, _ = jaccard(boxes_pred, boxes) 102 | total_iou += iou 103 | total_boxes += boxes_pred.size(0) 104 | inception_score(imgs_pred) 105 | 106 | num_samples += imgs.size(0) 107 | if num_samples >= args.num_val_samples: 108 | break 109 | 110 | inception_mean, inception_std = inception_score.compute_score(splits=5) 111 | 112 | avg_iou = total_iou / total_boxes 113 | 114 | out = [avg_iou, inception_mean, inception_std, fid] 115 | 116 | return tuple(out) 117 | 118 | 119 | def get_checkpoint(args, vocab): 120 | if args.restore_from_checkpoint: 121 | restore_path = '%s_with_model.pt' % args.checkpoint_name 122 | restore_path = os.path.join(args.output_dir, restore_path) 123 | assert restore_path is not None 124 | assert os.path.isfile(restore_path) 125 | print('Restoring from checkpoint:') 126 | print(restore_path) 127 | checkpoint = torch.load(restore_path) 128 | t = checkpoint['counters']['t'] 129 | epoch = checkpoint['counters']['epoch'] 130 | else: 131 | t, epoch = 0, 0 132 | checkpoint = { 133 | 'args': args.__dict__, 134 | 'vocab': vocab, 135 | 'model_kwargs': {}, 136 | 'd_obj_kwargs': {}, 137 | 'd_mask_kwargs': {}, 138 | 'd_img_kwargs': {}, 139 | 'd_global_mask_kwargs': {}, 140 | 'losses_ts': [], 141 | 'losses': defaultdict(list), 142 | 'd_losses': defaultdict(list), 143 | 'checkpoint_ts': [], 144 | 'train_inception': [], 145 | 'val_losses': defaultdict(list), 146 | 'val_inception': [], 147 | 'norm_d': [], 148 | 'norm_g': [], 149 | 'counters': { 150 | 't': None, 151 | 'epoch': None, 152 | }, 153 | 'model_state': None, 'model_best_state': None, 154 | 'optim_state': None, 'optim_best_state': None, 155 | 'd_obj_state': None, 'd_obj_best_state': None, 156 | 'd_obj_optim_state': None, 'd_obj_optim_best_state': None, 157 | 'd_img_state': None, 'd_img_best_state': None, 158 | 'd_img_optim_state': None, 'd_img_optim_best_state': None, 159 | 'd_mask_state': None, 'd_mask_best_state': None, 160 | 'd_mask_optim_state': None, 'd_mask_optim_best_state': None, 161 | 'best_t': [], 162 | } 163 | return t, epoch, checkpoint 164 | 165 | 166 | def main(args): 167 | print(args) 168 | vocab, train_loader, val_loader = build_loaders(args) 169 | t, epoch, checkpoint = get_checkpoint(args, vocab) 170 | trainer = Trainer(args, vocab, checkpoint) 171 | if args.restore_from_checkpoint: 172 | trainer.restore_checkpoint(checkpoint) 173 | else: 174 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as outfile: 175 | json.dump(vars(args), outfile) 176 | 177 | inception_score = InceptionScore(cuda=True, batch_size=args.batch_size, resize=True) 178 | train_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=True) 179 | t_avg_iou, t_inception_mean, t_inception_std, _ = train_results 180 | index = int(t / args.print_every) 181 | trainer.writer.add_scalar('checkpoint/{}'.format('train_iou'), t_avg_iou, index) 182 | trainer.writer.add_scalar('checkpoint/{}'.format('train_inception_mean'), t_inception_mean, index) 183 | trainer.writer.add_scalar('checkpoint/{}'.format('train_inception_std'), t_inception_std, index) 184 | print(t_avg_iou, t_inception_mean, t_inception_std) 185 | 186 | while t < args.num_iterations: 187 | epoch += 1 188 | print('Starting epoch %d' % epoch) 189 | 190 | for batch in train_loader: 191 | t += 1 192 | batch = [tensor.cuda() for tensor in batch] 193 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = batch 194 | 195 | use_gt = random.randint(0, 1) != 0 196 | if not use_gt: 197 | attributes = torch.zeros_like(attributes) 198 | model_out = trainer.model(imgs, objs, triples, obj_to_img, 199 | boxes_gt=boxes, masks_gt=masks, attributes=attributes) 200 | imgs_pred, boxes_pred, masks_pred, layout, layout_pred, layout_wrong = model_out 201 | 202 | layout_one_hot = layout[:, :trainer.num_obj, :, :] 203 | layout_pred_one_hot = layout_pred[:, :trainer.num_obj, :, :] 204 | 205 | trainer.train_generator(imgs, imgs_pred, masks, masks_pred, layout, 206 | objs, boxes, boxes_pred, obj_to_img, use_gt) 207 | 208 | imgs_pred_detach = imgs_pred.detach() 209 | masks_pred_detach = masks_pred.detach() 210 | boxes_pred_detach = boxes.detach() 211 | layout_detach = layout.detach() 212 | layout_wrong_detach = layout_wrong.detach() 213 | trainer.train_mask_discriminator(masks, masks_pred_detach, objs) 214 | trainer.train_obj_discriminator(imgs, imgs_pred_detach, objs, boxes, boxes_pred_detach, obj_to_img) 215 | trainer.train_image_discriminator(imgs, imgs_pred_detach, layout_detach, layout_wrong_detach) 216 | 217 | if t % args.print_every == 0 or t == 1: 218 | trainer.write_losses(checkpoint, t) 219 | trainer.write_images(t, imgs, imgs_pred, layout_one_hot, layout_pred_one_hot) 220 | 221 | if t % args.checkpoint_every == 0: 222 | print('begin check model train') 223 | train_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=True) 224 | print('begin check model val') 225 | val_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=False) 226 | trainer.save_checkpoint(checkpoint, t, args, epoch, train_results, val_results) 227 | 228 | 229 | if __name__ == '__main__': 230 | args = get_args() 231 | main(args) 232 | -------------------------------------------------------------------------------- /scripts/gui/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |

INTERACTIVE SCENE GENERATION

12 |

ORON ASHUAL & LIOR WOLF

13 |
14 |
15 |
16 | 19 |
20 | 227 |
228 |
229 | 232 |
233 | 234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 |
263 | 264 |
265 |
266 | 267 |
268 |
269 |
270 | 273 |
274 |
275 | 276 | 0 277 |
278 | 279 |
280 | 281 | 282 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /scene_generation/bilinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | """ 21 | Functions for performing differentiable bilinear cropping of images, for use in 22 | the object discriminator 23 | """ 24 | 25 | 26 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'): 27 | """ 28 | Inputs: 29 | - feats: FloatTensor of shape (N, C, H, W) 30 | - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates 31 | - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps; 32 | each element is in the range [0, N) and bbox_to_feats[b] = i means that 33 | bbox[b] will be cropped from feats[i]. 34 | - HH, WW: Size of the output crops 35 | 36 | Returns: 37 | - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to 38 | crop from feats[bbox_to_feats[i]]. 39 | """ 40 | if backend == 'cudnn': 41 | return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW) 42 | N, C, H, W = feats.size() 43 | B = bbox.size(0) 44 | if WW is None: WW = HH 45 | dtype, device = feats.dtype, feats.device 46 | crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device) 47 | for i in range(N): 48 | idx = (bbox_to_feats.data == i).nonzero() 49 | if idx.dim() == 0: 50 | continue 51 | idx = idx.view(-1) 52 | n = idx.size(0) 53 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 54 | cur_bbox = bbox[idx] 55 | cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW) 56 | crops[idx] = cur_crops 57 | return crops 58 | 59 | 60 | def _invperm(p): 61 | N = p.size(0) 62 | eye = torch.arange(0, N).type_as(p) 63 | pp = (eye[:, None] == p).nonzero()[:, 1] 64 | return pp 65 | 66 | 67 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None): 68 | N, C, H, W = feats.size() 69 | B = bbox.size(0) 70 | if WW is None: WW = HH 71 | dtype = feats.data.type() 72 | 73 | feats_flat, bbox_flat, all_idx = [], [], [] 74 | for i in range(N): 75 | idx = (bbox_to_feats.data == i).nonzero() 76 | if idx.dim() == 0: 77 | continue 78 | idx = idx.view(-1) 79 | n = idx.size(0) 80 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous() 81 | cur_bbox = bbox[idx] 82 | 83 | feats_flat.append(cur_feats) 84 | bbox_flat.append(cur_bbox) 85 | all_idx.append(idx) 86 | 87 | feats_flat = torch.cat(feats_flat, dim=0) 88 | bbox_flat = torch.cat(bbox_flat, dim=0) 89 | crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn') 90 | 91 | # If the crops were sequential (all_idx is identity permutation) then we can 92 | # simply return them; otherwise we need to permute crops by the inverse 93 | # permutation from all_idx. 94 | all_idx = torch.cat(all_idx, dim=0) 95 | eye = torch.arange(0, B).type_as(all_idx) 96 | if (all_idx == eye).all(): 97 | return crops 98 | return crops[_invperm(all_idx)] 99 | 100 | 101 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'): 102 | """ 103 | Take differentiable crops of feats specified by bbox. 104 | 105 | Inputs: 106 | - feats: Tensor of shape (N, C, H, W) 107 | - bbox: Bounding box coordinates of shape (N, 4) in the format 108 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 109 | - HH, WW: Size of the output crops. 110 | 111 | Returns: 112 | - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of 113 | feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling. 114 | """ 115 | N = feats.size(0) 116 | assert bbox.size(0) == N 117 | assert bbox.size(1) == 4 118 | if WW is None: WW = HH 119 | if backend == 'cudnn': 120 | # Change box from [0, 1] to [-1, 1] coordinate system 121 | bbox = 2 * bbox - 1 122 | x0, y0 = bbox[:, 0], bbox[:, 1] 123 | x1, y1 = bbox[:, 2], bbox[:, 3] 124 | X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW) 125 | Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW) 126 | if backend == 'jj': 127 | return bilinear_sample(feats, X, Y) 128 | elif backend == 'cudnn': 129 | grid = torch.stack([X, Y], dim=3) 130 | return F.grid_sample(feats, grid) 131 | 132 | 133 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0): 134 | """ 135 | Inverse operation to crop_bbox; construct output images where the feature maps 136 | from feats have been reshaped and placed into the positions specified by bbox. 137 | 138 | Inputs: 139 | - feats: Tensor of shape (N, C, HH, WW) 140 | - bbox: Bounding box coordinates of shape (N, 4) in the format 141 | [x0, y0, x1, y1] in the [0, 1] coordinate space. 142 | - H, W: Size of output. 143 | - fill_value: Portions of the output image that are outside the bounding box 144 | will be filled with this value. 145 | 146 | Returns: 147 | - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by 148 | bbox[i] contains feats[i], reshaped using bilinear sampling. 149 | """ 150 | N, C = feats.size(0), feats.size(1) 151 | assert bbox.size(0) == N 152 | assert bbox.size(1) == 4 153 | if W is None: H = W 154 | 155 | x0, y0 = bbox[:, 0], bbox[:, 1] 156 | x1, y1 = bbox[:, 2] + bbox[:, 0], bbox[:, 3] + bbox[:, 1] 157 | ww = x1 - x0 158 | hh = y1 - y0 159 | 160 | x0 = x0.contiguous().view(N, 1).expand(N, H) 161 | x1 = x1.contiguous().view(N, 1).expand(N, H) 162 | ww = ww.view(N, 1).expand(N, H) 163 | 164 | y0 = y0.contiguous().view(N, 1).expand(N, W) 165 | y1 = y1.contiguous().view(N, 1).expand(N, W) 166 | hh = hh.view(N, 1).expand(N, W) 167 | 168 | X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats) 169 | Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats) 170 | 171 | X = (X - x0) / ww 172 | Y = (Y - y0) / hh 173 | 174 | # For ByteTensors, (x + y).clamp(max=1) gives logical_or 175 | X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W) 176 | Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W) 177 | out_mask = (X_out_mask + Y_out_mask).clamp(max=1) 178 | out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W) 179 | 180 | X = X.view(N, 1, W).expand(N, H, W) 181 | Y = Y.view(N, H, 1).expand(N, H, W) 182 | 183 | out = bilinear_sample(feats, X, Y) 184 | out[out_mask] = fill_value 185 | return out 186 | 187 | 188 | def bilinear_sample(feats, X, Y): 189 | """ 190 | Perform bilinear sampling on the features in feats using the sampling grid 191 | given by X and Y. 192 | 193 | Inputs: 194 | - feats: Tensor holding input feature map, of shape (N, C, H, W) 195 | - X, Y: Tensors holding x and y coordinates of the sampling 196 | grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1]. 197 | Returns: 198 | - out: Tensor of shape (B, C, HH, WW) where out[i] is computed 199 | by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]). 200 | """ 201 | N, C, H, W = feats.size() 202 | assert X.size() == Y.size() 203 | assert X.size(0) == N 204 | _, HH, WW = X.size() 205 | 206 | X = X.mul(W) 207 | Y = Y.mul(H) 208 | 209 | # Get the x and y coordinates for the four samples 210 | x0 = X.floor().clamp(min=0, max=W - 1) 211 | x1 = (x0 + 1).clamp(min=0, max=W - 1) 212 | y0 = Y.floor().clamp(min=0, max=H - 1) 213 | y1 = (y0 + 1).clamp(min=0, max=H - 1) 214 | 215 | # In numpy we could do something like feats[i, :, y0, x0] to pull out 216 | # the elements of feats at coordinates y0 and x0, but PyTorch doesn't 217 | # yet support this style of indexing. Instead we have to use the gather 218 | # method, which only allows us to index along one dimension at a time; 219 | # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W) 220 | # and index along the last dimension. Below we generate linear indices into 221 | # the collapsed last dimension for each of the four combinations we need. 222 | y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 223 | y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW) 224 | y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 225 | y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW) 226 | 227 | # Actually use gather to pull out the values from feats corresponding 228 | # to our four samples, then reshape them to (BB, C, HH, WW) 229 | feats_flat = feats.view(N, C, H * W) 230 | v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW) 231 | v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW) 232 | v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW) 233 | v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW) 234 | 235 | # Compute the weights for the four samples 236 | w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 237 | w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 238 | w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW) 239 | w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW) 240 | 241 | # Multiply the samples by the weights to give our interpolated results. 242 | out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 243 | return out 244 | 245 | 246 | def tensor_linspace(start, end, steps=10): 247 | """ 248 | Vectorized version of torch.linspace. 249 | 250 | Inputs: 251 | - start: Tensor of any shape 252 | - end: Tensor of the same shape as start 253 | - steps: Integer 254 | 255 | Returns: 256 | - out: Tensor of shape start.size() + (steps,), such that 257 | out.select(-1, 0) == start, out.select(-1, -1) == end, 258 | and the other elements of out linearly interpolate between 259 | start and end. 260 | """ 261 | assert start.size() == end.size() 262 | view_size = start.size() + (1,) 263 | w_size = (1,) * start.dim() + (steps,) 264 | out_size = start.size() + (steps,) 265 | 266 | start_w = torch.linspace(1, 0, steps=steps).to(start) 267 | start_w = start_w.view(w_size).expand(out_size) 268 | end_w = torch.linspace(0, 1, steps=steps).to(start) 269 | end_w = end_w.view(w_size).expand(out_size) 270 | 271 | start = start.contiguous().view(view_size).expand(out_size) 272 | end = end.contiguous().view(view_size).expand(out_size) 273 | 274 | out = start_w * start + end_w * end 275 | return out 276 | 277 | 278 | if __name__ == '__main__': 279 | import numpy as np 280 | from scipy.misc import imread, imsave, imresize 281 | 282 | cat = imresize(imread('cat.jpg'), (256, 256), anti_aliasing=True) 283 | dog = imresize(imread('dog.jpg'), (256, 256), anti_aliasing=True) 284 | feats = torch.stack([ 285 | torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)), 286 | torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))], 287 | dim=0) 288 | 289 | boxes = torch.FloatTensor([ 290 | [0, 0, 1, 1], 291 | [0.25, 0.25, 0.75, 0.75], 292 | [0, 0, 0.5, 0.5], 293 | ]) 294 | 295 | box_to_feats = torch.LongTensor([1, 0, 1]).cuda() 296 | 297 | feats, boxes = feats.cuda(), boxes.cuda() 298 | crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128) 299 | for i in range(crops.size(0)): 300 | crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8) 301 | imsave('out%d.png' % i, crop_np) 302 | -------------------------------------------------------------------------------- /scripts/train_accuracy_net.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/Prakashvanapalli/pytorch_classifiers 2 | """ 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torchvision 12 | from torch.optim import lr_scheduler 13 | from torch.utils.data import DataLoader 14 | 15 | from scene_generation.bilinear import crop_bbox_batch 16 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn 17 | from scene_generation.utils import int_tuple, bool_flag, str_tuple 18 | 19 | parser = argparse.ArgumentParser(description='Training a pytorch model to classify different plants') 20 | parser.add_argument('-idl', '--input_data_loc', help='', default='data/training_data') 21 | parser.add_argument('-mo', '--model_name', default="resnet101") 22 | parser.add_argument('-f', '--freeze_layers', default=True, action='store_false', help='Bool type') 23 | parser.add_argument('-fi', '--freeze_initial_layers', default=True, action='store_false', help='Bool type') 24 | parser.add_argument('-ep', '--epochs', default=20, type=int) 25 | parser.add_argument('-b', '--batch_size', default=4, type=int) 26 | parser.add_argument('-is', '--input_shape', default=224, type=int) 27 | parser.add_argument('-sl', '--save_loc', default="models/") 28 | parser.add_argument("-g", '--use_gpu', default=True, action='store_false', help='Bool type gpu') 29 | parser.add_argument("-p", '--use_parallel', default=False, action='store_false', help='Bool type to use_parallel') 30 | 31 | parser.add_argument('--dataset', default='coco', type=str) 32 | parser.add_argument('--mask_size', default=32, type=int) 33 | parser.add_argument('--image_size', default='256,256', type=int_tuple) 34 | parser.add_argument('--num_train_samples', default=None, type=int) 35 | parser.add_argument('--num_val_samples', default=1024, type=int) 36 | parser.add_argument('--shuffle_val', default=True, type=bool_flag) 37 | parser.add_argument('--loader_num_workers', default=4, type=int) 38 | parser.add_argument('--include_relationships', default=True, type=bool_flag) 39 | 40 | # COCO-specific options 41 | COCO_DIR = os.path.expanduser('datasets/coco') 42 | parser.add_argument('--coco_train_image_dir', 43 | default=os.path.join(COCO_DIR, 'images/train2017')) 44 | parser.add_argument('--coco_val_image_dir', 45 | default=os.path.join(COCO_DIR, 'images/val2017')) 46 | parser.add_argument('--coco_train_instances_json', 47 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json')) 48 | parser.add_argument('--coco_train_stuff_json', 49 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json')) 50 | parser.add_argument('--coco_val_instances_json', 51 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json')) 52 | parser.add_argument('--coco_val_stuff_json', 53 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json')) 54 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple) 55 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple) 56 | parser.add_argument('--coco_include_other', default=False, type=bool_flag) 57 | parser.add_argument('--min_object_size', default=0.02, type=float) 58 | parser.add_argument('--min_objects_per_image', default=3, type=int) 59 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag) 60 | 61 | 62 | def all_pretrained_models(n_class, name="resnet101", pretrained=True): 63 | if pretrained: 64 | weights = "imagenet" 65 | else: 66 | weights = False 67 | 68 | if name == "resnet18": 69 | print("[Building resnet18]") 70 | model_conv = torchvision.models.resnet18(pretrained=weights) 71 | elif name == "resnet34": 72 | print("[Building resnet34]") 73 | model_conv = torchvision.models.resnet34(pretrained=weights) 74 | elif name == "resnet50": 75 | print("[Building resnet50]") 76 | model_conv = torchvision.models.resnet50(pretrained=weights) 77 | elif name == "resnet101": 78 | print("[Building resnet101]") 79 | model_conv = torchvision.models.resnet101(pretrained=weights) 80 | elif name == "resnet152": 81 | print("[Building resnet152]") 82 | model_conv = torchvision.models.resnet152(pretrained=weights) 83 | else: 84 | raise ValueError 85 | 86 | for i, param in model_conv.named_parameters(): 87 | param.requires_grad = False 88 | 89 | num_ftrs = model_conv.fc.in_features 90 | model_conv.fc = nn.Linear(num_ftrs, n_class) 91 | 92 | if "resnet" in name: 93 | print("[Resnet: Freezing layers only till layer1 including]") 94 | ct = [] 95 | for name, child in model_conv.named_children(): 96 | if "layer1" in ct: 97 | for params in child.parameters(): 98 | params.requires_grad = True 99 | ct.append(name) 100 | 101 | return model_conv 102 | 103 | 104 | def build_coco_dsets(args): 105 | dset_kwargs = { 106 | 'image_dir': args.coco_train_image_dir, 107 | 'instances_json': args.coco_train_instances_json, 108 | 'stuff_json': args.coco_train_stuff_json, 109 | 'stuff_only': args.coco_stuff_only, 110 | 'image_size': args.image_size, 111 | 'mask_size': args.mask_size, 112 | 'max_samples': args.num_train_samples, 113 | 'min_object_size': args.min_object_size, 114 | 'min_objects_per_image': args.min_objects_per_image, 115 | 'instance_whitelist': args.instance_whitelist, 116 | 'stuff_whitelist': args.stuff_whitelist, 117 | 'include_other': args.coco_include_other, 118 | 'include_relationships': args.include_relationships, 119 | 'no__img__': True 120 | } 121 | train_dset = CocoSceneGraphDataset(**dset_kwargs) 122 | num_objs = train_dset.total_objects() 123 | num_imgs = len(train_dset) 124 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs)) 125 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs)) 126 | 127 | dset_kwargs['image_dir'] = args.coco_val_image_dir 128 | dset_kwargs['instances_json'] = args.coco_val_instances_json 129 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json 130 | dset_kwargs['max_samples'] = args.num_val_samples 131 | val_dset = CocoSceneGraphDataset(**dset_kwargs) 132 | 133 | assert train_dset.vocab == val_dset.vocab 134 | vocab = json.loads(json.dumps(train_dset.vocab)) 135 | 136 | return vocab, train_dset, val_dset 137 | 138 | 139 | def build_loaders(args): 140 | vocab, train_dset, val_dset = build_coco_dsets(args) 141 | collate_fn = coco_collate_fn 142 | 143 | loader_kwargs = { 144 | 'batch_size': args.batch_size, 145 | 'num_workers': args.loader_num_workers, 146 | 'shuffle': True, 147 | 'collate_fn': collate_fn, 148 | } 149 | train_loader = DataLoader(train_dset, **loader_kwargs) 150 | 151 | loader_kwargs['shuffle'] = args.shuffle_val 152 | val_loader = DataLoader(val_dset, **loader_kwargs) 153 | return vocab, train_loader, val_loader 154 | 155 | 156 | def train_model(model, test_dataloader, val_dataloader, criterion, optimizer, scheduler, use_gpu, num_epochs=10, 157 | input_shape=224): 158 | since = time.time() 159 | 160 | best_model_wts = model.state_dict() 161 | best_acc = 0.0 162 | device = 'cuda' if use_gpu else 'cpu' 163 | 164 | for epoch in range(num_epochs): 165 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 166 | print('-' * 10) 167 | 168 | # Each epoch has a training and validation phase 169 | for phase in ['train', 'val']: 170 | if phase == 'train': 171 | scheduler.step() 172 | model.train(True) # Set model to training mode 173 | dataloader = test_dataloader 174 | else: 175 | model.train(False) # Set model to evaluate mode 176 | dataloader = val_dataloader 177 | 178 | running_loss = 0.0 179 | running_corrects = 0 180 | objects_len = 0 181 | 182 | # Iterate over data. 183 | for data in dataloader: 184 | # get the inputs 185 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = data 186 | imgs = imgs.to(device) 187 | boxes = boxes.to(device) 188 | obj_to_img = obj_to_img.to(device) 189 | labels = objs.to(device) 190 | 191 | objects_len += obj_to_img.shape[0] 192 | 193 | with torch.no_grad(): 194 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, input_shape) 195 | 196 | # zero the parameter gradients 197 | optimizer.zero_grad() 198 | 199 | # forward 200 | outputs = model(crops) 201 | if type(outputs) == tuple: 202 | outputs, _ = outputs 203 | _, preds = torch.max(outputs, 1) 204 | loss = criterion(outputs, labels) 205 | 206 | # backward + optimize only if in training phase 207 | if phase == 'train': 208 | loss.backward() 209 | optimizer.step() 210 | 211 | # statistics 212 | running_loss += loss.item() 213 | running_corrects += torch.sum(preds == labels) 214 | 215 | epoch_loss = running_loss / objects_len 216 | epoch_acc = running_corrects.item() / objects_len 217 | 218 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 219 | 220 | # deep copy the model 221 | if phase == 'val' and epoch_acc > best_acc: 222 | best_acc = epoch_acc 223 | best_model_wts = model.state_dict() 224 | 225 | print() 226 | 227 | time_elapsed = time.time() - since 228 | print('Training complete in {:.0f}m {:.0f}s'.format( 229 | time_elapsed // 60, time_elapsed % 60)) 230 | print('Best val Acc: {:4f}'.format(best_acc)) 231 | 232 | # load best model weights 233 | model.load_state_dict(best_model_wts) 234 | return model 235 | 236 | 237 | def load_model(model_path): 238 | model_name = 'resnet101' 239 | model = all_pretrained_models(172, name=model_name).cuda() 240 | model.load_state_dict(torch.load(model_path)) 241 | model.eval() 242 | return model 243 | 244 | 245 | if __name__ == '__main__': 246 | args = parser.parse_args() 247 | print(args) 248 | device = 'cuda' if args.use_gpu else 'cpu' 249 | vocab, train_loader, val_loader = build_loaders(args) 250 | num_objs = 172 # len(vocab['object_to_idx']) 251 | print("[Load the model...]") 252 | # Parameters of newly constructed modules have requires_grad=True by default 253 | print( 254 | "Loading model using class: {}, use_gpu: {}, freeze_layers: {}, freeze_initial_layers: {}, name_of_model: {}".format( 255 | num_objs, args.use_gpu, args.freeze_layers, args.freeze_initial_layers, args.model_name)) 256 | model_conv = all_pretrained_models(num_objs, name=args.model_name).to(device) 257 | if args.use_parallel: 258 | print("[Using all the available GPUs]") 259 | model_conv = nn.DataParallel(model_conv, device_ids=[0, 1]) 260 | 261 | print("[Using CrossEntropyLoss...]") 262 | criterion = nn.CrossEntropyLoss() 263 | 264 | print("[Using small learning rate with momentum...]") 265 | optimizer_conv = optim.SGD(list(filter(lambda p: p.requires_grad, model_conv.parameters())), lr=0.001, momentum=0.9) 266 | 267 | print("[Creating Learning rate scheduler...]") 268 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) 269 | 270 | print("[Training the model begun ....]") 271 | model_ft = train_model(model_conv, train_loader, val_loader, criterion, optimizer_conv, exp_lr_scheduler, 272 | args.use_gpu, num_epochs=args.epochs, input_shape=args.input_shape) 273 | 274 | print("[Save the best model]") 275 | model_save_loc = './{}_{}_classes.pth'.format(args.model_name, num_objs) 276 | torch.save(model_ft.state_dict(), model_save_loc) 277 | -------------------------------------------------------------------------------- /scene_generation/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import functools 18 | 19 | import torch.nn as nn 20 | from torch.nn.functional import interpolate 21 | 22 | 23 | def get_normalization_2d(channels, normalization): 24 | if normalization == 'instance': 25 | return nn.InstanceNorm2d(channels) 26 | elif normalization == 'batch': 27 | return nn.BatchNorm2d(channels) 28 | elif normalization == 'none': 29 | return None 30 | else: 31 | raise ValueError('Unrecognized normalization type "%s"' % normalization) 32 | 33 | 34 | def get_activation(name): 35 | kwargs = {} 36 | if name.lower().startswith('leakyrelu'): 37 | if '-' in name: 38 | slope = float(name.split('-')[1]) 39 | kwargs = {'negative_slope': slope} 40 | name = 'leakyrelu' 41 | activations = { 42 | 'relu': nn.ReLU, 43 | 'leakyrelu': nn.LeakyReLU, 44 | } 45 | if name.lower() not in activations: 46 | raise ValueError('Invalid activation "%s"' % name) 47 | return activations[name.lower()](**kwargs) 48 | 49 | 50 | def _init_conv(layer, method): 51 | if not isinstance(layer, nn.Conv2d): 52 | return 53 | if method == 'default': 54 | return 55 | elif method == 'kaiming-normal': 56 | nn.init.kaiming_normal(layer.weight) 57 | elif method == 'kaiming-uniform': 58 | nn.init.kaiming_uniform(layer.weight) 59 | 60 | 61 | class Flatten(nn.Module): 62 | def forward(self, x): 63 | return x.view(x.size(0), -1) 64 | 65 | def __repr__(self): 66 | return 'Flatten()' 67 | 68 | 69 | class Unflatten(nn.Module): 70 | def __init__(self, size): 71 | super(Unflatten, self).__init__() 72 | self.size = size 73 | 74 | def forward(self, x): 75 | return x.view(*self.size) 76 | 77 | def __repr__(self): 78 | size_str = ', '.join('%d' % d for d in self.size) 79 | return 'Unflatten(%s)' % size_str 80 | 81 | 82 | class GlobalAvgPool(nn.Module): 83 | def forward(self, x): 84 | N, C = x.size(0), x.size(1) 85 | return x.view(N, C, -1).mean(dim=2) 86 | 87 | 88 | class ResidualBlock(nn.Module): 89 | def __init__(self, channels, normalization='batch', activation='relu', 90 | padding='same', kernel_size=3, init='default'): 91 | super(ResidualBlock, self).__init__() 92 | 93 | K = kernel_size 94 | P = _get_padding(K, padding) 95 | C = channels 96 | self.padding = P 97 | layers = [ 98 | get_normalization_2d(C, normalization), 99 | get_activation(activation), 100 | nn.Conv2d(C, C, kernel_size=K, padding=P), 101 | get_normalization_2d(C, normalization), 102 | get_activation(activation), 103 | nn.Conv2d(C, C, kernel_size=K, padding=P), 104 | ] 105 | layers = [layer for layer in layers if layer is not None] 106 | for layer in layers: 107 | _init_conv(layer, method=init) 108 | self.net = nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | P = self.padding 112 | shortcut = x 113 | if P == 0: 114 | shortcut = x[:, :, P:-P, P:-P] 115 | y = self.net(x) 116 | return shortcut + self.net(x) 117 | 118 | 119 | def _get_padding(K, mode): 120 | """ Helper method to compute padding size """ 121 | if mode == 'valid': 122 | return 0 123 | elif mode == 'same': 124 | assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K 125 | return (K - 1) // 2 126 | 127 | 128 | def build_cnn(arch, normalization='batch', activation='relu', padding='same', 129 | pooling='max', init='default'): 130 | """ 131 | Build a CNN from an architecture string, which is a list of layer 132 | specification strings. The overall architecture can be given as a list or as 133 | a comma-separated string. 134 | 135 | All convolutions *except for the first* are preceeded by normalization and 136 | nonlinearity. 137 | 138 | All other layers support the following: 139 | - IX: Indicates that the number of input channels to the network is X. 140 | Can only be used at the first layer; if not present then we assume 141 | 3 input channels. 142 | - CK-X: KxK convolution with X output channels 143 | - CK-X-S: KxK convolution with X output channels and stride S 144 | - R: Residual block keeping the same number of channels 145 | - UX: Nearest-neighbor upsampling with factor X 146 | - PX: Spatial pooling with factor X 147 | - FC-X-Y: Flatten followed by fully-connected layer 148 | 149 | Returns a tuple of: 150 | - cnn: An nn.Sequential 151 | - channels: Number of output channels 152 | """ 153 | if isinstance(arch, str): 154 | arch = arch.split(',') 155 | cur_C = 3 156 | if len(arch) > 0 and arch[0][0] == 'I': 157 | cur_C = int(arch[0][1:]) 158 | arch = arch[1:] 159 | 160 | first_conv = True 161 | flat = False 162 | layers = [] 163 | for i, s in enumerate(arch): 164 | if s[0] == 'C': 165 | if not first_conv: 166 | layers.append(get_normalization_2d(cur_C, normalization)) 167 | layers.append(get_activation(activation)) 168 | first_conv = False 169 | vals = [int(i) for i in s[1:].split('-')] 170 | if len(vals) == 2: 171 | K, next_C = vals 172 | stride = 1 173 | elif len(vals) == 3: 174 | K, next_C, stride = vals 175 | # K, next_C = (int(i) for i in s[1:].split('-')) 176 | P = _get_padding(K, padding) 177 | conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride) 178 | layers.append(conv) 179 | _init_conv(layers[-1], init) 180 | cur_C = next_C 181 | elif s[0] == 'R': 182 | norm = 'none' if first_conv else normalization 183 | res = ResidualBlock(cur_C, normalization=norm, activation=activation, 184 | padding=padding, init=init) 185 | layers.append(res) 186 | first_conv = False 187 | elif s[0] == 'U': 188 | factor = int(s[1:]) 189 | layers.append(Interpolate(scale_factor=factor, mode='nearest')) 190 | elif s[0] == 'P': 191 | factor = int(s[1:]) 192 | if pooling == 'max': 193 | pool = nn.MaxPool2d(kernel_size=factor, stride=factor) 194 | elif pooling == 'avg': 195 | pool = nn.AvgPool2d(kernel_size=factor, stride=factor) 196 | layers.append(pool) 197 | elif s[:2] == 'FC': 198 | _, Din, Dout = s.split('-') 199 | Din, Dout = int(Din), int(Dout) 200 | if not flat: 201 | layers.append(Flatten()) 202 | flat = True 203 | layers.append(nn.Linear(Din, Dout)) 204 | if i + 1 < len(arch): 205 | layers.append(get_activation(activation)) 206 | cur_C = Dout 207 | else: 208 | raise ValueError('Invalid layer "%s"' % s) 209 | layers = [layer for layer in layers if layer is not None] 210 | # for layer in layers: 211 | # print(layer) 212 | return nn.Sequential(*layers), cur_C 213 | 214 | 215 | def build_mlp(dim_list, activation='relu', batch_norm='none', 216 | dropout=0, final_nonlinearity=True): 217 | layers = [] 218 | for i in range(len(dim_list) - 1): 219 | dim_in, dim_out = dim_list[i], dim_list[i + 1] 220 | layers.append(nn.Linear(dim_in, dim_out)) 221 | final_layer = (i == len(dim_list) - 2) 222 | if not final_layer or final_nonlinearity: 223 | if batch_norm == 'batch': 224 | layers.append(nn.BatchNorm1d(dim_out)) 225 | if activation == 'relu': 226 | layers.append(nn.ReLU()) 227 | elif activation == 'leakyrelu': 228 | layers.append(nn.LeakyReLU()) 229 | if dropout > 0: 230 | layers.append(nn.Dropout(p=dropout)) 231 | return nn.Sequential(*layers) 232 | 233 | 234 | class ResnetBlock(nn.Module): 235 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): 236 | super(ResnetBlock, self).__init__() 237 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) 238 | 239 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): 240 | conv_block = [] 241 | p = 0 242 | if padding_type == 'reflect': 243 | conv_block += [nn.ReflectionPad2d(1)] 244 | elif padding_type == 'replicate': 245 | conv_block += [nn.ReplicationPad2d(1)] 246 | elif padding_type == 'zero': 247 | p = 1 248 | else: 249 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 250 | 251 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 252 | norm_layer(dim), 253 | activation] 254 | if use_dropout: 255 | conv_block += [nn.Dropout(0.5)] 256 | 257 | p = 0 258 | if padding_type == 'reflect': 259 | conv_block += [nn.ReflectionPad2d(1)] 260 | elif padding_type == 'replicate': 261 | conv_block += [nn.ReplicationPad2d(1)] 262 | elif padding_type == 'zero': 263 | p = 1 264 | else: 265 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 266 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 267 | norm_layer(dim)] 268 | 269 | return nn.Sequential(*conv_block) 270 | 271 | def forward(self, x): 272 | out = x + self.conv_block(x) 273 | return out 274 | 275 | 276 | class ConditionalBatchNorm2d(nn.Module): 277 | def __init__(self, num_features, num_classes): 278 | super(ConditionalBatchNorm2d).__init__() 279 | self.num_features = num_features 280 | self.bn = nn.BatchNorm2d(num_features, affine=False) 281 | self.embed = nn.Embedding(num_classes, num_features * 2) 282 | self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 283 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 284 | 285 | def forward(self, x, y): 286 | out = self.bn(x) 287 | gamma, beta = self.embed(y).chunk(2, 1) 288 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 289 | return out 290 | 291 | 292 | def get_norm_layer(norm_type='instance'): 293 | if norm_type == 'batch': 294 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 295 | elif norm_type == 'instance': 296 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 297 | elif norm_type == 'conditional': 298 | norm_layer = functools.partial(ConditionalBatchNorm2d) 299 | else: 300 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 301 | return norm_layer 302 | 303 | 304 | class Interpolate(nn.Module): 305 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): 306 | super(Interpolate, self).__init__() 307 | self.size = size 308 | self.scale_factor = scale_factor 309 | self.mode = mode 310 | self.align_corners = align_corners 311 | 312 | def forward(self, x): 313 | return interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, 314 | align_corners=self.align_corners) 315 | -------------------------------------------------------------------------------- /scene_generation/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from scene_generation.bilinear import crop_bbox_batch 5 | from scene_generation.generators import mask_net, AppearanceEncoder, define_G 6 | from scene_generation.graph import GraphTripleConv, GraphTripleConvNet 7 | from scene_generation.layers import build_mlp 8 | from scene_generation.layout import masks_to_layout 9 | from scene_generation.utils import VectorPool 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, vocab, image_size=(64, 64), embedding_dim=128, 14 | gconv_dim=128, gconv_hidden_dim=512, 15 | gconv_pooling='avg', gconv_num_layers=5, 16 | mask_size=32, mlp_normalization='none', appearance_normalization='', activation='', 17 | n_downsample_global=4, box_dim=128, 18 | use_attributes=False, box_noise_dim=64, 19 | mask_noise_dim=64, pool_size=100, rep_size=32): 20 | super(Model, self).__init__() 21 | 22 | self.vocab = vocab 23 | self.image_size = image_size 24 | self.use_attributes = use_attributes 25 | self.box_noise_dim = box_noise_dim 26 | self.mask_noise_dim = mask_noise_dim 27 | self.object_size = 64 28 | self.fake_pool = VectorPool(pool_size) 29 | 30 | self.num_objs = len(vocab['object_to_idx']) 31 | self.num_preds = len(vocab['pred_idx_to_name']) 32 | self.obj_embeddings = nn.Embedding(self.num_objs, embedding_dim) 33 | self.pred_embeddings = nn.Embedding(self.num_preds, embedding_dim) 34 | 35 | if use_attributes: 36 | attributes_dim = vocab['num_attributes'] 37 | else: 38 | attributes_dim = 0 39 | if gconv_num_layers == 0: 40 | self.gconv = nn.Linear(embedding_dim, gconv_dim) 41 | elif gconv_num_layers > 0: 42 | gconv_kwargs = { 43 | 'input_dim': embedding_dim, 44 | 'attributes_dim': attributes_dim, 45 | 'output_dim': gconv_dim, 46 | 'hidden_dim': gconv_hidden_dim, 47 | 'pooling': gconv_pooling, 48 | 'mlp_normalization': mlp_normalization, 49 | } 50 | self.gconv = GraphTripleConv(**gconv_kwargs) 51 | 52 | self.gconv_net = None 53 | if gconv_num_layers > 1: 54 | gconv_kwargs = { 55 | 'input_dim': gconv_dim, 56 | 'hidden_dim': gconv_hidden_dim, 57 | 'pooling': gconv_pooling, 58 | 'num_layers': gconv_num_layers - 1, 59 | 'mlp_normalization': mlp_normalization, 60 | } 61 | self.gconv_net = GraphTripleConvNet(**gconv_kwargs) 62 | 63 | box_net_dim = 4 64 | self.box_dim = box_dim 65 | box_net_layers = [self.box_dim, gconv_hidden_dim, box_net_dim] 66 | self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization) 67 | 68 | self.g_mask_dim = gconv_dim + mask_noise_dim 69 | self.mask_net = mask_net(self.g_mask_dim, mask_size) 70 | 71 | self.repr_input = self.g_mask_dim 72 | rep_size = rep_size 73 | rep_hidden_size = 64 74 | repr_layers = [self.repr_input, rep_hidden_size, rep_size] 75 | self.repr_net = build_mlp(repr_layers, batch_norm=mlp_normalization) 76 | 77 | appearance_encoder_kwargs = { 78 | 'vocab': vocab, 79 | 'arch': 'C4-64-2,C4-128-2,C4-256-2', 80 | 'normalization': appearance_normalization, 81 | 'activation': activation, 82 | 'padding': 'valid', 83 | 'vecs_size': self.g_mask_dim 84 | } 85 | self.image_encoder = AppearanceEncoder(**appearance_encoder_kwargs) 86 | 87 | netG_input_nc = self.num_objs + rep_size 88 | output_nc = 3 89 | ngf = 64 90 | n_blocks_global = 9 91 | norm = 'instance' 92 | self.layout_to_image = define_G(netG_input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm) 93 | 94 | def forward(self, gt_imgs, objs, triples, obj_to_img, boxes_gt=None, masks_gt=None, attributes=None, 95 | test_mode=False, use_gt_box=False, features=None): 96 | O, T = objs.size(0), triples.size(0) 97 | obj_vecs, pred_vecs = self.scene_graph_to_vectors(objs, triples, attributes) 98 | 99 | box_vecs, mask_vecs, scene_layout_vecs, wrong_layout_vecs = \ 100 | self.create_components_vecs(gt_imgs, boxes_gt, obj_to_img, objs, obj_vecs, features) 101 | 102 | # Generate Boxes 103 | boxes_pred = self.box_net(box_vecs) 104 | 105 | # Generate Masks 106 | mask_scores = self.mask_net(mask_vecs.view(O, -1, 1, 1)) 107 | masks_pred = mask_scores.squeeze(1).sigmoid() 108 | 109 | H, W = self.image_size 110 | 111 | if test_mode: 112 | boxes = boxes_gt if use_gt_box else boxes_pred 113 | masks = masks_gt if masks_gt is not None else masks_pred 114 | gt_layout = None 115 | pred_layout = masks_to_layout(scene_layout_vecs, boxes, masks, obj_to_img, H, W, test_mode=True) 116 | wrong_layout = None 117 | imgs_pred = self.layout_to_image(pred_layout) 118 | else: 119 | gt_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False) 120 | pred_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_pred, obj_to_img, H, W, test_mode=False) 121 | wrong_layout = masks_to_layout(wrong_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False) 122 | 123 | imgs_pred = self.layout_to_image(gt_layout) 124 | return imgs_pred, boxes_pred, masks_pred, gt_layout, pred_layout, wrong_layout 125 | 126 | def scene_graph_to_vectors(self, objs, triples, attributes): 127 | s, p, o = triples.chunk(3, dim=1) 128 | s, p, o = [x.squeeze(1) for x in [s, p, o]] 129 | edges = torch.stack([s, o], dim=1) 130 | 131 | obj_vecs = self.obj_embeddings(objs) 132 | pred_vecs = self.pred_embeddings(p) 133 | if self.use_attributes: 134 | obj_vecs = torch.cat([obj_vecs, attributes], dim=1) 135 | 136 | if isinstance(self.gconv, nn.Linear): 137 | obj_vecs = self.gconv(obj_vecs) 138 | else: 139 | obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) 140 | if self.gconv_net is not None: 141 | obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) 142 | 143 | return obj_vecs, pred_vecs 144 | 145 | def create_components_vecs(self, imgs, boxes, obj_to_img, objs, obj_vecs, features): 146 | O = objs.size(0) 147 | box_vecs = obj_vecs 148 | mask_vecs = obj_vecs 149 | layout_noise = torch.randn((1, self.mask_noise_dim), dtype=mask_vecs.dtype, device=mask_vecs.device) \ 150 | .repeat((O, 1)) \ 151 | .view(O, self.mask_noise_dim) 152 | mask_vecs = torch.cat([mask_vecs, layout_noise], dim=1) 153 | 154 | # create encoding 155 | if features is None: 156 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size) 157 | obj_repr = self.repr_net(self.image_encoder(crops)) 158 | else: 159 | # Only in inference time 160 | obj_repr = self.repr_net(mask_vecs) 161 | for ind, feature in enumerate(features): 162 | if feature is not None: 163 | obj_repr[ind, :] = feature 164 | # create one-hot vector for label map 165 | one_hot_size = (O, self.num_objs) 166 | one_hot_obj = torch.zeros(one_hot_size, dtype=obj_repr.dtype, device=obj_repr.device) 167 | one_hot_obj = one_hot_obj.scatter_(1, objs.view(-1, 1).long(), 1.0) 168 | layout_vecs = torch.cat([one_hot_obj, obj_repr], dim=1) 169 | 170 | wrong_objs_rep = self.fake_pool.query(objs, obj_repr) 171 | wrong_layout_vecs = torch.cat([one_hot_obj, wrong_objs_rep], dim=1) 172 | return box_vecs, mask_vecs, layout_vecs, wrong_layout_vecs 173 | 174 | def encode_scene_graphs(self, scene_graphs, rand=False): 175 | """ 176 | Encode one or more scene graphs using this model's vocabulary. Inputs to 177 | this method are scene graphs represented as dictionaries like the following: 178 | 179 | { 180 | "objects": ["cat", "dog", "sky"], 181 | "relationships": [ 182 | [0, "next to", 1], 183 | [0, "beneath", 2], 184 | [2, "above", 1], 185 | ] 186 | } 187 | 188 | This scene graph has three relationshps: cat next to dog, cat beneath sky, 189 | and sky above dog. 190 | 191 | Inputs: 192 | - scene_graphs: A dictionary giving a single scene graph, or a list of 193 | dictionaries giving a sequence of scene graphs. 194 | 195 | Returns a tuple of LongTensors (objs, triples, obj_to_img) that have the 196 | same semantics as self.forward. The returned LongTensors will be on the 197 | same device as the model parameters. 198 | """ 199 | if isinstance(scene_graphs, dict): 200 | # We just got a single scene graph, so promote it to a list 201 | scene_graphs = [scene_graphs] 202 | device = next(self.parameters()).device 203 | objs, triples, obj_to_img = [], [], [] 204 | all_attributes = [] 205 | all_features = [] 206 | obj_offset = 0 207 | for i, sg in enumerate(scene_graphs): 208 | attributes = torch.zeros([len(sg['objects']) + 1, 25 + 10], dtype=torch.float, device=device) 209 | # Insert dummy __image__ object and __in_image__ relationships 210 | sg['objects'].append('__image__') 211 | sg['features'].append(sg['image_id']) 212 | image_idx = len(sg['objects']) - 1 213 | for j in range(image_idx): 214 | sg['relationships'].append([j, '__in_image__', image_idx]) 215 | 216 | for obj in sg['objects']: 217 | obj_idx = self.vocab['object_to_idx'][str(self.vocab['object_name_to_idx'][obj])] 218 | if obj_idx is None: 219 | raise ValueError('Object "%s" not in vocab' % obj) 220 | objs.append(obj_idx) 221 | obj_to_img.append(i) 222 | if self.features is not None: 223 | for obj_name, feat_num in zip(objs, sg['features']): 224 | if feat_num == -1: 225 | feat = self.features_one[obj_name][0] 226 | else: 227 | feat = self.features[obj_name][min(feat_num, 99), :] 228 | feat = torch.from_numpy(feat).type(torch.float32).to(device) 229 | all_features.append(feat) 230 | for s, p, o in sg['relationships']: 231 | pred_idx = self.vocab['pred_name_to_idx'].get(p, None) 232 | if pred_idx is None: 233 | raise ValueError('Relationship "%s" not in vocab' % p) 234 | triples.append([s + obj_offset, pred_idx, o + obj_offset]) 235 | for i, size_attr in enumerate(sg['attributes']['size']): 236 | attributes[i, size_attr] = 1 237 | # in image size 238 | attributes[-1, 9] = 1 239 | for i, location_attr in enumerate(sg['attributes']['location']): 240 | attributes[i, location_attr + 10] = 1 241 | # in image location 242 | attributes[-1, 12 + 10] = 1 243 | obj_offset += len(sg['objects']) 244 | all_attributes.append(attributes) 245 | objs = torch.tensor(objs, dtype=torch.int64, device=device) 246 | triples = torch.tensor(triples, dtype=torch.int64, device=device) 247 | obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64, device=device) 248 | attributes = torch.cat(all_attributes) 249 | features = all_features 250 | return objs, triples, obj_to_img, attributes, features 251 | 252 | def forward_json(self, scene_graphs): 253 | """ Convenience method that combines encode_scene_graphs and forward. """ 254 | objs, triples, obj_to_img, attributes, features = self.encode_scene_graphs(scene_graphs) 255 | return self.forward(None, objs, triples, obj_to_img, attributes=attributes, test_mode=True, 256 | use_gt_box=False, features=features), objs 257 | -------------------------------------------------------------------------------- /scripts/sample_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from random import randint 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.misc import imsave 8 | from torch.utils.data import DataLoader 9 | 10 | from scene_generation.data import imagenet_deprocess_batch 11 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn 12 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn 13 | from scene_generation.data.utils import split_graph_batch 14 | from scene_generation.metrics import jaccard 15 | from scene_generation.model import Model 16 | from scene_generation.utils import int_tuple, bool_flag 17 | from scene_generation.bilinear import crop_bbox_batch 18 | from scripts.train_accuracy_net import all_pretrained_models 19 | 20 | 21 | def load_model(model_path): 22 | model_name = 'resnet101' 23 | model = all_pretrained_models(172, name=model_name).cuda() 24 | model.load_state_dict(torch.load(model_path)) 25 | model.eval() 26 | return model 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--checkpoint', required=True) 31 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval']) 32 | parser.add_argument('--accuracy_model_path', default=None) 33 | 34 | # Shared dataset options 35 | parser.add_argument('--dataset', default='coco') 36 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple) 37 | parser.add_argument('--batch_size', default=24, type=int) 38 | parser.add_argument('--shuffle', default=False, type=bool_flag) 39 | parser.add_argument('--loader_num_workers', default=4, type=int) 40 | parser.add_argument('--num_samples', default=10000, type=int) 41 | parser.add_argument('--save_gt_imgs', default=False, type=bool_flag) 42 | parser.add_argument('--save_graphs', default=False, type=bool_flag) 43 | parser.add_argument('--use_gt_boxes', default=False, type=bool_flag) 44 | parser.add_argument('--use_gt_masks', default=False, type=bool_flag) 45 | parser.add_argument('--use_gt_attr', default=False, type=bool_flag) 46 | parser.add_argument('--use_gt_textures', default=False, type=bool_flag) 47 | parser.add_argument('--save_layout', default=False, type=bool_flag) 48 | parser.add_argument('--sample_attributes', default=False, type=bool_flag) 49 | parser.add_argument('--object_size', default=64, type=int) 50 | parser.add_argument('--grid_size', default=25, type=int) 51 | 52 | parser.add_argument('--output_dir', default='output') 53 | 54 | COCO_DIR = os.path.expanduser('../data/coco') 55 | parser.add_argument('--coco_image_dir', 56 | default=os.path.join(COCO_DIR, 'images/val2017')) 57 | parser.add_argument('--instances_json', 58 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json')) 59 | parser.add_argument('--stuff_json', 60 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json')) 61 | 62 | 63 | def build_coco_dset(args, checkpoint): 64 | checkpoint_args = checkpoint['args'] 65 | print('include other: ', checkpoint_args.get('coco_include_other')) 66 | # When using GT masks, using the 67 | mask_size = args.image_size[0] if args.use_gt_masks else checkpoint_args['mask_size'] 68 | dset_kwargs = { 69 | 'image_dir': args.coco_image_dir, 70 | 'instances_json': args.instances_json, 71 | 'stuff_json': args.stuff_json, 72 | 'image_size': args.image_size, 73 | 'mask_size': mask_size, 74 | 'max_samples': args.num_samples, 75 | 'min_object_size': checkpoint_args['min_object_size'], 76 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'], 77 | 'instance_whitelist': checkpoint_args['instance_whitelist'], 78 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'], 79 | 'include_other': checkpoint_args.get('coco_include_other', True), 80 | 'test_part': True, 81 | 'sample_attributes': args.sample_attributes, 82 | 'grid_size': args.grid_size 83 | } 84 | dset = CocoSceneGraphDataset(**dset_kwargs) 85 | return dset 86 | 87 | 88 | def build_coco_panoptic_dset(args, checkpoint): 89 | checkpoint_args = checkpoint['args'] 90 | print('include other: ', checkpoint_args.get('coco_include_other')) 91 | # When using GT masks, using the 92 | mask_size = args.image_size[0] if args.use_gt_masks else checkpoint_args['mask_size'] 93 | dset_kwargs = { 94 | 'image_dir': args.coco_image_dir, 95 | 'instances_json': args.instances_json, 96 | 'panoptic': checkpoint_args['coco_panoptic_val'], 97 | 'panoptic_segmentation': checkpoint_args['coco_panoptic_segmentation_val'], 98 | 'stuff_json': args.stuff_json, 99 | 'image_size': args.image_size, 100 | 'mask_size': mask_size, 101 | 'max_samples': args.num_samples, 102 | 'min_object_size': checkpoint_args['min_object_size'], 103 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'], 104 | 'instance_whitelist': checkpoint_args['instance_whitelist'], 105 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'], 106 | 'include_other': checkpoint_args.get('coco_include_other', True), 107 | 'test_part': True, 108 | 'sample_attributes': args.sample_attributes, 109 | 'grid_size': args.grid_size 110 | } 111 | dset = CocoPanopticSceneGraphDataset(**dset_kwargs) 112 | return dset 113 | 114 | 115 | def build_loader(args, checkpoint, is_panoptic): 116 | if is_panoptic: 117 | dset = build_coco_panoptic_dset(args, checkpoint) 118 | collate_fn = coco_panoptic_collate_fn 119 | else: 120 | dset = build_coco_dset(args, checkpoint) 121 | collate_fn = coco_collate_fn 122 | 123 | loader_kwargs = { 124 | 'batch_size': args.batch_size, 125 | 'num_workers': args.loader_num_workers, 126 | 'shuffle': args.shuffle, 127 | 'collate_fn': collate_fn, 128 | } 129 | loader = DataLoader(dset, **loader_kwargs) 130 | return loader 131 | 132 | 133 | def build_model(args, checkpoint): 134 | kwargs = checkpoint['model_kwargs'] 135 | model = Model(**kwargs) 136 | model_state = checkpoint['model_state'] 137 | model.load_state_dict(model_state) 138 | if args.model_mode == 'eval': 139 | model.eval() 140 | elif args.model_mode == 'train': 141 | model.train() 142 | model.image_size = args.image_size 143 | model.cuda() 144 | return model 145 | 146 | 147 | def makedir(base, name, flag=True): 148 | dir_name = None 149 | if flag: 150 | dir_name = os.path.join(base, name) 151 | if not os.path.isdir(dir_name): 152 | os.makedirs(dir_name) 153 | return dir_name 154 | 155 | 156 | def one_hot_to_rgb(layout_pred, colors, num_objs): 157 | one_hot = layout_pred[:, :num_objs, :, :] 158 | one_hot_3d = torch.einsum('abcd,be->aecd', [one_hot.cpu(), colors]) 159 | one_hot_3d *= (255.0 / one_hot_3d.max()) 160 | return one_hot_3d 161 | 162 | 163 | def run_model(args, checkpoint, output_dir, loader=None): 164 | if args.save_graphs: 165 | from scene_generation.vis import draw_scene_graph 166 | dirname = os.path.dirname(args.checkpoint) 167 | features = None 168 | if not args.use_gt_textures: 169 | features_path = os.path.join(dirname, 'features_clustered_001.npy') 170 | print(features_path) 171 | if os.path.isfile(features_path): 172 | features = np.load(features_path, allow_pickle=True).item() 173 | else: 174 | raise ValueError('No features file') 175 | with torch.no_grad(): 176 | vocab = checkpoint['model_kwargs']['vocab'] 177 | model = build_model(args, checkpoint) 178 | if loader is None: 179 | loader = build_loader(args, checkpoint, vocab['is_panoptic']) 180 | accuracy_model = None 181 | if args.accuracy_model_path is not None and os.path.isfile(args.accuracy_model_path): 182 | accuracy_model = load_model(args.accuracy_model_path) 183 | 184 | img_dir = makedir(output_dir, 'images') 185 | graph_dir = makedir(output_dir, 'graphs', args.save_graphs) 186 | gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs) 187 | layout_dir = makedir(output_dir, 'layouts', args.save_layout) 188 | 189 | img_idx = 0 190 | total_iou = 0 191 | total_boxes = 0 192 | r_05 = 0 193 | r_03 = 0 194 | corrects = 0 195 | real_objects_count = 0 196 | num_objs = model.num_objs 197 | colors = torch.randint(0, 256, [num_objs, 3]).float() 198 | for batch in loader: 199 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = [x.cuda() for x in batch] 200 | 201 | imgs_gt = imagenet_deprocess_batch(imgs) 202 | 203 | if args.use_gt_masks: 204 | masks_gt = masks 205 | else: 206 | masks_gt = None 207 | if args.use_gt_textures: 208 | all_features = None 209 | else: 210 | all_features = [] 211 | for obj_name in objs: 212 | obj_feature = features[obj_name.item()] 213 | random_index = randint(0, obj_feature.shape[0] - 1) 214 | feat = torch.from_numpy(obj_feature[random_index, :]).type(torch.float32).cuda() 215 | all_features.append(feat) 216 | if not args.use_gt_attr: 217 | attributes = torch.zeros_like(attributes) 218 | 219 | # Run the model with predicted masks 220 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks_gt, attributes=attributes, 221 | test_mode=True, use_gt_box=args.use_gt_boxes, features=all_features) 222 | imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out 223 | 224 | if accuracy_model is not None: 225 | if args.use_gt_boxes: 226 | crops = crop_bbox_batch(imgs_pred, boxes, obj_to_img, 224) 227 | else: 228 | crops = crop_bbox_batch(imgs_pred, boxes_pred, obj_to_img, 224) 229 | 230 | outputs = accuracy_model(crops) 231 | if type(outputs) == tuple: 232 | outputs, _ = outputs 233 | _, preds = torch.max(outputs, 1) 234 | 235 | # statistics 236 | for pred, label in zip(preds, objs): 237 | if label.item() != 0: 238 | real_objects_count += 1 239 | corrects += 1 if pred.item() == label.item() else 0 240 | 241 | # Remove the __image__ object 242 | boxes_pred_no_image = [] 243 | boxes_gt_no_image = [] 244 | for o_index in range(len(obj_to_img)): 245 | if o_index < len(obj_to_img) - 1 and obj_to_img[o_index] == obj_to_img[o_index + 1]: 246 | boxes_pred_no_image.append(boxes_pred[o_index]) 247 | boxes_gt_no_image.append(boxes[o_index]) 248 | boxes_pred_no_image = torch.stack(boxes_pred_no_image) 249 | boxes_gt_no_image = torch.stack(boxes_gt_no_image) 250 | 251 | iou, bigger_05, bigger_03 = jaccard(boxes_pred_no_image, boxes_gt_no_image) 252 | total_iou += iou 253 | r_05 += bigger_05 254 | r_03 += bigger_03 255 | total_boxes += boxes_pred_no_image.size(0) 256 | imgs_pred = imagenet_deprocess_batch(imgs_pred) 257 | 258 | obj_data = [objs, boxes_pred, masks_pred] 259 | _, obj_data = split_graph_batch(triples, obj_data, obj_to_img, triple_to_img) 260 | objs, boxes_pred, masks_pred = obj_data 261 | 262 | obj_data_gt = [boxes.data] 263 | if masks is not None: 264 | obj_data_gt.append(masks.data) 265 | triples, obj_data_gt = split_graph_batch(triples, obj_data_gt, obj_to_img, triple_to_img) 266 | layouts_3d = one_hot_to_rgb(layout, colors, num_objs) 267 | for i in range(imgs_pred.size(0)): 268 | img_filename = '%04d.png' % img_idx 269 | if args.save_gt_imgs: 270 | img_gt = imgs_gt[i].numpy().transpose(1, 2, 0) 271 | img_gt_path = os.path.join(gt_img_dir, img_filename) 272 | imsave(img_gt_path, img_gt) 273 | if args.save_layout: 274 | layout_3d = layouts_3d[i].numpy().transpose(1, 2, 0) 275 | layout_path = os.path.join(layout_dir, img_filename) 276 | imsave(layout_path, layout_3d) 277 | 278 | img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0) 279 | img_path = os.path.join(img_dir, img_filename) 280 | imsave(img_path, img_pred_np) 281 | 282 | if args.save_graphs: 283 | graph_img = draw_scene_graph(objs[i], triples[i], vocab) 284 | graph_path = os.path.join(graph_dir, img_filename) 285 | imsave(graph_path, graph_img) 286 | 287 | img_idx += 1 288 | 289 | print('Saved %d images' % img_idx) 290 | avg_iou = total_iou / total_boxes 291 | print('avg_iou {}'.format(avg_iou.item())) 292 | print('r0.5 {}'.format(r_05 / total_boxes)) 293 | print('r0.3 {}'.format(r_03 / total_boxes)) 294 | if accuracy_model is not None: 295 | print('Accuracy {}'.format(corrects / real_objects_count)) 296 | 297 | 298 | if __name__ == '__main__': 299 | args = parser.parse_args() 300 | if args.checkpoint is None: 301 | raise ValueError('Must specify --checkpoint') 302 | 303 | checkpoint = torch.load(args.checkpoint) 304 | print('Loading model from ', args.checkpoint) 305 | run_model(args, checkpoint, args.output_dir) 306 | --------------------------------------------------------------------------------