├── README.md
├── chamfer_iou_clevr.py
├── circles
└── placeholder
├── data.py
├── dspn.py
├── fspool.py
├── full_iou_clevr.py
├── imgs
├── clevr_tile_1.jpg
├── set.png
└── tiled_samples_1.png
├── models.py
├── preprocess-images.py
├── run_isodistance.py
├── run_reconstruct_circles.py
├── run_reconstruct_clevr.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # [NeurIPS 2020] Better Set Representations For Relational Reasoning
2 |
3 | 
4 |
5 | ## Software Requirements
6 |
7 | This codebase requires Python 3, PyTorch 1.0+, Torchvision 0.2+. In principle, this code can be run on CPU but we assume GPU utilization throughout the codebase.
8 |
9 | ## Usage
10 |
11 | The files `run_reconstruct_circles.py`, `run_reconstruct_clevr.py` correspond with the explanatory experiments in the paper. We implemented the three other experiments by simply plugging our module into existing repos linked in supplementary materials, where we specify more details.
12 |
13 | Full usages:
14 | ```
15 | usage: run_reconstruct_circles.py [-h] [--model_type MODEL_TYPE]
16 | [--batch_size BATCH_SIZE] [--lr LR]
17 | [--inner_lr INNER_LR]
18 |
19 | optional arguments:
20 | -h, --help show this help message and exit
21 | --model_type MODEL_TYPE
22 | model type: srn | mlp
23 | --batch_size BATCH_SIZE
24 | batch size
25 | --lr LR lr
26 | --inner_lr INNER_LR inner lr
27 | ```
28 | ```
29 | usage: run_reconstruct_clevr.py [-h] [--model_type MODEL_TYPE]
30 | [--batch_size BATCH_SIZE] [--lr LR]
31 | [--inner_lr INNER_LR]
32 |
33 | optional arguments:
34 | -h, --help show this help message and exit
35 | --model_type MODEL_TYPE
36 | model type: srn | mlp
37 | --batch_size BATCH_SIZE
38 | batch size
39 | --lr LR lr
40 | --inner_lr INNER_LR inner lr
41 | --save SAVE path to save checkpoint
42 | --resume RESUME path to resume a saved checkpoint
43 | ```
44 |
45 | ## Data Generation
46 |
47 | The data for CLEVR with masks was generated using https://github.com/facebookresearch/clevr-dataset-gen and adding the following line:
48 | ```render_shadeless(blender_objects, path=output_image[:-4]+'_mask.png')```
49 | on file ```image_generation/render_images.py``` ~line 311 (after the function ```add_random_objects``` is called).
50 |
51 | ## Results
52 |
53 | Circles reconstruction samples (From left to right, column-wise: original images, SRN reconstruction, SRN decomposition, baseline reconstruction, baseline decomposition.):
54 |
55 |
56 |
57 |
58 | CLEVR reconstruction samples:
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/chamfer_iou_clevr.py:
--------------------------------------------------------------------------------
1 | from run_reconstruct_clevr import SSLR
2 | import os
3 | import data
4 | import torch
5 | from tqdm import tqdm
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pickle
9 | from utils import chamfer_score, cv_bbox
10 | import torch.multiprocessing as mp
11 | import gc
12 | import cv2
13 | import argparse
14 |
15 | def get_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn")
18 | parser.add_argument('--batch_size', type=int, help='batch size', default=32)
19 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None)
20 | args = parser.parse_args()
21 | return args
22 |
23 | if __name__ == '__main__':
24 | args = get_args()
25 | print(args)
26 |
27 | use_srn = args.model_type == "srn"
28 |
29 | dataset_test = data.CLEVR(
30 | "clevr_no_mask", "val", box=True, full=True, chamfer=True
31 | )
32 | batch_size = args.batch_size
33 | test_loader = data.get_loader(
34 | dataset_test, batch_size=batch_size, shuffle=False
35 | )
36 |
37 | net = SSLR(use_srn=use_srn).float().cuda()
38 | net.eval()
39 | net.load_state_dict(torch.load(args.resume))
40 |
41 | test_loader = tqdm(
42 | test_loader,
43 | ncols=0,
44 | desc="test"
45 | )
46 |
47 | full_score = 0
48 | for idx, sample in enumerate(test_loader):
49 | def tfunc():
50 | gc.collect()
51 | image, masks = [x.cuda() for x in sample]
52 |
53 | p_, inner_losses, gs = net(image)
54 |
55 | thresh_mask = gs < 1e-2
56 | gs[thresh_mask] = 0
57 | gs[~thresh_mask] = 1
58 | gs = gs.sum(2).clamp(0,1)
59 | gs = gs.to(dtype=torch.uint8)
60 |
61 | img = cv_bbox(gs.detach().cpu().numpy().reshape(-1,128,128))
62 |
63 | score = chamfer_score(img.cuda().to(dtype=torch.uint8), masks.to(dtype=torch.uint8))
64 |
65 | return score
66 | full_score += tfunc()
67 |
68 |
69 | full_score /= len(test_loader)
70 | print(full_score)
71 |
--------------------------------------------------------------------------------
/circles/placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/circles/placeholder
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset, TensorDataset
4 | import numpy as np
5 | import os
6 | import random
7 | import cv2
8 | import h5py
9 | import json
10 |
11 |
12 | def get_loader(dataset, batch_size, num_workers=8, shuffle=True):
13 | return torch.utils.data.DataLoader(
14 | dataset,
15 | shuffle=shuffle,
16 | batch_size=batch_size,
17 | pin_memory=True,
18 | num_workers=num_workers,
19 | drop_last=True,
20 | )
21 |
22 | class IsoColorCircles(torch.utils.data.Dataset):
23 | def __init__(self, train=True, root='circles', size=1000, n = None):
24 | self.train = train
25 | self.root = root
26 | self.size = size
27 | self.n = n
28 | self.data = self.cache()
29 |
30 | def cache(self):
31 | cache_path = os.path.join(self.root, f"iso_color_circles_{self.train}_{self.n}.pth")
32 | if os.path.exists(cache_path):
33 | return torch.load(cache_path)
34 |
35 | print("Processing dataset...")
36 | data = []
37 | for i in range(self.size):
38 | if i%10000 == 0:
39 | print(i)
40 | img = np.zeros((64, 64,3), dtype = "float")
41 | n = int(random.randint(1, 10))
42 | if self.n is not None:
43 | n = self.n
44 | color_count = [0,0]
45 | circle_features = torch.zeros([10,4]).float()
46 | # Creating circle
47 | j = 0
48 | while j < n:
49 | tmp = np.zeros((64, 64,3), dtype = "float")
50 | l = range(1,12)
51 | r = l[int(random.random()*11)]
52 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r))
53 | c_p = random.randint(0, 1)
54 | c = [0,0,0]
55 | c[c_p] = 1
56 | tmp = cv2.circle(tmp, center, r+1, c, -1)
57 | if (img + tmp).max() > 1:
58 | continue
59 | elif img.min() >= 1:
60 | assert(False)
61 | else:
62 | tmp = np.zeros((64, 64,3), dtype = "float")
63 | tmp = cv2.circle(tmp, center, r, c, -1)
64 | color_count[c_p] += 1
65 | img+= tmp
66 | circle_features[j] = torch.tensor([center[0], center[1],r, c_p+1])
67 | j+=1
68 |
69 |
70 |
71 | l = range(1,12)
72 |
73 | # iso
74 |
75 | fail = True
76 | while fail:
77 | s = torch.zeros([10]).float()
78 | fail = False
79 | iso = np.zeros((64, 64,3), dtype = "float")
80 | for idx, f in enumerate(circle_features):
81 | if f[3].int() == 0 :
82 | break
83 | tmp = np.zeros((64, 64,3), dtype = "float")
84 | r = f[2]
85 | c = [0,0,0]
86 | c[f[3].int() - 1] = 1
87 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r))
88 |
89 | tmp = cv2.circle(tmp, center, r+1, c, -1)
90 | if (iso + tmp).max() > 1:
91 | fail = True
92 | break
93 | elif iso.min() >= 1:
94 | assert(False)
95 | else:
96 | tmp = np.zeros((64, 64,3), dtype = "float")
97 | tmp = cv2.circle(tmp, center, r, c, -1)
98 | s[idx] = (f[0] - center[0])**2 + (f[1] - center[1])**2
99 | iso+= tmp
100 |
101 | i+=1
102 | data.append((torch.tensor(img).transpose(0,2).float(), torch.tensor(iso).transpose(0,2).float(), s))
103 | torch.save(data, cache_path)
104 | print("Done!")
105 | return data
106 |
107 | def __getitem__(self, item):
108 | return self.data[item]
109 |
110 | def __len__(self):
111 | return self.size
112 |
113 | class MarkedColorCircles(torch.utils.data.Dataset):
114 | def __init__(self, train=True, root='circles', size=1000, colors = [[1,0,0],[0,1,0]]):
115 | self.train = train
116 | self.root = root
117 | self.size = size
118 | self.data = self.cache()
119 | self.colors = colors
120 |
121 | def cache(self):
122 | cache_path = os.path.join(self.root, f"marked_color_circles_{self.train}.pth")
123 | if os.path.exists(cache_path):
124 | return torch.load(cache_path)
125 |
126 | print("Processing dataset...")
127 | data = []
128 | for i in range(self.size):
129 | img = np.zeros((64, 64,3), dtype = "float")
130 | n = int(random.randint(0, 10))
131 | color_count = [0,0]
132 | circle_features = torch.zeros([10,4]).float()
133 | # Creating circle
134 | j = 0
135 | while j < n:
136 | tmp = np.zeros((64, 64,3), dtype = "float")
137 | l = range(1,12)
138 | r = l[int(random.random()*11)]
139 | center = (int(random.random()*(64-2*r)+r), int(random.random()*(64-2*r)+r))
140 | c_p = random.randint(0, 1)
141 | c = [0,0,0]
142 | c[c_p] = 1
143 | tmp = cv2.circle(tmp, center, r+1, c, -1)
144 | if (img + tmp).max() > 1:
145 | continue
146 | elif img.min() >= 1:
147 | assert(False)
148 | else:
149 | tmp = np.zeros((64, 64,3), dtype = "float")
150 | tmp = cv2.circle(tmp, center, r, c, -1)
151 | color_count[c_p] += 1
152 | img+= tmp
153 | circle_features[j] = torch.tensor([center[0], center[1],r, c_p+1])
154 | j+=1
155 | i+=1
156 | data.append((torch.tensor(img).transpose(0,2).float(), circle_features))
157 | torch.save(data, cache_path)
158 | print("Done!")
159 | return data
160 |
161 | def __getitem__(self, item):
162 | return self.data[item]
163 |
164 | def __len__(self):
165 | return self.size
166 |
167 | class CLEVR(torch.utils.data.Dataset):
168 | def __init__(self, base_path, split, box=False, full=False, chamfer=False):
169 | assert split in {
170 | "train",
171 | "val",
172 | "test",
173 | } # note: test isn't very useful since it doesn't have ground-truth scene information
174 | self.base_path = base_path
175 | self.split = split
176 | self.max_objects = 10
177 | self.box = box # True if clevr-box version, False if clevr-state version
178 | self.full = full # Use full validation set?
179 | self.chamfer = chamfer # Use Chamfer data?
180 |
181 | with self.img_db() as db:
182 | ids = db["image_ids"]
183 | self.image_id_to_index = {id: i for i, id in enumerate(ids)}
184 | self.image_db = None
185 |
186 | with open(self.scenes_path) as fd:
187 | scenes = json.load(fd)["scenes"]
188 | self.img_ids, self.scenes = self.prepare_scenes(scenes)
189 |
190 | def object_to_fv(self, obj):
191 | coords = [p / 3 for p in obj["3d_coords"]]
192 | one_hot = lambda key: [obj[key] == x for x in CLASSES[key]]
193 | material = one_hot("material")
194 | color = one_hot("color")
195 | shape = one_hot("shape")
196 | size = one_hot("size")
197 | assert sum(material) == 1
198 | assert sum(color) == 1
199 | assert sum(shape) == 1
200 | assert sum(size) == 1
201 | # concatenate all the classes
202 | return coords + material + color + shape + size
203 |
204 | def prepare_scenes(self, scenes_json):
205 | img_ids = []
206 | scenes = []
207 | for scene in scenes_json:
208 | img_idx = scene["image_index"]
209 | # different objects depending on bbox version or attribute version of CLEVR sets
210 | if self.box:
211 | objects = self.extract_bounding_boxes(scene)
212 | objects = torch.FloatTensor(objects)
213 | else:
214 | objects = [self.object_to_fv(obj) for obj in scene["objects"]]
215 | objects = torch.FloatTensor(objects).transpose(0, 1)
216 | num_objects = objects.size(1)
217 | # pad with 0s
218 | if num_objects < self.max_objects:
219 | objects = torch.cat(
220 | [
221 | objects,
222 | torch.zeros(objects.size(0), self.max_objects - num_objects),
223 | ],
224 | dim=1,
225 | )
226 | # fill in masks
227 | mask = torch.zeros(self.max_objects)
228 | mask[:num_objects] = 1
229 |
230 | img_ids.append(img_idx)
231 | scenes.append((objects, mask))
232 | return img_ids, scenes
233 |
234 | def extract_bounding_boxes(self, scene):
235 | """
236 | Code used for 'Object-based Reasoning in VQA' to generate bboxes
237 | https://arxiv.org/abs/1801.09718
238 | https://github.com/larchen/clevr-vqa/blob/master/bounding_box.py#L51-L107
239 | """
240 | objs = scene["objects"]
241 | rotation = scene["directions"]["right"]
242 |
243 | num_boxes = len(objs)
244 |
245 | boxes = np.zeros((1, num_boxes, 4))
246 |
247 | xmin = []
248 | ymin = []
249 | xmax = []
250 | ymax = []
251 | classes = []
252 | classes_text = []
253 |
254 | for i, obj in enumerate(objs):
255 | [x, y, z] = obj["pixel_coords"]
256 |
257 | [x1, y1, z1] = obj["3d_coords"]
258 |
259 | cos_theta, sin_theta, _ = rotation
260 |
261 | x1 = x1 * cos_theta + y1 * sin_theta
262 | y1 = x1 * -sin_theta + y1 * cos_theta
263 |
264 | height_d = 6.9 * z1 * (15 - y1) / 2.0
265 | height_u = height_d
266 | width_l = height_d
267 | width_r = height_d
268 |
269 | if obj["shape"] == "cylinder":
270 | d = 9.4 + y1
271 | h = 6.4
272 | s = z1
273 |
274 | height_u *= (s * (h / d + 1)) / ((s * (h / d + 1)) - (s * (h - s) / d))
275 | height_d = height_u * (h - s + d) / (h + s + d)
276 |
277 | width_l *= 11 / (10 + y1)
278 | width_r = width_l
279 |
280 | if obj["shape"] == "cube":
281 | height_u *= 1.3 * 10 / (10 + y1)
282 | height_d = height_u
283 | width_l = height_u
284 | width_r = height_u
285 |
286 | obj_name = (
287 | obj["size"]
288 | + " "
289 | + obj["color"]
290 | + " "
291 | + obj["material"]
292 | + " "
293 | + obj["shape"]
294 | )
295 | ymin.append((y - height_d) / 320.0)
296 | ymax.append((y + height_u) / 320.0)
297 | xmin.append((x - width_l) / 480.0)
298 | xmax.append((x + width_r) / 480.0)
299 |
300 | return xmin, ymin, xmax, ymax
301 |
302 | @property
303 | def images_folder(self):
304 | return os.path.join(self.base_path, "images", self.split)
305 |
306 | @property
307 | def scenes_path(self):
308 | if self.split == "test":
309 | raise ValueError("Scenes are not available for test")
310 | return os.path.join(
311 | self.base_path, "scenes", "CLEVR_{}_scenes.json".format(self.split)
312 | )
313 |
314 | def img_db(self):
315 | path = os.path.join(self.base_path, "{}-images.h5".format(self.split))
316 | return h5py.File(path, "r")
317 |
318 | def load_image(self, image_id):
319 | if self.image_db is None:
320 | self.image_db = self.img_db()
321 | index = self.image_id_to_index[image_id]
322 | image = self.image_db["images"][index]
323 | return image
324 |
325 | def make_mask(self, objects, size, num_objs):
326 | num_objs = len(size[size == 1])
327 | masks = torch.zeros([16,128,128])
328 | for i in range(num_objs):
329 | masks[i, objects[1, i]:objects[3, i], objects[0, i]:objects[2, i]] = 1
330 | return masks
331 |
332 | def __getitem__(self, item):
333 | image_id = self.img_ids[item]
334 | image = self.load_image(image_id)
335 | objects, size = self.scenes[item]
336 | if self.chamfer:
337 | objects = (objects * 128).to(dtype=torch.uint8)
338 | num_objs = len(size[size == 1])
339 | return image, self.make_mask(objects, size, num_objs)
340 | return image
341 |
342 | def __len__(self):
343 | if self.split == "train" or self.full:
344 | return len(self.scenes)
345 | else:
346 | return len(self.scenes) // 10
347 |
348 |
349 | class CLEVRMasked(torch.utils.data.Dataset):
350 | def __init__(self, base_path, split, full=False, iou=False):
351 | assert split in {
352 | "train",
353 | "test",
354 | } # note: test isn't very useful since it doesn't have ground-truth scene information
355 | self.base_path = base_path
356 | self.split = split
357 | self.full = full # Use full validation set?
358 | self.iou = iou
359 |
360 | with self.img_db() as db:
361 | ids = db["image_ids"]
362 | self.image_id_to_index = {id: i for i, id in enumerate(ids)}
363 | self.image_db = None
364 | self.img_ids = [i for i in range(len(self.image_id_to_index))]
365 |
366 | @property
367 | def images_folder(self):
368 | return os.path.join(self.base_path, "images", self.split)
369 |
370 | def img_db(self):
371 | path = os.path.join(self.base_path, "{}-images-foreground.h5".format(self.split))
372 | return h5py.File(path, "r")
373 |
374 | def load_image(self, image_id):
375 | if self.image_db is None:
376 | self.image_db = self.img_db()
377 | index = self.image_id_to_index[image_id]
378 | image = self.image_db["images"][index]
379 | image_mask = self.image_db["images_mask"][index]
380 | image_foreground = self.image_db["images_foreground"][index]
381 | return image, image_mask, image_foreground
382 |
383 | def __getitem__(self, item):
384 | image_id = self.img_ids[item]
385 | image, image_mask, image_foreground = self.load_image(image_id)
386 | if self.iou:
387 | return image, image_mask, image_foreground
388 | return image, image_foreground
389 |
390 | def __len__(self):
391 | if self.split == "train" or self.full:
392 | return len(self.img_ids)
393 | else:
394 | return len(self.img_ids) // 10
395 |
--------------------------------------------------------------------------------
/dspn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import higher
5 |
6 |
7 | class InnerSet(nn.Module):
8 | def __init__(self, mask):
9 | super().__init__()
10 | self.mask = mask
11 |
12 | def forward(self):
13 | return self.mask
14 |
15 | class DSPN(nn.Module):
16 | """ Deep Set Prediction Networks
17 | Yan Zhang, Jonathon Hare, Adam Prügel-Bennett
18 | https://arxiv.org/abs/1906.06565
19 | """
20 |
21 | def __init__(self, encoder, set_channels, iters, lr):
22 | """
23 | encoder: Set encoder module that takes a set as input and returns a representation thereof.
24 | It should have a forward function that takes two arguments:
25 | - a set: FloatTensor of size (batch_size, input_channels, maximum_set_size). Each set
26 | should be padded to the same maximum size with 0s, even across batches.
27 | - a mask: FloatTensor of size (batch_size, maximum_set_size). This should take the value 1
28 | if the corresponding element is present and 0 if not.
29 |
30 | channels: Number of channels of the set to predict.
31 |
32 | max_set_size: Maximum size of the set.
33 |
34 | iter: Number of iterations to run the DSPN algorithm for.
35 |
36 | lr: Learning rate of inner gradient descent in DSPN.
37 | """
38 | super().__init__()
39 | self.encoder = encoder
40 | self.iters = iters
41 | self.lr = lr
42 |
43 | def forward(self, target_repr, init):
44 | """
45 | Conceptually, DSPN simply turns the target_repr feature vector into a set.
46 |
47 | target_repr: Representation that the predicted set should match. FloatTensor of size (batch_size, repr_channels).
48 | This can come from a set processed with the same encoder as self.encoder (auto-encoder), or a different
49 | input completely (normal supervised learning), such as an image encoded into a feature vector.
50 | """
51 | # copy same initial set over batch
52 | current_set = nn.Parameter(init)
53 | inner_set = InnerSet(current_set)
54 |
55 | # info used for loss computation
56 | intermediate_sets = [current_set]
57 | # info used for debugging
58 | repr_losses = []
59 | grad_norms = []
60 |
61 | # optimise repr_loss for fixed number of steps
62 | with torch.enable_grad():
63 | opt = torch.optim.SGD(inner_set.parameters(), lr=self.lr, momentum=0.5)
64 | with higher.innerloop_ctx(inner_set, opt) as (fset, diffopt):
65 | for i in range(self.iters):
66 | predicted_repr = self.encoder(fset())
67 | # how well does the representation matches the target
68 | repr_loss = ((predicted_repr- target_repr)**2).sum()
69 | diffopt.step(repr_loss)
70 | intermediate_sets.append(fset.mask)
71 | repr_losses.append(repr_loss)
72 | grad_norms.append(())
73 |
74 | return intermediate_sets, repr_losses, grad_norms
75 |
--------------------------------------------------------------------------------
/fspool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FSPool(nn.Module):
7 | """
8 | Featurewise sort pooling. From:
9 |
10 | FSPool: Learning Set Representations with Featurewise Sort Pooling.
11 | Yan Zhang, Jonathon Hare, Adam Prügel-Bennett
12 | https://arxiv.org/abs/1906.02795
13 | https://github.com/Cyanogenoid/fspool
14 | """
15 |
16 | def __init__(self, in_channels, n_pieces, relaxed=False):
17 | """
18 | in_channels: Number of channels in input
19 | n_pieces: Number of pieces in piecewise linear
20 | relaxed: Use sorting networks relaxation instead of traditional sorting
21 | """
22 | super().__init__()
23 | self.n_pieces = n_pieces
24 | self.weight = nn.Parameter(torch.zeros(in_channels, n_pieces + 1))
25 | self.relaxed = relaxed
26 |
27 | self.reset_parameters()
28 |
29 | def reset_parameters(self):
30 | nn.init.normal_(self.weight)
31 |
32 | def forward(self, x, n=None):
33 | """ FSPool
34 |
35 | x: FloatTensor of shape (batch_size, in_channels, set size).
36 | This should contain the features of the elements in the set.
37 | Variable set sizes should be padded to the maximum set size in the batch with 0s.
38 |
39 | n: LongTensor of shape (batch_size).
40 | This tensor contains the sizes of each set in the batch.
41 | If not specified, assumes that every set has the same size of x.size(2).
42 | Note that n.max() should never be greater than x.size(2), i.e. the specified set size in the
43 | n tensor must not be greater than the number of elements stored in the x tensor.
44 |
45 | Returns: pooled input x, used permutation matrix perm
46 | """
47 | assert x.size(1) == self.weight.size(
48 | 0
49 | ), "incorrect number of input channels in weight"
50 | # can call withtout length tensor, uses same length for all sets in the batch
51 | if n is None:
52 | n = x.new(x.size(0)).fill_(x.size(2)).long()
53 | # create tensor of ratios $r$
54 | sizes, mask = fill_sizes(n, x)
55 | mask = mask.expand_as(x)
56 |
57 | # turn continuous into concrete weights
58 | weight = self.determine_weight(sizes)
59 |
60 | # make sure that fill value isn't affecting sort result
61 | # sort is descending, so put unreasonably low value in places to be masked away
62 | x = x + (1 - mask).float() * -99999
63 | if self.relaxed:
64 | x, perm = cont_sort(x, temp=self.relaxed)
65 | else:
66 | x, perm = x.sort(dim=2, descending=True)
67 |
68 | x = (x * weight * mask.float()).sum(dim=2)
69 | return x, perm
70 |
71 | def forward_transpose(self, x, perm, n=None):
72 | """ FSUnpool
73 |
74 | x: FloatTensor of shape (batch_size, in_channels)
75 | perm: Permutation matrix returned by forward function.
76 | n: LongTensor fo shape (batch_size)
77 | """
78 | if n is None:
79 | n = x.new(x.size(0)).fill_(perm.size(2)).long()
80 | sizes, mask = fill_sizes(n)
81 | mask = mask.expand(mask.size(0), x.size(1), mask.size(2))
82 |
83 | weight = self.determine_weight(sizes)
84 |
85 | x = x.unsqueeze(2) * weight * mask.float()
86 |
87 | if self.relaxed:
88 | x, _ = cont_sort(x, perm)
89 | else:
90 | x = x.scatter(2, perm, x)
91 | return x, mask
92 |
93 | def determine_weight(self, sizes):
94 | """
95 | Piecewise linear function. Evaluates f at the ratios in sizes.
96 | This should be a faster implementation than doing the sum over max terms, since we know that most terms in it are 0.
97 | """
98 | # share same sequence length within each sample, so copy weighht across batch dim
99 | weight = self.weight.unsqueeze(0)
100 | weight = weight.expand(sizes.size(0), weight.size(1), weight.size(2))
101 |
102 | # linspace [0, 1] -> linspace [0, n_pieces]
103 | index = self.n_pieces * sizes
104 | index = index.unsqueeze(1)
105 | index = index.expand(index.size(0), weight.size(1), index.size(2))
106 |
107 | # points in the weight vector to the left and right
108 | idx = index.long()
109 | frac = index.frac()
110 | left = weight.gather(2, idx)
111 | right = weight.gather(2, (idx + 1).clamp(max=self.n_pieces))
112 |
113 | # interpolate between left and right point
114 | return (1 - frac) * left + frac * right
115 |
116 |
117 | def fill_sizes(sizes, x=None):
118 | """
119 | sizes is a LongTensor of size [batch_size], containing the set sizes.
120 | Each set size n is turned into [0/(n-1), 1/(n-1), ..., (n-2)/(n-1), 1, 0, 0, ..., 0, 0].
121 | These are the ratios r at which f is evaluated at.
122 | The 0s at the end are there for padding to the largest n in the batch.
123 | If the input set x is passed in, it guarantees that the mask is the correct size even when sizes.max()
124 | is less than x.size(), which can be a case if there is at least one padding element in each set in the batch.
125 | """
126 | if x is not None:
127 | max_size = x.size(2)
128 | else:
129 | max_size = sizes.max()
130 | size_tensor = sizes.new(sizes.size(0), max_size).float().fill_(-1)
131 |
132 | size_tensor = torch.arange(end=max_size, device=sizes.device, dtype=torch.float32)
133 | size_tensor = size_tensor.unsqueeze(0) / (sizes.float() - 1).clamp(min=1).unsqueeze(
134 | 1
135 | )
136 |
137 | mask = size_tensor <= 1
138 | mask = mask.unsqueeze(1)
139 |
140 | return size_tensor.clamp(max=1), mask.float()
141 |
142 |
143 | def deterministic_sort(s, tau):
144 | """
145 | "Stochastic Optimization of Sorting Networks via Continuous Relaxations" https://openreview.net/forum?id=H1eSS3CcKX
146 |
147 | Aditya Grover, Eric Wang, Aaron Zweig, Stefano Ermon
148 |
149 | s: input elements to be sorted. Shape: batch_size x n x 1
150 | tau: temperature for relaxation. Scalar.
151 | """
152 | n = s.size()[1]
153 | one = torch.ones((n, 1), dtype=torch.float32, device=s.device)
154 | A_s = torch.abs(s - s.permute(0, 2, 1))
155 | B = torch.matmul(A_s, torch.matmul(one, one.transpose(0, 1)))
156 | scaling = (n + 1 - 2 * (torch.arange(n, device=s.device) + 1)).type(torch.float32)
157 | C = torch.matmul(s, scaling.unsqueeze(0))
158 | P_max = (C - B).permute(0, 2, 1)
159 | sm = torch.nn.Softmax(-1)
160 | P_hat = sm(P_max / tau)
161 | return P_hat
162 |
163 |
164 | def cont_sort(x, perm=None, temp=1):
165 | """ Helper function that calls deterministic_sort with the right shape.
166 | Since it assumes a shape of (batch_size, n, 1) while the input x is of shape (batch_size, channels, n),
167 | we can get this to the right shape by merging the first two dimensions.
168 | If an existing perm is passed in, we compute the "inverse" (transpose of perm) and just use that to unsort x.
169 | """
170 | original_size = x.size()
171 | x = x.view(-1, x.size(2), 1)
172 | if perm is None:
173 | perm = deterministic_sort(x, temp)
174 | else:
175 | perm = perm.transpose(1, 2)
176 | x = perm.matmul(x)
177 | x = x.view(original_size)
178 | return x, perm
179 |
180 |
181 | if __name__ == "__main__":
182 | pool = FSort(2, 1)
183 | x = torch.arange(0, 2 * 3 * 4).view(3, 2, 4).float()
184 | print("x", x)
185 | y, perm = pool(x, torch.LongTensor([2, 3, 4]))
186 | print("perm")
187 | print(perm)
188 | print("result")
189 | print(y)
190 |
--------------------------------------------------------------------------------
/full_iou_clevr.py:
--------------------------------------------------------------------------------
1 | from run_reconstruct_clevr import SSLR
2 | import os
3 | import data
4 | import torch
5 | from tqdm import tqdm
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pickle
9 | import argparse
10 |
11 |
12 | def get_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn")
15 | parser.add_argument('--batch_size', type=int, help='batch size', default=32)
16 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None)
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | if __name__ == '__main__':
22 | args = get_args()
23 | print(args)
24 |
25 | use_srn = args.model_type == "srn"
26 |
27 | dataset_test = data.CLEVRMasked(
28 | "clevr", "test", full=True, iou=True
29 | )
30 | batch_size = args.batch_size
31 | test_loader = data.get_loader(
32 | dataset_test, batch_size=batch_size, shuffle=False
33 | )
34 |
35 | net = SSLR(use_srn=use_srn).float().cuda()
36 | net.eval()
37 | net.load_state_dict(torch.load(args.resume))
38 |
39 | test_loader = tqdm(
40 | test_loader,
41 | ncols=0,
42 | desc="test"
43 | )
44 |
45 | SMOOTH = 1e-6
46 | full_iou = 0
47 | import gc
48 | for idx, data in enumerate(test_loader):
49 | def tfunc():
50 | gc.collect()
51 | image, image_mask, image_foreground_ = [x.cuda() for x in data]
52 |
53 | p_, inner_losses, gs_ = net(image)
54 |
55 | image, image_mask, image_foreground = [x.detach().cpu().numpy() for x in data]
56 |
57 | p = p_.detach().cpu().numpy()
58 | gs = gs_.detach().cpu().numpy()
59 |
60 | thresh_mask = p < 1e-2
61 | p[thresh_mask] = 0
62 | p[~thresh_mask] = 1
63 | p = p.astype('uint8')
64 |
65 | image_foreground[image_foreground != 0] = 1
66 | image_foreground = image_foreground.astype('uint8')
67 |
68 | intersect = (p & image_foreground).sum((1,2,3))
69 | union = (p | image_foreground).sum((1,2,3))
70 | iou = ((intersect + SMOOTH) / (union + SMOOTH)).sum()
71 |
72 | return iou
73 | full_iou += tfunc()
74 |
75 | full_iou /= len(test_loader) * batch_size
76 | print(full_iou)
77 |
--------------------------------------------------------------------------------
/imgs/clevr_tile_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/clevr_tile_1.jpg
--------------------------------------------------------------------------------
/imgs/set.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/set.png
--------------------------------------------------------------------------------
/imgs/tiled_samples_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CUAI/BetterSetRepresentations/0510627834462a498d42b80c08dbb4fea946b9c8/imgs/tiled_samples_1.png
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import utils
5 | from utils import hungarian_loss_each
6 | from dspn import DSPN
7 | from fspool import FSPool
8 |
9 |
10 | class Encoder(nn.Module):
11 |
12 | def __init__(self, element_dims, set_size, out_size):
13 | super(Encoder, self).__init__()
14 | self.nef = 64
15 | self.e_ksize = 4
16 | self.set_size = set_size
17 | self.out_size = out_size
18 | self.element_dims = element_dims
19 |
20 | self.conv1 = nn.Conv2d(3, self.nef, self.e_ksize, stride = 2, padding = 1, bias = False)
21 |
22 | self.conv2 = nn.Conv2d(self.nef, self.nef*2, self.e_ksize, stride = 2, padding = 1, bias = False)
23 | self.bn2 = nn.BatchNorm2d(self.nef*2)
24 |
25 | self.conv3 = nn.Conv2d(self.nef*2, self.nef*4, self.e_ksize, stride = 2, padding = 1, bias = False)
26 | self.bn3 = nn.BatchNorm2d(self.nef*4)
27 |
28 | self.conv4 = nn.Conv2d(self.nef*4, self.nef*8, self.e_ksize, stride = 4, padding = 1, bias = False)
29 |
30 |
31 | self.bn4 = nn.BatchNorm2d(self.nef*8)
32 |
33 | self.proj = nn.Linear(self.nef*32, self.out_size)
34 |
35 | self.proj_s = nn.Conv1d(2048//self.set_size, self.element_dims, 1)
36 |
37 | def forward(self, x):
38 | out = F.relu(self.conv1(x))
39 | out = F.relu(self.bn2(self.conv2(out)))
40 | out = F.relu(self.bn3(self.conv3(out)))
41 | out = F.relu(self.bn4(self.conv4(out)))
42 |
43 | s = self.proj_s(out.view(out.shape[0], self.set_size, 2048//self.set_size).transpose(1,2))
44 |
45 | out = out.view(out.shape[0], self.nef*32)
46 | return self.proj(out), s
47 |
48 |
49 | class Decoder(nn.Module):
50 |
51 | def __init__(self, input_dim):
52 | super(Decoder, self).__init__()
53 |
54 | self.ngf = 256
55 | g_ksize = 4
56 | self.proj = nn.Linear(input_dim, self.ngf * 4 * 4 * 4)
57 | self.bn0 = nn.BatchNorm1d(self.ngf * 4 * 4 * 4)
58 |
59 | self.dconv1 = nn.ConvTranspose2d(self.ngf * 4,self.ngf*2, g_ksize,
60 | stride=2, padding=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(self.ngf*2)
62 |
63 | self.dconv2 = nn.ConvTranspose2d(self.ngf*2, self.ngf, g_ksize,
64 | stride=2, padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(self.ngf)
66 |
67 | self.dconv3 = nn.ConvTranspose2d(self.ngf, 3, g_ksize,
68 | stride=4, padding=0, bias=False)
69 |
70 | def forward(self, z, c=None):
71 | out = F.relu(self.bn0(self.proj(z)).view(-1, self.ngf* 4, 4, 4))
72 | out = F.relu(self.bn1(self.dconv1(out)))
73 | out = F.relu(self.bn2(self.dconv2(out)))
74 | out = self.dconv3(out)
75 | return out
76 |
77 |
78 | class FSEncoder(nn.Module):
79 | def __init__(self, input_channels, output_channels, dim):
80 | super().__init__()
81 | self.conv = nn.Sequential(
82 | nn.Conv1d(input_channels, dim, 1),
83 | nn.ReLU(),
84 | nn.Conv1d(dim, dim, 1),
85 | nn.ReLU(),
86 | nn.Conv1d(dim, output_channels, 1),
87 | )
88 | self.pool = FSPool(output_channels, 20, relaxed=False)
89 |
90 | def forward(self, x, mask=None):
91 | x = self.conv(x)
92 | x = x / x.size(2) # normalise so that activations aren't too high with big sets
93 | x, _ = self.pool(x)
94 | return x
95 |
96 |
97 | class SetGen(nn.Module):
98 | def __init__(self, element_dims=10, set_size=16, lr=200, use_srn= True, iters=5):
99 | super(SetGen, self).__init__()
100 | self.use_srn = use_srn
101 | CNN_ENCODER_SPACE = 100
102 | # H_{agg}
103 | self.encoder = FSEncoder(element_dims, CNN_ENCODER_SPACE, 512)
104 | self.decoder = DSPN(
105 | self.encoder, element_dims, iters=iters, lr=lr
106 | )
107 | # H_{set} and H_{embed}
108 | self.img_encoder = Encoder(element_dims, set_size, CNN_ENCODER_SPACE)
109 |
110 | def forward(self, x):
111 | x, s = self.img_encoder(x)
112 | if self.use_srn:
113 | intermediate_sets, losses, grad_norms = self.decoder(x, s)
114 | x = intermediate_sets[-1]
115 | else:
116 | x = s
117 |
118 | if self.use_srn:
119 | return x, losses
120 | else:
121 | return x, None
122 |
123 |
124 | class F_match(nn.Module):
125 | def __init__(self):
126 | super(F_match, self).__init__()
127 | self.proj1 = torch.nn.Conv1d(10, 3, 1)
128 | self.proj2 = torch.nn.Conv1d(10, 3, 1)
129 |
130 | def forward(self, x_set, y_set, pool):
131 | # x_set shape: B, element_dims, set_size
132 | x_att = self.proj1(x_set)
133 | y_att = self.proj1(y_set)
134 |
135 | x_loc = self.proj2(x_set)
136 | y_loc = self.proj2(y_set)
137 |
138 | # matching
139 | indices = hungarian_loss_each(x_att, y_att, pool)
140 | l = [
141 | (x_loc[idx,:,row_idx] - y_loc[idx,:,col_idx])**2
142 | for idx, (row_idx, col_idx) in enumerate(indices)
143 | ]
144 | l_m = [
145 | ((x_att[idx,:,row_idx] - y_att[idx,:,col_idx])**2).sum()
146 | for idx, (row_idx, col_idx) in enumerate(indices)
147 | ]
148 | match_dist = torch.stack(list(l)).sum(1).sum(1)
149 | match_score = torch.stack(list(l_m))
150 | return match_dist, match_score
151 |
152 |
153 | class F_reconstruct(nn.Module):
154 | def __init__(self, element_dims=10):
155 | super(F_reconstruct, self).__init__()
156 | self.vec_decoder = Decoder(element_dims)
157 |
158 | def forward(self, x_set):
159 | batch_size = x_set.size(0)
160 | element_dims = x_set.size(1)
161 | set_size = x_set.size(2)
162 |
163 | x = x_set.transpose(1,2).reshape(-1,element_dims)
164 | generated = self.vec_decoder(x)
165 | generated = generated.reshape(batch_size, set_size, 3, 64, 64)
166 |
167 | attention = torch.softmax(generated, dim=1)
168 | generated_set = torch.sigmoid(generated)
169 |
170 | generated_set = generated_set*attention
171 | generated_f = generated_set.sum(dim=1).clamp(0,1)
172 |
173 | return generated_f, generated_set
174 |
175 |
176 | class EncoderCLEVR(nn.Module):
177 |
178 | def __init__(self, element_dims=10, set_size=16, out_size=512):
179 | super(EncoderCLEVR, self).__init__()
180 | self.nef = 64
181 | self.e_ksize = 4
182 | self.set_size = set_size
183 | self.out_size = out_size
184 | self.element_dims = element_dims
185 |
186 | self.conv1 = nn.Conv2d(3, self.nef, self.e_ksize, stride = 2, padding = 1, bias = False)
187 |
188 | self.conv2 = nn.Conv2d(self.nef, self.nef*2, self.e_ksize, stride = 2, padding = 1, bias = False)
189 | self.bn2 = nn.BatchNorm2d(self.nef*2)
190 |
191 | self.conv3 = nn.Conv2d(self.nef*2, self.nef*4, self.e_ksize, stride = 2, padding = 1, bias = False)
192 | self.bn3 = nn.BatchNorm2d(self.nef*4)
193 |
194 | self.conv4 = nn.Conv2d(self.nef*4, self.nef*8, self.e_ksize, stride = 4, padding = 1, bias = False)
195 | self.bn4 = nn.BatchNorm2d(self.nef*8)
196 |
197 | self.proj = nn.Linear(self.nef*128, self.out_size)
198 | self.proj_s = nn.Conv1d(8192//self.set_size, self.element_dims, 1)
199 |
200 |
201 | def forward(self, x):
202 | out = F.relu(self.conv1(x))
203 | out = F.relu(self.bn2(self.conv2(out)))
204 | out = F.relu(self.bn3(self.conv3(out)))
205 | out = F.relu(self.bn4(self.conv4(out)))
206 |
207 | s = self.proj_s(out.view(out.shape[0], self.set_size, 8192//self.set_size).transpose(1,2))
208 |
209 | out = out.view(out.shape[0], self.nef*128)
210 | return self.proj(out), s
211 |
212 |
213 | class DecoderCLEVR(nn.Module):
214 | def __init__(self, input_dim):
215 | super(DecoderCLEVR, self).__init__()
216 |
217 | self.ngf = 256
218 | g_ksize = 4
219 | self.proj = nn.Linear(input_dim, self.ngf * 4 * 4 * 4 * 4)
220 | self.bn0 = nn.BatchNorm1d(self.ngf * 4 * 4 * 4 * 4)
221 |
222 | self.dconv1 = nn.ConvTranspose2d(self.ngf * 4,self.ngf*2, g_ksize,
223 | stride=2, padding=1, bias=False)
224 | self.bn1 = nn.BatchNorm2d(self.ngf*2)
225 |
226 | self.dconv2 = nn.ConvTranspose2d(self.ngf*2, self.ngf, g_ksize,
227 | stride=2, padding=1, bias=False)
228 | self.bn2 = nn.BatchNorm2d(self.ngf)
229 |
230 | self.dconv3 = nn.ConvTranspose2d(self.ngf, 3, g_ksize,
231 | stride=4, padding=0, bias=False)
232 |
233 | def forward(self, z, c=None):
234 | out = F.relu(self.bn0(self.proj(z)).view(-1, self.ngf* 4, 4*2, 4*2))
235 | out = F.relu(self.bn1(self.dconv1(out)))
236 | out = F.relu(self.bn2(self.dconv2(out)))
237 | out = self.dconv3(out)
238 | return out
239 |
240 |
241 | class F_reconstruct_CLEVR(nn.Module):
242 | def __init__(self, element_dims=10):
243 | super(F_reconstruct_CLEVR, self).__init__()
244 | self.vec_decoder = DecoderCLEVR(element_dims)
245 |
246 | def forward(self, x_set):
247 | batch_size = x_set.size(0)
248 | element_dims = x_set.size(1)
249 | set_size = x_set.size(2)
250 |
251 | x = x_set.transpose(1,2).reshape(-1,element_dims)
252 | generated = self.vec_decoder(x)
253 | generated = generated.reshape(batch_size, set_size, 3, 128, 128)
254 |
255 | attention = torch.softmax(generated, dim=1)
256 | generated_set = torch.sigmoid(generated)
257 |
258 | generated_set = generated_set*attention
259 | generated_f = generated_set.sum(dim=1).clamp(0,1)
260 |
261 | return generated_f, generated_set
262 |
263 |
264 | class RNFSEncoder(nn.Module):
265 | def __init__(self, input_channels, output_channels, dim):
266 | super().__init__()
267 | self.conv = nn.Sequential(
268 | nn.Conv2d(2 * input_channels, dim, 1),
269 | nn.ReLU(),
270 | nn.Conv2d(dim, output_channels, 1),
271 | )
272 | self.lin = nn.Linear(dim, output_channels)
273 | self.pool = FSPool(output_channels, 20, relaxed=False)
274 |
275 | def forward(self, x, mask=None):
276 | # create all pairs of elements
277 | x = torch.cat(utils.outer(x), dim=1)
278 | x = self.conv(x)
279 | # flatten pairs and scale appropriately
280 | n, c, l, _ = x.size()
281 | x = x.view(x.size(0), x.size(1), -1) / l / l
282 | x, _ = self.pool(x)
283 | return x
284 |
285 |
286 | class SetGenCLEVR(nn.Module):
287 | def __init__(self, element_dims=10, set_size=16, lr=8, use_srn=True):
288 | super(SetGenCLEVR, self).__init__()
289 | self.use_srn = use_srn
290 | CNN_ENCODER_SPACE = 512
291 | # H_{agg}
292 | self.encoder = RNFSEncoder(element_dims, CNN_ENCODER_SPACE, 512)
293 | self.decoder = DSPN(
294 | self.encoder, element_dims, iters=10, lr=lr
295 | )
296 | # H_{set} and H_{embed}
297 | self.img_encoder = EncoderCLEVR(element_dims, set_size, CNN_ENCODER_SPACE)
298 |
299 | def forward(self, x):
300 | x, s = self.img_encoder(x)
301 | if self.use_srn:
302 | intermediate_sets, losses, grad_norms = self.decoder(x, s)
303 | x = intermediate_sets[-1]
304 | else:
305 | x = s
306 |
307 | if self.use_srn:
308 | return x, losses
309 | else:
310 | return x, None
311 |
312 |
--------------------------------------------------------------------------------
/preprocess-images.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import h5py
4 | import torch.utils.data
5 | import torchvision.models as models
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 | from tqdm import tqdm
9 | import re
10 |
11 |
12 | class CLEVR_Images(torch.utils.data.Dataset):
13 | """ Dataset for MSCOCO images located in a folder on the filesystem """
14 |
15 | def __init__(self, path, transform=None):
16 | super().__init__()
17 | self.p = re.compile('\d+')
18 | self.path = path
19 | self.id_to_filename = self._find_images()
20 | self.sorted_ids = sorted(
21 | self.id_to_filename.keys()
22 | ) # used for deterministic iteration order
23 | print("found {} images in {}".format(len(self), self.path))
24 | self.transform = transform
25 |
26 | def _find_images(self):
27 | id_to_filename = {}
28 | for filename in os.listdir(self.path):
29 | if not filename.endswith(".png") or 'mask' in filename or 'foreground' in filename:
30 | continue
31 | id = int(self.p.search(filename).group())
32 | no_ext = filename[:filename.rfind('.')]
33 | filename_mask = no_ext + '_mask.png'
34 | filename_foreground = no_ext + '_foreground.png'
35 | id_to_filename[id] = (filename, filename_mask, filename_foreground)
36 | return id_to_filename
37 |
38 | def __getitem__(self, item):
39 | id = self.sorted_ids[item]
40 | path = os.path.join(self.path, self.id_to_filename[id][0])
41 | path_mask = os.path.join(self.path, self.id_to_filename[id][1])
42 | path_foreground = os.path.join(self.path, self.id_to_filename[id][2])
43 |
44 | img = Image.open(path).convert("RGB")
45 | img_mask = Image.open(path_mask).convert("RGB")
46 | img_foreground = Image.open(path_foreground).convert("RGB")
47 |
48 | if self.transform is not None:
49 | img = self.transform(img)
50 | img_mask = self.transform(img_mask)
51 | img_foreground = self.transform(img_foreground)
52 | return id, img, img_mask, img_foreground
53 |
54 | def __len__(self):
55 | return len(self.sorted_ids)
56 |
57 |
58 | def create_coco_loader(path):
59 | transform = transforms.Compose(
60 | [transforms.Resize((128, 128)), transforms.ToTensor()]
61 | )
62 | dataset = CLEVR_Images(path, transform=transform)
63 | data_loader = torch.utils.data.DataLoader(
64 | dataset, batch_size=64, num_workers=12, shuffle=False, pin_memory=True
65 | )
66 | return data_loader
67 |
68 |
69 | def main():
70 | for split_name in ["train", "test"]:
71 | path = os.path.join("clevr", "images", split_name)
72 | loader = create_coco_loader(path)
73 | images_shape = (len(loader.dataset), 3, 128, 128)
74 |
75 | with h5py.File("{}-images-foreground.h5".format(split_name), libver="latest") as fd:
76 |
77 | images = fd.create_dataset("images", shape=images_shape, dtype="float32")
78 | images_mask = fd.create_dataset("images_mask", shape=images_shape, dtype="float32")
79 | images_foreground = fd.create_dataset("images_foreground", shape=images_shape, dtype="float32")
80 | image_ids = fd.create_dataset("image_ids", shape=(len(loader.dataset),), dtype="int32")
81 |
82 | i = 0
83 | for ids, imgs, imgs_mask, imgs_foreground in tqdm(loader):
84 | assert imgs.size(0) == imgs_mask.size(0) == imgs_foreground.size(0)
85 | j = i + imgs.size(0)
86 | images[i:j, :, :] = imgs.numpy()
87 | images_mask[i:j, :, :] = imgs_mask.numpy()
88 | images_foreground[i:j, :, :] = imgs_foreground.numpy()
89 | image_ids[i:j] = ids.numpy().astype("int32")
90 | i = j
91 |
92 |
93 | if __name__ == "__main__":
94 | main()
95 |
--------------------------------------------------------------------------------
/run_isodistance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import data
3 | import torch
4 | import torch.nn as nn
5 | from models import SetGen, F_match, F_reconstruct
6 | import torch.multiprocessing as mp
7 | from tensorboardX import SummaryWriter
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--model_type', help='model type: srn | mlp | cnn', default="srn")
12 | parser.add_argument('--batch_size', type=int, help='batch size', default=32)
13 | parser.add_argument('--recon', action="store_true" , help='transfer models', default=False)
14 | parser.add_argument('--resume', help='Resume checkpoint', default=None)
15 | parser.add_argument('--lr', type=float, help='lr', default=5e-4)
16 | parser.add_argument('--weight_decay', type=float, help='weight decay', default=0)
17 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=0.1)
18 | parser.add_argument('--save', help='Path of the saved checkpoint', default=None)
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | class Net(nn.Module):
24 | def __init__(self, lr=200):
25 | super(Net, self).__init__()
26 | self.img_encoder = Encoder()
27 | self.proj = nn.Linear(100, 1)
28 |
29 | def forward(self, x, y):
30 | all_images = torch.cat((x, y))
31 | x, s = self.img_encoder(all_images)
32 | batch_size = x.size(0) // 2
33 |
34 | reference = x[:batch_size,:]
35 | mem = x[batch_size:,:]
36 |
37 | x =(reference- mem)**2
38 |
39 | return self.proj(x)
40 |
41 |
42 |
43 | class SSLR(nn.Module):
44 | def __init__(self, lr=200, use_srn= True):
45 | super(SSLR, self).__init__()
46 | self.use_srn = use_srn
47 | element_dims=10
48 | set_size=16
49 | self.set_generator = SetGen(element_dims, set_size, lr, use_srn)
50 | self.f_match = F_match()
51 | self.f_reconstruct = F_reconstruct(element_dims)
52 |
53 | def forward(self, x, y, pool):
54 | all_images = torch.cat((x, y))
55 | x, losses = self.set_generator(all_images)
56 |
57 | batch_size = x.size(0) // 2
58 |
59 | match_dist, match_score = self.f_match(x[:batch_size,:,:], x[batch_size:,:,:], pool)
60 | generated_f, _ = self.f_reconstruct(x[:batch_size,:,:])
61 |
62 | if self.use_srn:
63 | return match_dist, losses, match_score, generated_f
64 | else:
65 | return match_dist, match_score, generated_f
66 |
67 |
68 | def eval(net, batch_size, test_loader, pool, epoch, model_type):
69 | net.eval()
70 | all_loss = 0
71 | acc = 0
72 | import gc;
73 | for idx, data in enumerate(test_loader):
74 | images_x, images_y, s = data
75 | images_x, images_y, s = images_x.cuda(), images_y.cuda(), s.sum(1).cuda()/(64*64)
76 |
77 | if model_type == "srn":
78 | match_dist, inner_losses, match_score, re = net(images_x, images_y, pool)
79 | elif model_type == "mlp":
80 | match_dist, match_score, re = net(images_x, images_y, pool)
81 | else:
82 | match_dist = net(images_x, images_y).view(-1)
83 |
84 | loss = ((match_dist- s)**2).sum()
85 | all_loss += loss.item()
86 |
87 | acc += torch.abs((match_dist- s)/s).mean()
88 | acc = acc.detach().cpu()
89 |
90 | gc.collect()
91 | return all_loss/len(test_loader), acc/len(test_loader)
92 |
93 |
94 |
95 | if __name__ == "__main__":
96 |
97 | args = get_args()
98 | print(args)
99 |
100 | train_loader = data.get_loader(data.IsoColorCircles(train=True, size=64000, n = 2), batch_size = args.batch_size)
101 | test_loader = data.get_loader(data.IsoColorCircles(train=False, size=4000, n = 2), batch_size = args.batch_size)
102 |
103 | if args.model_type == "srn":
104 | net = SSLR(float(args.inner_lr)).float().cuda()
105 |
106 | if args.resume is not None:
107 | print("resume from ", args.resume)
108 | # state_dict = torch.load("set_model_recon_0.1_l2.pt")
109 | state_dict = torch.load(args.resume)
110 | own_state = net.state_dict()
111 | for name, param in state_dict.items():
112 | if isinstance(param, torch.nn.Parameter):
113 | param = param.data
114 | own_state[name].copy_(param)
115 |
116 | elif args.model_type == "mlp":
117 | net = SSLR(use_srn = False).float().cuda()
118 |
119 | else:
120 | assert args.model_type == "cnn"
121 | net = Net().float().cuda()
122 |
123 | optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
124 | writer = SummaryWriter(f"match_runs/{args.model_type}_lr={args.lr}_wd={args.weight_decay}_ilr={args.inner_lr}", purge_step=0, flush_secs = 10)
125 |
126 | running_loss = 0
127 | for epoch in range(1000+1):
128 | with mp.Pool(10) as pool:
129 | print(f"epoch {epoch}")
130 |
131 | net.train()
132 | running_loss = 0
133 | for idx, data in enumerate(train_loader):
134 | images_x, images_y, s = data
135 | images_x, images_y, s = images_x.cuda(), images_y.cuda(), s.sum(1).cuda()/(64*64)
136 | optimizer.zero_grad()
137 |
138 | if args.model_type == "srn":
139 | match_dist, inner_losses, match_score, re = net(images_x, images_y, pool)
140 | elif args.model_type == "mlp":
141 | match_dist, match_score, re = net(images_x, images_y, pool)
142 | else:
143 | match_dist = net(images_x, images_y).view(-1)
144 |
145 | dist_loss = ((match_dist- s)**2).sum()
146 |
147 | use_set = (args.model_type == "srn") or (args.model_type == "mlp")
148 | if use_set:
149 | match_loss = match_score.mean()
150 | loss = dist_loss + 10*match_loss
151 | else:
152 | loss = dist_loss
153 |
154 | if args.recon :
155 | recon_loss = ((re - images)**2).mean()
156 | loss += recon_loss
157 |
158 | if use_set:
159 | writer.add_scalar("train/dist_loss", dist_loss.item(), global_step=epoch*len(train_loader) + idx)
160 | writer.add_scalar("train/match_loss", match_loss.item(), global_step=epoch*len(train_loader) + idx)
161 | if args.recon :
162 | writer.add_scalar("train/recon_loss", recon_loss.item(), global_step=epoch*len(train_loader) + idx)
163 | writer.add_scalar("train/loss", loss.item(), global_step=epoch*len(train_loader) + idx)
164 |
165 | loss.backward()
166 | alpha = 0.05
167 | optimizer.step()
168 |
169 |
170 | if idx % (len(train_loader)//4) == 0:
171 | if use_set:
172 | if args.model_type == "srn":
173 | print(f"inner loss {[l.item()/args.batch_size for l in inner_losses]}")
174 | print("dist_loss", dist_loss.item())
175 | print("match_loss",match_loss.item())
176 | if args.recon :
177 | print("recon_loss",recon_loss.item())
178 | print("loss",loss.item())
179 |
180 | running_loss += loss.item()
181 | print(running_loss/ len(train_loader))
182 | if epoch % 1 ==0:
183 | with mp.Pool(10) as pool:
184 | eval_loss, acc = eval(net, args.batch_size, test_loader, pool, epoch, args.model_type)
185 | print(f"eval: {eval_loss} {acc}")
186 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch)
187 | writer.add_scalar("eval/acc", acc, global_step=epoch)
188 | writer.flush()
189 |
190 | print()
191 | #save model
192 | if args.save is not None:
193 | torch.save(net.state_dict(), args.save)
194 |
--------------------------------------------------------------------------------
/run_reconstruct_circles.py:
--------------------------------------------------------------------------------
1 | import data
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import random
7 | from dspn import DSPN
8 | from fspool import FSPool
9 | from tensorboardX import SummaryWriter
10 | import matplotlib
11 | from models import *
12 | import argparse
13 |
14 | matplotlib.use("Agg")
15 | import matplotlib.pyplot as plt
16 |
17 |
18 | def get_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn")
21 | parser.add_argument('--batch_size', type=int, help='batch size', default=64)
22 | parser.add_argument('--lr', type=float, help='lr', default=3e-4)
23 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=0.1)
24 | parser.add_argument('--inner_iters',type=int, help='# of inner iterations steps to perform', default=10)
25 | parser.add_argument('--start_epoch',type=int, help='epoch to start at', default=0)
26 | parser.add_argument('--load_ckpt', default=False, action='store_true')
27 |
28 | args = parser.parse_args()
29 | return args
30 |
31 | class SSLR(nn.Module):
32 | def __init__(self, lr=200, num_iters=10, use_srn=True):
33 | super(SSLR, self).__init__()
34 | self.element_dims = 10
35 | self.set_generator = SetGen(element_dims = self.element_dims, set_size=16, lr=lr, use_srn=use_srn, iters=num_iters)
36 | self.f_reconstruct = F_reconstruct(element_dims = self.element_dims)
37 | self.use_srn = use_srn
38 |
39 | def forward(self, x, print_interm=False):
40 | x, losses = self.set_generator(x)
41 | generated_f, generated_set = self.f_reconstruct(x)
42 |
43 | if self.use_srn:
44 | return generated_f, losses, generated_set
45 | else:
46 | return generated_f, [], generated_set
47 |
48 |
49 |
50 | def eval(net, batch_size, test_loader, epoch, writer, use_srn = True):
51 | net.eval()
52 | all_loss = 0
53 | rel_error = 0
54 | for idx, data in enumerate(test_loader):
55 | images, labels = data
56 | images, labels = images.cuda(), labels.cuda()
57 |
58 | if use_srn:
59 | p, inner_losses, gs = net(images)
60 | else:
61 | p = net(images)
62 |
63 | loss = F.binary_cross_entropy(p, images)
64 |
65 | for j, s_ in enumerate(gs[0]):
66 | fig = plt.figure()
67 | plt.imshow(s_.transpose(0,2).detach().cpu())
68 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=j)
69 |
70 | fig = plt.figure()
71 | plt.imshow(p[0].transpose(0,2).detach().cpu())
72 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=len(gs[0]))
73 |
74 | fig = plt.figure()
75 | plt.imshow(images[0].transpose(0,2).detach().cpu())
76 | writer.add_figure(f"epoch-{epoch}/img-{idx}-target", fig, global_step=epoch)
77 | all_loss += loss.item()
78 | return all_loss/len(test_loader)
79 |
80 | if __name__ == "__main__":
81 | args = get_args()
82 | print(args)
83 | use_srn = args.model_type == "srn"
84 |
85 | batch_size = args.batch_size
86 | train_loader = data.get_loader(data.MarkedColorCircles(train=True, size=64000), batch_size = batch_size)
87 | test_loader = data.get_loader(data.MarkedColorCircles(train=False, size=4000), batch_size = batch_size)
88 |
89 | use_srn = True
90 | net = SSLR(lr = args.inner_lr, num_iters=args.inner_iters, use_srn=use_srn).float().cuda()
91 | if args.load_ckpt:
92 | net.load_state_dict(torch.load("set_model_recon.pt"))
93 |
94 |
95 | net.train()
96 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
97 |
98 | writer = SummaryWriter(f"recon_run/test_run", purge_step=0, flush_secs = 10)
99 |
100 | print(type(net))
101 | print(net.set_generator.decoder.iters)
102 | running_loss = 0
103 | best_loss = 1e50
104 | for epoch in range(args.start_epoch, 1000+1):
105 | if epoch == 20:
106 | net.set_generator.decoder.iters = 20
107 | net.train()
108 | print(f"epoch {epoch}")
109 | running_loss = 0
110 | for idx, data in enumerate(train_loader):
111 | images, labels = data
112 | images, labels = images.cuda(), labels.cuda()
113 | optimizer.zero_grad()
114 |
115 | if use_srn:
116 | p, inner_losses, _ = net(images)
117 | else:
118 | p = net(images)
119 | loss = ((images - p)**2).sum()
120 | writer.add_scalar("train/loss", loss.item(), global_step=epoch*len(train_loader) + idx)
121 |
122 | loss.backward()
123 | optimizer.step()
124 | if idx % (len(train_loader)//4) == 0:
125 | if use_srn:
126 | print(f"inner loss {[l.item()/batch_size for l in inner_losses]}")
127 | print(loss.item())
128 | running_loss += loss.item()
129 |
130 | print(running_loss/len(train_loader))
131 | if epoch % 1 ==0:
132 | eval_loss = eval(net, batch_size, test_loader, epoch, writer, use_srn)
133 | if eval_loss < best_loss:
134 | best_loss = eval_loss
135 | torch.save(net.state_dict(), "set_model_recon.pt")
136 | print(f"eval: {eval_loss}")
137 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch)
138 | writer.flush()
139 |
140 | print()
141 |
--------------------------------------------------------------------------------
/run_reconstruct_clevr.py:
--------------------------------------------------------------------------------
1 | import data
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import random
7 | from dspn import DSPN
8 | from fspool import FSPool
9 | from tensorboardX import SummaryWriter
10 | import matplotlib
11 | import utils
12 | from tqdm import tqdm
13 | from models import *
14 | import argparse
15 |
16 | matplotlib.use("Agg")
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | def get_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--model_type', help='model type: srn | mlp', default="srn")
23 | parser.add_argument('--batch_size', type=int, help='batch size', default=32)
24 | parser.add_argument('--lr', type=float, help='lr', default=3e-4)
25 | parser.add_argument('--inner_lr',type=float, help='inner lr', default=8)
26 | parser.add_argument('--save', help='path to save checkpoint', default=None)
27 | parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None)
28 | args = parser.parse_args()
29 | return args
30 |
31 |
32 | class SSLR(nn.Module):
33 | def __init__(self, lr=8, use_srn=True):
34 | super(SSLR, self).__init__()
35 | self.use_srn = use_srn
36 | element_dims=10
37 | set_size=16
38 | self.g = SetGenCLEVR(element_dims, set_size, lr, use_srn)
39 | self.F_reconstruct = F_reconstruct_CLEVR()
40 |
41 | def forward(self, images):
42 | x, inner_losses = self.g(images)
43 | generated_f, generated_set = self.F_reconstruct(x)
44 | return generated_f, inner_losses, generated_set
45 |
46 |
47 | def eval(net, batch_size, test_loader, epoch, writer, use_srn=True):
48 | with torch.no_grad():
49 | net.eval()
50 | all_loss = 0
51 | rel_error = 0
52 | test_loader = tqdm(
53 | test_loader,
54 | ncols=0,
55 | desc="test E{0:02d}".format(epoch),
56 | )
57 | iters_per_epoch = len(test_loader)
58 | for idx, (images, images_foreground) in enumerate(test_loader, start=epoch * iters_per_epoch):
59 | images, images_foreground = images.cuda(), images_foreground.cuda()
60 |
61 | p, inner_losses, gs = net(images)
62 |
63 | loss = F.binary_cross_entropy(p, images_foreground)
64 |
65 | for j, s_ in enumerate(gs[0]):
66 | fig = plt.figure()
67 | plt.imshow(s_.permute(1,2,0).detach().cpu())
68 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=j)
69 |
70 | fig = plt.figure()
71 | plt.imshow(p[0].permute(1,2,0).detach().cpu())
72 | writer.add_figure(f"epoch-{epoch}/img-{idx}", fig, global_step=len(gs[0]))
73 |
74 | fig = plt.figure()
75 | plt.imshow(images[0].permute(1,2,0).detach().cpu())
76 | writer.add_figure(f"epoch-{epoch}/img-{idx}-target", fig, global_step=epoch)
77 |
78 | all_loss += loss.item()
79 | return all_loss/len(test_loader)
80 |
81 |
82 |
83 |
84 | if __name__ == "__main__":
85 | args = get_args()
86 | print(args)
87 |
88 | use_srn = args.model_type == "srn"
89 |
90 | dataset_train = data.CLEVRMasked(
91 | "clevr", "train", full=True
92 | )
93 | dataset_test = data.CLEVRMasked(
94 | "clevr", "test", full=False
95 | )
96 |
97 | batch_size = args.batch_size
98 | train_loader = data.get_loader(
99 | dataset_train, batch_size=batch_size
100 | )
101 | test_loader = data.get_loader(
102 | dataset_test, batch_size=batch_size
103 | )
104 |
105 | net = SSLR(args.inner_lr, use_srn).float().cuda()
106 |
107 | if args.resume:
108 | net.load_state_dict(torch.load(args.resume))
109 |
110 | optimizer = torch.optim.Adam(
111 | [p for p in net.parameters() if p.requires_grad], lr=args.lr
112 | )
113 | writer = SummaryWriter(f"runs/recon_clevr", purge_step=0, flush_secs = 10)
114 |
115 |
116 | print(type(net))
117 | iters_per_epoch = len(train_loader)
118 |
119 | running_loss = 0
120 |
121 | for epoch in range(1000+1):
122 | train_loader = tqdm(
123 | train_loader,
124 | ncols=0,
125 | desc="train E{0:02d}".format(epoch),
126 | )
127 |
128 | net.train()
129 | running_loss = 0
130 |
131 | for idx, (images, images_foreground) in enumerate(train_loader, start=epoch * iters_per_epoch):
132 | images, images_foreground = images.cuda(), images_foreground.cuda()
133 | optimizer.zero_grad()
134 |
135 | p, inner_losses, _ = net(images)
136 |
137 | loss = F.binary_cross_entropy(p, images_foreground)
138 |
139 | writer.add_scalar("train/loss", loss.item(), global_step=idx)
140 |
141 | loss.backward()
142 | optimizer.step()
143 |
144 | if use_srn:
145 | print(f"inner loss {[l.item()/batch_size for l in inner_losses]}")
146 | print(f"{loss.item()}\n")
147 | running_loss += loss.item()
148 | print(running_loss/len(train_loader))
149 |
150 | if args.save:
151 | torch.save(net.state_dict(), args.save)
152 |
153 | eval_loss = eval(net, batch_size, test_loader, epoch, writer, use_srn)
154 | print(f"eval: {eval_loss}\n")
155 | writer.add_scalar("eval/loss", eval_loss, global_step=epoch)
156 | writer.flush()
157 |
158 | print()
159 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import scipy.optimize
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import cv2
7 |
8 | def hungarian_loss_each(predictions, targets, thread_pool):
9 | # predictions and targets shape :: (n, c, s)
10 | predictions, targets = outer(predictions, targets)
11 | # squared_error shape :: (n, s, s)
12 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(1)
13 |
14 | squared_error_np = squared_error.detach().cpu().numpy()
15 |
16 | indices = thread_pool.map(hungarian_loss_per_sample, squared_error_np)
17 | return indices
18 |
19 | def hungarian_loss_per_sample(sample_np):
20 | return scipy.optimize.linear_sum_assignment(sample_np)
21 |
22 |
23 | def chamfer_loss(predictions, targets):
24 | # predictions and targets shape :: (k, n, c, s)
25 | predictions, targets = outer(predictions, targets)
26 | # squared_error shape :: (k, n, s, s)
27 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(2)
28 | loss = squared_error.min(2)[0] + squared_error.min(3)[0]
29 | return loss.view(loss.size(0), -1).mean(1)
30 |
31 |
32 | def chamfer_loss_each(predictions, targets):
33 | # predictions and targets shape :: (k, n, c, s)
34 | predictions, targets = outer(predictions, targets)
35 | # squared_error shape :: (k, n, s, s)
36 | squared_error = F.smooth_l1_loss(predictions, targets, reduction="none").mean(2)
37 | return torch.cat((squared_error.min(2)[0], squared_error.min(3)[0]),2)[0]
38 |
39 |
40 |
41 | def scatter_masked(tensor, mask, p, binned=False, threshold=None):
42 | s = tensor[0].detach().cpu()
43 | mask = mask[0].detach().clamp(min=0, max=1).cpu()
44 | p = p[0].detach().clamp(min=0, max=1).cpu()
45 | if binned:
46 | s = s * 128
47 | s = s.view(-1, s.size(-1))
48 | mask = mask.view(-1)
49 | if threshold is not None:
50 | keep = mask.view(-1) > threshold
51 | s = s[:, keep]
52 | mask = mask[keep]
53 | return s, mask, p
54 |
55 |
56 | def cv_bbox(np_imgs):
57 | imgs = []
58 | for np_img in np_imgs:
59 | new_img = np_img.copy()
60 | cnts = cv2.findContours(new_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
61 | cnts = cnts[0] if len(cnts) == 2 else cnts[1]
62 | for c in cnts:
63 | x,y,w,h = cv2.boundingRect(c)
64 | new_img[y:y+h, x:x+w] = 1
65 | imgs.append(new_img)
66 | return torch.tensor(imgs).reshape(-1,16,128,128)
67 |
68 |
69 | def chamfer_score(s1, s2, SMOOTH=1e-6):
70 | batch = s1.size(0)
71 | size = s1.size(1)
72 | a = torch.cat(size*[s1.unsqueeze(1)],1).reshape(-1, 128,128)
73 | b = torch.cat(size*[s2.unsqueeze(2)],2).reshape(-1, 128,128)
74 |
75 | intersect = (a & b).sum((1,2)).float()
76 | union = (a | b).sum((1,2)).float()
77 | iou = ((intersect + SMOOTH) / (union + SMOOTH))
78 |
79 | r = iou.reshape(batch, size,size, -1).squeeze(3)
80 |
81 | return r.max(2)[0].mean()
82 |
83 |
84 | def outer(a, b=None):
85 | """ Compute outer product between a and b (or a and a if b is not specified). """
86 | if b is None:
87 | b = a
88 | size_a = tuple(a.size()) + (b.size()[-1],)
89 | size_b = tuple(b.size()) + (a.size()[-1],)
90 | a = a.unsqueeze(dim=-1).expand(*size_a)
91 | b = b.unsqueeze(dim=-2).expand(*size_b)
92 | return a, b
93 |
--------------------------------------------------------------------------------