├── scene_generation
├── __init__.py
├── data
│ ├── __init__.py
│ └── utils.py
├── metrics.py
├── utils.py
├── generators.py
├── args.py
├── graph.py
├── losses.py
├── vis.py
├── layout.py
├── discriminators.py
├── bilinear.py
├── layers.py
└── model.py
├── images
├── arch.png
└── scene_generation.png
├── requirements.txt
├── scripts
├── download_coco.sh
├── gui
│ ├── simple-server.py
│ ├── model.py
│ ├── index.css
│ ├── index_panoptic.html
│ ├── index.js
│ └── index.html
├── inception_score.py
├── create_attributes_file.py
├── encode_features.py
├── train_accuracy_net.py
└── sample_images.py
├── README.md
├── train.py
└── LICENSE
/scene_generation/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/images/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashual/scene_generation/HEAD/images/arch.png
--------------------------------------------------------------------------------
/images/scene_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashual/scene_generation/HEAD/images/scene_generation.png
--------------------------------------------------------------------------------
/scene_generation/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import imagenet_preprocess, imagenet_deprocess
2 | from .utils import imagenet_deprocess_batch
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cloudpickle==1.2.2
2 | imageio==2.5.0
3 | matplotlib==3.1.1
4 | numpy==1.17.2
5 | Pillow==6.2.0
6 | scikit-image==0.15.0
7 | scipy==1.3.1
8 | pytorch==1.0.0
9 | tensorboardX==1.8
10 | torchvision==0.2.2
11 | graphviz==2.40.1
--------------------------------------------------------------------------------
/scripts/download_coco.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eu
2 |
3 | COCO_DIR=datasets/coco
4 | mkdir -p $COCO_DIR
5 |
6 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O $COCO_DIR/annotations_trainval2017.zip
7 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip -O $COCO_DIR/stuff_annotations_trainval2017.zip
8 | wget http://images.cocodataset.org/zips/train2017.zip -O $COCO_DIR/train2017.zip
9 | wget http://images.cocodataset.org/zips/val2017.zip -O $COCO_DIR/val2017.zip
10 |
11 | unzip $COCO_DIR/annotations_trainval2017.zip -d $COCO_DIR
12 | unzip $COCO_DIR/stuff_annotations_trainval2017.zip -d $COCO_DIR
13 | unzip $COCO_DIR/train2017.zip -d $COCO_DIR/images
14 | unzip $COCO_DIR/val2017.zip -d $COCO_DIR/images
15 |
--------------------------------------------------------------------------------
/scene_generation/metrics.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 |
19 |
20 | def intersection(bbox_pred, bbox_gt):
21 | max_xy = torch.min(bbox_pred[:, 2:], bbox_gt[:, 2:])
22 | min_xy = torch.max(bbox_pred[:, :2], bbox_gt[:, :2])
23 | inter = torch.clamp((max_xy - min_xy), min=0)
24 | return inter[:, 0] * inter[:, 1]
25 |
26 |
27 | def jaccard(bbox_pred, bbox_gt):
28 | inter = intersection(bbox_pred, bbox_gt)
29 | area_pred = (bbox_pred[:, 2] - bbox_pred[:, 0]) * (bbox_pred[:, 3] -
30 | bbox_pred[:, 1])
31 | area_gt = (bbox_gt[:, 2] - bbox_gt[:, 0]) * (bbox_gt[:, 3] -
32 | bbox_gt[:, 1])
33 | union = area_pred + area_gt - inter
34 | iou = torch.div(inter, union)
35 | return torch.sum(iou), (iou > 0.5).sum().item(), (iou > 0.3).sum().item()
36 |
--------------------------------------------------------------------------------
/scripts/gui/simple-server.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import json
3 | import os
4 | from http.server import HTTPServer, BaseHTTPRequestHandler
5 | from urllib.parse import unquote
6 |
7 | from scripts.gui.model import get_model, json_to_img
8 |
9 | model = get_model()
10 |
11 |
12 | class StaticServer(BaseHTTPRequestHandler):
13 |
14 | def do_GET(self):
15 | root = os.path.dirname(os.path.realpath(__file__))
16 | path = unquote(self.path)
17 | if path == '/':
18 | filename = root + '/index.html'
19 | elif path.startswith('/get_data?'):
20 | img_path, layout_path = json_to_img(path.split('/get_data?data=')[1], model)
21 | self.send_response(200)
22 | self.send_header('Content-type', 'application/json')
23 | self.end_headers()
24 | self.wfile.write(str.encode(json.dumps({'img_pred': img_path, 'layout_pred': layout_path})))
25 | return
26 | else:
27 | filename = root + self.path
28 |
29 | self.send_response(200)
30 | if filename[-4:] == '.css':
31 | self.send_header('Content-type', 'text/css')
32 | elif filename[-5:] == '.json':
33 | self.send_header('Content-type', 'application/javascript')
34 | elif filename[-3:] == '.js':
35 | self.send_header('Content-type', 'application/javascript')
36 | elif filename[-4:] == '.ico':
37 | return
38 | # self.send_header('Content-type', 'image/x-icon')
39 | else:
40 | self.send_header('Content-type', 'text/html')
41 | self.end_headers()
42 | with open(filename, 'rb') as fh:
43 | html = fh.read()
44 | # html = bytes(html, 'utf8')
45 | self.wfile.write(html)
46 |
47 |
48 | def run(server_class=HTTPServer, handler_class=StaticServer, port=8000):
49 | print('Loading model')
50 | server_address = ('', port)
51 | httpd = server_class(server_address, handler_class)
52 | print('Starting httpd on port {}'.format(port))
53 | httpd.serve_forever()
54 |
55 |
56 | run()
57 |
--------------------------------------------------------------------------------
/scene_generation/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import random
18 |
19 | import torch
20 |
21 |
22 | def int_tuple(s):
23 | return tuple(int(i) for i in s.split(','))
24 |
25 |
26 | def float_tuple(s):
27 | return tuple(float(i) for i in s.split(','))
28 |
29 |
30 | def str_tuple(s):
31 | return tuple(s.split(','))
32 |
33 |
34 | def bool_flag(s):
35 | if s == '1':
36 | return True
37 | elif s == '0':
38 | return False
39 | msg = 'Invalid value "%s" for bool flag (should be 0 or 1)'
40 | raise ValueError(msg % s)
41 |
42 |
43 | class LossManager(object):
44 | def __init__(self):
45 | self.total_loss = None
46 | self.all_losses = {}
47 |
48 | def add_loss(self, loss, name, weight=1.0, use_loss=True):
49 | cur_loss = loss * weight
50 | if use_loss:
51 | if self.total_loss is not None:
52 | self.total_loss += cur_loss
53 | else:
54 | self.total_loss = cur_loss
55 |
56 | self.all_losses[name] = cur_loss.data.cpu().item()
57 |
58 | def items(self):
59 | return self.all_losses.items()
60 |
61 |
62 | class VectorPool:
63 | def __init__(self, pool_size):
64 | self.pool_size = pool_size
65 | self.vectors = {}
66 |
67 | def query(self, objs, vectors):
68 | if self.pool_size == 0:
69 | return vectors
70 | return_vectors = []
71 | for obj, vector in zip(objs, vectors):
72 | obj = obj.item()
73 | vector = vector.cpu().clone().detach()
74 | if obj not in self.vectors:
75 | self.vectors[obj] = []
76 | obj_pool_size = len(self.vectors[obj])
77 | if obj_pool_size == 0:
78 | return_vectors.append(vector)
79 | self.vectors[obj].append(vector)
80 | elif obj_pool_size < self.pool_size:
81 | random_id = random.randint(0, obj_pool_size - 1)
82 | self.vectors[obj].append(vector)
83 | return_vectors.append(self.vectors[obj][random_id])
84 | else:
85 | random_id = random.randint(0, obj_pool_size - 1)
86 | tmp = self.vectors[obj][random_id]
87 | self.vectors[obj][random_id] = vector
88 | return_vectors.append(tmp)
89 | return_vectors = torch.stack(return_vectors).to(vectors.device)
90 | return return_vectors
91 |
--------------------------------------------------------------------------------
/scene_generation/data/utils.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import torch
3 | import numpy as np
4 | import torchvision.transforms as T
5 |
6 | MEAN = [0.5, 0.5, 0.5]
7 | STD = [0.5, 0.5, 0.5]
8 |
9 | INV_MEAN = [-m for m in MEAN]
10 | INV_STD = [1.0 / s for s in STD]
11 |
12 |
13 | def imagenet_preprocess():
14 | return T.Normalize(mean=MEAN, std=STD)
15 |
16 |
17 | def rescale(x):
18 | lo, hi = x.min(), x.max()
19 | return x.sub(lo).div(hi - lo)
20 |
21 |
22 | def imagenet_deprocess(rescale_image=True):
23 | transforms = [
24 | T.Normalize(mean=[0, 0, 0], std=INV_STD),
25 | T.Normalize(mean=INV_MEAN, std=[1.0, 1.0, 1.0]),
26 | ]
27 | if rescale_image:
28 | transforms.append(rescale)
29 | return T.Compose(transforms)
30 |
31 |
32 | def imagenet_deprocess_batch(imgs, rescale=True):
33 | """
34 | Input:
35 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images
36 |
37 | Output:
38 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images
39 | in the range [0, 255]
40 | """
41 | if isinstance(imgs, torch.autograd.Variable):
42 | imgs = imgs.data
43 | imgs = imgs.cpu().clone()
44 | deprocess_fn = imagenet_deprocess(rescale_image=rescale)
45 | imgs_de = []
46 | for i in range(imgs.size(0)):
47 | img_de = deprocess_fn(imgs[i])[None]
48 | img_de = img_de.mul(255).clamp(0, 255)
49 | imgs_de.append(img_de)
50 | imgs_de = torch.cat(imgs_de, dim=0)
51 | return imgs_de
52 |
53 |
54 | class Resize(object):
55 | def __init__(self, size, interp=PIL.Image.BILINEAR):
56 | if isinstance(size, tuple):
57 | H, W = size
58 | self.size = (W, H)
59 | else:
60 | self.size = (size, size)
61 | self.interp = interp
62 |
63 | def __call__(self, img):
64 | return img.resize(self.size, self.interp)
65 |
66 |
67 | def unpack_var(v):
68 | if isinstance(v, torch.autograd.Variable):
69 | return v.data
70 | return v
71 |
72 |
73 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img):
74 | triples = unpack_var(triples)
75 | obj_data = [unpack_var(o) for o in obj_data]
76 | obj_to_img = unpack_var(obj_to_img)
77 | triple_to_img = unpack_var(triple_to_img)
78 |
79 | triples_out = []
80 | obj_data_out = [[] for _ in obj_data]
81 | obj_offset = 0
82 | N = obj_to_img.max() + 1
83 | for i in range(N):
84 | o_idxs = (obj_to_img == i).nonzero().view(-1)
85 | t_idxs = (triple_to_img == i).nonzero().view(-1)
86 |
87 | cur_triples = triples[t_idxs].clone()
88 | cur_triples[:, 0] -= obj_offset
89 | cur_triples[:, 2] -= obj_offset
90 | triples_out.append(cur_triples)
91 |
92 | for j, o_data in enumerate(obj_data):
93 | cur_o_data = None
94 | if o_data is not None:
95 | cur_o_data = o_data[o_idxs]
96 | obj_data_out[j].append(cur_o_data)
97 |
98 | obj_offset += o_idxs.size(0)
99 |
100 | return triples_out, obj_data_out
101 |
102 |
103 | def rgb2id(color):
104 | if isinstance(color, np.ndarray) and len(color.shape) == 3:
105 | if color.dtype == np.uint8:
106 | color = color.astype(np.uint32)
107 | return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
108 | return color[0] + 256 * color[1] + 256 * 256 * color[2]
109 |
--------------------------------------------------------------------------------
/scene_generation/generators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from scene_generation.layers import GlobalAvgPool, build_cnn, ResnetBlock, get_norm_layer, Interpolate
5 |
6 |
7 | def weights_init(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1:
10 | m.weight.data.normal_(0.0, 0.02)
11 | elif classname.find('BatchNorm2d') != -1:
12 | m.weight.data.normal_(1.0, 0.02)
13 | m.bias.data.fill_(0)
14 |
15 |
16 | def mask_net(dim, mask_size):
17 | output_dim = 1
18 | layers, cur_size = [], 1
19 | while cur_size < mask_size:
20 | layers.append(Interpolate(scale_factor=2, mode='nearest'))
21 | layers.append(nn.Conv2d(dim, dim, kernel_size=3, padding=1))
22 | layers.append(nn.BatchNorm2d(dim))
23 | layers.append(nn.ReLU())
24 | cur_size *= 2
25 | if cur_size != mask_size:
26 | raise ValueError('Mask size must be a power of 2')
27 | layers.append(nn.Conv2d(dim, output_dim, kernel_size=1))
28 | return nn.Sequential(*layers)
29 |
30 |
31 | class AppearanceEncoder(nn.Module):
32 | def __init__(self, vocab, arch, normalization='none', activation='relu',
33 | padding='same', vecs_size=1024, pooling='avg'):
34 | super(AppearanceEncoder, self).__init__()
35 | self.vocab = vocab
36 |
37 | cnn_kwargs = {
38 | 'arch': arch,
39 | 'normalization': normalization,
40 | 'activation': activation,
41 | 'pooling': pooling,
42 | 'padding': padding,
43 | }
44 | cnn, channels = build_cnn(**cnn_kwargs)
45 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(channels, vecs_size))
46 |
47 | def forward(self, crops):
48 | return self.cnn(crops)
49 |
50 |
51 | def define_G(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=9, norm='instance'):
52 | norm_layer = get_norm_layer(norm_type=norm)
53 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
54 | assert (torch.cuda.is_available())
55 | netG.cuda()
56 | netG.apply(weights_init)
57 | return netG
58 |
59 | ##############################################################################
60 | # Generator
61 | ##############################################################################
62 | class GlobalGenerator(nn.Module):
63 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
64 | padding_type='reflect'):
65 | assert (n_blocks >= 0)
66 | super(GlobalGenerator, self).__init__()
67 | activation = nn.ReLU(True)
68 |
69 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
70 | ### downsample
71 | for i in range(n_downsampling):
72 | mult = 2 ** i
73 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
74 | norm_layer(ngf * mult * 2), activation]
75 |
76 | ### resnet blocks
77 | mult = 2 ** n_downsampling
78 | for i in range(n_blocks):
79 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
80 |
81 | ### upsample
82 | for i in range(n_downsampling):
83 | mult = 2 ** (n_downsampling - i)
84 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
85 | output_padding=1),
86 | norm_layer(int(ngf * mult / 2)), activation]
87 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
88 | self.model = nn.Sequential(*model)
89 |
90 | def forward(self, input):
91 | return self.model(input)
--------------------------------------------------------------------------------
/scripts/inception_score.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data
6 | import torchvision.datasets as dset
7 | import torchvision.transforms as transforms
8 | from scipy.stats import entropy
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from torchvision.models.inception import inception_v3
12 | from scene_generation.layers import Interpolate
13 |
14 |
15 | class InceptionScore(nn.Module):
16 | def __init__(self, cuda=True, batch_size=32, resize=False):
17 | super(InceptionScore, self).__init__()
18 | assert batch_size > 0
19 | self.resize = resize
20 | self.batch_size = batch_size
21 | self.cuda = cuda
22 | # Set up dtype
23 | self.device = 'cuda' if cuda else 'cpu'
24 | if not cuda and torch.cuda.is_available():
25 | print("WARNING: You have a CUDA device, so you should probably set cuda=True")
26 |
27 | # Load inception model
28 | self.inception_model = inception_v3(pretrained=True, transform_input=False).to(self.device)
29 | self.inception_model.eval()
30 | self.up = Interpolate(size=(299, 299), mode='bilinear').to(self.device)
31 | self.clean()
32 |
33 | def clean(self):
34 | self.preds = np.zeros((0, 1000))
35 |
36 | def get_pred(self, x):
37 | if self.resize:
38 | x = self.up(x)
39 | x = self.inception_model(x)
40 | return F.softmax(x, dim=1).data.cpu().numpy()
41 |
42 | def forward(self, imgs):
43 | # Get predictions
44 | preds_imgs = self.get_pred(imgs.to(self.device))
45 | self.preds = np.append(self.preds, preds_imgs, axis=0)
46 |
47 | def compute_score(self, splits=1):
48 | # Now compute the mean kl-div
49 | split_scores = []
50 | preds = self.preds
51 | N = self.preds.shape[0]
52 | for k in range(splits):
53 | part = preds[k * (N // splits): (k + 1) * (N // splits), :]
54 | py = np.mean(part, axis=0)
55 | scores = []
56 | for i in range(part.shape[0]):
57 | pyx = part[i, :]
58 | scores.append(entropy(pyx, py))
59 | split_scores.append(np.exp(np.mean(scores)))
60 |
61 | return np.mean(split_scores), np.std(split_scores)
62 |
63 |
64 | # def inception_score():
65 | # """Computes the inception score of the generated images imgs
66 | #
67 | # imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
68 | # cuda -- whether or not to run on GPU
69 | # batch_size -- batch size for feeding into Inception v3
70 | # splits -- number of splits
71 | # """
72 |
73 |
74 | if __name__ == '__main__':
75 | class IgnoreLabelDataset(torch.utils.data.Dataset):
76 | def __init__(self, orig):
77 | self.orig = orig
78 |
79 | def __getitem__(self, index):
80 | return self.orig[index][0]
81 |
82 | def __len__(self):
83 | return len(self.orig)
84 |
85 |
86 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
87 | parser.add_argument('--dir', default='', type=str)
88 | parser.add_argument('--splits', default=1, type=int)
89 | args = parser.parse_args()
90 | imagenet_data = dset.ImageFolder(args.dir, transform=transforms.Compose([
91 | transforms.ToTensor(),
92 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
93 | ]))
94 |
95 | # data_loader = torch.utils.data.DataLoader(imagenet_data,
96 | # batch_size=4,
97 | # shuffle=False,
98 | # num_workers=4)
99 |
100 | print("Calculating Inception Score...")
101 | # print(inception_score(imagenet_data, cuda=True, batch_size=32, resize=True, splits=args.splits))
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Specifying Object Attributes and Relations in Interactive Scene Generation
2 | A PyTorch implementation of the paper [Specifying Object Attributes and Relations in Interactive Scene Generation](https://arxiv.org/abs/1909.05379)
3 |

4 |
5 | ## Paper
6 | [Specifying Object Attributes and Relations in Interactive Scene Generation](https://arxiv.org/abs/1909.05379)
7 |
8 | [Oron Ashual](https://www.linkedin.com/in/oronashual/)1, [Lior Wolf](https://www.cs.tau.ac.il/~wolf/)1,2
9 | 1 Tel-Aviv University, 2 Facebook AI Research
10 | The IEEE International Conference on Computer Vision ([ICCV](http://iccv2019.thecvf.com/)), 2019, (Oral)
11 |
12 | ## Network Architechture
13 | 
14 |
15 | ## Youtube
16 |
17 |

18 |
19 |
20 | ## Usage
21 |
22 | ### 1. Create a virtual environment (optional)
23 | All code was developed and tested on Ubuntu 18.04 with Python 3.6 (Anaconda) and PyTorch 1.0.
24 |
25 | ```bash
26 | conda create -n scene_generation python=3.7
27 | conda activate scene_generation
28 | ```
29 | ### 2. Clone the repository
30 | ```bash
31 | cd ~
32 | git clone https://github.com/ashual/scene_generation.git
33 | cd scene_generation
34 | ```
35 |
36 | ### 3. Install dependencies
37 | ```bash
38 | conda install --file requirements.txt -c conda-forge -c pytorch
39 | ```
40 | * install a PyTorch version which will fit your CUDA TOOLKIT
41 |
42 | ### 4. Install COCO API
43 | Note: we didn't train our models with COCO panoptic dataset, the coco_panoptic.py code is for the sake of the community only.
44 | ```bash
45 | cd ~
46 | git clone https://github.com/cocodataset/cocoapi.git
47 | cd cocoapi/PythonAPI/
48 | python setup.py install
49 | cd ~/scene_generation
50 | ```
51 |
52 | ### 5. Train
53 | ```bash
54 | $ python train.py
55 | ```
56 |
57 | ### 6. Encode the Appearance attributes
58 | ```bash
59 | python scripts/encode_features --checkpoint TRAINED_MODEL_CHECKPOINT
60 | ```
61 |
62 | ### 7. Sample Images
63 | ```bash
64 | python scripts/sample_images.py --checkpoint TRAINED_MODEL_CHECKPOINT --batch_size 32 --output_dir OUTPUT_DIR
65 | ```
66 |
67 | ### 8. or Download trained models
68 | Download [these](https://drive.google.com/drive/folders/1_E56YskDXdmq06FRsIiPAedpBovYOO8X?usp=sharing) files into models/
69 |
70 |
71 | ### 9. Play with the GUI
72 | The GUI was built as POC. Use it at your own risk:
73 | ```bash
74 | python scripts/gui/simple-server.py --checkpoint YOUR_MODEL_CHECKPOINT --output_dir [DIR_NAME] --draw_scene_graphs 0
75 | ```
76 |
77 | ### 10. Results
78 | Results were measured by sample images from the validation set and then running these 3 official scripts:
79 | 1. FID - https://github.com/bioinf-jku/TTUR (Tensorflow implementation)
80 | 2. Inception - https://github.com/openai/improved-gan/blob/master/inception_score/model.py (Tensorflow implementation)
81 | 3. Diversity - https://github.com/richzhang/PerceptualSimilarity (Pytorch implementation)
82 | 4. Accuracy - Training code is attached `train_accuracy_net.py`. A trained model is provided. Adding the argument `--accuracy_model_path MODEL_PATH` will output the accuracy of the objects.
83 |
84 | #### Reproduce the comparison figure (Figure 3.)
85 | Run this command
86 | ```bash
87 | $ python scripts/sample_images.py --checkpoint TRAINED_MODEL_CHECKPOINT --output_dir OUTPUT_DIR
88 | ```
89 |
90 | with these arguments:
91 | * (c) - Ground truth layout: --use_gt_boxes 1 --use_gt_masks 1
92 | * (d) - Ground truth location attributes: --use_gt_attr 1
93 | * (e) - Ground truth appearance attributes: --use_gt_textures 1
94 | * (f) - Scene Graph only - No extra attributes needed
95 |
96 | ## Citation
97 |
98 | If you find this code useful in your research then please cite
99 | ```
100 | @InProceedings{Ashual_2019_ICCV,
101 | author = {Ashual, Oron and Wolf, Lior},
102 | title = {Specifying Object Attributes and Relations in Interactive Scene Generation},
103 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
104 | month = {October},
105 | year = {2019}
106 | }
107 | ```
108 |
109 | ## Acknowledgement
110 | Our project borrows some source files from [sg2im](https://github.com/google/sg2im). We thank the authors.
--------------------------------------------------------------------------------
/scripts/create_attributes_file.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | from torch.utils.data import DataLoader
5 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn
6 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn
7 | from scene_generation.utils import int_tuple, bool_flag
8 | from scene_generation.utils import str_tuple
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 |
13 | # Shared dataset options
14 | parser.add_argument('--dataset', default='coco')
15 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple)
16 | parser.add_argument('--min_object_size', default=0.02, type=float)
17 | parser.add_argument('--min_objects_per_image', default=3, type=int)
18 | parser.add_argument('--max_objects_per_image', default=8, type=int)
19 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple)
20 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple)
21 | parser.add_argument('--coco_include_other', default=False, type=bool_flag)
22 | parser.add_argument('--is_panoptic', default=False, type=bool_flag)
23 | parser.add_argument('--batch_size', default=24, type=int)
24 | parser.add_argument('--shuffle', default=False, type=bool_flag)
25 | parser.add_argument('--loader_num_workers', default=4, type=int)
26 | parser.add_argument('--num_samples', default=None, type=int)
27 | parser.add_argument('--object_size', default=64, type=int)
28 | parser.add_argument('--grid_size', default=25, type=int)
29 | parser.add_argument('--size_attribute_len', default=10, type=int)
30 |
31 | parser.add_argument('--output_dir', default='models')
32 |
33 |
34 | COCO_DIR = os.path.expanduser('datasets/coco')
35 | parser.add_argument('--coco_image_dir',
36 | default=os.path.join(COCO_DIR, 'images/train2017'))
37 | parser.add_argument('--instances_json',
38 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json'))
39 | parser.add_argument('--stuff_json',
40 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json'))
41 | parser.add_argument('--coco_panoptic_train', default=os.path.join(COCO_DIR, 'annotations/panoptic_train2017.json'))
42 | parser.add_argument('--coco_panoptic_segmentation_train',
43 | default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_train2017'))
44 |
45 |
46 |
47 | def build_coco_dset(args):
48 | dset_kwargs = {
49 | 'image_dir': args.coco_image_dir,
50 | 'instances_json': args.instances_json,
51 | 'stuff_json': args.stuff_json,
52 | 'image_size': args.image_size,
53 | 'mask_size': 32,
54 | 'max_samples': args.num_samples,
55 | 'min_object_size': args.min_object_size,
56 | 'min_objects_per_image': args.min_objects_per_image,
57 | 'max_objects_per_image': args.max_objects_per_image,
58 | 'instance_whitelist': args.instance_whitelist,
59 | 'stuff_whitelist': args.stuff_whitelist,
60 | 'include_other': args.coco_include_other,
61 | 'test_part': False,
62 | 'sample_attributes': False,
63 | 'grid_size': args.grid_size
64 | }
65 | dset = CocoSceneGraphDataset(**dset_kwargs)
66 | return dset
67 |
68 |
69 | def build_coco_panoptic_dset(args):
70 | dset_kwargs = {
71 | 'image_dir': args.coco_image_dir,
72 | 'instances_json': args.instances_json,
73 | 'panoptic': args.coco_panoptic_train,
74 | 'panoptic_segmentation': args.coco_panoptic_segmentation_train,
75 | 'stuff_json': args.stuff_json,
76 | 'image_size': args.image_size,
77 | 'mask_size': 32,
78 | 'max_samples': args.num_samples,
79 | 'min_object_size': args.min_object_size,
80 | 'min_objects_per_image': args.min_objects_per_image,
81 | 'max_objects_per_image': args.max_objects_per_image,
82 | 'instance_whitelist': args.instance_whitelist,
83 | 'stuff_whitelist': args.stuff_whitelist,
84 | 'include_other': args.coco_include_other,
85 | 'test_part': False,
86 | 'sample_attributes': args.sample_attributes,
87 | 'grid_size': args.grid_size
88 | }
89 | dset = CocoPanopticSceneGraphDataset(**dset_kwargs)
90 | return dset
91 |
92 |
93 | def build_loader(args, is_panoptic):
94 | if is_panoptic:
95 | dset = build_coco_panoptic_dset(args)
96 | collate_fn = coco_panoptic_collate_fn
97 | else:
98 | dset = build_coco_dset(args)
99 | collate_fn = coco_collate_fn
100 |
101 | loader_kwargs = {
102 | 'batch_size': args.batch_size,
103 | 'num_workers': args.loader_num_workers,
104 | 'shuffle': args.shuffle,
105 | 'collate_fn': collate_fn,
106 | }
107 | loader = DataLoader(dset, **loader_kwargs)
108 | return loader
109 |
110 |
111 | if __name__ == '__main__':
112 | args = parser.parse_args()
113 |
114 | print('Loading dataset')
115 | loader = build_loader(args, args.is_panoptic)
116 | vocab = loader.dataset.vocab
117 | idx_to_name = vocab['my_idx_to_obj']
118 | sample_attributes = {'location': {}, 'size': {}}
119 | for obj_name in idx_to_name:
120 | sample_attributes['location'][obj_name] = [0] * args.grid_size
121 | sample_attributes['size'][obj_name] = [0] * args.size_attribute_len
122 |
123 | print('Iterating objects')
124 | for _, objs, _, _, _, _, _, attributes in loader:
125 | for obj, attribute in zip(objs, attributes):
126 | obj = obj.item()
127 | if obj == 0:
128 | continue
129 | obj_name = idx_to_name[obj - 1]
130 | size_index = attribute.int().tolist()[:args.size_attribute_len].index(1)
131 | location_index = attribute.int().tolist()[args.size_attribute_len:].index(1)
132 | sample_attributes['size'][obj_name][size_index] += 1
133 | sample_attributes['location'][obj_name][location_index] += 1
134 |
135 | attributes_file = './models/attributes_{}_{}.pickle'.format(args.size_attribute_len, args.grid_size)
136 | print('Saving attributes file to {}'.format(attributes_file))
137 | pickle.dump(sample_attributes, open(attributes_file, 'wb'))
138 |
--------------------------------------------------------------------------------
/scene_generation/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import socket
4 | from datetime import datetime
5 |
6 | from scene_generation.utils import int_tuple, str_tuple, bool_flag
7 |
8 | COCO_DIR = os.path.expanduser('datasets/coco')
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | # Optimization hyperparameters
13 | parser.add_argument('--batch_size', default=12, type=int)
14 | parser.add_argument('--num_iterations', default=1000000, type=int)
15 | parser.add_argument('--learning_rate', default=1e-4, type=float)
16 | parser.add_argument('--mask_learning_rate', default=1e-5, type=float)
17 |
18 | # Dataset options
19 | parser.add_argument('--image_size', default='128,128', type=int_tuple)
20 | parser.add_argument('--num_train_samples', default=None, type=int)
21 | parser.add_argument('--num_val_samples', default=1024, type=int)
22 | parser.add_argument('--shuffle_val', default=True, type=bool_flag)
23 | parser.add_argument('--loader_num_workers', default=4, type=int)
24 | parser.add_argument('--coco_train_image_dir',
25 | default=os.path.join(COCO_DIR, 'images/train2017'))
26 | parser.add_argument('--coco_val_image_dir',
27 | default=os.path.join(COCO_DIR, 'images/val2017'))
28 | parser.add_argument('--coco_train_instances_json',
29 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json'))
30 | parser.add_argument('--coco_train_stuff_json',
31 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json'))
32 | parser.add_argument('--coco_val_instances_json',
33 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json'))
34 | parser.add_argument('--coco_val_stuff_json',
35 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json'))
36 | parser.add_argument('--coco_panoptic_train', default=os.path.join(COCO_DIR, 'annotations/panoptic_train2017.json'))
37 | parser.add_argument('--coco_panoptic_val', default=os.path.join(COCO_DIR, 'annotations/panoptic_val2017.json'))
38 | parser.add_argument('--coco_panoptic_segmentation_train', default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_train2017'))
39 | parser.add_argument('--coco_panoptic_segmentation_val', default=os.path.join(COCO_DIR, 'panoptic/annotations/panoptic_val2017'))
40 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple)
41 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple)
42 | parser.add_argument('--coco_include_other', default=False, type=bool_flag)
43 | parser.add_argument('--min_object_size', default=0.02, type=float)
44 | parser.add_argument('--min_objects_per_image', default=3, type=int)
45 | parser.add_argument('--max_objects_per_image', default=8, type=int)
46 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag) # Train over images that have at least one stuff
47 | parser.add_argument('--is_panoptic', default=False, type=bool_flag)
48 |
49 | # Generator options
50 | parser.add_argument('--mask_size', default=32, type=int)
51 | parser.add_argument('--embedding_dim', default=128, type=int)
52 | parser.add_argument('--gconv_dim', default=128, type=int)
53 | parser.add_argument('--gconv_hidden_dim', default=512, type=int)
54 | parser.add_argument('--gconv_num_layers', default=5, type=int)
55 | parser.add_argument('--mlp_normalization', default='none', type=str)
56 | parser.add_argument('--activation', default='leakyrelu-0.2')
57 | parser.add_argument('--pool_size', default=100, type=int)
58 | parser.add_argument('--output_nc', default=3, type=int)
59 | parser.add_argument('--n_downsample_global', default=4, type=int)
60 | parser.add_argument('--box_dim', default=128, type=int)
61 | parser.add_argument('--use_attributes', default=True, type=bool_flag)
62 | parser.add_argument('--beta1', default=0.5, type=float)
63 | parser.add_argument('--box_noise_dim', default=64, type=int)
64 | parser.add_argument('--mask_noise_dim', default=64, type=int)
65 |
66 | # Appearance Generator options
67 | parser.add_argument('--rep_size', default=32, type=int)
68 | parser.add_argument('--appearance_normalization', default='batch')
69 |
70 | # Generator losses
71 | parser.add_argument('--l1_pixel_loss_weight', default=.0, type=float)
72 | parser.add_argument('--bbox_pred_loss_weight', default=10, type=float)
73 | parser.add_argument('--vgg_features_weight', default=10.0, type=float)
74 | parser.add_argument('--d_img_weight', default=1.0, type=float)
75 | parser.add_argument('--d_img_features_weight', default=10.0, type=float)
76 | parser.add_argument('--d_mask_weight', default=1.0, type=float)
77 | parser.add_argument('--d_mask_features_weight', default=10.0, type=float)
78 | parser.add_argument('--d_obj_weight', default=0.1, type=float)
79 | parser.add_argument('--ac_loss_weight', default=0.1, type=float)
80 |
81 | # Image discriminator
82 | parser.add_argument('--ndf', default=64, type=int)
83 | parser.add_argument('--num_D', default=2, type=int)
84 | parser.add_argument('--norm_D', default='instance', type=str)
85 | parser.add_argument('--n_layers_D', default=3, type=int)
86 | parser.add_argument('--no_lsgan', default=False, type=bool_flag) # Default is LSGAN (no_lsgan == False)
87 |
88 | # Mask Discriminator
89 | parser.add_argument('--ndf_mask', default=64, type=int)
90 | parser.add_argument('--num_D_mask', default=1, type=int)
91 | parser.add_argument('--norm_D_mask', default='instance', type=str)
92 | parser.add_argument('--n_layers_D_mask', default=2, type=int)
93 |
94 | # Object discriminator
95 | parser.add_argument('--gan_loss_type', default='gan')
96 | parser.add_argument('--d_normalization', default='batch')
97 | parser.add_argument('--d_padding', default='valid')
98 | parser.add_argument('--d_activation', default='leakyrelu-0.2')
99 | parser.add_argument('--d_obj_arch', default='C4-64-2,C4-128-2,C4-256-2')
100 | parser.add_argument('--crop_size', default=32, type=int)
101 |
102 | # Output options
103 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
104 | log_dir = os.path.join(os.getcwd(), 'output', current_time + '_' + socket.gethostname())
105 | parser.add_argument('--print_every', default=100, type=int)
106 | parser.add_argument('--checkpoint_every', default=10000, type=int)
107 | parser.add_argument('--output_dir', default=log_dir)
108 | parser.add_argument('--checkpoint_name', default='checkpoint')
109 | parser.add_argument('--restore_from_checkpoint', default=False, type=bool_flag)
110 |
111 |
112 | def get_args():
113 | return parser.parse_args()
114 |
--------------------------------------------------------------------------------
/scripts/encode_features.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from sklearn.cluster import KMeans
7 | from sklearn.manifold import TSNE
8 | from torch.utils.data import DataLoader
9 |
10 | from scene_generation.bilinear import crop_bbox_batch
11 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn
12 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn
13 | from scene_generation.model import Model
14 | from scene_generation.utils import int_tuple, bool_flag
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--checkpoint', required=True)
18 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval'])
19 |
20 | # Shared dataset options
21 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple)
22 | parser.add_argument('--batch_size', default=64, type=int)
23 | parser.add_argument('--shuffle', default=False, type=bool_flag)
24 | parser.add_argument('--loader_num_workers', default=4, type=int)
25 | parser.add_argument('--num_samples', default=None, type=int)
26 | parser.add_argument('--object_size', default=64, type=int)
27 |
28 | # For COCO
29 | COCO_DIR = os.path.expanduser('~/data3/data/coco')
30 | parser.add_argument('--coco_image_dir', default=os.path.join(COCO_DIR, 'images/train2017'))
31 | parser.add_argument('--instances_json', default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json'))
32 | parser.add_argument('--stuff_json', default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json'))
33 |
34 |
35 | def build_coco_dset(args, checkpoint):
36 | checkpoint_args = checkpoint['args']
37 | print('include other: ', checkpoint_args.get('coco_include_other'))
38 | dset_kwargs = {
39 | 'image_dir': args.coco_image_dir,
40 | 'instances_json': args.instances_json,
41 | 'stuff_json': args.stuff_json,
42 | 'image_size': args.image_size,
43 | 'mask_size': checkpoint_args['mask_size'],
44 | 'max_samples': args.num_samples,
45 | 'min_object_size': checkpoint_args['min_object_size'],
46 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'],
47 | 'instance_whitelist': checkpoint_args['instance_whitelist'],
48 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'],
49 | 'include_other': checkpoint_args.get('coco_include_other', True),
50 | }
51 | dset = CocoSceneGraphDataset(**dset_kwargs)
52 | return dset
53 |
54 |
55 | def build_model(args, checkpoint):
56 | kwargs = checkpoint['model_kwargs']
57 | model = Model(**kwargs)
58 | model.load_state_dict(checkpoint['model_state'])
59 | if args.model_mode == 'eval':
60 | model.eval()
61 | elif args.model_mode == 'train':
62 | model.train()
63 | model.image_size = args.image_size
64 | model.cuda()
65 | return model
66 |
67 |
68 | def build_loader(args, checkpoint):
69 | dset = build_coco_dset(args, checkpoint)
70 | collate_fn = coco_collate_fn
71 |
72 | loader_kwargs = {
73 | 'batch_size': args.batch_size,
74 | 'num_workers': args.loader_num_workers,
75 | 'shuffle': args.shuffle,
76 | 'collate_fn': collate_fn,
77 | }
78 | loader = DataLoader(dset, **loader_kwargs)
79 | return loader
80 |
81 |
82 | def cluster(features, num_objs, n_clusters, save_path):
83 | name = 'features'
84 | centers = {}
85 | for label in range(num_objs):
86 | feat = features[label]
87 | if feat.shape[0]:
88 | n_feat_clusters = min(feat.shape[0], n_clusters)
89 | if n_feat_clusters < n_clusters:
90 | print(label)
91 | kmeans = KMeans(n_clusters=n_feat_clusters, random_state=0).fit(feat)
92 | if n_feat_clusters == 1:
93 | centers[label] = kmeans.cluster_centers_
94 | else:
95 | one_dimension_centers = TSNE(n_components=1).fit_transform(kmeans.cluster_centers_)
96 | args = np.argsort(one_dimension_centers.reshape(-1))
97 | centers[label] = kmeans.cluster_centers_[args]
98 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % n_clusters)
99 | np.save(save_name, centers)
100 | print('saving to %s' % save_name)
101 |
102 |
103 | def main(opt):
104 | name = 'features'
105 | checkpoint = torch.load(opt.checkpoint)
106 | rep_size = checkpoint['model_kwargs']['rep_size']
107 | vocab = checkpoint['model_kwargs']['vocab']
108 | num_objs = len(vocab['object_to_idx'])
109 | model = build_model(opt, checkpoint)
110 | loader = build_loader(opt, checkpoint)
111 |
112 | save_path = os.path.dirname(opt.checkpoint)
113 |
114 | ########### Encode features ###########
115 | counter = 0
116 | max_counter = 1000000000
117 | print('begin')
118 | with torch.no_grad():
119 | features = {}
120 | for label in range(num_objs):
121 | features[label] = np.zeros((0, rep_size))
122 | for i, data in enumerate(loader):
123 | if counter >= max_counter:
124 | break
125 | imgs = data[0].cuda()
126 | objs = data[1]
127 | objs = [j.item() for j in objs]
128 | boxes = data[2].cuda()
129 | obj_to_img = data[5].cuda()
130 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, opt.object_size)
131 | feat = model.repr_net(model.image_encoder(crops)).cpu()
132 | for ind, label in enumerate(objs):
133 | features[label] = np.append(features[label], feat[ind].view(1, -1), axis=0)
134 | counter += len(objs)
135 |
136 | # print('%d / %d images' % (i + 1, dataset_size))
137 | save_name = os.path.join(save_path, name + '.npy')
138 | np.save(save_name, features)
139 |
140 | ############## Clustering ###########
141 | print('begin clustering')
142 | load_name = os.path.join(save_path, name + '.npy')
143 | features = np.load(load_name).item()
144 | cluster(features, num_objs, 100, save_path)
145 | cluster(features, num_objs, 10, save_path)
146 | cluster(features, num_objs, 1, save_path)
147 |
148 |
149 | if __name__ == '__main__':
150 | opt = parser.parse_args()
151 | opt.object_size = 64
152 | main(opt)
153 |
--------------------------------------------------------------------------------
/scene_generation/graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 | from scene_generation.layers import build_mlp
21 |
22 | """
23 | PyTorch modules for dealing with graphs.
24 | """
25 |
26 |
27 | def _init_weights(module):
28 | if hasattr(module, 'weight'):
29 | if isinstance(module, nn.Linear):
30 | nn.init.kaiming_normal_(module.weight)
31 |
32 |
33 | class GraphTripleConv(nn.Module):
34 | """
35 | A single layer of scene graph convolution.
36 | """
37 |
38 | def __init__(self, input_dim, attributes_dim=0, output_dim=None, hidden_dim=512,
39 | pooling='avg', mlp_normalization='none'):
40 | super(GraphTripleConv, self).__init__()
41 | if output_dim is None:
42 | output_dim = input_dim
43 | self.input_dim = input_dim
44 | self.output_dim = output_dim
45 | self.hidden_dim = hidden_dim
46 |
47 | assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling
48 | self.pooling = pooling
49 | net1_layers = [3 * input_dim + 2 * attributes_dim, hidden_dim, 2 * hidden_dim + output_dim]
50 | net1_layers = [l for l in net1_layers if l is not None]
51 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization)
52 | self.net1.apply(_init_weights)
53 |
54 | net2_layers = [hidden_dim, hidden_dim, output_dim]
55 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization)
56 | self.net2.apply(_init_weights)
57 |
58 | def forward(self, obj_vecs, pred_vecs, edges):
59 | """
60 | Inputs:
61 | - obj_vecs: FloatTensor of shape (O, D) giving vectors for all objects
62 | - pred_vecs: FloatTensor of shape (T, D) giving vectors for all predicates
63 | - edges: LongTensor of shape (T, 2) where edges[k] = [i, j] indicates the
64 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]]
65 |
66 | Outputs:
67 | - new_obj_vecs: FloatTensor of shape (O, D) giving new vectors for objects
68 | - new_pred_vecs: FloatTensor of shape (T, D) giving new vectors for predicates
69 | """
70 | dtype, device = obj_vecs.dtype, obj_vecs.device
71 | O, T = obj_vecs.size(0), pred_vecs.size(0)
72 | Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim
73 |
74 | # Break apart indices for subjects and objects; these have shape (T,)
75 | s_idx = edges[:, 0].contiguous()
76 | o_idx = edges[:, 1].contiguous()
77 |
78 | # Get current vectors for subjects and objects; these have shape (T, Din)
79 | cur_s_vecs = obj_vecs[s_idx]
80 | cur_o_vecs = obj_vecs[o_idx]
81 |
82 | # Get current vectors for triples; shape is (T, 3 * Din)
83 | # Pass through net1 to get new triple vecs; shape is (T, 2 * H + Dout)
84 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1)
85 | new_t_vecs = self.net1(cur_t_vecs)
86 |
87 | # Break apart into new s, p, and o vecs; s and o vecs have shape (T, H) and
88 | # p vecs have shape (T, Dout)
89 | new_s_vecs = new_t_vecs[:, :H]
90 | new_p_vecs = new_t_vecs[:, H:(H + Dout)]
91 | new_o_vecs = new_t_vecs[:, (H + Dout):(2 * H + Dout)]
92 |
93 | # Allocate space for pooled object vectors of shape (O, H)
94 | pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device)
95 |
96 | # Use scatter_add to sum vectors for objects that appear in multiple triples;
97 | # we first need to expand the indices to have shape (T, D)
98 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs)
99 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs)
100 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs)
101 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs)
102 |
103 | if self.pooling == 'avg':
104 | # Figure out how many times each object has appeared, again using
105 | # some scatter_add trickery.
106 | obj_counts = torch.zeros(O, dtype=dtype, device=device)
107 | ones = torch.ones(T, dtype=dtype, device=device)
108 | obj_counts = obj_counts.scatter_add(0, s_idx, ones)
109 | obj_counts = obj_counts.scatter_add(0, o_idx, ones)
110 |
111 | # Divide the new object vectors by the number of times they
112 | # appeared, but first clamp at 1 to avoid dividing by zero;
113 | # objects that appear in no triples will have output vector 0
114 | # so this will not affect them.
115 | obj_counts = obj_counts.clamp(min=1)
116 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1)
117 |
118 | # Send pooled object vectors through net2 to get output object vectors,
119 | # of shape (O, Dout)
120 | new_obj_vecs = self.net2(pooled_obj_vecs)
121 |
122 | return new_obj_vecs, new_p_vecs
123 |
124 |
125 | class GraphTripleConvNet(nn.Module):
126 | """ A sequence of scene graph convolution layers """
127 |
128 | def __init__(self, input_dim, num_layers=5, hidden_dim=512, pooling='avg',
129 | mlp_normalization='none'):
130 | super(GraphTripleConvNet, self).__init__()
131 |
132 | self.num_layers = num_layers
133 | self.gconvs = nn.ModuleList()
134 | gconv_kwargs = {
135 | 'input_dim': input_dim,
136 | 'hidden_dim': hidden_dim,
137 | 'pooling': pooling,
138 | 'mlp_normalization': mlp_normalization,
139 | }
140 | for _ in range(self.num_layers):
141 | self.gconvs.append(GraphTripleConv(**gconv_kwargs))
142 |
143 | def forward(self, obj_vecs, pred_vecs, edges):
144 | for i in range(self.num_layers):
145 | gconv = self.gconvs[i]
146 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges)
147 | return obj_vecs, pred_vecs
148 |
--------------------------------------------------------------------------------
/scripts/gui/model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import os
5 | from datetime import datetime
6 |
7 | import numpy as np
8 | import torch
9 | from imageio import imwrite
10 |
11 | import scene_generation.vis as vis
12 | from scene_generation.data.utils import imagenet_deprocess_batch
13 | from scene_generation.model import Model
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--checkpoint', required=True)
17 | parser.add_argument('--output_dir', default='outputs')
18 | parser.add_argument('--draw_scene_graphs', type=int, default=0)
19 | parser.add_argument('--device', default='gpu', choices=['cpu', 'gpu'])
20 | args = parser.parse_args()
21 |
22 |
23 | def get_model():
24 | if not os.path.isfile(args.checkpoint):
25 | print('ERROR: Checkpoint file "%s" not found' % args.checkpoint)
26 | print('Maybe you forgot to download pretraind models? Try running:')
27 | print('bash scripts/download_models.sh')
28 | return
29 |
30 | output_dir = os.path.join('scripts', 'gui', 'images', args.output_dir)
31 | if not os.path.isdir(output_dir):
32 | print('Output directory "%s" does not exist; creating it' % args.output_dir)
33 | os.makedirs(output_dir)
34 |
35 | if args.device == 'cpu':
36 | device = torch.device('cpu')
37 | elif args.device == 'gpu':
38 | device = torch.device('cuda:0')
39 | if not torch.cuda.is_available():
40 | print('WARNING: CUDA not available; falling back to CPU')
41 | device = torch.device('cpu')
42 |
43 | # Load the model, with a bit of care in case there are no GPUs
44 | map_location = 'cpu' if device == torch.device('cpu') else None
45 | checkpoint = torch.load(args.checkpoint, map_location=map_location)
46 | dirname = os.path.dirname(args.checkpoint)
47 | features_path = os.path.join(dirname, 'features_clustered_100.npy')
48 | features_path_one = os.path.join(dirname, 'features_clustered_001.npy')
49 | features = np.load(features_path, allow_pickle=True).item()
50 | features_one = np.load(features_path_one, allow_pickle=True).item()
51 | model = Model(**checkpoint['model_kwargs'])
52 | model_state = checkpoint['model_state']
53 | model.load_state_dict(model_state)
54 | model.features = features
55 | model.features_one = features_one
56 | model.colors = torch.randint(0, 256, [172, 3]).float()
57 | model.colors[0, :] = 256
58 | model.eval()
59 | model.to(device)
60 | return model
61 |
62 |
63 | def json_to_img(scene_graph, model):
64 | output_dir = args.output_dir
65 | scene_graphs = json_to_scene_graph(scene_graph)
66 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
67 |
68 | # Run the model forward
69 | with torch.no_grad():
70 | (imgs, boxes_pred, masks_pred, layout, layout_pred, _), objs = model.forward_json(scene_graphs)
71 | imgs = imagenet_deprocess_batch(imgs)
72 |
73 | # Save the generated image
74 | for i in range(imgs.shape[0]):
75 | img_np = imgs[i].numpy().transpose(1, 2, 0).astype('uint8')
76 | img_path = os.path.join('scripts', 'gui', 'images', output_dir, 'img{}.png'.format(current_time))
77 | imwrite(img_path, img_np)
78 | return_img_path = os.path.join('images', output_dir, 'img{}.png'.format(current_time))
79 |
80 | # Save the generated layout image
81 | for i in range(imgs.shape[0]):
82 | img_layout_np = one_hot_to_rgb(layout_pred[:, :172, :, :], model.colors)[0].numpy().transpose(1, 2, 0).astype(
83 | 'uint8')
84 | obj_colors = []
85 | for obj in objs[:-1]:
86 | new_color = torch.cat([model.colors[obj] / 256, torch.ones(1)])
87 | obj_colors.append(new_color)
88 |
89 | img_layout_path = os.path.join('scripts', 'gui', 'images', output_dir, 'img_layout{}.png'.format(current_time))
90 | vis.add_boxes_to_layout(img_layout_np, scene_graphs[i]['objects'], boxes_pred, img_layout_path,
91 | colors=obj_colors)
92 | return_img_layout_path = os.path.join('images', output_dir, 'img_layout{}.png'.format(current_time))
93 |
94 | # Draw and save the scene graph
95 | if args.draw_scene_graphs:
96 | for i, sg in enumerate(scene_graphs):
97 | sg_img = vis.draw_scene_graph(sg['objects'], sg['relationships'])
98 | sg_img_path = os.path.join('scripts', 'gui', 'images', output_dir, 'sg{}.png'.format(current_time))
99 | imwrite(sg_img_path, sg_img)
100 | sg_img_path = os.path.join('images', output_dir, 'sg{}.png'.format(current_time))
101 |
102 | return return_img_path, return_img_layout_path
103 |
104 |
105 | def one_hot_to_rgb(one_hot, colors):
106 | one_hot_3d = torch.einsum('abcd,be->aecd', (one_hot.cpu(), colors.cpu()))
107 | one_hot_3d *= (255.0 / one_hot_3d.max())
108 | return one_hot_3d
109 |
110 |
111 | def json_to_scene_graph(json_text):
112 | scene = json.loads(json_text)
113 | if len(scene) == 0:
114 | return []
115 | image_id = scene['image_id']
116 | scene = scene['objects']
117 | objects = [i['text'] for i in scene]
118 | relationships = []
119 | size = []
120 | location = []
121 | features = []
122 | for i in range(0, len(objects)):
123 | obj_s = scene[i]
124 | # Check for inside / surrounding
125 |
126 | sx0 = obj_s['left']
127 | sy0 = obj_s['top']
128 | sx1 = obj_s['width'] + sx0
129 | sy1 = obj_s['height'] + sy0
130 |
131 | margin = (obj_s['size'] + 1) / 10 / 2
132 | mean_x_s = 0.5 * (sx0 + sx1)
133 | mean_y_s = 0.5 * (sy0 + sy1)
134 |
135 | sx0 = max(0, mean_x_s - margin)
136 | sx1 = min(1, mean_x_s + margin)
137 | sy0 = max(0, mean_y_s - margin)
138 | sy1 = min(1, mean_y_s + margin)
139 |
140 | size.append(obj_s['size'])
141 | location.append(obj_s['location'])
142 |
143 | features.append(obj_s['feature'])
144 | if i == len(objects) - 1:
145 | continue
146 |
147 | obj_o = scene[i + 1]
148 | ox0 = obj_o['left']
149 | oy0 = obj_o['top']
150 | ox1 = obj_o['width'] + ox0
151 | oy1 = obj_o['height'] + oy0
152 |
153 | mean_x_o = 0.5 * (ox0 + ox1)
154 | mean_y_o = 0.5 * (oy0 + oy1)
155 | d_x = mean_x_s - mean_x_o
156 | d_y = mean_y_s - mean_y_o
157 | theta = math.atan2(d_y, d_x)
158 |
159 | margin = (obj_o['size'] + 1) / 10 / 2
160 | ox0 = max(0, mean_x_o - margin)
161 | ox1 = min(1, mean_x_o + margin)
162 | oy0 = max(0, mean_y_o - margin)
163 | oy1 = min(1, mean_y_o + margin)
164 |
165 | if sx0 < ox0 and sx1 > ox1 and sy0 < oy0 and sy1 > oy1:
166 | p = 'surrounding'
167 | elif sx0 > ox0 and sx1 < ox1 and sy0 > oy0 and sy1 < oy1:
168 | p = 'inside'
169 | elif theta >= 3 * math.pi / 4 or theta <= -3 * math.pi / 4:
170 | p = 'left of'
171 | elif -3 * math.pi / 4 <= theta < -math.pi / 4:
172 | p = 'above'
173 | elif -math.pi / 4 <= theta < math.pi / 4:
174 | p = 'right of'
175 | elif math.pi / 4 <= theta < 3 * math.pi / 4:
176 | p = 'below'
177 | relationships.append([i, p, i + 1])
178 |
179 | return [{'objects': objects, 'relationships': relationships, 'attributes': {'size': size, 'location': location},
180 | 'features': features, 'image_id': image_id}]
181 |
--------------------------------------------------------------------------------
/scripts/gui/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | background: linear-gradient(120deg, #252525 0%, #39404c 100%);
3 | padding: 70px 70px 0px 70px;
4 | }
5 |
6 | h1, h2 {
7 | font-family: 'Open Sans', sans-serif;
8 | color: #dbab21;
9 | text-align: center;
10 | font-weight: 100;
11 | letter-spacing: 4px;
12 | }
13 |
14 | .main {
15 | display: flex;
16 | flex-direction: row;
17 | justify-content: space-between;
18 | }
19 |
20 | .buttons {
21 | display: flex;
22 | flex-direction: row;
23 | justify-content: left;
24 | }
25 |
26 | .button {
27 | margin: 20px;
28 | }
29 |
30 | .main-movable {
31 | height: 512px;
32 | width: 512px;
33 | background-color: #FEFEFE;
34 | background-image: -webkit-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -webkit-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB);
35 | background-image: -moz-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -moz-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB);
36 | background-image: -o-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -o-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB);
37 | background-image: -ms-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), -ms-linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB);
38 | background-image: linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB), linear-gradient(45deg, #CBCBCB 25%, transparent 25%, transparent 75%, #CBCBCB 75%, #CBCBCB);
39 | -webkit-background-size: 30px 30px;
40 | -moz-background-size: 30px 30px;
41 | background-size: 30px 30px;
42 | background-position: 0 0, 15px 15px;
43 | }
44 |
45 |
46 | .resize-drag {
47 | color: black;
48 | font-size: 30px;
49 | font-family: 'Open Sans', sans-serif;
50 | border-radius: 8px;
51 | padding: 20px;
52 | margin: 0;
53 | touch-action: none;
54 | text-align: center;
55 |
56 | width: 120px;
57 | /* This makes things *much* easier */
58 | box-sizing: border-box;
59 | }
60 |
61 | #resize-container {
62 | display: inline-block;
63 | width: 100%;
64 | height: 100%;
65 | position: relative;
66 | }
67 |
68 | .line1 {
69 | left: 20%;
70 | }
71 |
72 | .line2 {
73 | left: 40%;
74 | }
75 |
76 | .line3 {
77 | left: 60%;
78 | }
79 |
80 | .line4 {
81 | left: 80%;
82 | }
83 |
84 | .line5 {
85 | bottom: 20%;
86 | }
87 |
88 | .line6 {
89 | bottom: 40%;
90 | }
91 |
92 | .line7 {
93 | bottom: 60%;
94 | }
95 |
96 | .line8 {
97 | bottom: 80%;
98 | }
99 |
100 |
101 | .line-horizontal {
102 | content: "";
103 | position: absolute;
104 | z-index: 1;
105 | border-left: 2px dotted #000000;
106 | top: 0;
107 | bottom: 0;
108 | }
109 |
110 | .line-vertical {
111 | content: "";
112 | position: absolute;
113 | z-index: 1;
114 | border-top: 1px dotted #000000;
115 | left: 0;
116 | right: 0;
117 | }
118 |
119 | #layout_pred, #img_pred {
120 | width: 512px;
121 | height: 512px;
122 | }
123 |
124 | .selected {
125 | color: red;
126 | }
127 |
128 | /* Range Slider */
129 |
130 | *, *:before, *:after {
131 | box-sizing: border-box;
132 | }
133 |
134 |
135 | .range-slider {
136 | margin: 15px 0 0 0;
137 | }
138 |
139 | .range-slider {
140 | width: 50%;
141 | }
142 |
143 | .range-slider__range {
144 | -webkit-appearance: none;
145 | width: 512px;
146 | height: 10px;
147 | border-radius: 5px;
148 | background: #d7dcdf;
149 | outline: none;
150 | }
151 |
152 | .range-slider__range::-webkit-slider-thumb {
153 | -webkit-appearance: none;
154 | appearance: none;
155 | width: 20px;
156 | height: 20px;
157 | border-radius: 50%;
158 | background: #dbab21;
159 | cursor: pointer;
160 | transition: background .15s ease-in-out;
161 | }
162 |
163 | .range-slider__range::-webkit-slider-thumb:hover {
164 | background: #47370d;
165 | }
166 |
167 | .range-slider__range:active::-webkit-slider-thumb {
168 | background: #47370d;
169 | }
170 |
171 | .range-slider__range::-moz-range-thumb {
172 | width: 20px;
173 | height: 20px;
174 | border: 0;
175 | border-radius: 50%;
176 | background: #47370d;
177 | cursor: pointer;
178 | transition: background .15s ease-in-out;
179 | }
180 |
181 | .range-slider__range::-moz-range-thumb:hover {
182 | background: #47370d;
183 | }
184 |
185 | .range-slider__range:active::-moz-range-thumb {
186 | background: #47370d;
187 | }
188 |
189 | .range-slider__range:focus::-webkit-slider-thumb {
190 | box-shadow: 0 0 0 3px #fff, 0 0 0 6px #47370d;
191 | }
192 |
193 | .range-slider__value {
194 | display: inline-block;
195 | position: relative;
196 | width: 60px;
197 | color: rgba(0, 0, 0, 0.7);
198 | line-height: 20px;
199 | text-align: center;
200 | border-radius: 3px;
201 | background: #dbab21;
202 | padding: 5px 10px;
203 | margin-left: 8px;
204 | font-family: 'Open Sans', sans-serif;
205 | }
206 |
207 | .range-slider__value:after {
208 | position: absolute;
209 | top: 8px;
210 | left: -7px;
211 | width: 0;
212 | height: 0;
213 | border-top: 7px solid transparent;
214 | border-right: 7px solid #2c3e50;
215 | border-bottom: 7px solid transparent;
216 | content: '';
217 | }
218 |
219 | ::-moz-range-track {
220 | background: #df2e46;
221 | border: 0;
222 | }
223 |
224 | input::-moz-focus-inner,
225 | input::-moz-focus-outer {
226 | border: 0;
227 | }
228 |
229 | #table {
230 | color: white;
231 | }
232 |
233 | /* navigation bar */
234 | ul {
235 | list-style: none;
236 | font-family: 'Open Sans', sans-serif;
237 | width: 100%;
238 | display: flex;
239 | justify-content: space-around;
240 | }
241 |
242 | .containerMain {
243 | display: flex;
244 | flex-direction: column;
245 | }
246 |
247 | .container {
248 | position: relative;
249 | z-index: 10;
250 | display: flex;
251 | flex-direction: column;
252 | justify-content: center;
253 | align-items: start;
254 | }
255 |
256 | .nav {
257 | display: inline-block;
258 | text-align: center;
259 | margin: 0 0;
260 | padding-bottom: 20px;
261 | width: 950px;
262 | }
263 |
264 | .nav li:hover {
265 | background: #FFCB25;
266 | }
267 |
268 | .nav > ul {
269 | list-style: none;
270 | padding: 0;
271 | margin: 0;
272 | background: #dbab21;
273 | border-radius: 5px;
274 | color: rgba(0, 0, 0, 0.7);
275 | z-index: 11;
276 | display: flex;
277 | justify-content: space-evenly;
278 | }
279 |
280 | .nav > ul > li {
281 | float: left;
282 | width: 90px;
283 | height: 30px;
284 | line-height: 30px;
285 | position: relative;
286 | font-size: 20px;
287 | cursor: pointer;
288 | }
289 |
290 | ul.drop-menu {
291 | position: absolute;
292 | top: 100%;
293 | left: 0%;
294 | padding: 0;
295 | display: flex;
296 | flex-direction: column;
297 | flex-wrap: wrap;
298 | max-height: 430px;
299 | width: 150px;
300 | text-align: start;
301 | }
302 |
303 | ul.drop-menu li {
304 | background: #a7a8a2;
305 | z-index: 12;
306 | padding: 3px;
307 | }
308 |
309 | ul.drop-menu li:hover {
310 | background: #ffcb25;
311 | }
312 |
313 | ul.drop-menu li {
314 | display: none;
315 | }
316 |
317 | li:hover > ul.drop-menu li {
318 | display: block;
319 | }
320 |
321 | .item {
322 | position: relative;
323 | padding-bottom: 20px;
324 | }
325 |
326 | .adjust {
327 | padding-bottom: 0;
328 | padding-top: 20px;
329 | }
330 |
331 | .nav-arrow {
332 | border: 1px solid transparent;
333 | background: transparent;
334 | border-radius: 3px;
335 | color: white;
336 | transition: background 0.5s ease;
337 | font-size: 24px;
338 | font-family: 'Open Sans', sans-serif;
339 | }
--------------------------------------------------------------------------------
/scripts/gui/index_panoptic.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | - PEOPLE
13 |
21 |
22 | - ANIMALS
23 |
35 |
36 | - VEHICLES
37 |
47 |
48 | - FOOD
49 |
65 |
66 | - SPORTS
67 |
80 |
81 | - INTERIOR
82 |
129 |
130 | - EXTERIOR
131 |
161 |
162 | - MATERIALS
163 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 | | Object |
199 | Size |
200 | Location |
201 | feature |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |

