├── README.md ├── chamfer_iou_clevr.py ├── circles └── placeholder ├── data.py ├── dspn.py ├── fspool.py ├── full_iou_clevr.py ├── imgs ├── clevr_tile_1.jpg ├── set.png └── tiled_samples_1.png ├── models.py ├── preprocess-images.py ├── run_isodistance.py ├── run_reconstruct_circles.py ├── run_reconstruct_clevr.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [NeurIPS 2020] Better Set Representations For Relational Reasoning 2 | 3 | ![main figure](https://github.com/CUAI/BetterSetRepresentations/blob/master/imgs/set.png) 4 | 5 | ## Software Requirements 6 | 7 | This codebase requires Python 3, PyTorch 1.0+, Torchvision 0.2+. In principle, this code can be run on CPU but we assume GPU utilization throughout the codebase. 8 | 9 | ## Usage 10 | 11 | The files `run_reconstruct_circles.py`, `run_reconstruct_clevr.py` correspond with the explanatory experiments in the paper. We implemented the three other experiments by simply plugging our module into existing repos linked in supplementary materials, where we specify more details. 12 | 13 | Full usages: 14 | ``` 15 | usage: run_reconstruct_circles.py [-h] [--model_type MODEL_TYPE] 16 | [--batch_size BATCH_SIZE] [--lr LR] 17 | [--inner_lr INNER_LR] 18 | 19 | optional arguments: 20 | -h, --help show this help message and exit 21 | --model_type MODEL_TYPE 22 | model type: srn | mlp 23 | --batch_size BATCH_SIZE 24 | batch size 25 | --lr LR lr 26 | --inner_lr INNER_LR inner lr 27 | ``` 28 | ``` 29 | usage: run_reconstruct_clevr.py [-h] [--model_type MODEL_TYPE] 30 | [--batch_size BATCH_SIZE] [--lr LR] 31 | [--inner_lr INNER_LR] 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | --model_type MODEL_TYPE 36 | model type: srn | mlp 37 | --batch_size BATCH_SIZE 38 | batch size 39 | --lr LR lr 40 | --inner_lr INNER_LR inner lr 41 | --save SAVE path to save checkpoint 42 | --resume RESUME path to resume a saved checkpoint 43 | ``` 44 | 45 | ## Data Generation 46 | 47 | The data for CLEVR with masks was generated using https://github.com/facebookresearch/clevr-dataset-gen and adding the following line: 48 | ```render_shadeless(blender_objects, path=output_image[:-4]+'_mask.png')``` 49 | on file ```image_generation/render_images.py``` ~line 311 (after the function ```add_random_objects``` is called). 50 | 51 | ## Results 52 | 53 | Circles reconstruction samples (From left to right, column-wise: original images, SRN reconstruction, SRN decomposition, baseline reconstruction, baseline decomposition.): 54 | 55 | 56 | 57 | 58 | CLEVR reconstruction samples: 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /chamfer_iou_clevr.py: -------------------------------------------------------------------------------- 1 | from run_reconstruct_clevr import SSLR 2 | import os 3 | import data 4 | import torch 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pickle 9 | from utils import chamfer_score, cv_bbox 10 | import torch.multiprocessing as mp 11 | import gc 12 | import cv2 13 | import argparse 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn") 18 | parser.add_argument('--batch_size', type=int, help='batch size', default=32) 19 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None) 20 | args = parser.parse_args() 21 | return args 22 | 23 | if __name__ == '__main__': 24 | args = get_args() 25 | print(args) 26 | 27 | use_srn = args.model_type == "srn" 28 | 29 | dataset_test = data.CLEVR( 30 | "clevr_no_mask", "val", box=True, full=True, chamfer=True 31 | ) 32 | batch_size = args.batch_size 33 | test_loader = data.get_loader( 34 | dataset_test, batch_size=batch_size, shuffle=False 35 | ) 36 | 37 | net = SSLR(use_srn=use_srn).float().cuda() 38 | net.eval() 39 | net.load_state_dict(torch.load(args.resume)) 40 | 41 | test_loader = tqdm( 42 | test_loader, 43 | ncols=0, 44 | desc="test" 45 | ) 46 | 47 | full_score = 0 48 | for idx, sample in enumerate(test_loader): 49 | def tfunc(): 50 | gc.collect() 51 | image, masks = [x.cuda() for x in sample] 52 | 53 | p_, inner_losses, gs = net(image) 54 | 55 | thresh_mask = gs < 1e-2 56 | gs[thresh_mask] = 0 57 | gs[~thresh_mask] = 1 58 | gs = gs.sum(2).clamp(0,1) 59 | gs = gs.to(dtype=torch.uint8) 60 | 61 | img = cv_bbox(gs.detach().cpu().numpy().reshape(-1,128,128)) 62 | 63 | score = chamfer_score(img.cuda().to(dtype=torch.uint8), masks.to(dtype=torch.uint8)) 64 | 65 | return score 66 | full_score += tfunc() 67 | 68 | 69 | full_score /= len(test_loader) 70 | print(full_score) 71 | -------------------------------------------------------------------------------- /circles/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/circles/placeholder -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset, TensorDataset 4 | import numpy as np 5 | import os 6 | import random 7 | import cv2 8 | import h5py 9 | import json 10 | 11 | 12 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True): 13 | return torch.utils.data.DataLoader( 14 | dataset, 15 | shuffle=shuffle, 16 | batch_size=batch_size, 17 | pin_memory=True, 18 | num_workers=num_workers, 19 | drop_last=True, 20 | ) 21 | 22 | class IsoColorCircles(torch.utils.data.Dataset): 23 | def __init__(self, train=True, root='circles', size=1000, n = None): 24 | self.train = train 25 | self.root = root 26 | self.size = size 27 | self.n = n 28 | self.data = self.cache() 29 | 30 | def cache(self): 31 | cache_path = os.path.join(self.root, f"iso_color_circles_{self.train}_{self.n}.pth") 32 | if os.path.exists(cache_path): 33 | return torch.load(cache_path) 34 | 35 | print("Processing dataset...") 36 | data = [] 37 | for i in range(self.size): 38 | if i%10000 == 0: 39 | print(i) 40 | img = np.zeros((64, 64,3), dtype = "float") 41 | n = int(random.randint(1, 10)) 42 | if self.n is not None: 43 | n = self.n 44 | color_count = [0,0] 45 | circle_features = torch.zeros([10,4]).float() 46 | # Creating circle 47 | j = 0 48 | while j < n: 49 | tmp = np.zeros((64, 64,3), dtype = "float") 50 | l = range(1,12) 51 | r = l[int(random.random()*11)] 52 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r)) 53 | c_p = random.randint(0, 1) 54 | c = [0,0,0] 55 | c[c_p] = 1 56 | tmp = cv2.circle(tmp, center, r+1, c, -1) 57 | if (img + tmp).max() > 1: 58 | continue 59 | elif img.min() >= 1: 60 | assert(False) 61 | else: 62 | tmp = np.zeros((64, 64,3), dtype = "float") 63 | tmp = cv2.circle(tmp, center, r, c, -1) 64 | color_count[c_p] += 1 65 | img+= tmp 66 | circle_features[j] = torch.tensor([center[0], center[1],r, c_p+1]) 67 | j+=1 68 | 69 | 70 | 71 | l = range(1,12) 72 | 73 | # iso 74 | 75 | fail = True 76 | while fail: 77 | s = torch.zeros([10]).float() 78 | fail = False 79 | iso = np.zeros((64, 64,3), dtype = "float") 80 | for idx, f in enumerate(circle_features): 81 | if f[3].int() == 0 : 82 | break 83 | tmp = np.zeros((64, 64,3), dtype = "float") 84 | r = f[2] 85 | c = [0,0,0] 86 | c[f[3].int() - 1] = 1 87 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r)) 88 | 89 | tmp = cv2.circle(tmp, center, r+1, c, -1) 90 | if (iso + tmp).max() > 1: 91 | fail = True 92 | break 93 | elif iso.min() >= 1: 94 | assert(False) 95 | else: 96 | tmp = np.zeros((64, 64,3), dtype = "float") 97 | tmp = cv2.circle(tmp, center, r, c, -1) 98 | s[idx] = (f[0] - center[0])**2 + (f[1] - center[1])**2 99 | iso+= tmp 100 | 101 | i+=1 102 | data.append((torch.tensor(img).transpose(0,2).float(), torch.tensor(iso).transpose(0,2).float(), s)) 103 | torch.save(data, cache_path) 104 | print("Done!") 105 | return data 106 | 107 | def __getitem__(self, item): 108 | return self.data[item] 109 | 110 | def __len__(self): 111 | return self.size 112 | 113 | class MarkedColorCircles(torch.utils.data.Dataset): 114 | def __init__(self, train=True, root='circles', size=1000, colors = [[1,0,0],[0,1,0]]): 115 | self.train = train 116 | self.root = root 117 | self.size = size 118 | self.data = self.cache() 119 | self.colors = colors 120 | 121 | def cache(self): 122 | cache_path = os.path.join(self.root, f"marked_color_circles_{self.train}.pth") 123 | if os.path.exists(cache_path): 124 | return torch.load(cache_path) 125 | 126 | print("Processing dataset...") 127 | data = [] 128 | for i in range(self.size): 129 | img = np.zeros((64, 64,3), dtype = "float") 130 | n = int(random.randint(0, 10)) 131 | color_count = [0,0] 132 | circle_features = torch.zeros([10,4]).float() 133 | # Creating circle 134 | j = 0 135 | while j < n: 136 | tmp = np.zeros((64, 64,3), dtype = "float") 137 | l = range(1,12) 138 | r = l[int(random.random()*11)] 139 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r)) 140 | c_p = random.randint(0, 1) 141 | c = [0,0,0] 142 | c[c_p] = 1 143 | tmp = cv2.circle(tmp, center, r+1, c, -1) 144 | if (img + tmp).max() > 1: 145 | continue 146 | elif img.min() >= 1: 147 | assert(False) 148 | else: 149 | tmp = np.zeros((64, 64,3), dtype = "float") 150 | tmp = cv2.circle(tmp, center, r, c, -1) 151 | color_count[c_p] += 1 152 | img+= tmp 153 | circle_features[j] = torch.tensor([center[0], center[1],r, c_p+1]) 154 | j+=1 155 | i+=1 156 | data.append((torch.tensor(img).transpose(0,2).float(), circle_features)) 157 | torch.save(data, cache_path) 158 | print("Done!") 159 | return data 160 | 161 | def __getitem__(self, item): 162 | return self.data[item] 163 | 164 | def __len__(self): 165 | return self.size 166 | 167 | class CLEVR(torch.utils.data.Dataset): 168 | def __init__(self, base_path, split, box=False, full=False, chamfer=False): 169 | assert split in { 170 | "train", 171 | "val", 172 | "test", 173 | } # note: test isn't very useful since it doesn't have ground-truth scene information 174 | self.base_path = base_path 175 | self.split = split 176 | self.max_objects = 10 177 | self.box = box # True if clevr-box version, False if clevr-state version 178 | self.full = full # Use full validation set? 179 | self.chamfer = chamfer # Use Chamfer data? 180 | 181 | with self.img_db() as db: 182 | ids = db["image_ids"] 183 | self.image_id_to_index = {id: i for i, id in enumerate(ids)} 184 | self.image_db = None 185 | 186 | with open(self.scenes_path) as fd: 187 | scenes = json.load(fd)["scenes"] 188 | self.img_ids, self.scenes = self.prepare_scenes(scenes) 189 | 190 | def object_to_fv(self, obj): 191 | coords = [p / 3 for p in obj["3d_coords"]] 192 | one_hot = lambda key: [obj[key] == x for x in CLASSES[key]] 193 | material = one_hot("material") 194 | color = one_hot("color") 195 | shape = one_hot("shape") 196 | size = one_hot("size") 197 | assert sum(material) == 1 198 | assert sum(color) == 1 199 | assert sum(shape) == 1 200 | assert sum(size) == 1 201 | # concatenate all the classes 202 | return coords + material + color + shape + size 203 | 204 | def prepare_scenes(self, scenes_json): 205 | img_ids = [] 206 | scenes = [] 207 | for scene in scenes_json: 208 | img_idx = scene["image_index"] 209 | # different objects depending on bbox version or attribute version of CLEVR sets 210 | if self.box: 211 | objects = self.extract_bounding_boxes(scene) 212 | objects = torch.FloatTensor(objects) 213 | else: 214 | objects = [self.object_to_fv(obj) for obj in scene["objects"]] 215 | objects = torch.FloatTensor(objects).transpose(0, 1) 216 | num_objects = objects.size(1) 217 | # pad with 0s 218 | if num_objects < self.max_objects: 219 | objects = torch.cat( 220 | [ 221 | objects, 222 | torch.zeros(objects.size(0), self.max_objects - num_objects), 223 | ], 224 | dim=1, 225 | ) 226 | # fill in masks 227 | mask = torch.zeros(self.max_objects) 228 | mask[:num_objects] = 1 229 | 230 | img_ids.append(img_idx) 231 | scenes.append((objects, mask)) 232 | return img_ids, scenes 233 | 234 | def extract_bounding_boxes(self, scene): 235 | """ 236 | Code used for 'Object-based Reasoning in VQA' to generate bboxes 237 | https://arxiv.org/abs/1801.09718 238 | https://github.com/larchen/clevr-vqa/blob/master/bounding_box.py#L51-L107 239 | """ 240 | objs = scene["objects"] 241 | rotation = scene["directions"]["right"] 242 | 243 | num_boxes = len(objs) 244 | 245 | boxes = np.zeros((1, num_boxes, 4)) 246 | 247 | xmin = [] 248 | ymin = [] 249 | xmax = [] 250 | ymax = [] 251 | classes = [] 252 | classes_text = [] 253 | 254 | for i, obj in enumerate(objs): 255 | [x, y, z] = obj["pixel_coords"] 256 | 257 | [x1, y1, z1] = obj["3d_coords"] 258 | 259 | cos_theta, sin_theta, _ = rotation 260 | 261 | x1 = x1 * cos_theta + y1 * sin_theta 262 | y1 = x1 * -sin_theta + y1 * cos_theta 263 | 264 | height_d = 6.9 * z1 * (15 - y1) / 2.0 265 | height_u = height_d 266 | width_l = height_d 267 | width_r = height_d 268 | 269 | if obj["shape"] == "cylinder": 270 | d = 9.4 + y1 271 | h = 6.4 272 | s = z1 273 | 274 | height_u *= (s * (h / d + 1)) / ((s * (h / d + 1)) - (s * (h - s) / d)) 275 | height_d = height_u * (h - s + d) / (h + s + d) 276 | 277 | width_l *= 11 / (10 + y1) 278 | width_r = width_l 279 | 280 | if obj["shape"] == "cube": 281 | height_u *= 1.3 * 10 / (10 + y1) 282 | height_d = height_u 283 | width_l = height_u 284 | width_r = height_u 285 | 286 | obj_name = ( 287 | obj["size"] 288 | + " " 289 | + obj["color"] 290 | + " " 291 | + obj["material"] 292 | + " " 293 | + obj["shape"] 294 | ) 295 | ymin.append((y - height_d) / 320.0) 296 | ymax.append((y + height_u) / 320.0) 297 | xmin.append((x - width_l) / 480.0) 298 | xmax.append((x + width_r) / 480.0) 299 | 300 | return xmin, ymin, xmax, ymax 301 | 302 | @property 303 | def images_folder(self): 304 | return os.path.join(self.base_path, "images", self.split) 305 | 306 | @property 307 | def scenes_path(self): 308 | if self.split == "test": 309 | raise ValueError("Scenes are not available for test") 310 | return os.path.join( 311 | self.base_path, "scenes", "CLEVR_{}_scenes.json".format(self.split) 312 | ) 313 | 314 | def img_db(self): 315 | path = os.path.join(self.base_path, "{}-images.h5".format(self.split)) 316 | return h5py.File(path, "r") 317 | 318 | def load_image(self, image_id): 319 | if self.image_db is None: 320 | self.image_db = self.img_db() 321 | index = self.image_id_to_index[image_id] 322 | image = self.image_db["images"][index] 323 | return image 324 | 325 | def make_mask(self, objects, size, num_objs): 326 | num_objs = len(size[size == 1]) 327 | masks = torch.zeros([16,128,128]) 328 | for i in range(num_objs): 329 | masks[i, objects[1, i]:objects[3, i], objects[0, i]:objects[2, i]] = 1 330 | return masks 331 | 332 | def __getitem__(self, item): 333 | image_id = self.img_ids[item] 334 | image = self.load_image(image_id) 335 | objects, size = self.scenes[item] 336 | if self.chamfer: 337 | objects = (objects * 128).to(dtype=torch.uint8) 338 | num_objs = len(size[size == 1]) 339 | return image, self.make_mask(objects, size, num_objs) 340 | return image 341 | 342 | def __len__(self): 343 | if self.split == "train" or self.full: 344 | return len(self.scenes) 345 | else: 346 | return len(self.scenes) // 10 347 | 348 | 349 | class CLEVRMasked(torch.utils.data.Dataset): 350 | def __init__(self, base_path, split, full=False, iou=False): 351 | assert split in { 352 | "train", 353 | "test", 354 | } # note: test isn't very useful since it doesn't have ground-truth scene information 355 | self.base_path = base_path 356 | self.split = split 357 | self.full = full # Use full validation set? 358 | self.iou = iou 359 | 360 | with self.img_db() as db: 361 | ids = db["image_ids"] 362 | self.image_id_to_index = {id: i for i, id in enumerate(ids)} 363 | self.image_db = None 364 | self.img_ids = [i for i in range(len(self.image_id_to_index))] 365 | 366 | @property 367 | def images_folder(self): 368 | return os.path.join(self.base_path, "images", self.split) 369 | 370 | def img_db(self): 371 | path = os.path.join(self.base_path, "{}-images-foreground.h5".format(self.split)) 372 | return h5py.File(path, "r") 373 | 374 | def load_image(self, image_id): 375 | if self.image_db is None: 376 | self.image_db = self.img_db() 377 | index = self.image_id_to_index[image_id] 378 | image = self.image_db["images"][index] 379 | image_mask = self.image_db["images_mask"][index] 380 | image_foreground = self.image_db["images_foreground"][index] 381 | return image, image_mask, image_foreground 382 | 383 | def __getitem__(self, item): 384 | image_id = self.img_ids[item] 385 | image, image_mask, image_foreground = self.load_image(image_id) 386 | if self.iou: 387 | return image, image_mask, image_foreground 388 | return image, image_foreground 389 | 390 | def __len__(self): 391 | if self.split == "train" or self.full: 392 | return len(self.img_ids) 393 | else: 394 | return len(self.img_ids) // 10 395 | -------------------------------------------------------------------------------- /dspn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import higher 5 | 6 | 7 | class InnerSet(nn.Module): 8 | def __init__(self, mask): 9 | super().__init__() 10 | self.mask = mask 11 | 12 | def forward(self): 13 | return self.mask 14 | 15 | class DSPN(nn.Module): 16 | """ Deep Set Prediction Networks 17 | Yan Zhang, Jonathon Hare, Adam Prügel-Bennett 18 | https://arxiv.org/abs/1906.06565 19 | """ 20 | 21 | def __init__(self, encoder, set_channels, iters, lr): 22 | """ 23 | encoder: Set encoder module that takes a set as input and returns a representation thereof. 24 | It should have a forward function that takes two arguments: 25 | - a set: FloatTensor of size (batch_size, input_channels, maximum_set_size). Each set 26 | should be padded to the same maximum size with 0s, even across batches. 27 | - a mask: FloatTensor of size (batch_size, maximum_set_size). This should take the value 1 28 | if the corresponding element is present and 0 if not. 29 | 30 | channels: Number of channels of the set to predict. 31 | 32 | max_set_size: Maximum size of the set. 33 | 34 | iter: Number of iterations to run the DSPN algorithm for. 35 | 36 | lr: Learning rate of inner gradient descent in DSPN. 37 | """ 38 | super().__init__() 39 | self.encoder = encoder 40 | self.iters = iters 41 | self.lr = lr 42 | 43 | def forward(self, target_repr, init): 44 | """ 45 | Conceptually, DSPN simply turns the target_repr feature vector into a set. 46 | 47 | target_repr: Representation that the predicted set should match. FloatTensor of size (batch_size, repr_channels). 48 | This can come from a set processed with the same encoder as self.encoder (auto-encoder), or a different 49 | input completely (normal supervised learning), such as an image encoded into a feature vector. 50 | """ 51 | # copy same initial set over batch 52 | current_set = nn.Parameter(init) 53 | inner_set = InnerSet(current_set) 54 | 55 | # info used for loss computation 56 | intermediate_sets = [current_set] 57 | # info used for debugging 58 | repr_losses = [] 59 | grad_norms = [] 60 | 61 | # optimise repr_loss for fixed number of steps 62 | with torch.enable_grad(): 63 | opt = torch.optim.SGD(inner_set.parameters(), lr=self.lr, momentum=0.5) 64 | with higher.innerloop_ctx(inner_set, opt) as (fset, diffopt): 65 | for i in range(self.iters): 66 | predicted_repr = self.encoder(fset()) 67 | # how well does the representation matches the target 68 | repr_loss = ((predicted_repr- target_repr)**2).sum() 69 | diffopt.step(repr_loss) 70 | intermediate_sets.append(fset.mask) 71 | repr_losses.append(repr_loss) 72 | grad_norms.append(()) 73 | 74 | return intermediate_sets, repr_losses, grad_norms 75 | -------------------------------------------------------------------------------- /fspool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FSPool(nn.Module): 7 | """ 8 | Featurewise sort pooling. From: 9 | 10 | FSPool: Learning Set Representations with Featurewise Sort Pooling. 11 | Yan Zhang, Jonathon Hare, Adam Prügel-Bennett 12 | https://arxiv.org/abs/1906.02795 13 | https://github.com/Cyanogenoid/fspool 14 | """ 15 | 16 | def __init__(self, in_channels, n_pieces, relaxed=False): 17 | """ 18 | in_channels: Number of channels in input 19 | n_pieces: Number of pieces in piecewise linear 20 | relaxed: Use sorting networks relaxation instead of traditional sorting 21 | """ 22 | super().__init__() 23 | self.n_pieces = n_pieces 24 | self.weight = nn.Parameter(torch.zeros(in_channels, n_pieces + 1)) 25 | self.relaxed = relaxed 26 | 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | nn.init.normal_(self.weight) 31 | 32 | def forward(self, x, n=None): 33 | """ FSPool 34 | 35 | x: FloatTensor of shape (batch_size, in_channels, set size). 36 | This should contain the features of the elements in the set. 37 | Variable set sizes should be padded to the maximum set size in the batch with 0s. 38 | 39 | n: LongTensor of shape (batch_size). 40 | This tensor contains the sizes of each set in the batch. 41 | If not specified, assumes that every set has the same size of x.size(2). 42 | Note that n.max() should never be greater than x.size(2), i.e. the specified set size in the 43 | n tensor must not be greater than the number of elements stored in the x tensor. 44 | 45 | Returns: pooled input x, used permutation matrix perm 46 | """ 47 | assert x.size(1) == self.weight.size( 48 | 0 49 | ), "incorrect number of input channels in weight" 50 | # can call withtout length tensor, uses same length for all sets in the batch 51 | if n is None: 52 | n = x.new(x.size(0)).fill_(x.size(2)).long() 53 | # create tensor of ratios $r$ 54 | sizes, mask = fill_sizes(n, x) 55 | mask = mask.expand_as(x) 56 | 57 | # turn continuous into concrete weights 58 | weight = self.determine_weight(sizes) 59 | 60 | # make sure that fill value isn't affecting sort result 61 | # sort is descending, so put unreasonably low value in places to be masked away 62 | x = x + (1 - mask).float() * -99999 63 | if self.relaxed: 64 | x, perm = cont_sort(x, temp=self.relaxed) 65 | else: 66 | x, perm = x.sort(dim=2, descending=True) 67 | 68 | x = (x * weight * mask.float()).sum(dim=2) 69 | return x, perm 70 | 71 | def forward_transpose(self, x, perm, n=None): 72 | """ FSUnpool 73 | 74 | x: FloatTensor of shape (batch_size, in_channels) 75 | perm: Permutation matrix returned by forward function. 76 | n: LongTensor fo shape (batch_size) 77 | """ 78 | if n is None: 79 | n = x.new(x.size(0)).fill_(perm.size(2)).long() 80 | sizes, mask = fill_sizes(n) 81 | mask = mask.expand(mask.size(0), x.size(1), mask.size(2)) 82 | 83 | weight = self.determine_weight(sizes) 84 | 85 | x = x.unsqueeze(2) * weight * mask.float() 86 | 87 | if self.relaxed: 88 | x, _ = cont_sort(x, perm) 89 | else: 90 | x = x.scatter(2, perm, x) 91 | return x, mask 92 | 93 | def determine_weight(self, sizes): 94 | """ 95 | Piecewise linear function. Evaluates f at the ratios in sizes. 96 | This should be a faster implementation than doing the sum over max terms, since we know that most terms in it are 0. 97 | """ 98 | # share same sequence length within each sample, so copy weighht across batch dim 99 | weight = self.weight.unsqueeze(0) 100 | weight = weight.expand(sizes.size(0), weight.size(1), weight.size(2)) 101 | 102 | # linspace [0, 1] -> linspace [0, n_pieces] 103 | index = self.n_pieces * sizes 104 | index = index.unsqueeze(1) 105 | index = index.expand(index.size(0), weight.size(1), index.size(2)) 106 | 107 | # points in the weight vector to the left and right 108 | idx = index.long() 109 | frac = index.frac() 110 | left = weight.gather(2, idx) 111 | right = weight.gather(2, (idx + 1).clamp(max=self.n_pieces)) 112 | 113 | # interpolate between left and right point 114 | return (1 - frac) * left + frac * right 115 | 116 | 117 | def fill_sizes(sizes, x=None): 118 | """ 119 | sizes is a LongTensor of size [batch_size], containing the set sizes. 120 | Each set size n is turned into [0/(n-1), 1/(n-1), ..., (n-2)/(n-1), 1, 0, 0, ..., 0, 0]. 121 | These are the ratios r at which f is evaluated at. 122 | The 0s at the end are there for padding to the largest n in the batch. 123 | If the input set x is passed in, it guarantees that the mask is the correct size even when sizes.max() 124 | is less than x.size(), which can be a case if there is at least one padding element in each set in the batch. 125 | """ 126 | if x is not None: 127 | max_size = x.size(2) 128 | else: 129 | max_size = sizes.max() 130 | size_tensor = sizes.new(sizes.size(0), max_size).float().fill_(-1) 131 | 132 | size_tensor = torch.arange(end=max_size, device=sizes.device, dtype=torch.float32) 133 | size_tensor = size_tensor.unsqueeze(0) / (sizes.float() - 1).clamp(min=1).unsqueeze( 134 | 1 135 | ) 136 | 137 | mask = size_tensor <= 1 138 | mask = mask.unsqueeze(1) 139 | 140 | return size_tensor.clamp(max=1), mask.float() 141 | 142 | 143 | def deterministic_sort(s, tau): 144 | """ 145 | "Stochastic Optimization of Sorting Networks via Continuous Relaxations" https://openreview.net/forum?id=H1eSS3CcKX 146 | 147 | Aditya Grover, Eric Wang, Aaron Zweig, Stefano Ermon 148 | 149 | s: input elements to be sorted. Shape: batch_size x n x 1 150 | tau: temperature for relaxation. Scalar. 151 | """ 152 | n = s.size()[1] 153 | one = torch.ones((n, 1), dtype=torch.float32, device=s.device) 154 | A_s = torch.abs(s - s.permute(0, 2, 1)) 155 | B = torch.matmul(A_s, torch.matmul(one, one.transpose(0, 1))) 156 | scaling = (n + 1 - 2 * (torch.arange(n, device=s.device) + 1)).type(torch.float32) 157 | C = torch.matmul(s, scaling.unsqueeze(0)) 158 | P_max = (C - B).permute(0, 2, 1) 159 | sm = torch.nn.Softmax(-1) 160 | P_hat = sm(P_max / tau) 161 | return P_hat 162 | 163 | 164 | def cont_sort(x, perm=None, temp=1): 165 | """ Helper function that calls deterministic_sort with the right shape. 166 | Since it assumes a shape of (batch_size, n, 1) while the input x is of shape (batch_size, channels, n), 167 | we can get this to the right shape by merging the first two dimensions. 168 | If an existing perm is passed in, we compute the "inverse" (transpose of perm) and just use that to unsort x. 169 | """ 170 | original_size = x.size() 171 | x = x.view(-1, x.size(2), 1) 172 | if perm is None: 173 | perm = deterministic_sort(x, temp) 174 | else: 175 | perm = perm.transpose(1, 2) 176 | x = perm.matmul(x) 177 | x = x.view(original_size) 178 | return x, perm 179 | 180 | 181 | if __name__ == "__main__": 182 | pool = FSort(2, 1) 183 | x = torch.arange(0, 2 * 3 * 4).view(3, 2, 4).float() 184 | print("x", x) 185 | y, perm = pool(x, torch.LongTensor([2, 3, 4])) 186 | print("perm") 187 | print(perm) 188 | print("result") 189 | print(y) 190 | -------------------------------------------------------------------------------- /full_iou_clevr.py: -------------------------------------------------------------------------------- 1 | from run_reconstruct_clevr import SSLR 2 | import os 3 | import data 4 | import torch 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pickle 9 | import argparse 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn") 15 | parser.add_argument('--batch_size', type=int, help='batch size', default=32) 16 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None) 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | if __name__ == '__main__': 22 | args = get_args() 23 | print(args) 24 | 25 | use_srn = args.model_type == "srn" 26 | 27 | dataset_test = data.CLEVRMasked( 28 | "clevr", "test", full=True, iou=True 29 | ) 30 | batch_size = args.batch_size 31 | test_loader = data.get_loader( 32 | dataset_test, batch_size=batch_size, shuffle=False 33 | ) 34 | 35 | net = SSLR(use_srn=use_srn).float().cuda() 36 | net.eval() 37 | net.load_state_dict(torch.load(args.resume)) 38 | 39 | test_loader = tqdm( 40 | test_loader, 41 | ncols=0, 42 | desc="test" 43 | ) 44 | 45 | SMOOTH = 1e-6 46 | full_iou = 0 47 | import gc 48 | for idx, data in enumerate(test_loader): 49 | def tfunc(): 50 | gc.collect() 51 | image, image_mask, image_foreground_ = [x.cuda() for x in data] 52 | 53 | p_, inner_losses, gs_ = net(image) 54 | 55 | image, image_mask, image_foreground = [x.detach().cpu().numpy() for x in data] 56 | 57 | p = p_.detach().cpu().numpy() 58 | gs = gs_.detach().cpu().numpy() 59 | 60 | thresh_mask = p < 1e-2 61 | p[thresh_mask] = 0 62 | p[~thresh_mask] = 1 63 | p = p.astype('uint8') 64 | 65 | image_foreground[image_foreground != 0] = 1 66 | image_foreground = image_foreground.astype('uint8') 67 | 68 | intersect = (p & image_foreground).sum((1,2,3)) 69 | union = (p | image_foreground).sum((1,2,3)) 70 | iou = ((intersect + SMOOTH) / (union + SMOOTH)).sum() 71 | 72 | return iou 73 | full_iou += tfunc() 74 | 75 | full_iou /= len(test_loader) * batch_size 76 | print(full_iou) 77 | -------------------------------------------------------------------------------- /imgs/clevr_tile_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/clevr_tile_1.jpg -------------------------------------------------------------------------------- /imgs/set.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/set.png -------------------------------------------------------------------------------- /imgs/tiled_samples_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/tiled_samples_1.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils 5 | from utils import hungarian_loss_each 6 | from dspn import DSPN 7 | from fspool import FSPool 8 | 9 | 10 | class Encoder(nn.Module): 11 | 12 | def __init__(self, element_dims, set_size, out_size): 13 | super(Encoder, self).__init__() 14 | self.nef = 64 15 | self.e_ksize = 4 16 | self.set_size = set_size 17 | self.out_size = out_size 18 | self.element_dims = element_dims 19 | 20 | self.conv1 = nn.Conv2d(3, self.nef, self.e_ksize, stride = 2, padding = 1, bias = False) 21 | 22 | self.conv2 = nn.Conv2d(self.nef, self.nef*2, self.e_ksize, stride = 2, padding = 1, bias = False) 23 | self.bn2 = nn.BatchNorm2d(self.nef*2) 24 | 25 | self.conv3 = nn.Conv2d(self.nef*2, self.nef*4, self.e_ksize, stride = 2, padding = 1, bias = False) 26 | self.bn3 = nn.BatchNorm2d(self.nef*4) 27 | 28 | self.conv4 = nn.Conv2d(self.nef*4, self.nef*8, self.e_ksize, stride = 4, padding = 1, bias = False) 29 | 30 | 31 | self.bn4 = nn.BatchNorm2d(self.nef*8) 32 | 33 | self.proj = nn.Linear(self.nef*32, self.out_size) 34 | 35 | self.proj_s = nn.Conv1d(2048//self.set_size, self.element_dims, 1) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.conv1(x)) 39 | out = F.relu(self.bn2(self.conv2(out))) 40 | out = F.relu(self.bn3(self.conv3(out))) 41 | out = F.relu(self.bn4(self.conv4(out))) 42 | 43 | s = self.proj_s(out.view(out.shape[0], self.set_size, 2048//self.set_size).transpose(1,2)) 44 | 45 | out = out.view(out.shape[0], self.nef*32) 46 | return self.proj(out), s 47 | 48 | 49 | class Decoder(nn.Module): 50 | 51 | def __init__(self, input_dim): 52 | super(Decoder, self).__init__() 53 | 54 | self.ngf = 256 55 | g_ksize = 4 56 | self.proj = nn.Linear(input_dim, self.ngf * 4 * 4 * 4) 57 | self.bn0 = nn.BatchNorm1d(self.ngf * 4 * 4 * 4) 58 | 59 | self.dconv1 = nn.ConvTranspose2d(self.ngf * 4,self.ngf*2, g_ksize, 60 | stride=2, padding=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(self.ngf*2) 62 | 63 | self.dconv2 = nn.ConvTranspose2d(self.ngf*2, self.ngf, g_ksize, 64 | stride=2, padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(self.ngf) 66 | 67 | self.dconv3 = nn.ConvTranspose2d(self.ngf, 3, g_ksize, 68 | stride=4, padding=0, bias=False) 69 | 70 | def forward(self, z, c=None): 71 | out = F.relu(self.bn0(self.proj(z)).view(-1, self.ngf* 4, 4, 4)) 72 | out = F.relu(self.bn1(self.dconv1(out))) 73 | out = F.relu(self.bn2(self.dconv2(out))) 74 | out = self.dconv3(out) 75 | return out 76 | 77 | 78 | class FSEncoder(nn.Module): 79 | def __init__(self, input_channels, output_channels, dim): 80 | super().__init__() 81 | self.conv = nn.Sequential( 82 | nn.Conv1d(input_channels, dim, 1), 83 | nn.ReLU(), 84 | nn.Conv1d(dim, dim, 1), 85 | nn.ReLU(), 86 | nn.Conv1d(dim, output_channels, 1), 87 | ) 88 | self.pool = FSPool(output_channels, 20, relaxed=False) 89 | 90 | def forward(self, x, mask=None): 91 | x = self.conv(x) 92 | x = x / x.size(2) # normalise so that activations aren't too high with big sets 93 | x, _ = self.pool(x) 94 | return x 95 | 96 | 97 | class SetGen(nn.Module): 98 | def __init__(self, element_dims=10, set_size=16, lr=200, use_srn= True, iters=5): 99 | super(SetGen, self).__init__() 100 | self.use_srn = use_srn 101 | CNN_ENCODER_SPACE = 100 102 | # H_{agg} 103 | self.encoder = FSEncoder(element_dims, CNN_ENCODER_SPACE, 512) 104 | self.decoder = DSPN( 105 | self.encoder, element_dims, iters=iters, lr=lr 106 | ) 107 | # H_{set} and H_{embed} 108 | self.img_encoder = Encoder(element_dims, set_size, CNN_ENCODER_SPACE) 109 | 110 | def forward(self, x): 111 | x, s = self.img_encoder(x) 112 | if self.use_srn: 113 | intermediate_sets, losses, grad_norms = self.decoder(x, s) 114 | x = intermediate_sets[-1] 115 | else: 116 | x = s 117 | 118 | if self.use_srn: 119 | return x, losses 120 | else: 121 | return x, None 122 | 123 | 124 | class F_match(nn.Module): 125 | def __init__(self): 126 | super(F_match, self).__init__() 127 | self.proj1 = torch.nn.Conv1d(10, 3, 1) 128 | self.proj2 = torch.nn.Conv1d(10, 3, 1) 129 | 130 | def forward(self, x_set, y_set, pool): 131 | # x_set shape: B, element_dims, set_size 132 | x_att = self.proj1(x_set) 133 | y_att = self.proj1(y_set) 134 | 135 | x_loc = self.proj2(x_set) 136 | y_loc = self.proj2(y_set) 137 | 138 | # matching 139 | indices = hungarian_loss_each(x_att, y_att, pool) 140 | l = [ 141 | (x_loc[idx,:,row_idx] - y_loc[idx,:,col_idx])**2 142 | for idx, (row_idx, col_idx) in enumerate(indices) 143 | ] 144 | l_m = [ 145 | ((x_att[idx,:,row_idx] - y_att[idx,:,col_idx])**2).sum() 146 | for idx, (row_idx, col_idx) in enumerate(indices) 147 | ] 148 | match_dist = torch.stack(list(l)).sum(1).sum(1) 149 | match_score = torch.stack(list(l_m)) 150 | return match_dist, match_score 151 | 152 | 153 | class F_reconstruct(nn.Module): 154 | def __init__(self, element_dims=10): 155 | super(F_reconstruct, self).__init__() 156 | self.vec_decoder = Decoder(element_dims) 157 | 158 | def forward(self, x_set): 159 | batch_size = x_set.size(0) 160 | element_dims = x_set.size(1) 161 | set_size = x_set.size(2) 162 | 163 | x = x_set.transpose(1,2).reshape(-1,element_dims) 164 | generated = self.vec_decoder(x) 165 | generated = generated.reshape(batch_size, set_size, 3, 64, 64) 166 | 167 | attention = torch.softmax(generated, dim=1) 168 | generated_set = torch.sigmoid(generated) 169 | 170 | generated_set = generated_set*attention 171 | generated_f = generated_set.sum(dim=1).clamp(0,1) 172 | 173 | return generated_f, generated_set 174 | 175 | 176 | class EncoderCLEVR(nn.Module): 177 | 178 | def __init__(self, element_dims=10, set_size=16, out_size=512): 179 | super(EncoderCLEVR, self).__init__() 180 | self.nef = 64 181 | self.e_ksize = 4 182 | self.set_size = set_size 183 | self.out_size = out_size 184 | self.element_dims = element_dims 185 | 186 | self.conv1 = nn.Conv2d(3, self.nef, self.e_ksize, stride = 2, padding = 1, bias = False) 187 | 188 | self.conv2 = nn.Conv2d(self.nef, self.nef*2, self.e_ksize, stride = 2, padding = 1, bias = False) 189 | self.bn2 = nn.BatchNorm2d(self.nef*2) 190 | 191 | self.conv3 = nn.Conv2d(self.nef*2, self.nef*4, self.e_ksize, stride = 2, padding = 1, bias = False) 192 | self.bn3 = nn.BatchNorm2d(self.nef*4) 193 | 194 | self.conv4 = nn.Conv2d(self.nef*4, self.nef*8, self.e_ksize, stride = 4, padding = 1, bias = False) 195 | self.bn4 = nn.BatchNorm2d(self.nef*8) 196 | 197 | self.proj = nn.Linear(self.nef*128, self.out_size) 198 | self.proj_s = nn.Conv1d(8192//self.set_size, self.element_dims, 1) 199 | 200 | 201 | def forward(self, x): 202 | out = F.relu(self.conv1(x)) 203 | out = F.relu(self.bn2(self.conv2(out))) 204 | out = F.relu(self.bn3(self.conv3(out))) 205 | out = F.relu(self.bn4(self.conv4(out))) 206 | 207 | s = self.proj_s(out.view(out.shape[0], self.set_size, 8192//self.set_size).transpose(1,2)) 208 | 209 | out = out.view(out.shape[0], self.nef*128) 210 | return self.proj(out), s 211 | 212 | 213 | class DecoderCLEVR(nn.Module): 214 | def __init__(self, input_dim): 215 | super(DecoderCLEVR, self).__init__() 216 | 217 | self.ngf = 256 218 | g_ksize = 4 219 | self.proj = nn.Linear(input_dim, self.ngf * 4 * 4 * 4 * 4) 220 | self.bn0 = nn.BatchNorm1d(self.ngf * 4 * 4 * 4 * 4) 221 | 222 | self.dconv1 = nn.ConvTranspose2d(self.ngf * 4,self.ngf*2, g_ksize, 223 | stride=2, padding=1, bias=False) 224 | self.bn1 = nn.BatchNorm2d(self.ngf*2) 225 | 226 | self.dconv2 = nn.ConvTranspose2d(self.ngf*2, self.ngf, g_ksize, 227 | stride=2, padding=1, bias=False) 228 | self.bn2 = nn.BatchNorm2d(self.ngf) 229 | 230 | self.dconv3 = nn.ConvTranspose2d(self.ngf, 3, g_ksize, 231 | stride=4, padding=0, bias=False) 232 | 233 | def forward(self, z, c=None): 234 | out = F.relu(self.bn0(self.proj(z)).view(-1, self.ngf* 4, 4*2, 4*2)) 235 | out = F.relu(self.bn1(self.dconv1(out))) 236 | out = F.relu(self.bn2(self.dconv2(out))) 237 | out = self.dconv3(out) 238 | return out 239 | 240 | 241 | class F_reconstruct_CLEVR(nn.Module): 242 | def __init__(self, element_dims=10): 243 | super(F_reconstruct_CLEVR, self).__init__() 244 | self.vec_decoder = DecoderCLEVR(element_dims) 245 | 246 | def forward(self, x_set): 247 | batch_size = x_set.size(0) 248 | element_dims = x_set.size(1) 249 | set_size = x_set.size(2) 250 | 251 | x = x_set.transpose(1,2).reshape(-1,element_dims) 252 | generated = self.vec_decoder(x) 253 | generated = generated.reshape(batch_size, set_size, 3, 128, 128) 254 | 255 | attention = torch.softmax(generated, dim=1) 256 | generated_set = torch.sigmoid(generated) 257 | 258 | generated_set = generated_set*attention 259 | generated_f = generated_set.sum(dim=1).clamp(0,1) 260 | 261 | return generated_f, generated_set 262 | 263 | 264 | class RNFSEncoder(nn.Module): 265 | def __init__(self, input_channels, output_channels, dim): 266 | super().__init__() 267 | self.conv = nn.Sequential( 268 | nn.Conv2d(2 * input_channels, dim, 1), 269 | nn.ReLU(), 270 | nn.Conv2d(dim, output_channels, 1), 271 | ) 272 | self.lin = nn.Linear(dim, output_channels) 273 | self.pool = FSPool(output_channels, 20, relaxed=False) 274 | 275 | def forward(self, x, mask=None): 276 | # create all pairs of elements 277 | x = torch.cat(utils.outer(x), dim=1) 278 | x = self.conv(x) 279 | # flatten pairs and scale appropriately 280 | n, c, l, _ = x.size() 281 | x = x.view(x.size(0), x.size(1), -1) / l / l 282 | x, _ = self.pool(x) 283 | return x 284 | 285 | 286 | class SetGenCLEVR(nn.Module): 287 | def __init__(self, element_dims=10, set_size=16, lr=8, use_srn=True): 288 | super(SetGenCLEVR, self).__init__() 289 | self.use_srn = use_srn 290 | CNN_ENCODER_SPACE = 512 291 | # H_{agg} 292 | self.encoder = RNFSEncoder(element_dims, CNN_ENCODER_SPACE, 512) 293 | self.decoder = DSPN( 294 | self.encoder, element_dims, iters=10, lr=lr 295 | ) 296 | # H_{set} and H_{embed} 297 | self.img_encoder = EncoderCLEVR(element_dims, set_size, CNN_ENCODER_SPACE) 298 | 299 | def forward(self, x): 300 | x, s = self.img_encoder(x) 301 | if self.use_srn: 302 | intermediate_sets, losses, grad_norms = self.decoder(x, s) 303 | x = intermediate_sets[-1] 304 | else: 305 | x = s 306 | 307 | if self.use_srn: 308 | return x, losses 309 | else: 310 | return x, None 311 | 312 | -------------------------------------------------------------------------------- /preprocess-images.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import torch.utils.data 5 | import torchvision.models as models 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import re 10 | 11 | 12 | class CLEVR_Images(torch.utils.data.Dataset): 13 | """ Dataset for MSCOCO images located in a folder on the filesystem """ 14 | 15 | def __init__(self, path, transform=None): 16 | super().__init__() 17 | self.p = re.compile('\d+') 18 | self.path = path 19 | self.id_to_filename = self._find_images() 20 | self.sorted_ids = sorted( 21 | self.id_to_filename.keys() 22 | ) # used for deterministic iteration order 23 | print("found {} images in {}".format(len(self), self.path)) 24 | self.transform = transform 25 | 26 | def _find_images(self): 27 | id_to_filename = {} 28 | for filename in os.listdir(self.path): 29 | if not filename.endswith(".png") or 'mask' in filename or 'foreground' in filename: 30 | continue 31 | id = int(self.p.search(filename).group()) 32 | no_ext = filename[:filename.rfind('.')] 33 | filename_mask = no_ext + '_mask.png' 34 | filename_foreground = no_ext + '_foreground.png' 35 | id_to_filename[id] = (filename, filename_mask, filename_foreground) 36 | return id_to_filename 37 | 38 | def __getitem__(self, item): 39 | id = self.sorted_ids[item] 40 | path = os.path.join(self.path, self.id_to_filename[id][0]) 41 | path_mask = os.path.join(self.path, self.id_to_filename[id][1]) 42 | path_foreground = os.path.join(self.path, self.id_to_filename[id][2]) 43 | 44 | img = Image.open(path).convert("RGB") 45 | img_mask = Image.open(path_mask).convert("RGB") 46 | img_foreground = Image.open(path_foreground).convert("RGB") 47 | 48 | if self.transform is not None: 49 | img = self.transform(img) 50 | img_mask = self.transform(img_mask) 51 | img_foreground = self.transform(img_foreground) 52 | return id, img, img_mask, img_foreground 53 | 54 | def __len__(self): 55 | return len(self.sorted_ids) 56 | 57 | 58 | def create_coco_loader(path): 59 | transform = transforms.Compose( 60 | [transforms.Resize((128, 128)), transforms.ToTensor()] 61 | ) 62 | dataset = CLEVR_Images(path, transform=transform) 63 | data_loader = torch.utils.data.DataLoader( 64 | dataset, batch_size=64, num_workers=12, shuffle=False, pin_memory=True 65 | ) 66 | return data_loader 67 | 68 | 69 | def main(): 70 | for split_name in ["train", "test"]: 71 | path = os.path.join("clevr", "images", split_name) 72 | loader = create_coco_loader(path) 73 | images_shape = (len(loader.dataset), 3, 128, 128) 74 | 75 | with h5py.File("{}-images-foreground.h5".format(split_name), libver="latest") as fd: 76 | 77 | images = fd.create_dataset("images", shape=images_shape, dtype="float32") 78 | images_mask = fd.create_dataset("images_mask", shape=images_shape, dtype="float32") 79 | images_foreground = fd.create_dataset("images_foreground", shape=images_shape, dtype="float32") 80 | image_ids = fd.create_dataset("image_ids", shape=(len(loader.dataset),), dtype="int32") 81 | 82 | i = 0 83 | for ids, imgs, imgs_mask, imgs_foreground in tqdm(loader): 84 | assert imgs.size(0) == imgs_mask.size(0) == imgs_foreground.size(0) 85 | j = i + imgs.size(0) 86 | images[i:j, :, :] = imgs.numpy() 87 | images_mask[i:j, :, :] = imgs_mask.numpy() 88 | images_foreground[i:j, :, :] = imgs_foreground.numpy() 89 | image_ids[i:j] = ids.numpy().astype("int32") 90 | i = j 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /run_isodistance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import data 3 | import torch 4 | import torch.nn as nn 5 | from models import SetGen, F_match, F_reconstruct 6 | import torch.multiprocessing as mp 7 | from tensorboardX import SummaryWriter 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model_type', help='model type: srn | mlp | cnn', default="srn") 12 | parser.add_argument('--batch_size', type=int, help='batch size', default=32) 13 | parser.add_argument('--recon', action="store_true" , help='transfer models', default=False) 14 | parser.add_argument('--resume', help='Resume checkpoint', default=None) 15 | parser.add_argument('--lr', type=float, help='lr', default=5e-4) 16 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=0) 17 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=0.1) 18 | parser.add_argument('--save', help='Path of the saved checkpoint', default=None) 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | class Net(nn.Module): 24 | def __init__(self, lr=200): 25 | super(Net, self).__init__() 26 | self.img_encoder = Encoder() 27 | self.proj = nn.Linear(100, 1) 28 | 29 | def forward(self, x, y): 30 | all_images = torch.cat((x, y)) 31 | x, s = self.img_encoder(all_images) 32 | batch_size = x.size(0) // 2 33 | 34 | reference = x[:batch_size,:] 35 | mem = x[batch_size:,:] 36 | 37 | x =(reference- mem)**2 38 | 39 | return self.proj(x) 40 | 41 | 42 | 43 | class SSLR(nn.Module): 44 | def __init__(self, lr=200, use_srn= True): 45 | super(SSLR, self).__init__() 46 | self.use_srn = use_srn 47 | element_dims=10 48 | set_size=16 49 | self.set_generator = SetGen(element_dims, set_size, lr, use_srn) 50 | self.f_match = F_match() 51 | self.f_reconstruct = F_reconstruct(element_dims) 52 | 53 | def forward(self, x, y, pool): 54 | all_images = torch.cat((x, y)) 55 | x, losses = self.set_generator(all_images) 56 | 57 | batch_size = x.size(0) // 2 58 | 59 | match_dist, match_score = self.f_match(x[:batch_size,:,:], x[batch_size:,:,:], pool) 60 | generated_f, _ = self.f_reconstruct(x[:batch_size,:,:]) 61 | 62 | if self.use_srn: 63 | return match_dist, losses, match_score, generated_f 64 | else: 65 | return match_dist, match_score, generated_f 66 | 67 | 68 | def eval(net, batch_size, test_loader, pool, epoch, model_type): 69 | net.eval() 70 | all_loss = 0 71 | acc = 0 72 | import gc; 73 | for idx, data in enumerate(test_loader): 74 | images_x, images_y, s = data 75 | images_x, images_y, s = images_x.cuda(), images_y.cuda(), s.sum(1).cuda()/(64*64) 76 | 77 | if model_type == "srn": 78 | match_dist, inner_losses, match_score, re = net(images_x, images_y, pool) 79 | elif model_type == "mlp": 80 | match_dist, match_score, re = net(images_x, images_y, pool) 81 | else: 82 | match_dist = net(images_x, images_y).view(-1) 83 | 84 | loss = ((match_dist- s)**2).sum() 85 | all_loss += loss.item() 86 | 87 | acc += torch.abs((match_dist- s)/s).mean() 88 | acc = acc.detach().cpu() 89 | 90 | gc.collect() 91 | return all_loss/len(test_loader), acc/len(test_loader) 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | 97 | args = get_args() 98 | print(args) 99 | 100 | train_loader = data.get_loader(data.IsoColorCircles(train=True, size=64000, n = 2), batch_size = args.batch_size) 101 | test_loader = data.get_loader(data.IsoColorCircles(train=False, size=4000, n = 2), batch_size = args.batch_size) 102 | 103 | if args.model_type == "srn": 104 | net = SSLR(float(args.inner_lr)).float().cuda() 105 | 106 | if args.resume is not None: 107 | print("resume from ", args.resume) 108 | # state_dict = torch.load("set_model_recon_0.1_l2.pt") 109 | state_dict = torch.load(args.resume) 110 | own_state = net.state_dict() 111 | for name, param in state_dict.items(): 112 | if isinstance(param, torch.nn.Parameter): 113 | param = param.data 114 | own_state[name].copy_(param) 115 | 116 | elif args.model_type == "mlp": 117 | net = SSLR(use_srn = False).float().cuda() 118 | 119 | else: 120 | assert args.model_type == "cnn" 121 | net = Net().float().cuda() 122 | 123 | optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) 124 | writer = SummaryWriter(f"match_runs/{args.model_type}_lr={args.lr}_wd={args.weight_decay}_ilr={args.inner_lr}", purge_step=0, flush_secs = 10) 125 | 126 | running_loss = 0 127 | for epoch in range(1000+1): 128 | with mp.Pool(10) as pool: 129 | print(f"epoch {epoch}") 130 | 131 | net.train() 132 | running_loss = 0 133 | for idx, data in enumerate(train_loader): 134 | images_x, images_y, s = data 135 | images_x, images_y, s = images_x.cuda(), images_y.cuda(), s.sum(1).cuda()/(64*64) 136 | optimizer.zero_grad() 137 | 138 | if args.model_type == "srn": 139 | match_dist, inner_losses, match_score, re = net(images_x, images_y, pool) 140 | elif args.model_type == "mlp": 141 | match_dist, match_score, re = net(images_x, images_y, pool) 142 | else: 143 | match_dist = net(images_x, images_y).view(-1) 144 | 145 | dist_loss = ((match_dist- s)**2).sum() 146 | 147 | use_set = (args.model_type == "srn") or (args.model_type == "mlp") 148 | if use_set: 149 | match_loss = match_score.mean() 150 | loss = dist_loss + 10*match_loss 151 | else: 152 | loss = dist_loss 153 | 154 | if args.recon : 155 | recon_loss = ((re - images)**2).mean() 156 | loss += recon_loss 157 | 158 | if use_set: 159 | writer.add_scalar("train/dist_loss", dist_loss.item(), global_step=epoch*len(train_loader) + idx) 160 | writer.add_scalar("train/match_loss", match_loss.item(), global_step=epoch*len(train_loader) + idx) 161 | if args.recon : 162 | writer.add_scalar("train/recon_loss", recon_loss.item(), global_step=epoch*len(train_loader) + idx) 163 | writer.add_scalar("train/loss", loss.item(), global_step=epoch*len(train_loader) + idx) 164 | 165 | loss.backward() 166 | alpha = 0.05 167 | optimizer.step() 168 | 169 | 170 | if idx % (len(train_loader)//4) == 0: 171 | if use_set: 172 | if args.model_type == "srn": 173 | print(f"inner loss {[l.item()/args.batch_size for l in inner_losses]}") 174 | print("dist_loss", dist_loss.item()) 175 | print("match_loss",match_loss.item()) 176 | if args.recon : 177 | print("recon_loss",recon_loss.item()) 178 | print("loss",loss.item()) 179 | 180 | running_loss += loss.item() 181 | print(running_loss/ len(train_loader)) 182 | if epoch % 1 ==0: 183 | with mp.Pool(10) as pool: 184 | eval_loss, acc = eval(net, args.batch_size, test_loader, pool, epoch, args.model_type) 185 | print(f"eval: {eval_loss} {acc}") 186 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch) 187 | writer.add_scalar("eval/acc", acc, global_step=epoch) 188 | writer.flush() 189 | 190 | print() 191 | #save model 192 | if args.save is not None: 193 | torch.save(net.state_dict(), args.save) 194 | -------------------------------------------------------------------------------- /run_reconstruct_circles.py: -------------------------------------------------------------------------------- 1 | import data 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | from dspn import DSPN 8 | from fspool import FSPool 9 | from tensorboardX import SummaryWriter 10 | import matplotlib 11 | from models import * 12 | import argparse 13 | 14 | matplotlib.use("Agg") 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn") 21 | parser.add_argument('--batch_size', type=int, help='batch size', default=64) 22 | parser.add_argument('--lr', type=float, help='lr', default=3e-4) 23 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=0.1) 24 | parser.add_argument('--inner_iters',type=int, help='# of inner iterations steps to perform', default=10) 25 | parser.add_argument('--start_epoch',type=int, help='epoch to start at', default=0) 26 | parser.add_argument('--load_ckpt', default=False, action='store_true') 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | class SSLR(nn.Module): 32 | def __init__(self, lr=200, num_iters=10, use_srn=True): 33 | super(SSLR, self).__init__() 34 | self.element_dims = 10 35 | self.set_generator = SetGen(element_dims = self.element_dims, set_size=16, lr=lr, use_srn=use_srn, iters=num_iters) 36 | self.f_reconstruct = F_reconstruct(element_dims = self.element_dims) 37 | self.use_srn = use_srn 38 | 39 | def forward(self, x, print_interm=False): 40 | x, losses = self.set_generator(x) 41 | generated_f, generated_set = self.f_reconstruct(x) 42 | 43 | if self.use_srn: 44 | return generated_f, losses, generated_set 45 | else: 46 | return generated_f, [], generated_set 47 | 48 | 49 | 50 | def eval(net, batch_size, test_loader, epoch, writer, use_srn = True): 51 | net.eval() 52 | all_loss = 0 53 | rel_error = 0 54 | for idx, data in enumerate(test_loader): 55 | images, labels = data 56 | images, labels = images.cuda(), labels.cuda() 57 | 58 | if use_srn: 59 | p, inner_losses, gs = net(images) 60 | else: 61 | p = net(images) 62 | 63 | loss = F.binary_cross_entropy(p, images) 64 | 65 | for j, s_ in enumerate(gs[0]): 66 | fig = plt.figure() 67 | plt.imshow(s_.transpose(0,2).detach().cpu()) 68 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=j) 69 | 70 | fig = plt.figure() 71 | plt.imshow(p[0].transpose(0,2).detach().cpu()) 72 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=len(gs[0])) 73 | 74 | fig = plt.figure() 75 | plt.imshow(images[0].transpose(0,2).detach().cpu()) 76 | writer.add_figure(f"epoch-{epoch}/img-{idx}-target", fig, global_step=epoch) 77 | all_loss += loss.item() 78 | return all_loss/len(test_loader) 79 | 80 | if __name__ == "__main__": 81 | args = get_args() 82 | print(args) 83 | use_srn = args.model_type == "srn" 84 | 85 | batch_size = args.batch_size 86 | train_loader = data.get_loader(data.MarkedColorCircles(train=True, size=64000), batch_size = batch_size) 87 | test_loader = data.get_loader(data.MarkedColorCircles(train=False, size=4000), batch_size = batch_size) 88 | 89 | use_srn = True 90 | net = SSLR(lr = args.inner_lr, num_iters=args.inner_iters, use_srn=use_srn).float().cuda() 91 | if args.load_ckpt: 92 | net.load_state_dict(torch.load("set_model_recon.pt")) 93 | 94 | 95 | net.train() 96 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) 97 | 98 | writer = SummaryWriter(f"recon_run/test_run", purge_step=0, flush_secs = 10) 99 | 100 | print(type(net)) 101 | print(net.set_generator.decoder.iters) 102 | running_loss = 0 103 | best_loss = 1e50 104 | for epoch in range(args.start_epoch, 1000+1): 105 | if epoch == 20: 106 | net.set_generator.decoder.iters = 20 107 | net.train() 108 | print(f"epoch {epoch}") 109 | running_loss = 0 110 | for idx, data in enumerate(train_loader): 111 | images, labels = data 112 | images, labels = images.cuda(), labels.cuda() 113 | optimizer.zero_grad() 114 | 115 | if use_srn: 116 | p, inner_losses, _ = net(images) 117 | else: 118 | p = net(images) 119 | loss = ((images - p)**2).sum() 120 | writer.add_scalar("train/loss", loss.item(), global_step=epoch*len(train_loader) + idx) 121 | 122 | loss.backward() 123 | optimizer.step() 124 | if idx % (len(train_loader)//4) == 0: 125 | if use_srn: 126 | print(f"inner loss {[l.item()/batch_size for l in inner_losses]}") 127 | print(loss.item()) 128 | running_loss += loss.item() 129 | 130 | print(running_loss/len(train_loader)) 131 | if epoch % 1 ==0: 132 | eval_loss = eval(net, batch_size, test_loader, epoch, writer, use_srn) 133 | if eval_loss < best_loss: 134 | best_loss = eval_loss 135 | torch.save(net.state_dict(), "set_model_recon.pt") 136 | print(f"eval: {eval_loss}") 137 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch) 138 | writer.flush() 139 | 140 | print() 141 | -------------------------------------------------------------------------------- /run_reconstruct_clevr.py: -------------------------------------------------------------------------------- 1 | import data 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | from dspn import DSPN 8 | from fspool import FSPool 9 | from tensorboardX import SummaryWriter 10 | import matplotlib 11 | import utils 12 | from tqdm import tqdm 13 | from models import * 14 | import argparse 15 | 16 | matplotlib.use("Agg") 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn") 23 | parser.add_argument('--batch_size', type=int, help='batch size', default=32) 24 | parser.add_argument('--lr', type=float, help='lr', default=3e-4) 25 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=8) 26 | parser.add_argument('--save', help='path to save checkpoint', default=None) 27 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | class SSLR(nn.Module): 33 | def __init__(self, lr=8, use_srn=True): 34 | super(SSLR, self).__init__() 35 | self.use_srn = use_srn 36 | element_dims=10 37 | set_size=16 38 | self.g = SetGenCLEVR(element_dims, set_size, lr, use_srn) 39 | self.F_reconstruct = F_reconstruct_CLEVR() 40 | 41 | def forward(self, images): 42 | x, inner_losses = self.g(images) 43 | generated_f, generated_set = self.F_reconstruct(x) 44 | return generated_f, inner_losses, generated_set 45 | 46 | 47 | def eval(net, batch_size, test_loader, epoch, writer, use_srn=True): 48 | with torch.no_grad(): 49 | net.eval() 50 | all_loss = 0 51 | rel_error = 0 52 | test_loader = tqdm( 53 | test_loader, 54 | ncols=0, 55 | desc="test E{0:02d}".format(epoch), 56 | ) 57 | iters_per_epoch = len(test_loader) 58 | for idx, (images, images_foreground) in enumerate(test_loader, start=epoch * iters_per_epoch): 59 | images, images_foreground = images.cuda(), images_foreground.cuda() 60 | 61 | p, inner_losses, gs = net(images) 62 | 63 | loss = F.binary_cross_entropy(p, images_foreground) 64 | 65 | for j, s_ in enumerate(gs[0]): 66 | fig = plt.figure() 67 | plt.imshow(s_.permute(1,2,0).detach().cpu()) 68 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=j) 69 | 70 | fig = plt.figure() 71 | plt.imshow(p[0].permute(1,2,0).detach().cpu()) 72 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=len(gs[0])) 73 | 74 | fig = plt.figure() 75 | plt.imshow(images[0].permute(1,2,0).detach().cpu()) 76 | writer.add_figure(f"epoch-{epoch}/img-{idx}-target", fig, global_step=epoch) 77 | 78 | all_loss += loss.item() 79 | return all_loss/len(test_loader) 80 | 81 | 82 | 83 | 84 | if __name__ == "__main__": 85 | args = get_args() 86 | print(args) 87 | 88 | use_srn = args.model_type == "srn" 89 | 90 | dataset_train = data.CLEVRMasked( 91 | "clevr", "train", full=True 92 | ) 93 | dataset_test = data.CLEVRMasked( 94 | "clevr", "test", full=False 95 | ) 96 | 97 | batch_size = args.batch_size 98 | train_loader = data.get_loader( 99 | dataset_train, batch_size=batch_size 100 | ) 101 | test_loader = data.get_loader( 102 | dataset_test, batch_size=batch_size 103 | ) 104 | 105 | net = SSLR(args.inner_lr, use_srn).float().cuda() 106 | 107 | if args.resume: 108 | net.load_state_dict(torch.load(args.resume)) 109 | 110 | optimizer = torch.optim.Adam( 111 | [p for p in net.parameters() if p.requires_grad], lr=args.lr 112 | ) 113 | writer = SummaryWriter(f"runs/recon_clevr", purge_step=0, flush_secs = 10) 114 | 115 | 116 | print(type(net)) 117 | iters_per_epoch = len(train_loader) 118 | 119 | running_loss = 0 120 | 121 | for epoch in range(1000+1): 122 | train_loader = tqdm( 123 | train_loader, 124 | ncols=0, 125 | desc="train E{0:02d}".format(epoch), 126 | ) 127 | 128 | net.train() 129 | running_loss = 0 130 | 131 | for idx, (images, images_foreground) in enumerate(train_loader, start=epoch * iters_per_epoch): 132 | images, images_foreground = images.cuda(), images_foreground.cuda() 133 | optimizer.zero_grad() 134 | 135 | p, inner_losses, _ = net(images) 136 | 137 | loss = F.binary_cross_entropy(p, images_foreground) 138 | 139 | writer.add_scalar("train/loss", loss.item(), global_step=idx) 140 | 141 | loss.backward() 142 | optimizer.step() 143 | 144 | if use_srn: 145 | print(f"inner loss {[l.item()/batch_size for l in inner_losses]}") 146 | print(f"{loss.item()}\n") 147 | running_loss += loss.item() 148 | print(running_loss/len(train_loader)) 149 | 150 | if args.save: 151 | torch.save(net.state_dict(), args.save) 152 | 153 | eval_loss = eval(net, batch_size, test_loader, epoch, writer, use_srn) 154 | print(f"eval: {eval_loss}\n") 155 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch) 156 | writer.flush() 157 | 158 | print() 159 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import scipy.optimize 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import cv2 7 | 8 | def hungarian_loss_each(predictions, targets, thread_pool): 9 | # predictions and targets shape :: (n, c, s) 10 | predictions, targets = outer(predictions, targets) 11 | # squared_error shape :: (n, s, s) 12 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(1) 13 | 14 | squared_error_np = squared_error.detach().cpu().numpy() 15 | 16 | indices = thread_pool.map(hungarian_loss_per_sample, squared_error_np) 17 | return indices 18 | 19 | def hungarian_loss_per_sample(sample_np): 20 | return scipy.optimize.linear_sum_assignment(sample_np) 21 | 22 | 23 | def chamfer_loss(predictions, targets): 24 | # predictions and targets shape :: (k, n, c, s) 25 | predictions, targets = outer(predictions, targets) 26 | # squared_error shape :: (k, n, s, s) 27 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(2) 28 | loss = squared_error.min(2)[0] + squared_error.min(3)[0] 29 | return loss.view(loss.size(0), -1).mean(1) 30 | 31 | 32 | def chamfer_loss_each(predictions, targets): 33 | # predictions and targets shape :: (k, n, c, s) 34 | predictions, targets = outer(predictions, targets) 35 | # squared_error shape :: (k, n, s, s) 36 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(2) 37 | return torch.cat((squared_error.min(2)[0], squared_error.min(3)[0]),2)[0] 38 | 39 | 40 | 41 | def scatter_masked(tensor, mask, p, binned=False, threshold=None): 42 | s = tensor[0].detach().cpu() 43 | mask = mask[0].detach().clamp(min=0, max=1).cpu() 44 | p = p[0].detach().clamp(min=0, max=1).cpu() 45 | if binned: 46 | s = s * 128 47 | s = s.view(-1, s.size(-1)) 48 | mask = mask.view(-1) 49 | if threshold is not None: 50 | keep = mask.view(-1) > threshold 51 | s = s[:, keep] 52 | mask = mask[keep] 53 | return s, mask, p 54 | 55 | 56 | def cv_bbox(np_imgs): 57 | imgs = [] 58 | for np_img in np_imgs: 59 | new_img = np_img.copy() 60 | cnts = cv2.findContours(new_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 61 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 62 | for c in cnts: 63 | x,y,w,h = cv2.boundingRect(c) 64 | new_img[y:y+h, x:x+w] = 1 65 | imgs.append(new_img) 66 | return torch.tensor(imgs).reshape(-1,16,128,128) 67 | 68 | 69 | def chamfer_score(s1, s2, SMOOTH=1e-6): 70 | batch = s1.size(0) 71 | size = s1.size(1) 72 | a = torch.cat(size*[s1.unsqueeze(1)],1).reshape(-1, 128,128) 73 | b = torch.cat(size*[s2.unsqueeze(2)],2).reshape(-1, 128,128) 74 | 75 | intersect = (a & b).sum((1,2)).float() 76 | union = (a | b).sum((1,2)).float() 77 | iou = ((intersect + SMOOTH) / (union + SMOOTH)) 78 | 79 | r = iou.reshape(batch, size,size, -1).squeeze(3) 80 | 81 | return r.max(2)[0].mean() 82 | 83 | 84 | def outer(a, b=None): 85 | """ Compute outer product between a and b (or a and a if b is not specified). """ 86 | if b is None: 87 | b = a 88 | size_a = tuple(a.size()) + (b.size()[-1],) 89 | size_b = tuple(b.size()) + (a.size()[-1],) 90 | a = a.unsqueeze(dim=-1).expand(*size_a) 91 | b = b.unsqueeze(dim=-2).expand(*size_b) 92 | return a, b 93 | --------------------------------------------------------------------------------