211 |
212 |
213 |
214 |
215 | 0
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/scene_generation/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision import models
5 | from torch.autograd import Variable
6 |
7 |
8 | def get_gan_losses(gan_type):
9 | """
10 | Returns the generator and discriminator loss for a particular GAN type.
11 |
12 | The returned functions have the following API:
13 | loss_g = g_loss(scores_fake)
14 | loss_d = d_loss(scores_real, scores_fake)
15 | """
16 | if gan_type == 'gan':
17 | return gan_g_loss, gan_d_loss
18 | elif gan_type == 'wgan':
19 | return wgan_g_loss, wgan_d_loss
20 | elif gan_type == 'lsgan':
21 | return lsgan_g_loss, lsgan_d_loss
22 | else:
23 | raise ValueError('Unrecognized GAN type "%s"' % gan_type)
24 |
25 |
26 | def bce_loss(input, target):
27 | """
28 | Numerically stable version of the binary cross-entropy loss function.
29 |
30 | As per https://github.com/pytorch/pytorch/issues/751
31 | See the TensorFlow docs for a derivation of this formula:
32 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
33 |
34 | Inputs:
35 | - input: PyTorch Tensor of shape (N, ) giving scores.
36 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
37 |
38 | Returns:
39 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of
40 | input data.
41 | """
42 | neg_abs = -input.abs()
43 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
44 | return loss.mean()
45 |
46 |
47 | def _make_targets(x, y):
48 | """
49 | Inputs:
50 | - x: PyTorch Tensor
51 | - y: Python scalar
52 |
53 | Outputs:
54 | - out: PyTorch Variable with same shape and dtype as x, but filled with y
55 | """
56 | return torch.full_like(x, y)
57 |
58 |
59 | def gan_g_loss(scores_fake):
60 | """
61 | Input:
62 | - scores_fake: Tensor of shape (N,) containing scores for fake samples
63 |
64 | Output:
65 | - loss: Variable of shape (,) giving GAN generator loss
66 | """
67 | if scores_fake.dim() > 1:
68 | scores_fake = scores_fake.view(-1)
69 | y_fake = _make_targets(scores_fake, 1)
70 | return bce_loss(scores_fake, y_fake)
71 |
72 |
73 | def gan_d_loss(scores_real, scores_fake):
74 | """
75 | Input:
76 | - scores_real: Tensor of shape (N,) giving scores for real samples
77 | - scores_fake: Tensor of shape (N,) giving scores for fake samples
78 |
79 | Output:
80 | - loss: Tensor of shape (,) giving GAN discriminator loss
81 | """
82 | assert scores_real.size() == scores_fake.size()
83 | if scores_real.dim() > 1:
84 | scores_real = scores_real.view(-1)
85 | scores_fake = scores_fake.view(-1)
86 | y_real = _make_targets(scores_real, 1)
87 | y_fake = _make_targets(scores_fake, 0)
88 | loss_real = bce_loss(scores_real, y_real)
89 | loss_fake = bce_loss(scores_fake, y_fake)
90 | return loss_real + loss_fake
91 |
92 |
93 | def wgan_g_loss(scores_fake):
94 | """
95 | Input:
96 | - scores_fake: Tensor of shape (N,) containing scores for fake samples
97 |
98 | Output:
99 | - loss: Tensor of shape (,) giving WGAN generator loss
100 | """
101 | return -scores_fake.mean()
102 |
103 |
104 | def wgan_d_loss(scores_real, scores_fake):
105 | """
106 | Input:
107 | - scores_real: Tensor of shape (N,) giving scores for real samples
108 | - scores_fake: Tensor of shape (N,) giving scores for fake samples
109 |
110 | Output:
111 | - loss: Tensor of shape (,) giving WGAN discriminator loss
112 | """
113 | return scores_fake.mean() - scores_real.mean()
114 |
115 |
116 | def lsgan_g_loss(scores_fake):
117 | if scores_fake.dim() > 1:
118 | scores_fake = scores_fake.view(-1)
119 | y_fake = _make_targets(scores_fake, 1)
120 | return F.mse_loss(scores_fake.sigmoid(), y_fake)
121 |
122 |
123 | def lsgan_d_loss(scores_real, scores_fake):
124 | assert scores_real.size() == scores_fake.size()
125 | if scores_real.dim() > 1:
126 | scores_real = scores_real.view(-1)
127 | scores_fake = scores_fake.view(-1)
128 | y_real = _make_targets(scores_real, 1)
129 | y_fake = _make_targets(scores_fake, 0)
130 | loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
131 | loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
132 | return loss_real + loss_fake
133 |
134 |
135 | class GANLoss(nn.Module):
136 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
137 | super(GANLoss, self).__init__()
138 | self.real_label = target_real_label
139 | self.fake_label = target_fake_label
140 | self.real_label_var = None
141 | self.fake_label_var = None
142 | self.Tensor = tensor
143 | if use_lsgan:
144 | self.loss = nn.MSELoss()
145 | else:
146 | self.loss = nn.BCELoss()
147 |
148 | def get_target_tensor(self, input, target_is_real):
149 | if target_is_real:
150 | create_label = ((self.real_label_var is None) or
151 | (self.real_label_var.numel() != input.numel()))
152 | if create_label:
153 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
154 | self.real_label_var = Variable(real_tensor, requires_grad=False)
155 | target_tensor = self.real_label_var
156 | else:
157 | create_label = ((self.fake_label_var is None) or
158 | (self.fake_label_var.numel() != input.numel()))
159 | if create_label:
160 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
161 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
162 | target_tensor = self.fake_label_var
163 | return target_tensor
164 |
165 | def __call__(self, input, target_is_real):
166 | if isinstance(input[0], list):
167 | loss = 0
168 | for input_i in input:
169 | pred = input_i[-1]
170 | target_tensor = self.get_target_tensor(pred, target_is_real)
171 | loss += self.loss(pred, target_tensor)
172 | return loss
173 | else:
174 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
175 | return self.loss(input[-1], target_tensor)
176 |
177 |
178 | # VGG Features matching
179 | class Vgg19(torch.nn.Module):
180 | def __init__(self, requires_grad=False):
181 | super(Vgg19, self).__init__()
182 | vgg_pretrained_features = models.vgg19(pretrained=True).features
183 | self.slice1 = torch.nn.Sequential()
184 | self.slice2 = torch.nn.Sequential()
185 | self.slice3 = torch.nn.Sequential()
186 | self.slice4 = torch.nn.Sequential()
187 | self.slice5 = torch.nn.Sequential()
188 | for x in range(2):
189 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
190 | for x in range(2, 7):
191 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
192 | for x in range(7, 12):
193 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
194 | for x in range(12, 21):
195 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
196 | for x in range(21, 30):
197 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
198 | if not requires_grad:
199 | for param in self.parameters():
200 | param.requires_grad = False
201 |
202 | def forward(self, X):
203 | h_relu1 = self.slice1(X)
204 | h_relu2 = self.slice2(h_relu1)
205 | h_relu3 = self.slice3(h_relu2)
206 | h_relu4 = self.slice4(h_relu3)
207 | h_relu5 = self.slice5(h_relu4)
208 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
209 | return out
210 |
211 |
212 | class VGGLoss(nn.Module):
213 | def __init__(self):
214 | super(VGGLoss, self).__init__()
215 | self.vgg = Vgg19().cuda()
216 | self.criterion = nn.L1Loss()
217 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
218 |
219 | def forward(self, x, y):
220 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
221 | loss = 0
222 | for i in range(len(x_vgg)):
223 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
224 | return loss
--------------------------------------------------------------------------------
/scene_generation/vis.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import tempfile
19 |
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import torch
23 | from imageio import imread
24 | from matplotlib.patches import Rectangle
25 |
26 | """
27 | Utilities for making visualizations.
28 | """
29 |
30 |
31 | def draw_layout(vocab, objs, boxes, masks=None, size=256,
32 | show_boxes=False, bgcolor=(0, 0, 0)):
33 | if bgcolor == 'white':
34 | bgcolor = (255, 255, 255)
35 |
36 | cmap = plt.get_cmap('rainbow')
37 | colors = cmap(np.linspace(0, 1, len(objs)))
38 |
39 | with torch.no_grad():
40 | objs = objs.cpu().clone()
41 | boxes = boxes.cpu().clone()
42 | boxes *= size
43 |
44 | if masks is not None:
45 | masks = masks.cpu().clone()
46 |
47 | bgcolor = np.asarray(bgcolor)
48 | bg = np.ones((size, size, 1)) * bgcolor
49 | plt.imshow(bg.astype(np.uint8))
50 |
51 | plt.gca().set_xlim(0, size)
52 | plt.gca().set_ylim(size, 0)
53 | plt.gca().set_aspect(1.0, adjustable='box')
54 |
55 | for i, obj in enumerate(objs):
56 | name = vocab['object_idx_to_name'][obj]
57 | if name == '__image__':
58 | continue
59 | box = boxes[i]
60 |
61 | if masks is None:
62 | continue
63 | mask = masks[i].numpy()
64 | mask /= mask.max()
65 |
66 | r, g, b, a = colors[i]
67 | colored_mask = mask[:, :, None] * np.asarray(colors[i])
68 |
69 | x0, y0, x1, y1 = box
70 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0),
71 | interpolation='bicubic', alpha=1.0)
72 |
73 | if show_boxes:
74 | for i, obj in enumerate(objs):
75 | name = vocab['object_idx_to_name'][obj]
76 | if name == '__image__':
77 | continue
78 | box = boxes[i]
79 |
80 | draw_box(box, colors[i], name)
81 |
82 |
83 | def add_boxes_to_layout(img, objs, boxes, image_path, size=256, colors=None):
84 | if colors is None:
85 | cmap = plt.get_cmap('rainbow')
86 | colors = cmap(np.linspace(0, 1, len(objs)))
87 | plt.clf()
88 | with torch.no_grad():
89 | boxes = boxes.cpu().clone()
90 | boxes *= size
91 |
92 | plt.imshow(img)
93 |
94 | plt.gca().set_xlim(0, size)
95 | plt.gca().set_ylim(size, 0)
96 | plt.gca().set_aspect(1.0, adjustable='box')
97 |
98 | for i, obj in enumerate(objs):
99 | if obj == '__image__':
100 | continue
101 | draw_box(boxes[i], colors[i], obj, alpha=0.8)
102 | plt.axis('off')
103 | plt.savefig(image_path, bbox_inches='tight', pad_inches=0)
104 |
105 |
106 | def draw_box(box, color, text=None, alpha=1.0):
107 | """
108 | Draw a bounding box using pyplot, optionally with a text box label.
109 |
110 | Inputs:
111 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H]
112 | coordinate system.
113 | - color: pyplot color to use for the box.
114 | - text: (Optional) String; if provided then draw a label for this box.
115 | """
116 | TEXT_BOX_HEIGHT = 10
117 | if torch.is_tensor(box) and box.dim() == 2:
118 | box = box.view(-1)
119 | assert box.size(0) == 4
120 | x0, y0, x1, y1 = box
121 | assert y1 > y0, box
122 | assert x1 > x0, box
123 | w, h = x1 - x0, y1 - y0
124 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color, alpha=alpha)
125 | plt.gca().add_patch(rect)
126 | if text is not None:
127 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5)
128 | plt.gca().add_patch(text_rect)
129 | tx = 0.5 * (x0 + x1)
130 | ty = y0 + TEXT_BOX_HEIGHT / 2.0
131 | plt.text(tx, ty, text, va='center', ha='center')
132 |
133 |
134 | def draw_scene_graph(objs, triples, vocab=None, **kwargs):
135 | """
136 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume
137 | that objs and triples are python lists containing strings for object and
138 | relationship names.
139 |
140 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy:
141 | sudo apt-get install graphviz
142 | """
143 | output_filename = kwargs.pop('output_filename', 'graph.png')
144 | orientation = kwargs.pop('orientation', 'V')
145 | edge_width = kwargs.pop('edge_width', 6)
146 | arrow_size = kwargs.pop('arrow_size', 1.5)
147 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2)
148 | ignore_dummies = kwargs.pop('ignore_dummies', True)
149 |
150 | if orientation not in ['V', 'H']:
151 | raise ValueError('Invalid orientation "%s"' % orientation)
152 | rankdir = {'H': 'LR', 'V': 'TD'}[orientation]
153 |
154 | if vocab is not None:
155 | # Decode object and relationship names
156 | assert torch.is_tensor(objs)
157 | assert torch.is_tensor(triples)
158 | objs_list, triples_list = [], []
159 | idx_to_obj = ['__image__'] + vocab['my_idx_to_obj']
160 | for i in range(objs.size(0)):
161 | objs_list.append(idx_to_obj[objs[i].item()])
162 | for i in range(triples.size(0)):
163 | s = triples[i, 0].item()
164 | p = vocab['pred_idx_to_name'][triples[i, 1].item()]
165 | o = triples[i, 2].item()
166 | triples_list.append([s, p, o])
167 | objs, triples = objs_list, triples_list
168 |
169 | # General setup, and style for object nodes
170 | lines = [
171 | 'digraph{',
172 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="transparent"]',
173 | 'rankdir=%s' % rankdir,
174 | 'nodesep="0.5"',
175 | 'ranksep="0.5"',
176 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]',
177 | 'node [fillcolor="lightpink1"]',
178 | ]
179 | # Output nodes for objects
180 | for i, obj in enumerate(objs):
181 | if ignore_dummies and obj == '__image__':
182 | continue
183 | lines.append('%d [label="%s"]' % (i, obj))
184 |
185 | # Output relationships
186 | next_node_id = len(objs)
187 | lines.append('node [fillcolor="lightblue1"]')
188 | for s, p, o in triples:
189 | if ignore_dummies and p == '__in_image__':
190 | continue
191 | lines += [
192 | '%d [label="%s"]' % (next_node_id, p),
193 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
194 | s, next_node_id, edge_width, arrow_size, binary_edge_weight),
195 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % (
196 | next_node_id, o, edge_width, arrow_size, binary_edge_weight)
197 | ]
198 | next_node_id += 1
199 | lines.append('}')
200 |
201 | # Now it gets slightly hacky. Write the graphviz spec to a temporary
202 | # text file
203 | ff, dot_filename = tempfile.mkstemp()
204 | with open(dot_filename, 'w') as f:
205 | for line in lines:
206 | f.write('%s\n' % line)
207 | os.close(ff)
208 |
209 | # Shell out to invoke graphviz; this will save the resulting image to disk,
210 | # so we read it, delete it, then return it.
211 | output_format = os.path.splitext(output_filename)[1][1:]
212 | os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename))
213 | os.remove(dot_filename)
214 | img = imread(output_filename)
215 | os.remove(output_filename)
216 |
217 | return img
218 |
219 |
220 | if __name__ == '__main__':
221 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard']
222 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above']
223 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)}
224 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)}
225 | vocab = {
226 | 'object_idx_to_name': o_idx_to_name,
227 | 'object_name_to_idx': o_name_to_idx,
228 | 'pred_idx_to_name': p_idx_to_name,
229 | 'pred_name_to_idx': p_name_to_idx,
230 | }
231 |
232 | objs = [
233 | 'cat',
234 | 'cat',
235 | 'skateboard',
236 | 'hat',
237 | ]
238 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs])
239 | triples = [
240 | [0, 'next to', 1],
241 | [0, 'riding', 2],
242 | [1, 'wearing', 3],
243 | [3, 'above', 2],
244 | ]
245 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples]
246 | triples = torch.LongTensor(triples)
247 |
248 | draw_scene_graph(objs, triples, vocab, orientation='V')
249 |
--------------------------------------------------------------------------------
/scripts/gui/index.js:
--------------------------------------------------------------------------------
1 | window.onload = function () {
2 |
3 | var rangeSlider = function () {
4 | var slider = $('.range-slider'),
5 | range = $('.range-slider__range'),
6 | value = $('.range-slider__value');
7 |
8 | slider.each(function () {
9 |
10 | value.each(function () {
11 | var value = $(this).prev().attr('value');
12 | $(this).html(value);
13 | });
14 |
15 | range.on('input', function () {
16 | var id = this.getAttribute('data-id');
17 | if (id) {
18 | document.getElementById(id).setAttribute('data-style', this.value);
19 | } else {
20 | this.setAttribute('image-id', this.value);
21 | }
22 | $(this).next(value).html(this.value);
23 | render_button();
24 | });
25 | });
26 | };
27 |
28 | rangeSlider();
29 |
30 | function dragMoveListener(event) {
31 | console.log('dragMobve');
32 | var target = event.target,
33 | // keep the dragged position in the data-x/data-y attributes
34 | x = (parseFloat(target.getAttribute('data-x')) || 0) + event.dx,
35 | y = (parseFloat(target.getAttribute('data-y')) || 0) + event.dy;
36 |
37 | // translate the element
38 | target.style.webkitTransform =
39 | target.style.transform =
40 | 'translate(' + x + 'px, ' + y + 'px)';
41 |
42 | // update the posiion attributes
43 | target.setAttribute('data-x', x);
44 | target.setAttribute('data-y', y);
45 | selectItem(event, event.target);
46 | // render_button();
47 | }
48 |
49 | // this is used later in the resizing and gesture demos
50 | window.dragMoveListener = dragMoveListener;
51 |
52 | interact('.resize-drag')
53 | .draggable({
54 | onmove: window.dragMoveListener,
55 | onend: render_button,
56 | restrict: {
57 | restriction: 'parent',
58 | elementRect: {top: 0, left: 0, bottom: 1, right: 1}
59 | },
60 | })
61 |
62 | .on('tap', function (event) {
63 | console.log('tap');
64 | var target = event.target;
65 | var size = parseInt(target.getAttribute('data-size'));
66 | var new_size = (size + 1) % 10;
67 | target.setAttribute('data-size', new_size);
68 | target.style.fontSize = sizeToFont(new_size);
69 | // $(event.currentTarget).remove();
70 | selectItem(event, event.target);
71 | render_button();
72 | event.preventDefault();
73 | })
74 | .on('hold', function (event) {
75 | console.log('hold');
76 | $(event.currentTarget).remove();
77 | render_button();
78 | event.preventDefault();
79 | });
80 |
81 | function selectItem(event, target, should_deselect) {
82 | event.stopPropagation();
83 | var hasClass = $(target).hasClass('selected');
84 | $(".resize-drag").removeClass("selected");
85 | $('#range-slider').attr('data-id', '');
86 | if (should_deselect && hasClass) {
87 | } else {
88 | $(target).addClass("selected");
89 | $('#range-slider').attr('data-id', target.id);
90 | var style = target.getAttribute('data-style');
91 | style = style ? style : -1;
92 | $('#range-slider').val(style);
93 | $('.range-slider__value').text(style.toString());
94 | }
95 | }
96 |
97 | $(".resize-drag").click(function (e) {
98 | $(".resize-drag").removeClass("selected");
99 | $(this).addClass("selected");
100 | e.stopPropagation();
101 | });
102 |
103 | function guidGenerator() {
104 | var S4 = function () {
105 | return (((1 + Math.random()) * 0x10000) | 0).toString(16).substring(1);
106 | };
107 | return (S4() + S4() + "-" + S4() + "-" + S4() + "-" + S4() + "-" + S4() + S4() + S4());
108 | }
109 |
110 | function stuff_add(evt) {
111 | evt.stopPropagation();
112 | var newContent = document.createTextNode(evt.currentTarget.textContent);
113 | var node = document.createElement("DIV");
114 | node.className = "resize-drag";
115 | node.id = guidGenerator();
116 | node.appendChild(newContent);
117 | var init_size = 0;
118 | node.setAttribute('data-size', init_size);
119 | node.style.fontSize = sizeToFont(init_size);
120 | document.getElementById("resize-container").appendChild(node);
121 | render_button();
122 | }
123 |
124 | function sizeToFont(size) {
125 | return size * 8 + 20;
126 | }
127 |
128 | function refresh_image(response) {
129 | response = JSON.parse(response);
130 | document.getElementById("img_pred").src = response.img_pred;
131 | document.getElementById("layout_pred").src = response.layout_pred;
132 | }
133 |
134 | function addRow(obj, size, location, feature) {
135 | return;
136 | // Get a reference to the table
137 | let tableRef = document.getElementById('table').getElementsByTagName('tbody')[0];
138 |
139 | // Insert a row at the end of the table
140 | let newRow = tableRef.insertRow(-1);
141 |
142 | // Insert a cell in the row at index 0
143 | newRow.insertCell(0).appendChild(document.createTextNode(obj));
144 | newRow.insertCell(1).appendChild(document.createTextNode(size + ''));
145 | newRow.insertCell(2).appendChild(document.createTextNode(location + ''));
146 | newRow.insertCell(3).appendChild(document.createTextNode(feature + ''));
147 | }
148 |
149 | function render_button() {
150 | console.log('render');
151 | var allObjects = [];
152 | $("tbody").children().remove();
153 | var container = document.getElementById("resize-container");
154 | var container_rect = interact.getElementRect(container);
155 | var containerOffsetLeft = container_rect.left;
156 | var containerOffsetTop = container_rect.top;
157 | var containerWidth = container_rect.width;
158 | var containerHeight = container_rect.height;
159 | var children = document.getElementsByClassName('resize-drag');
160 | if (children.length < 3) {
161 | return;
162 | }
163 | for (var i = 0; i < children.length; i++) {
164 | var rect = interact.getElementRect(children[i]);
165 | var height = rect.height / containerHeight;
166 | var width = rect.width / containerWidth;
167 | var left = (rect.left - containerOffsetLeft) / containerWidth;
168 | var top = (rect.top - containerOffsetTop) / containerHeight;
169 | var sx0 = left;
170 | var sy0 = top;
171 | var sx1 = width + left;
172 | var sy1 = height + sy0;
173 | var mean_x_s = 0.5 * (sx0 + sx1);
174 | var mean_y_s = 0.5 * (sy0 + sy1);
175 | var grid = 25 / 5;
176 | var location = Math.round(mean_x_s * (grid - 1)) + grid * Math.round(mean_y_s * (grid - 1));
177 | var size = parseInt(children[i].getAttribute('data-size'));
178 | var text = children[i].innerText;
179 | var style = children[i].getAttribute('data-style') ?
180 | parseInt(children[i].getAttribute('data-style')) :
181 | -1;
182 | allObjects.push({
183 | 'height': height,
184 | 'width': width,
185 | 'left': left,
186 | 'top': top,
187 | 'text': text,
188 | 'feature': style,
189 | 'size': size,
190 | 'location': location,
191 | });
192 | console.log(size, location, text);
193 | addRow(text, size, location, style);
194 | }
195 | console.log(allObjects);
196 | var image_id = document.getElementById('range-slider').getAttribute('image-id');
197 | var image_feature = image_id ? parseInt(image_id) : -1;
198 | addRow('background', '-', '-', image_feature);
199 | var results = {'image_id': image_feature, 'objects': allObjects};
200 | var url = 'get_data?data=' + JSON.stringify(results);
201 | var xmlHttp = new XMLHttpRequest();
202 | xmlHttp.onreadystatechange = function () {
203 | if (xmlHttp.readyState == 4 && xmlHttp.status == 200)
204 | refresh_image(xmlHttp.responseText);
205 | };
206 | xmlHttp.open("GET", url, true); // true for asynchronous
207 | xmlHttp.send(null);
208 | }
209 |
210 | document.querySelectorAll("ul.drop-menu > li").forEach(function (e) {
211 | e.addEventListener("click", stuff_add)
212 | });
213 | $(window).click(function (devt) {
214 | if (!devt.target.getAttribute('data-size') && !devt.target.getAttribute('max')) {
215 | $(".resize-drag").removeClass("selected");
216 | var image_style = $('#range-slider').attr('image-id');
217 | image_style = image_style ? parseInt(image_style) : -1;
218 | $('#range-slider').val(image_style);
219 | $('.range-slider__value').text(image_style.toString());
220 | $('#range-slider').attr('data-id', '');
221 | }
222 | });
223 | };
224 |
225 |
--------------------------------------------------------------------------------
/scene_generation/layout.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 | import torch
19 | import torch.nn.functional as F
20 |
21 | """
22 | Functions for computing image layouts from object vectors, bounding boxes,
23 | and segmentation masks. These are used to compute course scene layouts which
24 | are then fed as input to the cascaded refinement network.
25 | """
26 |
27 |
28 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'):
29 | """
30 | Inputs:
31 | - vecs: Tensor of shape (O, D) giving vectors
32 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
33 | [x0, y0, x1, y1] in the [0, 1] coordinate space
34 | - obj_to_img: LongTensor of shape (O,) mapping each element of vecs to
35 | an image, where each element is in the range [0, N). If obj_to_img[i] = j
36 | then vecs[i] belongs to image j.
37 | - H, W: Size of the output
38 |
39 | Returns:
40 | - out: Tensor of shape (N, D, H, W)
41 | """
42 | O, D = vecs.size()
43 | if W is None:
44 | W = H
45 |
46 | grid = _boxes_to_grid(boxes, H, W)
47 |
48 | # If we don't add extra spatial dimensions here then out-of-bounds
49 | # elements won't be automatically set to 0
50 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8)
51 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W)
52 |
53 | # Explicitly masking makes everything quite a bit slower.
54 | # If we rely on implicit masking the interpolated boxes end up
55 | # blurred around the edges, but it should be fine.
56 | # mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1)
57 | # sampled[mask[:, None]] = 0
58 |
59 | out = _pool_samples(sampled, obj_to_img, pooling=pooling)
60 |
61 | return out
62 |
63 |
64 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum', test_mode=False):
65 | """
66 | Inputs:
67 | - vecs: Tensor of shape (O, D) giving vectors
68 | - boxes: Tensor of shape (O, 4) giving bounding boxes in the format
69 | [x0, y0, x1, y1] in the [0, 1] coordinate space
70 | - masks: Tensor of shape (O, M, M) giving binary masks for each object
71 | - obj_to_img: LongTensor of shape (O,) mapping objects to images
72 | - H, W: Size of the output image.
73 |
74 | Returns:
75 | - out: Tensor of shape (N, D, H, W)
76 | """
77 | O, D = vecs.size()
78 | M = masks.size(1)
79 | assert masks.size() == (O, M, M)
80 | if W is None:
81 | W = H
82 |
83 | grid = _boxes_to_grid(boxes, H, W)
84 |
85 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M)
86 | sampled = F.grid_sample(img_in, grid)
87 | if test_mode:
88 | clean_mask_sampled = F.grid_sample(masks.float().view(O, 1, M, M), grid)
89 | else:
90 | clean_mask_sampled = None
91 |
92 | out = _pool_samples(sampled, clean_mask_sampled, obj_to_img, pooling=pooling)
93 | return out
94 |
95 |
96 | def _boxes_to_grid(boxes, H, W):
97 | """
98 | Input:
99 | - boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1]
100 | format in the [0, 1] coordinate space
101 | - H, W: Scalars giving size of output
102 |
103 | Returns:
104 | - grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample
105 | """
106 | O = boxes.size(0)
107 |
108 | boxes = boxes.view(O, 4, 1, 1)
109 |
110 | # All these are (O, 1, 1)
111 | x0, y0 = boxes[:, 0], boxes[:, 1]
112 | ww, hh = boxes[:, 2] - x0, boxes[:, 3] - y0
113 |
114 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes)
115 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes)
116 |
117 | X = (X - x0) / ww # (O, 1, W)
118 | Y = (Y - y0) / hh # (O, H, 1)
119 |
120 | # Stack does not broadcast its arguments so we need to expand explicitly
121 | X = X.expand(O, H, W)
122 | Y = Y.expand(O, H, W)
123 | grid = torch.stack([X, Y], dim=3) # (O, H, W, 2)
124 |
125 | # Right now grid is in [0, 1] space; transform to [-1, 1]
126 | grid = grid.mul(2).sub(1)
127 |
128 | return grid
129 |
130 |
131 | def _pool_samples(samples, clean_mask_sampled, obj_to_img, pooling='sum'):
132 | """
133 | Input:
134 | - samples: FloatTensor of shape (O, D, H, W)
135 | - obj_to_img: LongTensor of shape (O,) with each element in the range
136 | [0, N) mapping elements of samples to output images
137 |
138 | Output:
139 | - pooled: FloatTensor of shape (N, D, H, W)
140 | """
141 | dtype, device = samples.dtype, samples.device
142 | O, D, H, W = samples.size()
143 | N = obj_to_img.data.max().item() + 1
144 |
145 | # Use scatter_add to sum the sampled outputs for each image
146 | # out = torch.zeros(N, D, H, W, dtype=dtype, device=device)
147 | # idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W)
148 | # out = out.scatter_add(0, idx, samples)
149 | obj_to_img_list = [i.item() for i in list(obj_to_img)]
150 | all_out = []
151 | if clean_mask_sampled is None:
152 | for i in range(N):
153 | start = obj_to_img_list.index(i)
154 | end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i)
155 | all_out.append(torch.sum(samples[start:end, :, :, :], dim=0))
156 | else:
157 | _, d, h, w = samples.shape
158 | for i in range(N):
159 | start = obj_to_img_list.index(i)
160 | end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i)
161 | mass = [torch.sum(samples[j, :, :, :]).item() for j in range(start, end)]
162 | argsort = np.argsort(mass)
163 | result = torch.zeros((d, h, w), device=samples.device, dtype=samples.dtype)
164 | result_clean = torch.zeros((h, w), device=samples.device, dtype=samples.dtype)
165 | for j in argsort:
166 | masked_mask = (result_clean == 0).float() * (clean_mask_sampled[start + j, 0] > 0.5).float()
167 | result_clean += masked_mask
168 | result += samples[start + j] * masked_mask
169 | all_out.append(result)
170 | out = torch.stack(all_out)
171 |
172 | if pooling == 'avg':
173 | # Divide each output mask by the number of objects; use scatter_add again
174 | # to count the number of objects per image.
175 | ones = torch.ones(O, dtype=dtype, device=device)
176 | obj_counts = torch.zeros(N, dtype=dtype, device=device)
177 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones)
178 | # print(obj_counts)
179 | obj_counts = obj_counts.clamp(min=1)
180 | out = out / obj_counts.view(N, 1, 1, 1)
181 | elif pooling != 'sum':
182 | raise ValueError('Invalid pooling "%s"' % pooling)
183 |
184 | return out
185 |
186 |
187 | if __name__ == '__main__':
188 | vecs = torch.FloatTensor([
189 | [1, 0, 0], [0, 1, 0], [0, 0, 1],
190 | [1, 0, 0], [0, 1, 0], [0, 0, 1],
191 | ])
192 | boxes = torch.FloatTensor([
193 | [0.25, 0.125, 0.5, 0.875],
194 | [0, 0, 1, 0.25],
195 | [0.6125, 0, 0.875, 1],
196 | [0, 0.8, 1, 1.0],
197 | [0.25, 0.125, 0.5, 0.875],
198 | [0.6125, 0, 0.875, 1],
199 | ])
200 | obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1])
201 | # vecs = torch.FloatTensor([[[1]]])
202 | # boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]])
203 | vecs, boxes = vecs.cuda(), boxes.cuda()
204 | obj_to_img = obj_to_img.cuda()
205 | out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum')
206 |
207 | from torchvision.utils import save_image
208 |
209 | save_image(out.data, 'out.png')
210 |
211 | masks = torch.FloatTensor([
212 | [
213 | [0, 0, 1, 0, 0],
214 | [0, 1, 1, 1, 0],
215 | [1, 1, 1, 1, 1],
216 | [0, 1, 1, 1, 0],
217 | [0, 0, 1, 0, 0],
218 | ],
219 | [
220 | [0, 0, 1, 0, 0],
221 | [0, 1, 0, 1, 0],
222 | [1, 0, 0, 0, 1],
223 | [0, 1, 0, 1, 0],
224 | [0, 0, 1, 0, 0],
225 | ],
226 | [
227 | [0, 0, 1, 0, 0],
228 | [0, 1, 1, 1, 0],
229 | [1, 1, 1, 1, 1],
230 | [0, 1, 1, 1, 0],
231 | [0, 0, 1, 0, 0],
232 | ],
233 | [
234 | [0, 0, 1, 0, 0],
235 | [0, 1, 1, 1, 0],
236 | [1, 1, 1, 1, 1],
237 | [0, 1, 1, 1, 0],
238 | [0, 0, 1, 0, 0],
239 | ],
240 | [
241 | [0, 0, 1, 0, 0],
242 | [0, 1, 1, 1, 0],
243 | [1, 1, 1, 1, 1],
244 | [0, 1, 1, 1, 0],
245 | [0, 0, 1, 0, 0],
246 | ],
247 | [
248 | [0, 0, 1, 0, 0],
249 | [0, 1, 1, 1, 0],
250 | [1, 1, 1, 1, 1],
251 | [0, 1, 1, 1, 0],
252 | [0, 0, 1, 0, 0],
253 | ]
254 | ])
255 | masks = masks.cuda()
256 | out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256)
257 | save_image(out.data, 'out_masks.png')
258 |
--------------------------------------------------------------------------------
/scene_generation/discriminators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from scene_generation.bilinear import crop_bbox_batch
7 | from scene_generation.layers import GlobalAvgPool, build_cnn, get_norm_layer
8 |
9 |
10 | class AcDiscriminator(nn.Module):
11 | def __init__(self, vocab, arch, normalization='none', activation='relu', padding='same', pooling='avg'):
12 | super(AcDiscriminator, self).__init__()
13 | self.vocab = vocab
14 |
15 | cnn_kwargs = {
16 | 'arch': arch,
17 | 'normalization': normalization,
18 | 'activation': activation,
19 | 'pooling': pooling,
20 | 'padding': padding,
21 | }
22 | cnn, D = build_cnn(**cnn_kwargs)
23 | self.cnn = nn.Sequential(cnn, GlobalAvgPool(), nn.Linear(D, 1024))
24 | num_objects = len(vocab['object_to_idx'])
25 |
26 | self.real_classifier = nn.Linear(1024, 1)
27 | self.obj_classifier = nn.Linear(1024, num_objects)
28 |
29 | def forward(self, x, y):
30 | if x.dim() == 3:
31 | x = x[:, None]
32 | vecs = self.cnn(x)
33 | real_scores = self.real_classifier(vecs)
34 | obj_scores = self.obj_classifier(vecs)
35 | ac_loss = F.cross_entropy(obj_scores, y)
36 | return real_scores, ac_loss
37 |
38 |
39 | class AcCropDiscriminator(nn.Module):
40 | def __init__(self, vocab, arch, normalization='none', activation='relu',
41 | object_size=64, padding='same', pooling='avg'):
42 | super(AcCropDiscriminator, self).__init__()
43 | self.vocab = vocab
44 | self.discriminator = AcDiscriminator(vocab, arch, normalization,
45 | activation, padding, pooling)
46 | self.object_size = object_size
47 |
48 | def forward(self, imgs, objs, boxes, obj_to_img):
49 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
50 | real_scores, ac_loss = self.discriminator(crops, objs)
51 | return real_scores, ac_loss, crops
52 |
53 |
54 | ###############################################################################
55 | # Functions
56 | ###############################################################################
57 | def weights_init(m):
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | m.weight.data.normal_(0.0, 0.02)
61 | elif classname.find('BatchNorm2d') != -1:
62 | m.weight.data.normal_(1.0, 0.02)
63 | m.bias.data.fill_(0)
64 |
65 |
66 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1):
67 | norm_layer = get_norm_layer(norm_type=norm)
68 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D)
69 | # print(netD)
70 | assert (torch.cuda.is_available())
71 | netD.cuda()
72 | netD.apply(weights_init)
73 | return netD
74 |
75 |
76 | def define_mask_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1,
77 | num_objects=None):
78 | norm_layer = get_norm_layer(norm_type=norm)
79 | netD = MultiscaleMaskDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D,
80 | num_objects)
81 | assert (torch.cuda.is_available())
82 | netD.cuda()
83 | netD.apply(weights_init)
84 | return netD
85 |
86 |
87 | class MultiscaleMaskDiscriminator(nn.Module):
88 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
89 | use_sigmoid=False, num_D=3, num_objects=None):
90 | super(MultiscaleMaskDiscriminator, self).__init__()
91 | self.num_D = num_D
92 | self.n_layers = n_layers
93 |
94 | for i in range(num_D):
95 | netD = NLayerMaskDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, num_objects)
96 | for j in range(n_layers + 2):
97 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
98 |
99 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
100 |
101 | def singleD_forward(self, model, input, cond):
102 | result = [input]
103 | for i in range(len(model) - 2):
104 | # print(result[-1].shape)
105 | result.append(model[i](result[-1]))
106 |
107 | a, b, c, d = result[-1].shape
108 | cond = cond.view(a, -1, 1, 1).expand(-1, -1, c, d)
109 | concat = torch.cat([result[-1], cond], dim=1)
110 | result.append(model[len(model) - 2](concat))
111 | result.append(model[len(model) - 1](result[-1]))
112 | return result[1:]
113 |
114 | def forward(self, input, cond):
115 | num_D = self.num_D
116 | result = []
117 | input_downsampled = input
118 | for i in range(num_D):
119 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
120 | range(self.n_layers + 2)]
121 | result.append(self.singleD_forward(model, input_downsampled, cond))
122 | if i != (num_D - 1):
123 | input_downsampled = self.downsample(input_downsampled)
124 | return result
125 |
126 |
127 | # Defines the PatchGAN discriminator with the specified arguments.
128 | class NLayerMaskDiscriminator(nn.Module):
129 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False,
130 | num_objects=None):
131 | super(NLayerMaskDiscriminator, self).__init__()
132 | self.n_layers = n_layers
133 |
134 | kw = 3
135 | padw = int(np.ceil((kw - 1.0) / 2))
136 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
137 |
138 | nf = ndf
139 | for n in range(1, n_layers):
140 | nf_prev = nf
141 | nf = min(nf * 2, 512)
142 | sequence += [[
143 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
144 | norm_layer(nf), nn.LeakyReLU(0.2, True)
145 | ]]
146 |
147 | nf_prev = nf
148 | nf = min(nf * 2, 512)
149 | nf_prev += num_objects
150 | sequence += [[
151 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
152 | norm_layer(nf),
153 | nn.LeakyReLU(0.2, True)
154 | ]]
155 |
156 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
157 |
158 | if use_sigmoid:
159 | sequence += [[nn.Sigmoid()]]
160 |
161 | for n in range(len(sequence)):
162 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
163 |
164 | def forward(self, input):
165 | res = [input]
166 | for n in range(self.n_layers + 2):
167 | model = getattr(self, 'model' + str(n))
168 | res.append(model(res[-1]))
169 | return res[1:]
170 |
171 |
172 | class MultiscaleDiscriminator(nn.Module):
173 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
174 | use_sigmoid=False, num_D=3):
175 | super(MultiscaleDiscriminator, self).__init__()
176 | self.num_D = num_D
177 | self.n_layers = n_layers
178 |
179 | for i in range(num_D):
180 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid)
181 | for j in range(n_layers + 2):
182 | setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
183 |
184 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
185 |
186 | def singleD_forward(self, model, input):
187 | result = [input]
188 | for i in range(len(model)):
189 | result.append(model[i](result[-1]))
190 | return result[1:]
191 |
192 | def forward(self, input):
193 | num_D = self.num_D
194 | result = []
195 | input_downsampled = input
196 | for i in range(num_D):
197 | model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
198 | range(self.n_layers + 2)]
199 | result.append(self.singleD_forward(model, input_downsampled))
200 | if i != (num_D - 1):
201 | input_downsampled = self.downsample(input_downsampled)
202 | return result
203 |
204 |
205 | # Defines the PatchGAN discriminator with the specified arguments.
206 | class NLayerDiscriminator(nn.Module):
207 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
208 | super(NLayerDiscriminator, self).__init__()
209 | self.n_layers = n_layers
210 |
211 | kw = 4
212 | padw = int(np.ceil((kw - 1.0) / 2))
213 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
214 |
215 | nf = ndf
216 | for n in range(1, n_layers):
217 | nf_prev = nf
218 | nf = min(nf * 2, 512)
219 | sequence += [[
220 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
221 | norm_layer(nf), nn.LeakyReLU(0.2, True)
222 | ]]
223 |
224 | nf_prev = nf
225 | nf = min(nf * 2, 512)
226 | sequence += [[
227 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
228 | norm_layer(nf),
229 | nn.LeakyReLU(0.2, True)
230 | ]]
231 |
232 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
233 |
234 | if use_sigmoid:
235 | sequence += [[nn.Sigmoid()]]
236 |
237 | for n in range(len(sequence)):
238 | setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
239 |
240 | def forward(self, input):
241 | res = [input]
242 | for n in range(self.n_layers + 2):
243 | model = getattr(self, 'model' + str(n))
244 | res.append(model(res[-1]))
245 | return res[1:]
246 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import os
3 | import json
4 | from collections import defaultdict
5 | import random
6 | import torch
7 | from torch.utils.data import DataLoader
8 |
9 | from scene_generation.args import get_args
10 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn
11 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn
12 | from scene_generation.metrics import jaccard
13 | from scene_generation.trainer import Trainer
14 |
15 | from scripts.inception_score import InceptionScore
16 |
17 |
18 | def build_coco_dsets(args):
19 | dset_kwargs = {
20 | 'image_dir': args.coco_train_image_dir,
21 | 'instances_json': args.coco_train_instances_json,
22 | 'stuff_json': args.coco_train_stuff_json,
23 | 'image_size': args.image_size,
24 | 'mask_size': args.mask_size,
25 | 'max_samples': args.num_train_samples,
26 | 'min_object_size': args.min_object_size,
27 | 'min_objects_per_image': args.min_objects_per_image,
28 | 'instance_whitelist': args.instance_whitelist,
29 | 'stuff_whitelist': args.stuff_whitelist,
30 | 'include_other': args.coco_include_other,
31 | }
32 | if args.is_panoptic:
33 | dset_kwargs['panoptic'] = args.coco_panoptic_train
34 | dset_kwargs['panoptic_segmentation'] = args.coco_panoptic_segmentation_train
35 | train_dset = CocoPanopticSceneGraphDataset(**dset_kwargs)
36 | else:
37 | train_dset = CocoSceneGraphDataset(**dset_kwargs)
38 | num_objs = train_dset.total_objects()
39 | num_imgs = len(train_dset)
40 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs))
41 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs))
42 |
43 | dset_kwargs['image_dir'] = args.coco_val_image_dir
44 | dset_kwargs['instances_json'] = args.coco_val_instances_json
45 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json
46 | dset_kwargs['max_samples'] = args.num_val_samples
47 | if args.is_panoptic:
48 | dset_kwargs['panoptic'] = args.coco_panoptic_val
49 | dset_kwargs['panoptic_segmentation'] = args.coco_panoptic_segmentation_val
50 | val_dset = CocoPanopticSceneGraphDataset(**dset_kwargs)
51 | else:
52 | val_dset = CocoSceneGraphDataset(**dset_kwargs)
53 |
54 | assert train_dset.vocab == val_dset.vocab
55 | vocab = json.loads(json.dumps(train_dset.vocab))
56 |
57 | return vocab, train_dset, val_dset
58 |
59 |
60 | def build_loaders(args):
61 | vocab, train_dset, val_dset = build_coco_dsets(args)
62 | if args.is_panoptic:
63 | collate_fn = coco_panoptic_collate_fn
64 | else:
65 | collate_fn = coco_collate_fn
66 |
67 | loader_kwargs = {
68 | 'batch_size': args.batch_size,
69 | 'num_workers': args.loader_num_workers,
70 | 'shuffle': True,
71 | 'collate_fn': collate_fn,
72 | }
73 | train_loader = DataLoader(train_dset, **loader_kwargs)
74 |
75 | loader_kwargs['shuffle'] = args.shuffle_val
76 | val_loader = DataLoader(val_dset, **loader_kwargs)
77 | return vocab, train_loader, val_loader
78 |
79 |
80 | def check_model(args, loader, model, inception_score, use_gt):
81 | fid = None
82 | num_samples = 0
83 | total_iou = 0
84 | total_boxes = 0
85 | inception_score.clean()
86 | with torch.no_grad():
87 | for batch in loader:
88 | batch = [tensor.cuda() for tensor in batch]
89 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = batch
90 |
91 | # Run the model as it has been run during training
92 | if use_gt:
93 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks, attributes=attributes,
94 | test_mode=True, use_gt_box=True)
95 | else:
96 | attributes = torch.zeros_like(attributes)
97 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=None, attributes=attributes,
98 | test_mode=True, use_gt_box=False)
99 | imgs_pred, boxes_pred, masks_pred, _, pred_layout, _ = model_out
100 |
101 | iou, _, _ = jaccard(boxes_pred, boxes)
102 | total_iou += iou
103 | total_boxes += boxes_pred.size(0)
104 | inception_score(imgs_pred)
105 |
106 | num_samples += imgs.size(0)
107 | if num_samples >= args.num_val_samples:
108 | break
109 |
110 | inception_mean, inception_std = inception_score.compute_score(splits=5)
111 |
112 | avg_iou = total_iou / total_boxes
113 |
114 | out = [avg_iou, inception_mean, inception_std, fid]
115 |
116 | return tuple(out)
117 |
118 |
119 | def get_checkpoint(args, vocab):
120 | if args.restore_from_checkpoint:
121 | restore_path = '%s_with_model.pt' % args.checkpoint_name
122 | restore_path = os.path.join(args.output_dir, restore_path)
123 | assert restore_path is not None
124 | assert os.path.isfile(restore_path)
125 | print('Restoring from checkpoint:')
126 | print(restore_path)
127 | checkpoint = torch.load(restore_path)
128 | t = checkpoint['counters']['t']
129 | epoch = checkpoint['counters']['epoch']
130 | else:
131 | t, epoch = 0, 0
132 | checkpoint = {
133 | 'args': args.__dict__,
134 | 'vocab': vocab,
135 | 'model_kwargs': {},
136 | 'd_obj_kwargs': {},
137 | 'd_mask_kwargs': {},
138 | 'd_img_kwargs': {},
139 | 'd_global_mask_kwargs': {},
140 | 'losses_ts': [],
141 | 'losses': defaultdict(list),
142 | 'd_losses': defaultdict(list),
143 | 'checkpoint_ts': [],
144 | 'train_inception': [],
145 | 'val_losses': defaultdict(list),
146 | 'val_inception': [],
147 | 'norm_d': [],
148 | 'norm_g': [],
149 | 'counters': {
150 | 't': None,
151 | 'epoch': None,
152 | },
153 | 'model_state': None, 'model_best_state': None,
154 | 'optim_state': None, 'optim_best_state': None,
155 | 'd_obj_state': None, 'd_obj_best_state': None,
156 | 'd_obj_optim_state': None, 'd_obj_optim_best_state': None,
157 | 'd_img_state': None, 'd_img_best_state': None,
158 | 'd_img_optim_state': None, 'd_img_optim_best_state': None,
159 | 'd_mask_state': None, 'd_mask_best_state': None,
160 | 'd_mask_optim_state': None, 'd_mask_optim_best_state': None,
161 | 'best_t': [],
162 | }
163 | return t, epoch, checkpoint
164 |
165 |
166 | def main(args):
167 | print(args)
168 | vocab, train_loader, val_loader = build_loaders(args)
169 | t, epoch, checkpoint = get_checkpoint(args, vocab)
170 | trainer = Trainer(args, vocab, checkpoint)
171 | if args.restore_from_checkpoint:
172 | trainer.restore_checkpoint(checkpoint)
173 | else:
174 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as outfile:
175 | json.dump(vars(args), outfile)
176 |
177 | inception_score = InceptionScore(cuda=True, batch_size=args.batch_size, resize=True)
178 | train_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=True)
179 | t_avg_iou, t_inception_mean, t_inception_std, _ = train_results
180 | index = int(t / args.print_every)
181 | trainer.writer.add_scalar('checkpoint/{}'.format('train_iou'), t_avg_iou, index)
182 | trainer.writer.add_scalar('checkpoint/{}'.format('train_inception_mean'), t_inception_mean, index)
183 | trainer.writer.add_scalar('checkpoint/{}'.format('train_inception_std'), t_inception_std, index)
184 | print(t_avg_iou, t_inception_mean, t_inception_std)
185 |
186 | while t < args.num_iterations:
187 | epoch += 1
188 | print('Starting epoch %d' % epoch)
189 |
190 | for batch in train_loader:
191 | t += 1
192 | batch = [tensor.cuda() for tensor in batch]
193 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = batch
194 |
195 | use_gt = random.randint(0, 1) != 0
196 | if not use_gt:
197 | attributes = torch.zeros_like(attributes)
198 | model_out = trainer.model(imgs, objs, triples, obj_to_img,
199 | boxes_gt=boxes, masks_gt=masks, attributes=attributes)
200 | imgs_pred, boxes_pred, masks_pred, layout, layout_pred, layout_wrong = model_out
201 |
202 | layout_one_hot = layout[:, :trainer.num_obj, :, :]
203 | layout_pred_one_hot = layout_pred[:, :trainer.num_obj, :, :]
204 |
205 | trainer.train_generator(imgs, imgs_pred, masks, masks_pred, layout,
206 | objs, boxes, boxes_pred, obj_to_img, use_gt)
207 |
208 | imgs_pred_detach = imgs_pred.detach()
209 | masks_pred_detach = masks_pred.detach()
210 | boxes_pred_detach = boxes.detach()
211 | layout_detach = layout.detach()
212 | layout_wrong_detach = layout_wrong.detach()
213 | trainer.train_mask_discriminator(masks, masks_pred_detach, objs)
214 | trainer.train_obj_discriminator(imgs, imgs_pred_detach, objs, boxes, boxes_pred_detach, obj_to_img)
215 | trainer.train_image_discriminator(imgs, imgs_pred_detach, layout_detach, layout_wrong_detach)
216 |
217 | if t % args.print_every == 0 or t == 1:
218 | trainer.write_losses(checkpoint, t)
219 | trainer.write_images(t, imgs, imgs_pred, layout_one_hot, layout_pred_one_hot)
220 |
221 | if t % args.checkpoint_every == 0:
222 | print('begin check model train')
223 | train_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=True)
224 | print('begin check model val')
225 | val_results = check_model(args, val_loader, trainer.model, inception_score, use_gt=False)
226 | trainer.save_checkpoint(checkpoint, t, args, epoch, train_results, val_results)
227 |
228 |
229 | if __name__ == '__main__':
230 | args = get_args()
231 | main(args)
232 |
--------------------------------------------------------------------------------
/scripts/gui/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | INTERACTIVE SCENE GENERATION
12 | ORON ASHUAL & LIOR WOLF
13 |
14 |
15 |
16 |
17 | 1. SELECT 3 OBJECTS OR MORE
18 |
19 |
20 |
21 |
22 | - PEOPLE
23 |
32 |
33 | - ANIMALS
34 |
46 |
47 | - VEHICLES
48 |
58 |
59 | - FOOD
60 |
78 |
79 | - SPORTS
80 |
93 |
94 | - INTERIOR
95 |
149 |
150 | - EXTERIOR
151 |
197 |
198 | - MATERIALS
199 |
224 |
225 |
226 |
227 |
228 |
229 |
230 | 2. LOCATE AND RESIZE YOUR OBJECTS
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |

264 |
265 |
266 |

267 |
268 |
269 |
270 |
271 | 3. ADJUST THE APPEARANCE
272 |
273 |
274 |
275 |
276 | 0
277 |
278 |
279 |
280 |
281 |
282 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/scene_generation/bilinear.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import torch
18 | import torch.nn.functional as F
19 |
20 | """
21 | Functions for performing differentiable bilinear cropping of images, for use in
22 | the object discriminator
23 | """
24 |
25 |
26 | def crop_bbox_batch(feats, bbox, bbox_to_feats, HH, WW=None, backend='cudnn'):
27 | """
28 | Inputs:
29 | - feats: FloatTensor of shape (N, C, H, W)
30 | - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates
31 | - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps;
32 | each element is in the range [0, N) and bbox_to_feats[b] = i means that
33 | bbox[b] will be cropped from feats[i].
34 | - HH, WW: Size of the output crops
35 |
36 | Returns:
37 | - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to
38 | crop from feats[bbox_to_feats[i]].
39 | """
40 | if backend == 'cudnn':
41 | return crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW)
42 | N, C, H, W = feats.size()
43 | B = bbox.size(0)
44 | if WW is None: WW = HH
45 | dtype, device = feats.dtype, feats.device
46 | crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device)
47 | for i in range(N):
48 | idx = (bbox_to_feats.data == i).nonzero()
49 | if idx.dim() == 0:
50 | continue
51 | idx = idx.view(-1)
52 | n = idx.size(0)
53 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
54 | cur_bbox = bbox[idx]
55 | cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW)
56 | crops[idx] = cur_crops
57 | return crops
58 |
59 |
60 | def _invperm(p):
61 | N = p.size(0)
62 | eye = torch.arange(0, N).type_as(p)
63 | pp = (eye[:, None] == p).nonzero()[:, 1]
64 | return pp
65 |
66 |
67 | def crop_bbox_batch_cudnn(feats, bbox, bbox_to_feats, HH, WW=None):
68 | N, C, H, W = feats.size()
69 | B = bbox.size(0)
70 | if WW is None: WW = HH
71 | dtype = feats.data.type()
72 |
73 | feats_flat, bbox_flat, all_idx = [], [], []
74 | for i in range(N):
75 | idx = (bbox_to_feats.data == i).nonzero()
76 | if idx.dim() == 0:
77 | continue
78 | idx = idx.view(-1)
79 | n = idx.size(0)
80 | cur_feats = feats[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
81 | cur_bbox = bbox[idx]
82 |
83 | feats_flat.append(cur_feats)
84 | bbox_flat.append(cur_bbox)
85 | all_idx.append(idx)
86 |
87 | feats_flat = torch.cat(feats_flat, dim=0)
88 | bbox_flat = torch.cat(bbox_flat, dim=0)
89 | crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn')
90 |
91 | # If the crops were sequential (all_idx is identity permutation) then we can
92 | # simply return them; otherwise we need to permute crops by the inverse
93 | # permutation from all_idx.
94 | all_idx = torch.cat(all_idx, dim=0)
95 | eye = torch.arange(0, B).type_as(all_idx)
96 | if (all_idx == eye).all():
97 | return crops
98 | return crops[_invperm(all_idx)]
99 |
100 |
101 | def crop_bbox(feats, bbox, HH, WW=None, backend='cudnn'):
102 | """
103 | Take differentiable crops of feats specified by bbox.
104 |
105 | Inputs:
106 | - feats: Tensor of shape (N, C, H, W)
107 | - bbox: Bounding box coordinates of shape (N, 4) in the format
108 | [x0, y0, x1, y1] in the [0, 1] coordinate space.
109 | - HH, WW: Size of the output crops.
110 |
111 | Returns:
112 | - crops: Tensor of shape (N, C, HH, WW) where crops[i] is the portion of
113 | feats[i] specified by bbox[i], reshaped to (HH, WW) using bilinear sampling.
114 | """
115 | N = feats.size(0)
116 | assert bbox.size(0) == N
117 | assert bbox.size(1) == 4
118 | if WW is None: WW = HH
119 | if backend == 'cudnn':
120 | # Change box from [0, 1] to [-1, 1] coordinate system
121 | bbox = 2 * bbox - 1
122 | x0, y0 = bbox[:, 0], bbox[:, 1]
123 | x1, y1 = bbox[:, 2], bbox[:, 3]
124 | X = tensor_linspace(x0, x1, steps=WW).view(N, 1, WW).expand(N, HH, WW)
125 | Y = tensor_linspace(y0, y1, steps=HH).view(N, HH, 1).expand(N, HH, WW)
126 | if backend == 'jj':
127 | return bilinear_sample(feats, X, Y)
128 | elif backend == 'cudnn':
129 | grid = torch.stack([X, Y], dim=3)
130 | return F.grid_sample(feats, grid)
131 |
132 |
133 | def uncrop_bbox(feats, bbox, H, W=None, fill_value=0):
134 | """
135 | Inverse operation to crop_bbox; construct output images where the feature maps
136 | from feats have been reshaped and placed into the positions specified by bbox.
137 |
138 | Inputs:
139 | - feats: Tensor of shape (N, C, HH, WW)
140 | - bbox: Bounding box coordinates of shape (N, 4) in the format
141 | [x0, y0, x1, y1] in the [0, 1] coordinate space.
142 | - H, W: Size of output.
143 | - fill_value: Portions of the output image that are outside the bounding box
144 | will be filled with this value.
145 |
146 | Returns:
147 | - out: Tensor of shape (N, C, H, W) where the portion of out[i] given by
148 | bbox[i] contains feats[i], reshaped using bilinear sampling.
149 | """
150 | N, C = feats.size(0), feats.size(1)
151 | assert bbox.size(0) == N
152 | assert bbox.size(1) == 4
153 | if W is None: H = W
154 |
155 | x0, y0 = bbox[:, 0], bbox[:, 1]
156 | x1, y1 = bbox[:, 2] + bbox[:, 0], bbox[:, 3] + bbox[:, 1]
157 | ww = x1 - x0
158 | hh = y1 - y0
159 |
160 | x0 = x0.contiguous().view(N, 1).expand(N, H)
161 | x1 = x1.contiguous().view(N, 1).expand(N, H)
162 | ww = ww.view(N, 1).expand(N, H)
163 |
164 | y0 = y0.contiguous().view(N, 1).expand(N, W)
165 | y1 = y1.contiguous().view(N, 1).expand(N, W)
166 | hh = hh.view(N, 1).expand(N, W)
167 |
168 | X = torch.linspace(0, 1, steps=W).view(1, W).expand(N, W).to(feats)
169 | Y = torch.linspace(0, 1, steps=H).view(1, H).expand(N, H).to(feats)
170 |
171 | X = (X - x0) / ww
172 | Y = (Y - y0) / hh
173 |
174 | # For ByteTensors, (x + y).clamp(max=1) gives logical_or
175 | X_out_mask = ((X < 0) + (X > 1)).view(N, 1, W).expand(N, H, W)
176 | Y_out_mask = ((Y < 0) + (Y > 1)).view(N, H, 1).expand(N, H, W)
177 | out_mask = (X_out_mask + Y_out_mask).clamp(max=1)
178 | out_mask = out_mask.view(N, 1, H, W).expand(N, C, H, W)
179 |
180 | X = X.view(N, 1, W).expand(N, H, W)
181 | Y = Y.view(N, H, 1).expand(N, H, W)
182 |
183 | out = bilinear_sample(feats, X, Y)
184 | out[out_mask] = fill_value
185 | return out
186 |
187 |
188 | def bilinear_sample(feats, X, Y):
189 | """
190 | Perform bilinear sampling on the features in feats using the sampling grid
191 | given by X and Y.
192 |
193 | Inputs:
194 | - feats: Tensor holding input feature map, of shape (N, C, H, W)
195 | - X, Y: Tensors holding x and y coordinates of the sampling
196 | grids; both have shape shape (N, HH, WW) and have elements in the range [0, 1].
197 | Returns:
198 | - out: Tensor of shape (B, C, HH, WW) where out[i] is computed
199 | by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]).
200 | """
201 | N, C, H, W = feats.size()
202 | assert X.size() == Y.size()
203 | assert X.size(0) == N
204 | _, HH, WW = X.size()
205 |
206 | X = X.mul(W)
207 | Y = Y.mul(H)
208 |
209 | # Get the x and y coordinates for the four samples
210 | x0 = X.floor().clamp(min=0, max=W - 1)
211 | x1 = (x0 + 1).clamp(min=0, max=W - 1)
212 | y0 = Y.floor().clamp(min=0, max=H - 1)
213 | y1 = (y0 + 1).clamp(min=0, max=H - 1)
214 |
215 | # In numpy we could do something like feats[i, :, y0, x0] to pull out
216 | # the elements of feats at coordinates y0 and x0, but PyTorch doesn't
217 | # yet support this style of indexing. Instead we have to use the gather
218 | # method, which only allows us to index along one dimension at a time;
219 | # therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W)
220 | # and index along the last dimension. Below we generate linear indices into
221 | # the collapsed last dimension for each of the four combinations we need.
222 | y0x0_idx = (W * y0 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
223 | y1x0_idx = (W * y1 + x0).view(N, 1, HH * WW).expand(N, C, HH * WW)
224 | y0x1_idx = (W * y0 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
225 | y1x1_idx = (W * y1 + x1).view(N, 1, HH * WW).expand(N, C, HH * WW)
226 |
227 | # Actually use gather to pull out the values from feats corresponding
228 | # to our four samples, then reshape them to (BB, C, HH, WW)
229 | feats_flat = feats.view(N, C, H * W)
230 | v1 = feats_flat.gather(2, y0x0_idx.long()).view(N, C, HH, WW)
231 | v2 = feats_flat.gather(2, y1x0_idx.long()).view(N, C, HH, WW)
232 | v3 = feats_flat.gather(2, y0x1_idx.long()).view(N, C, HH, WW)
233 | v4 = feats_flat.gather(2, y1x1_idx.long()).view(N, C, HH, WW)
234 |
235 | # Compute the weights for the four samples
236 | w1 = ((x1 - X) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
237 | w2 = ((x1 - X) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
238 | w3 = ((X - x0) * (y1 - Y)).view(N, 1, HH, WW).expand(N, C, HH, WW)
239 | w4 = ((X - x0) * (Y - y0)).view(N, 1, HH, WW).expand(N, C, HH, WW)
240 |
241 | # Multiply the samples by the weights to give our interpolated results.
242 | out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
243 | return out
244 |
245 |
246 | def tensor_linspace(start, end, steps=10):
247 | """
248 | Vectorized version of torch.linspace.
249 |
250 | Inputs:
251 | - start: Tensor of any shape
252 | - end: Tensor of the same shape as start
253 | - steps: Integer
254 |
255 | Returns:
256 | - out: Tensor of shape start.size() + (steps,), such that
257 | out.select(-1, 0) == start, out.select(-1, -1) == end,
258 | and the other elements of out linearly interpolate between
259 | start and end.
260 | """
261 | assert start.size() == end.size()
262 | view_size = start.size() + (1,)
263 | w_size = (1,) * start.dim() + (steps,)
264 | out_size = start.size() + (steps,)
265 |
266 | start_w = torch.linspace(1, 0, steps=steps).to(start)
267 | start_w = start_w.view(w_size).expand(out_size)
268 | end_w = torch.linspace(0, 1, steps=steps).to(start)
269 | end_w = end_w.view(w_size).expand(out_size)
270 |
271 | start = start.contiguous().view(view_size).expand(out_size)
272 | end = end.contiguous().view(view_size).expand(out_size)
273 |
274 | out = start_w * start + end_w * end
275 | return out
276 |
277 |
278 | if __name__ == '__main__':
279 | import numpy as np
280 | from scipy.misc import imread, imsave, imresize
281 |
282 | cat = imresize(imread('cat.jpg'), (256, 256), anti_aliasing=True)
283 | dog = imresize(imread('dog.jpg'), (256, 256), anti_aliasing=True)
284 | feats = torch.stack([
285 | torch.from_numpy(cat.transpose(2, 0, 1).astype(np.float32)),
286 | torch.from_numpy(dog.transpose(2, 0, 1).astype(np.float32))],
287 | dim=0)
288 |
289 | boxes = torch.FloatTensor([
290 | [0, 0, 1, 1],
291 | [0.25, 0.25, 0.75, 0.75],
292 | [0, 0, 0.5, 0.5],
293 | ])
294 |
295 | box_to_feats = torch.LongTensor([1, 0, 1]).cuda()
296 |
297 | feats, boxes = feats.cuda(), boxes.cuda()
298 | crops = crop_bbox_batch_cudnn(feats, boxes, box_to_feats, 128)
299 | for i in range(crops.size(0)):
300 | crop_np = crops.data[i].cpu().numpy().transpose(1, 2, 0).astype(np.uint8)
301 | imsave('out%d.png' % i, crop_np)
302 |
--------------------------------------------------------------------------------
/scripts/train_accuracy_net.py:
--------------------------------------------------------------------------------
1 | """ From https://github.com/Prakashvanapalli/pytorch_classifiers
2 | """
3 | import argparse
4 | import json
5 | import os
6 | import time
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import torchvision
12 | from torch.optim import lr_scheduler
13 | from torch.utils.data import DataLoader
14 |
15 | from scene_generation.bilinear import crop_bbox_batch
16 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn
17 | from scene_generation.utils import int_tuple, bool_flag, str_tuple
18 |
19 | parser = argparse.ArgumentParser(description='Training a pytorch model to classify different plants')
20 | parser.add_argument('-idl', '--input_data_loc', help='', default='data/training_data')
21 | parser.add_argument('-mo', '--model_name', default="resnet101")
22 | parser.add_argument('-f', '--freeze_layers', default=True, action='store_false', help='Bool type')
23 | parser.add_argument('-fi', '--freeze_initial_layers', default=True, action='store_false', help='Bool type')
24 | parser.add_argument('-ep', '--epochs', default=20, type=int)
25 | parser.add_argument('-b', '--batch_size', default=4, type=int)
26 | parser.add_argument('-is', '--input_shape', default=224, type=int)
27 | parser.add_argument('-sl', '--save_loc', default="models/")
28 | parser.add_argument("-g", '--use_gpu', default=True, action='store_false', help='Bool type gpu')
29 | parser.add_argument("-p", '--use_parallel', default=False, action='store_false', help='Bool type to use_parallel')
30 |
31 | parser.add_argument('--dataset', default='coco', type=str)
32 | parser.add_argument('--mask_size', default=32, type=int)
33 | parser.add_argument('--image_size', default='256,256', type=int_tuple)
34 | parser.add_argument('--num_train_samples', default=None, type=int)
35 | parser.add_argument('--num_val_samples', default=1024, type=int)
36 | parser.add_argument('--shuffle_val', default=True, type=bool_flag)
37 | parser.add_argument('--loader_num_workers', default=4, type=int)
38 | parser.add_argument('--include_relationships', default=True, type=bool_flag)
39 |
40 | # COCO-specific options
41 | COCO_DIR = os.path.expanduser('datasets/coco')
42 | parser.add_argument('--coco_train_image_dir',
43 | default=os.path.join(COCO_DIR, 'images/train2017'))
44 | parser.add_argument('--coco_val_image_dir',
45 | default=os.path.join(COCO_DIR, 'images/val2017'))
46 | parser.add_argument('--coco_train_instances_json',
47 | default=os.path.join(COCO_DIR, 'annotations/instances_train2017.json'))
48 | parser.add_argument('--coco_train_stuff_json',
49 | default=os.path.join(COCO_DIR, 'annotations/stuff_train2017.json'))
50 | parser.add_argument('--coco_val_instances_json',
51 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json'))
52 | parser.add_argument('--coco_val_stuff_json',
53 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json'))
54 | parser.add_argument('--instance_whitelist', default=None, type=str_tuple)
55 | parser.add_argument('--stuff_whitelist', default=None, type=str_tuple)
56 | parser.add_argument('--coco_include_other', default=False, type=bool_flag)
57 | parser.add_argument('--min_object_size', default=0.02, type=float)
58 | parser.add_argument('--min_objects_per_image', default=3, type=int)
59 | parser.add_argument('--coco_stuff_only', default=True, type=bool_flag)
60 |
61 |
62 | def all_pretrained_models(n_class, name="resnet101", pretrained=True):
63 | if pretrained:
64 | weights = "imagenet"
65 | else:
66 | weights = False
67 |
68 | if name == "resnet18":
69 | print("[Building resnet18]")
70 | model_conv = torchvision.models.resnet18(pretrained=weights)
71 | elif name == "resnet34":
72 | print("[Building resnet34]")
73 | model_conv = torchvision.models.resnet34(pretrained=weights)
74 | elif name == "resnet50":
75 | print("[Building resnet50]")
76 | model_conv = torchvision.models.resnet50(pretrained=weights)
77 | elif name == "resnet101":
78 | print("[Building resnet101]")
79 | model_conv = torchvision.models.resnet101(pretrained=weights)
80 | elif name == "resnet152":
81 | print("[Building resnet152]")
82 | model_conv = torchvision.models.resnet152(pretrained=weights)
83 | else:
84 | raise ValueError
85 |
86 | for i, param in model_conv.named_parameters():
87 | param.requires_grad = False
88 |
89 | num_ftrs = model_conv.fc.in_features
90 | model_conv.fc = nn.Linear(num_ftrs, n_class)
91 |
92 | if "resnet" in name:
93 | print("[Resnet: Freezing layers only till layer1 including]")
94 | ct = []
95 | for name, child in model_conv.named_children():
96 | if "layer1" in ct:
97 | for params in child.parameters():
98 | params.requires_grad = True
99 | ct.append(name)
100 |
101 | return model_conv
102 |
103 |
104 | def build_coco_dsets(args):
105 | dset_kwargs = {
106 | 'image_dir': args.coco_train_image_dir,
107 | 'instances_json': args.coco_train_instances_json,
108 | 'stuff_json': args.coco_train_stuff_json,
109 | 'stuff_only': args.coco_stuff_only,
110 | 'image_size': args.image_size,
111 | 'mask_size': args.mask_size,
112 | 'max_samples': args.num_train_samples,
113 | 'min_object_size': args.min_object_size,
114 | 'min_objects_per_image': args.min_objects_per_image,
115 | 'instance_whitelist': args.instance_whitelist,
116 | 'stuff_whitelist': args.stuff_whitelist,
117 | 'include_other': args.coco_include_other,
118 | 'include_relationships': args.include_relationships,
119 | 'no__img__': True
120 | }
121 | train_dset = CocoSceneGraphDataset(**dset_kwargs)
122 | num_objs = train_dset.total_objects()
123 | num_imgs = len(train_dset)
124 | print('Training dataset has %d images and %d objects' % (num_imgs, num_objs))
125 | print('(%.2f objects per image)' % (float(num_objs) / num_imgs))
126 |
127 | dset_kwargs['image_dir'] = args.coco_val_image_dir
128 | dset_kwargs['instances_json'] = args.coco_val_instances_json
129 | dset_kwargs['stuff_json'] = args.coco_val_stuff_json
130 | dset_kwargs['max_samples'] = args.num_val_samples
131 | val_dset = CocoSceneGraphDataset(**dset_kwargs)
132 |
133 | assert train_dset.vocab == val_dset.vocab
134 | vocab = json.loads(json.dumps(train_dset.vocab))
135 |
136 | return vocab, train_dset, val_dset
137 |
138 |
139 | def build_loaders(args):
140 | vocab, train_dset, val_dset = build_coco_dsets(args)
141 | collate_fn = coco_collate_fn
142 |
143 | loader_kwargs = {
144 | 'batch_size': args.batch_size,
145 | 'num_workers': args.loader_num_workers,
146 | 'shuffle': True,
147 | 'collate_fn': collate_fn,
148 | }
149 | train_loader = DataLoader(train_dset, **loader_kwargs)
150 |
151 | loader_kwargs['shuffle'] = args.shuffle_val
152 | val_loader = DataLoader(val_dset, **loader_kwargs)
153 | return vocab, train_loader, val_loader
154 |
155 |
156 | def train_model(model, test_dataloader, val_dataloader, criterion, optimizer, scheduler, use_gpu, num_epochs=10,
157 | input_shape=224):
158 | since = time.time()
159 |
160 | best_model_wts = model.state_dict()
161 | best_acc = 0.0
162 | device = 'cuda' if use_gpu else 'cpu'
163 |
164 | for epoch in range(num_epochs):
165 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
166 | print('-' * 10)
167 |
168 | # Each epoch has a training and validation phase
169 | for phase in ['train', 'val']:
170 | if phase == 'train':
171 | scheduler.step()
172 | model.train(True) # Set model to training mode
173 | dataloader = test_dataloader
174 | else:
175 | model.train(False) # Set model to evaluate mode
176 | dataloader = val_dataloader
177 |
178 | running_loss = 0.0
179 | running_corrects = 0
180 | objects_len = 0
181 |
182 | # Iterate over data.
183 | for data in dataloader:
184 | # get the inputs
185 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = data
186 | imgs = imgs.to(device)
187 | boxes = boxes.to(device)
188 | obj_to_img = obj_to_img.to(device)
189 | labels = objs.to(device)
190 |
191 | objects_len += obj_to_img.shape[0]
192 |
193 | with torch.no_grad():
194 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, input_shape)
195 |
196 | # zero the parameter gradients
197 | optimizer.zero_grad()
198 |
199 | # forward
200 | outputs = model(crops)
201 | if type(outputs) == tuple:
202 | outputs, _ = outputs
203 | _, preds = torch.max(outputs, 1)
204 | loss = criterion(outputs, labels)
205 |
206 | # backward + optimize only if in training phase
207 | if phase == 'train':
208 | loss.backward()
209 | optimizer.step()
210 |
211 | # statistics
212 | running_loss += loss.item()
213 | running_corrects += torch.sum(preds == labels)
214 |
215 | epoch_loss = running_loss / objects_len
216 | epoch_acc = running_corrects.item() / objects_len
217 |
218 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
219 |
220 | # deep copy the model
221 | if phase == 'val' and epoch_acc > best_acc:
222 | best_acc = epoch_acc
223 | best_model_wts = model.state_dict()
224 |
225 | print()
226 |
227 | time_elapsed = time.time() - since
228 | print('Training complete in {:.0f}m {:.0f}s'.format(
229 | time_elapsed // 60, time_elapsed % 60))
230 | print('Best val Acc: {:4f}'.format(best_acc))
231 |
232 | # load best model weights
233 | model.load_state_dict(best_model_wts)
234 | return model
235 |
236 |
237 | def load_model(model_path):
238 | model_name = 'resnet101'
239 | model = all_pretrained_models(172, name=model_name).cuda()
240 | model.load_state_dict(torch.load(model_path))
241 | model.eval()
242 | return model
243 |
244 |
245 | if __name__ == '__main__':
246 | args = parser.parse_args()
247 | print(args)
248 | device = 'cuda' if args.use_gpu else 'cpu'
249 | vocab, train_loader, val_loader = build_loaders(args)
250 | num_objs = 172 # len(vocab['object_to_idx'])
251 | print("[Load the model...]")
252 | # Parameters of newly constructed modules have requires_grad=True by default
253 | print(
254 | "Loading model using class: {}, use_gpu: {}, freeze_layers: {}, freeze_initial_layers: {}, name_of_model: {}".format(
255 | num_objs, args.use_gpu, args.freeze_layers, args.freeze_initial_layers, args.model_name))
256 | model_conv = all_pretrained_models(num_objs, name=args.model_name).to(device)
257 | if args.use_parallel:
258 | print("[Using all the available GPUs]")
259 | model_conv = nn.DataParallel(model_conv, device_ids=[0, 1])
260 |
261 | print("[Using CrossEntropyLoss...]")
262 | criterion = nn.CrossEntropyLoss()
263 |
264 | print("[Using small learning rate with momentum...]")
265 | optimizer_conv = optim.SGD(list(filter(lambda p: p.requires_grad, model_conv.parameters())), lr=0.001, momentum=0.9)
266 |
267 | print("[Creating Learning rate scheduler...]")
268 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
269 |
270 | print("[Training the model begun ....]")
271 | model_ft = train_model(model_conv, train_loader, val_loader, criterion, optimizer_conv, exp_lr_scheduler,
272 | args.use_gpu, num_epochs=args.epochs, input_shape=args.input_shape)
273 |
274 | print("[Save the best model]")
275 | model_save_loc = './{}_{}_classes.pth'.format(args.model_name, num_objs)
276 | torch.save(model_ft.state_dict(), model_save_loc)
277 |
--------------------------------------------------------------------------------
/scene_generation/layers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import functools
18 |
19 | import torch.nn as nn
20 | from torch.nn.functional import interpolate
21 |
22 |
23 | def get_normalization_2d(channels, normalization):
24 | if normalization == 'instance':
25 | return nn.InstanceNorm2d(channels)
26 | elif normalization == 'batch':
27 | return nn.BatchNorm2d(channels)
28 | elif normalization == 'none':
29 | return None
30 | else:
31 | raise ValueError('Unrecognized normalization type "%s"' % normalization)
32 |
33 |
34 | def get_activation(name):
35 | kwargs = {}
36 | if name.lower().startswith('leakyrelu'):
37 | if '-' in name:
38 | slope = float(name.split('-')[1])
39 | kwargs = {'negative_slope': slope}
40 | name = 'leakyrelu'
41 | activations = {
42 | 'relu': nn.ReLU,
43 | 'leakyrelu': nn.LeakyReLU,
44 | }
45 | if name.lower() not in activations:
46 | raise ValueError('Invalid activation "%s"' % name)
47 | return activations[name.lower()](**kwargs)
48 |
49 |
50 | def _init_conv(layer, method):
51 | if not isinstance(layer, nn.Conv2d):
52 | return
53 | if method == 'default':
54 | return
55 | elif method == 'kaiming-normal':
56 | nn.init.kaiming_normal(layer.weight)
57 | elif method == 'kaiming-uniform':
58 | nn.init.kaiming_uniform(layer.weight)
59 |
60 |
61 | class Flatten(nn.Module):
62 | def forward(self, x):
63 | return x.view(x.size(0), -1)
64 |
65 | def __repr__(self):
66 | return 'Flatten()'
67 |
68 |
69 | class Unflatten(nn.Module):
70 | def __init__(self, size):
71 | super(Unflatten, self).__init__()
72 | self.size = size
73 |
74 | def forward(self, x):
75 | return x.view(*self.size)
76 |
77 | def __repr__(self):
78 | size_str = ', '.join('%d' % d for d in self.size)
79 | return 'Unflatten(%s)' % size_str
80 |
81 |
82 | class GlobalAvgPool(nn.Module):
83 | def forward(self, x):
84 | N, C = x.size(0), x.size(1)
85 | return x.view(N, C, -1).mean(dim=2)
86 |
87 |
88 | class ResidualBlock(nn.Module):
89 | def __init__(self, channels, normalization='batch', activation='relu',
90 | padding='same', kernel_size=3, init='default'):
91 | super(ResidualBlock, self).__init__()
92 |
93 | K = kernel_size
94 | P = _get_padding(K, padding)
95 | C = channels
96 | self.padding = P
97 | layers = [
98 | get_normalization_2d(C, normalization),
99 | get_activation(activation),
100 | nn.Conv2d(C, C, kernel_size=K, padding=P),
101 | get_normalization_2d(C, normalization),
102 | get_activation(activation),
103 | nn.Conv2d(C, C, kernel_size=K, padding=P),
104 | ]
105 | layers = [layer for layer in layers if layer is not None]
106 | for layer in layers:
107 | _init_conv(layer, method=init)
108 | self.net = nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | P = self.padding
112 | shortcut = x
113 | if P == 0:
114 | shortcut = x[:, :, P:-P, P:-P]
115 | y = self.net(x)
116 | return shortcut + self.net(x)
117 |
118 |
119 | def _get_padding(K, mode):
120 | """ Helper method to compute padding size """
121 | if mode == 'valid':
122 | return 0
123 | elif mode == 'same':
124 | assert K % 2 == 1, 'Invalid kernel size %d for "same" padding' % K
125 | return (K - 1) // 2
126 |
127 |
128 | def build_cnn(arch, normalization='batch', activation='relu', padding='same',
129 | pooling='max', init='default'):
130 | """
131 | Build a CNN from an architecture string, which is a list of layer
132 | specification strings. The overall architecture can be given as a list or as
133 | a comma-separated string.
134 |
135 | All convolutions *except for the first* are preceeded by normalization and
136 | nonlinearity.
137 |
138 | All other layers support the following:
139 | - IX: Indicates that the number of input channels to the network is X.
140 | Can only be used at the first layer; if not present then we assume
141 | 3 input channels.
142 | - CK-X: KxK convolution with X output channels
143 | - CK-X-S: KxK convolution with X output channels and stride S
144 | - R: Residual block keeping the same number of channels
145 | - UX: Nearest-neighbor upsampling with factor X
146 | - PX: Spatial pooling with factor X
147 | - FC-X-Y: Flatten followed by fully-connected layer
148 |
149 | Returns a tuple of:
150 | - cnn: An nn.Sequential
151 | - channels: Number of output channels
152 | """
153 | if isinstance(arch, str):
154 | arch = arch.split(',')
155 | cur_C = 3
156 | if len(arch) > 0 and arch[0][0] == 'I':
157 | cur_C = int(arch[0][1:])
158 | arch = arch[1:]
159 |
160 | first_conv = True
161 | flat = False
162 | layers = []
163 | for i, s in enumerate(arch):
164 | if s[0] == 'C':
165 | if not first_conv:
166 | layers.append(get_normalization_2d(cur_C, normalization))
167 | layers.append(get_activation(activation))
168 | first_conv = False
169 | vals = [int(i) for i in s[1:].split('-')]
170 | if len(vals) == 2:
171 | K, next_C = vals
172 | stride = 1
173 | elif len(vals) == 3:
174 | K, next_C, stride = vals
175 | # K, next_C = (int(i) for i in s[1:].split('-'))
176 | P = _get_padding(K, padding)
177 | conv = nn.Conv2d(cur_C, next_C, kernel_size=K, padding=P, stride=stride)
178 | layers.append(conv)
179 | _init_conv(layers[-1], init)
180 | cur_C = next_C
181 | elif s[0] == 'R':
182 | norm = 'none' if first_conv else normalization
183 | res = ResidualBlock(cur_C, normalization=norm, activation=activation,
184 | padding=padding, init=init)
185 | layers.append(res)
186 | first_conv = False
187 | elif s[0] == 'U':
188 | factor = int(s[1:])
189 | layers.append(Interpolate(scale_factor=factor, mode='nearest'))
190 | elif s[0] == 'P':
191 | factor = int(s[1:])
192 | if pooling == 'max':
193 | pool = nn.MaxPool2d(kernel_size=factor, stride=factor)
194 | elif pooling == 'avg':
195 | pool = nn.AvgPool2d(kernel_size=factor, stride=factor)
196 | layers.append(pool)
197 | elif s[:2] == 'FC':
198 | _, Din, Dout = s.split('-')
199 | Din, Dout = int(Din), int(Dout)
200 | if not flat:
201 | layers.append(Flatten())
202 | flat = True
203 | layers.append(nn.Linear(Din, Dout))
204 | if i + 1 < len(arch):
205 | layers.append(get_activation(activation))
206 | cur_C = Dout
207 | else:
208 | raise ValueError('Invalid layer "%s"' % s)
209 | layers = [layer for layer in layers if layer is not None]
210 | # for layer in layers:
211 | # print(layer)
212 | return nn.Sequential(*layers), cur_C
213 |
214 |
215 | def build_mlp(dim_list, activation='relu', batch_norm='none',
216 | dropout=0, final_nonlinearity=True):
217 | layers = []
218 | for i in range(len(dim_list) - 1):
219 | dim_in, dim_out = dim_list[i], dim_list[i + 1]
220 | layers.append(nn.Linear(dim_in, dim_out))
221 | final_layer = (i == len(dim_list) - 2)
222 | if not final_layer or final_nonlinearity:
223 | if batch_norm == 'batch':
224 | layers.append(nn.BatchNorm1d(dim_out))
225 | if activation == 'relu':
226 | layers.append(nn.ReLU())
227 | elif activation == 'leakyrelu':
228 | layers.append(nn.LeakyReLU())
229 | if dropout > 0:
230 | layers.append(nn.Dropout(p=dropout))
231 | return nn.Sequential(*layers)
232 |
233 |
234 | class ResnetBlock(nn.Module):
235 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
236 | super(ResnetBlock, self).__init__()
237 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
238 |
239 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
240 | conv_block = []
241 | p = 0
242 | if padding_type == 'reflect':
243 | conv_block += [nn.ReflectionPad2d(1)]
244 | elif padding_type == 'replicate':
245 | conv_block += [nn.ReplicationPad2d(1)]
246 | elif padding_type == 'zero':
247 | p = 1
248 | else:
249 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
250 |
251 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
252 | norm_layer(dim),
253 | activation]
254 | if use_dropout:
255 | conv_block += [nn.Dropout(0.5)]
256 |
257 | p = 0
258 | if padding_type == 'reflect':
259 | conv_block += [nn.ReflectionPad2d(1)]
260 | elif padding_type == 'replicate':
261 | conv_block += [nn.ReplicationPad2d(1)]
262 | elif padding_type == 'zero':
263 | p = 1
264 | else:
265 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
266 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
267 | norm_layer(dim)]
268 |
269 | return nn.Sequential(*conv_block)
270 |
271 | def forward(self, x):
272 | out = x + self.conv_block(x)
273 | return out
274 |
275 |
276 | class ConditionalBatchNorm2d(nn.Module):
277 | def __init__(self, num_features, num_classes):
278 | super(ConditionalBatchNorm2d).__init__()
279 | self.num_features = num_features
280 | self.bn = nn.BatchNorm2d(num_features, affine=False)
281 | self.embed = nn.Embedding(num_classes, num_features * 2)
282 | self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
283 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
284 |
285 | def forward(self, x, y):
286 | out = self.bn(x)
287 | gamma, beta = self.embed(y).chunk(2, 1)
288 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
289 | return out
290 |
291 |
292 | def get_norm_layer(norm_type='instance'):
293 | if norm_type == 'batch':
294 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
295 | elif norm_type == 'instance':
296 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
297 | elif norm_type == 'conditional':
298 | norm_layer = functools.partial(ConditionalBatchNorm2d)
299 | else:
300 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
301 | return norm_layer
302 |
303 |
304 | class Interpolate(nn.Module):
305 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
306 | super(Interpolate, self).__init__()
307 | self.size = size
308 | self.scale_factor = scale_factor
309 | self.mode = mode
310 | self.align_corners = align_corners
311 |
312 | def forward(self, x):
313 | return interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
314 | align_corners=self.align_corners)
315 |
--------------------------------------------------------------------------------
/scene_generation/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from scene_generation.bilinear import crop_bbox_batch
5 | from scene_generation.generators import mask_net, AppearanceEncoder, define_G
6 | from scene_generation.graph import GraphTripleConv, GraphTripleConvNet
7 | from scene_generation.layers import build_mlp
8 | from scene_generation.layout import masks_to_layout
9 | from scene_generation.utils import VectorPool
10 |
11 |
12 | class Model(nn.Module):
13 | def __init__(self, vocab, image_size=(64, 64), embedding_dim=128,
14 | gconv_dim=128, gconv_hidden_dim=512,
15 | gconv_pooling='avg', gconv_num_layers=5,
16 | mask_size=32, mlp_normalization='none', appearance_normalization='', activation='',
17 | n_downsample_global=4, box_dim=128,
18 | use_attributes=False, box_noise_dim=64,
19 | mask_noise_dim=64, pool_size=100, rep_size=32):
20 | super(Model, self).__init__()
21 |
22 | self.vocab = vocab
23 | self.image_size = image_size
24 | self.use_attributes = use_attributes
25 | self.box_noise_dim = box_noise_dim
26 | self.mask_noise_dim = mask_noise_dim
27 | self.object_size = 64
28 | self.fake_pool = VectorPool(pool_size)
29 |
30 | self.num_objs = len(vocab['object_to_idx'])
31 | self.num_preds = len(vocab['pred_idx_to_name'])
32 | self.obj_embeddings = nn.Embedding(self.num_objs, embedding_dim)
33 | self.pred_embeddings = nn.Embedding(self.num_preds, embedding_dim)
34 |
35 | if use_attributes:
36 | attributes_dim = vocab['num_attributes']
37 | else:
38 | attributes_dim = 0
39 | if gconv_num_layers == 0:
40 | self.gconv = nn.Linear(embedding_dim, gconv_dim)
41 | elif gconv_num_layers > 0:
42 | gconv_kwargs = {
43 | 'input_dim': embedding_dim,
44 | 'attributes_dim': attributes_dim,
45 | 'output_dim': gconv_dim,
46 | 'hidden_dim': gconv_hidden_dim,
47 | 'pooling': gconv_pooling,
48 | 'mlp_normalization': mlp_normalization,
49 | }
50 | self.gconv = GraphTripleConv(**gconv_kwargs)
51 |
52 | self.gconv_net = None
53 | if gconv_num_layers > 1:
54 | gconv_kwargs = {
55 | 'input_dim': gconv_dim,
56 | 'hidden_dim': gconv_hidden_dim,
57 | 'pooling': gconv_pooling,
58 | 'num_layers': gconv_num_layers - 1,
59 | 'mlp_normalization': mlp_normalization,
60 | }
61 | self.gconv_net = GraphTripleConvNet(**gconv_kwargs)
62 |
63 | box_net_dim = 4
64 | self.box_dim = box_dim
65 | box_net_layers = [self.box_dim, gconv_hidden_dim, box_net_dim]
66 | self.box_net = build_mlp(box_net_layers, batch_norm=mlp_normalization)
67 |
68 | self.g_mask_dim = gconv_dim + mask_noise_dim
69 | self.mask_net = mask_net(self.g_mask_dim, mask_size)
70 |
71 | self.repr_input = self.g_mask_dim
72 | rep_size = rep_size
73 | rep_hidden_size = 64
74 | repr_layers = [self.repr_input, rep_hidden_size, rep_size]
75 | self.repr_net = build_mlp(repr_layers, batch_norm=mlp_normalization)
76 |
77 | appearance_encoder_kwargs = {
78 | 'vocab': vocab,
79 | 'arch': 'C4-64-2,C4-128-2,C4-256-2',
80 | 'normalization': appearance_normalization,
81 | 'activation': activation,
82 | 'padding': 'valid',
83 | 'vecs_size': self.g_mask_dim
84 | }
85 | self.image_encoder = AppearanceEncoder(**appearance_encoder_kwargs)
86 |
87 | netG_input_nc = self.num_objs + rep_size
88 | output_nc = 3
89 | ngf = 64
90 | n_blocks_global = 9
91 | norm = 'instance'
92 | self.layout_to_image = define_G(netG_input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm)
93 |
94 | def forward(self, gt_imgs, objs, triples, obj_to_img, boxes_gt=None, masks_gt=None, attributes=None,
95 | test_mode=False, use_gt_box=False, features=None):
96 | O, T = objs.size(0), triples.size(0)
97 | obj_vecs, pred_vecs = self.scene_graph_to_vectors(objs, triples, attributes)
98 |
99 | box_vecs, mask_vecs, scene_layout_vecs, wrong_layout_vecs = \
100 | self.create_components_vecs(gt_imgs, boxes_gt, obj_to_img, objs, obj_vecs, features)
101 |
102 | # Generate Boxes
103 | boxes_pred = self.box_net(box_vecs)
104 |
105 | # Generate Masks
106 | mask_scores = self.mask_net(mask_vecs.view(O, -1, 1, 1))
107 | masks_pred = mask_scores.squeeze(1).sigmoid()
108 |
109 | H, W = self.image_size
110 |
111 | if test_mode:
112 | boxes = boxes_gt if use_gt_box else boxes_pred
113 | masks = masks_gt if masks_gt is not None else masks_pred
114 | gt_layout = None
115 | pred_layout = masks_to_layout(scene_layout_vecs, boxes, masks, obj_to_img, H, W, test_mode=True)
116 | wrong_layout = None
117 | imgs_pred = self.layout_to_image(pred_layout)
118 | else:
119 | gt_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False)
120 | pred_layout = masks_to_layout(scene_layout_vecs, boxes_gt, masks_pred, obj_to_img, H, W, test_mode=False)
121 | wrong_layout = masks_to_layout(wrong_layout_vecs, boxes_gt, masks_gt, obj_to_img, H, W, test_mode=False)
122 |
123 | imgs_pred = self.layout_to_image(gt_layout)
124 | return imgs_pred, boxes_pred, masks_pred, gt_layout, pred_layout, wrong_layout
125 |
126 | def scene_graph_to_vectors(self, objs, triples, attributes):
127 | s, p, o = triples.chunk(3, dim=1)
128 | s, p, o = [x.squeeze(1) for x in [s, p, o]]
129 | edges = torch.stack([s, o], dim=1)
130 |
131 | obj_vecs = self.obj_embeddings(objs)
132 | pred_vecs = self.pred_embeddings(p)
133 | if self.use_attributes:
134 | obj_vecs = torch.cat([obj_vecs, attributes], dim=1)
135 |
136 | if isinstance(self.gconv, nn.Linear):
137 | obj_vecs = self.gconv(obj_vecs)
138 | else:
139 | obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges)
140 | if self.gconv_net is not None:
141 | obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges)
142 |
143 | return obj_vecs, pred_vecs
144 |
145 | def create_components_vecs(self, imgs, boxes, obj_to_img, objs, obj_vecs, features):
146 | O = objs.size(0)
147 | box_vecs = obj_vecs
148 | mask_vecs = obj_vecs
149 | layout_noise = torch.randn((1, self.mask_noise_dim), dtype=mask_vecs.dtype, device=mask_vecs.device) \
150 | .repeat((O, 1)) \
151 | .view(O, self.mask_noise_dim)
152 | mask_vecs = torch.cat([mask_vecs, layout_noise], dim=1)
153 |
154 | # create encoding
155 | if features is None:
156 | crops = crop_bbox_batch(imgs, boxes, obj_to_img, self.object_size)
157 | obj_repr = self.repr_net(self.image_encoder(crops))
158 | else:
159 | # Only in inference time
160 | obj_repr = self.repr_net(mask_vecs)
161 | for ind, feature in enumerate(features):
162 | if feature is not None:
163 | obj_repr[ind, :] = feature
164 | # create one-hot vector for label map
165 | one_hot_size = (O, self.num_objs)
166 | one_hot_obj = torch.zeros(one_hot_size, dtype=obj_repr.dtype, device=obj_repr.device)
167 | one_hot_obj = one_hot_obj.scatter_(1, objs.view(-1, 1).long(), 1.0)
168 | layout_vecs = torch.cat([one_hot_obj, obj_repr], dim=1)
169 |
170 | wrong_objs_rep = self.fake_pool.query(objs, obj_repr)
171 | wrong_layout_vecs = torch.cat([one_hot_obj, wrong_objs_rep], dim=1)
172 | return box_vecs, mask_vecs, layout_vecs, wrong_layout_vecs
173 |
174 | def encode_scene_graphs(self, scene_graphs, rand=False):
175 | """
176 | Encode one or more scene graphs using this model's vocabulary. Inputs to
177 | this method are scene graphs represented as dictionaries like the following:
178 |
179 | {
180 | "objects": ["cat", "dog", "sky"],
181 | "relationships": [
182 | [0, "next to", 1],
183 | [0, "beneath", 2],
184 | [2, "above", 1],
185 | ]
186 | }
187 |
188 | This scene graph has three relationshps: cat next to dog, cat beneath sky,
189 | and sky above dog.
190 |
191 | Inputs:
192 | - scene_graphs: A dictionary giving a single scene graph, or a list of
193 | dictionaries giving a sequence of scene graphs.
194 |
195 | Returns a tuple of LongTensors (objs, triples, obj_to_img) that have the
196 | same semantics as self.forward. The returned LongTensors will be on the
197 | same device as the model parameters.
198 | """
199 | if isinstance(scene_graphs, dict):
200 | # We just got a single scene graph, so promote it to a list
201 | scene_graphs = [scene_graphs]
202 | device = next(self.parameters()).device
203 | objs, triples, obj_to_img = [], [], []
204 | all_attributes = []
205 | all_features = []
206 | obj_offset = 0
207 | for i, sg in enumerate(scene_graphs):
208 | attributes = torch.zeros([len(sg['objects']) + 1, 25 + 10], dtype=torch.float, device=device)
209 | # Insert dummy __image__ object and __in_image__ relationships
210 | sg['objects'].append('__image__')
211 | sg['features'].append(sg['image_id'])
212 | image_idx = len(sg['objects']) - 1
213 | for j in range(image_idx):
214 | sg['relationships'].append([j, '__in_image__', image_idx])
215 |
216 | for obj in sg['objects']:
217 | obj_idx = self.vocab['object_to_idx'][str(self.vocab['object_name_to_idx'][obj])]
218 | if obj_idx is None:
219 | raise ValueError('Object "%s" not in vocab' % obj)
220 | objs.append(obj_idx)
221 | obj_to_img.append(i)
222 | if self.features is not None:
223 | for obj_name, feat_num in zip(objs, sg['features']):
224 | if feat_num == -1:
225 | feat = self.features_one[obj_name][0]
226 | else:
227 | feat = self.features[obj_name][min(feat_num, 99), :]
228 | feat = torch.from_numpy(feat).type(torch.float32).to(device)
229 | all_features.append(feat)
230 | for s, p, o in sg['relationships']:
231 | pred_idx = self.vocab['pred_name_to_idx'].get(p, None)
232 | if pred_idx is None:
233 | raise ValueError('Relationship "%s" not in vocab' % p)
234 | triples.append([s + obj_offset, pred_idx, o + obj_offset])
235 | for i, size_attr in enumerate(sg['attributes']['size']):
236 | attributes[i, size_attr] = 1
237 | # in image size
238 | attributes[-1, 9] = 1
239 | for i, location_attr in enumerate(sg['attributes']['location']):
240 | attributes[i, location_attr + 10] = 1
241 | # in image location
242 | attributes[-1, 12 + 10] = 1
243 | obj_offset += len(sg['objects'])
244 | all_attributes.append(attributes)
245 | objs = torch.tensor(objs, dtype=torch.int64, device=device)
246 | triples = torch.tensor(triples, dtype=torch.int64, device=device)
247 | obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64, device=device)
248 | attributes = torch.cat(all_attributes)
249 | features = all_features
250 | return objs, triples, obj_to_img, attributes, features
251 |
252 | def forward_json(self, scene_graphs):
253 | """ Convenience method that combines encode_scene_graphs and forward. """
254 | objs, triples, obj_to_img, attributes, features = self.encode_scene_graphs(scene_graphs)
255 | return self.forward(None, objs, triples, obj_to_img, attributes=attributes, test_mode=True,
256 | use_gt_box=False, features=features), objs
257 |
--------------------------------------------------------------------------------
/scripts/sample_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from random import randint
4 |
5 | import numpy as np
6 | import torch
7 | from scipy.misc import imsave
8 | from torch.utils.data import DataLoader
9 |
10 | from scene_generation.data import imagenet_deprocess_batch
11 | from scene_generation.data.coco import CocoSceneGraphDataset, coco_collate_fn
12 | from scene_generation.data.coco_panoptic import CocoPanopticSceneGraphDataset, coco_panoptic_collate_fn
13 | from scene_generation.data.utils import split_graph_batch
14 | from scene_generation.metrics import jaccard
15 | from scene_generation.model import Model
16 | from scene_generation.utils import int_tuple, bool_flag
17 | from scene_generation.bilinear import crop_bbox_batch
18 | from scripts.train_accuracy_net import all_pretrained_models
19 |
20 |
21 | def load_model(model_path):
22 | model_name = 'resnet101'
23 | model = all_pretrained_models(172, name=model_name).cuda()
24 | model.load_state_dict(torch.load(model_path))
25 | model.eval()
26 | return model
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--checkpoint', required=True)
31 | parser.add_argument('--model_mode', default='eval', choices=['train', 'eval'])
32 | parser.add_argument('--accuracy_model_path', default=None)
33 |
34 | # Shared dataset options
35 | parser.add_argument('--dataset', default='coco')
36 | parser.add_argument('--image_size', default=(128, 128), type=int_tuple)
37 | parser.add_argument('--batch_size', default=24, type=int)
38 | parser.add_argument('--shuffle', default=False, type=bool_flag)
39 | parser.add_argument('--loader_num_workers', default=4, type=int)
40 | parser.add_argument('--num_samples', default=10000, type=int)
41 | parser.add_argument('--save_gt_imgs', default=False, type=bool_flag)
42 | parser.add_argument('--save_graphs', default=False, type=bool_flag)
43 | parser.add_argument('--use_gt_boxes', default=False, type=bool_flag)
44 | parser.add_argument('--use_gt_masks', default=False, type=bool_flag)
45 | parser.add_argument('--use_gt_attr', default=False, type=bool_flag)
46 | parser.add_argument('--use_gt_textures', default=False, type=bool_flag)
47 | parser.add_argument('--save_layout', default=False, type=bool_flag)
48 | parser.add_argument('--sample_attributes', default=False, type=bool_flag)
49 | parser.add_argument('--object_size', default=64, type=int)
50 | parser.add_argument('--grid_size', default=25, type=int)
51 |
52 | parser.add_argument('--output_dir', default='output')
53 |
54 | COCO_DIR = os.path.expanduser('../data/coco')
55 | parser.add_argument('--coco_image_dir',
56 | default=os.path.join(COCO_DIR, 'images/val2017'))
57 | parser.add_argument('--instances_json',
58 | default=os.path.join(COCO_DIR, 'annotations/instances_val2017.json'))
59 | parser.add_argument('--stuff_json',
60 | default=os.path.join(COCO_DIR, 'annotations/stuff_val2017.json'))
61 |
62 |
63 | def build_coco_dset(args, checkpoint):
64 | checkpoint_args = checkpoint['args']
65 | print('include other: ', checkpoint_args.get('coco_include_other'))
66 | # When using GT masks, using the
67 | mask_size = args.image_size[0] if args.use_gt_masks else checkpoint_args['mask_size']
68 | dset_kwargs = {
69 | 'image_dir': args.coco_image_dir,
70 | 'instances_json': args.instances_json,
71 | 'stuff_json': args.stuff_json,
72 | 'image_size': args.image_size,
73 | 'mask_size': mask_size,
74 | 'max_samples': args.num_samples,
75 | 'min_object_size': checkpoint_args['min_object_size'],
76 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'],
77 | 'instance_whitelist': checkpoint_args['instance_whitelist'],
78 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'],
79 | 'include_other': checkpoint_args.get('coco_include_other', True),
80 | 'test_part': True,
81 | 'sample_attributes': args.sample_attributes,
82 | 'grid_size': args.grid_size
83 | }
84 | dset = CocoSceneGraphDataset(**dset_kwargs)
85 | return dset
86 |
87 |
88 | def build_coco_panoptic_dset(args, checkpoint):
89 | checkpoint_args = checkpoint['args']
90 | print('include other: ', checkpoint_args.get('coco_include_other'))
91 | # When using GT masks, using the
92 | mask_size = args.image_size[0] if args.use_gt_masks else checkpoint_args['mask_size']
93 | dset_kwargs = {
94 | 'image_dir': args.coco_image_dir,
95 | 'instances_json': args.instances_json,
96 | 'panoptic': checkpoint_args['coco_panoptic_val'],
97 | 'panoptic_segmentation': checkpoint_args['coco_panoptic_segmentation_val'],
98 | 'stuff_json': args.stuff_json,
99 | 'image_size': args.image_size,
100 | 'mask_size': mask_size,
101 | 'max_samples': args.num_samples,
102 | 'min_object_size': checkpoint_args['min_object_size'],
103 | 'min_objects_per_image': checkpoint_args['min_objects_per_image'],
104 | 'instance_whitelist': checkpoint_args['instance_whitelist'],
105 | 'stuff_whitelist': checkpoint_args['stuff_whitelist'],
106 | 'include_other': checkpoint_args.get('coco_include_other', True),
107 | 'test_part': True,
108 | 'sample_attributes': args.sample_attributes,
109 | 'grid_size': args.grid_size
110 | }
111 | dset = CocoPanopticSceneGraphDataset(**dset_kwargs)
112 | return dset
113 |
114 |
115 | def build_loader(args, checkpoint, is_panoptic):
116 | if is_panoptic:
117 | dset = build_coco_panoptic_dset(args, checkpoint)
118 | collate_fn = coco_panoptic_collate_fn
119 | else:
120 | dset = build_coco_dset(args, checkpoint)
121 | collate_fn = coco_collate_fn
122 |
123 | loader_kwargs = {
124 | 'batch_size': args.batch_size,
125 | 'num_workers': args.loader_num_workers,
126 | 'shuffle': args.shuffle,
127 | 'collate_fn': collate_fn,
128 | }
129 | loader = DataLoader(dset, **loader_kwargs)
130 | return loader
131 |
132 |
133 | def build_model(args, checkpoint):
134 | kwargs = checkpoint['model_kwargs']
135 | model = Model(**kwargs)
136 | model_state = checkpoint['model_state']
137 | model.load_state_dict(model_state)
138 | if args.model_mode == 'eval':
139 | model.eval()
140 | elif args.model_mode == 'train':
141 | model.train()
142 | model.image_size = args.image_size
143 | model.cuda()
144 | return model
145 |
146 |
147 | def makedir(base, name, flag=True):
148 | dir_name = None
149 | if flag:
150 | dir_name = os.path.join(base, name)
151 | if not os.path.isdir(dir_name):
152 | os.makedirs(dir_name)
153 | return dir_name
154 |
155 |
156 | def one_hot_to_rgb(layout_pred, colors, num_objs):
157 | one_hot = layout_pred[:, :num_objs, :, :]
158 | one_hot_3d = torch.einsum('abcd,be->aecd', [one_hot.cpu(), colors])
159 | one_hot_3d *= (255.0 / one_hot_3d.max())
160 | return one_hot_3d
161 |
162 |
163 | def run_model(args, checkpoint, output_dir, loader=None):
164 | if args.save_graphs:
165 | from scene_generation.vis import draw_scene_graph
166 | dirname = os.path.dirname(args.checkpoint)
167 | features = None
168 | if not args.use_gt_textures:
169 | features_path = os.path.join(dirname, 'features_clustered_001.npy')
170 | print(features_path)
171 | if os.path.isfile(features_path):
172 | features = np.load(features_path, allow_pickle=True).item()
173 | else:
174 | raise ValueError('No features file')
175 | with torch.no_grad():
176 | vocab = checkpoint['model_kwargs']['vocab']
177 | model = build_model(args, checkpoint)
178 | if loader is None:
179 | loader = build_loader(args, checkpoint, vocab['is_panoptic'])
180 | accuracy_model = None
181 | if args.accuracy_model_path is not None and os.path.isfile(args.accuracy_model_path):
182 | accuracy_model = load_model(args.accuracy_model_path)
183 |
184 | img_dir = makedir(output_dir, 'images')
185 | graph_dir = makedir(output_dir, 'graphs', args.save_graphs)
186 | gt_img_dir = makedir(output_dir, 'images_gt', args.save_gt_imgs)
187 | layout_dir = makedir(output_dir, 'layouts', args.save_layout)
188 |
189 | img_idx = 0
190 | total_iou = 0
191 | total_boxes = 0
192 | r_05 = 0
193 | r_03 = 0
194 | corrects = 0
195 | real_objects_count = 0
196 | num_objs = model.num_objs
197 | colors = torch.randint(0, 256, [num_objs, 3]).float()
198 | for batch in loader:
199 | imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, attributes = [x.cuda() for x in batch]
200 |
201 | imgs_gt = imagenet_deprocess_batch(imgs)
202 |
203 | if args.use_gt_masks:
204 | masks_gt = masks
205 | else:
206 | masks_gt = None
207 | if args.use_gt_textures:
208 | all_features = None
209 | else:
210 | all_features = []
211 | for obj_name in objs:
212 | obj_feature = features[obj_name.item()]
213 | random_index = randint(0, obj_feature.shape[0] - 1)
214 | feat = torch.from_numpy(obj_feature[random_index, :]).type(torch.float32).cuda()
215 | all_features.append(feat)
216 | if not args.use_gt_attr:
217 | attributes = torch.zeros_like(attributes)
218 |
219 | # Run the model with predicted masks
220 | model_out = model(imgs, objs, triples, obj_to_img, boxes_gt=boxes, masks_gt=masks_gt, attributes=attributes,
221 | test_mode=True, use_gt_box=args.use_gt_boxes, features=all_features)
222 | imgs_pred, boxes_pred, masks_pred, _, layout, _ = model_out
223 |
224 | if accuracy_model is not None:
225 | if args.use_gt_boxes:
226 | crops = crop_bbox_batch(imgs_pred, boxes, obj_to_img, 224)
227 | else:
228 | crops = crop_bbox_batch(imgs_pred, boxes_pred, obj_to_img, 224)
229 |
230 | outputs = accuracy_model(crops)
231 | if type(outputs) == tuple:
232 | outputs, _ = outputs
233 | _, preds = torch.max(outputs, 1)
234 |
235 | # statistics
236 | for pred, label in zip(preds, objs):
237 | if label.item() != 0:
238 | real_objects_count += 1
239 | corrects += 1 if pred.item() == label.item() else 0
240 |
241 | # Remove the __image__ object
242 | boxes_pred_no_image = []
243 | boxes_gt_no_image = []
244 | for o_index in range(len(obj_to_img)):
245 | if o_index < len(obj_to_img) - 1 and obj_to_img[o_index] == obj_to_img[o_index + 1]:
246 | boxes_pred_no_image.append(boxes_pred[o_index])
247 | boxes_gt_no_image.append(boxes[o_index])
248 | boxes_pred_no_image = torch.stack(boxes_pred_no_image)
249 | boxes_gt_no_image = torch.stack(boxes_gt_no_image)
250 |
251 | iou, bigger_05, bigger_03 = jaccard(boxes_pred_no_image, boxes_gt_no_image)
252 | total_iou += iou
253 | r_05 += bigger_05
254 | r_03 += bigger_03
255 | total_boxes += boxes_pred_no_image.size(0)
256 | imgs_pred = imagenet_deprocess_batch(imgs_pred)
257 |
258 | obj_data = [objs, boxes_pred, masks_pred]
259 | _, obj_data = split_graph_batch(triples, obj_data, obj_to_img, triple_to_img)
260 | objs, boxes_pred, masks_pred = obj_data
261 |
262 | obj_data_gt = [boxes.data]
263 | if masks is not None:
264 | obj_data_gt.append(masks.data)
265 | triples, obj_data_gt = split_graph_batch(triples, obj_data_gt, obj_to_img, triple_to_img)
266 | layouts_3d = one_hot_to_rgb(layout, colors, num_objs)
267 | for i in range(imgs_pred.size(0)):
268 | img_filename = '%04d.png' % img_idx
269 | if args.save_gt_imgs:
270 | img_gt = imgs_gt[i].numpy().transpose(1, 2, 0)
271 | img_gt_path = os.path.join(gt_img_dir, img_filename)
272 | imsave(img_gt_path, img_gt)
273 | if args.save_layout:
274 | layout_3d = layouts_3d[i].numpy().transpose(1, 2, 0)
275 | layout_path = os.path.join(layout_dir, img_filename)
276 | imsave(layout_path, layout_3d)
277 |
278 | img_pred_np = imgs_pred[i].numpy().transpose(1, 2, 0)
279 | img_path = os.path.join(img_dir, img_filename)
280 | imsave(img_path, img_pred_np)
281 |
282 | if args.save_graphs:
283 | graph_img = draw_scene_graph(objs[i], triples[i], vocab)
284 | graph_path = os.path.join(graph_dir, img_filename)
285 | imsave(graph_path, graph_img)
286 |
287 | img_idx += 1
288 |
289 | print('Saved %d images' % img_idx)
290 | avg_iou = total_iou / total_boxes
291 | print('avg_iou {}'.format(avg_iou.item()))
292 | print('r0.5 {}'.format(r_05 / total_boxes))
293 | print('r0.3 {}'.format(r_03 / total_boxes))
294 | if accuracy_model is not None:
295 | print('Accuracy {}'.format(corrects / real_objects_count))
296 |
297 |
298 | if __name__ == '__main__':
299 | args = parser.parse_args()
300 | if args.checkpoint is None:
301 | raise ValueError('Must specify --checkpoint')
302 |
303 | checkpoint = torch.load(args.checkpoint)
304 | print('Loading model from ', args.checkpoint)
305 | run_model(args, checkpoint, args.output_dir)
306 |
--------------------------------------------------------------------------------