├── README.md ├── Triangle ├── dataset.py ├── run.py ├── run.sh ├── transformer_utilities │ ├── Gelu.py │ ├── GroupGRUCell.py │ ├── GroupLinearLayer.py │ ├── attention_rim.py │ ├── basic_mha.py │ ├── efficientnet_model.py │ ├── efficientnet_utils.py │ ├── fairseq_dropout.py │ ├── fairseq_utils.py │ ├── group_linear_layer.py │ ├── isab.py │ ├── layer_norm.py │ ├── multihead_attention.py │ ├── pos_enc.py │ ├── quant_noise.py │ ├── relational_memory.py │ ├── relational_memory_lstm.py │ ├── relational_memory_regressive.py │ ├── relational_memory_volatile.py │ ├── set_transformer.py │ ├── sparse_attn.py │ ├── sparse_grad_attn.py │ ├── transformer_helper.py │ ├── transformer_interface.py │ └── transformer_layer.py └── transformers.py ├── requirements.txt └── sort_of_clevr ├── README.md ├── main.py ├── main_splits.py ├── model.py ├── run_transformer.sh ├── run_transformer_splits.sh ├── sort_of_clevr_generator.py ├── sort_of_clevr_splits.py ├── transformer_utilities ├── Gelu.py ├── GroupGRUCell.py ├── GroupLinearLayer.py ├── attention_rim.py ├── basic_mha.py ├── efficientnet_model.py ├── efficientnet_utils.py ├── fairseq_dropout.py ├── fairseq_utils.py ├── group_linear_layer.py ├── isab.py ├── layer_norm.py ├── multihead_attention.py ├── pos_enc.py ├── quant_noise.py ├── relational_memory.py ├── relational_memory_lstm.py ├── relational_memory_regressive.py ├── relational_memory_volatile.py ├── set_transformer.py ├── sparse_attn.py ├── sparse_grad_attn.py ├── transformer_helper.py ├── transformer_interface.py └── transformer_layer.py ├── transformers.py └── translator.py /README.md: -------------------------------------------------------------------------------- 1 | # Coordination Among Neural Modules Through a Shared Global Workspace 2 | 3 | This repository contains the code to reproduce the `relational reasoning: sort_of_clever` and `detecting equilateral triangles` tasks from our paper. 4 | 5 | 6 | ## Install relevant libraries 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | ## Detecting Equilateral Triangles 11 | Folder: Triangle/ 12 | 13 | The following commands to be executed from inside in the `Triangle` folder. 14 | 15 | ``` 16 | sh run.sh num_layers h_dim ffn_dim share_vanilla_parameters use_topk topk shared_memory_attention seed mem_slots 17 | 18 | share_vanilla_parameters: Whether share parameters across layers. If False, it will run TR + HC. For shared workspace experiments it should be True. 19 | 20 | use_topk: Whether to use top-k competition 21 | 22 | topk: Value of k in top-k competition 23 | 24 | shared_memory_attention: Whether to use shared workspace 25 | 26 | mem_slots: Number of slots in memory 27 | ``` 28 | 29 | To reproduce experiments in paper: 30 | ``` 31 | TR + HSW 32 | sh run.sh 4 256 512 True True 20 True 1 8 default 33 | 34 | TR + SSW 35 | sh run.sh 4 256 512 True False 20 True 1 8 default 36 | 37 | TR 38 | sh run.sh 4 256 512 True False 20 False 1 8 default 39 | 40 | STR 41 | sh run.sh 4 256 512 True True 20 False 1 8 default 42 | 43 | TR + HC 44 | sh run.sh 4 256 512 False False 20 False 1 8 default 45 | 46 | ISAB 47 | sh run.sh 4 256 512 False False 20 False 1 8 functional 48 | ``` 49 | 50 | ## Sort-of-CLEVR 51 | The following commands to be executed from inside in the `sort_of_clevr` folder. 52 | 53 | Dataset generation: 54 | ``` 55 | python sort_of_clevr_generator.py 56 | ``` 57 | 58 | ``` 59 | sh run_transformer.sh h_dim num_layers share_vanilla_parameters use_topk topk shared_memory_attention mem_slots seed 60 | ``` 61 | To reproduce experiments in paper: 62 | ``` 63 | TR + HSW 64 | sh run_transformer.sh 256 4 True True 5 True 8 1 False 65 | 66 | TR 67 | sh run_transformer.sh 256 4 True False 5 False 8 1 False 68 | 69 | STR 70 | sh run_transformer.sh 256 4 True True 5 False 8 1 False 71 | 72 | TR + HC 73 | sh run_transformer.sh 256 4 False False 5 False 8 1 False 74 | 75 | ISAB 76 | sh run_transformer.sh 256 4 False False 5 False 8 1 True 77 | 78 | 79 | ``` 80 | 81 | -------------------------------------------------------------------------------- /Triangle/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | def get_cluster_points(num_points_per_cluster, x, y, points, cluster_radius): 8 | for c in range(num_points_per_cluster // 2): 9 | x_range = (x - cluster_radius, x + cluster_radius) 10 | y_range = (y - cluster_radius, y + cluster_radius) 11 | 12 | c_1_x = random.randint(x_range[0], x_range[1]) 13 | c_1_y = random.randint(y_range[0], y_range[1]) 14 | 15 | c_2_x = 2 * x - c_1_x 16 | c_2_y = 2 * y - c_1_y 17 | 18 | points.append((c_1_x, c_1_y)) 19 | points.append((c_2_x, c_2_y)) 20 | return points 21 | 22 | 23 | def get_point(x1, y1, x2, y2): 24 | #express coordinates of the point (x2, y2) with respect to point (x1, y1) 25 | dx = x2 - x1 26 | dy = y2 - y1 27 | 28 | alpha = 60./180*math.pi 29 | #rotate the displacement vector and add the result back to the original point 30 | xp = x1 + math.cos( alpha)*dx + math.sin(alpha)*dy 31 | yp = y1 + math.sin(-alpha)*dx + math.cos(alpha)*dy 32 | 33 | return (int(xp), int(yp)) 34 | 35 | def get_point_square(x1, y1, x2, y2): 36 | direction = random.randint(0, 1) 37 | if direction == 0: 38 | slope_y = (y2 - y1) 39 | slope_x = -(x2 - x1) 40 | else: 41 | slope_y = -(y2 - y1) 42 | slope_x = (x2 - x1) 43 | 44 | 45 | 46 | 47 | x3 = x1 + slope_y 48 | x4 = x2 + slope_y 49 | 50 | y3 = y1 + slope_x 51 | y4 = y2 + slope_x 52 | 53 | return int(x3), int(y3), int(x4), int(y4) 54 | 55 | def get_point_rectangle(x1, y1, x2, y2): 56 | direction = random.randint(0, 1) 57 | if direction == 0: 58 | slope_y = (y2 - y1) 59 | slope_x = -(x2 - x1) 60 | else: 61 | slope_y = -(y2 - y1) 62 | slope_x = (x2 - x1) 63 | 64 | 65 | 66 | length = random.uniform(0, 1) + 0.5 67 | 68 | x3 = x1 + length * slope_y 69 | x4 = x2 + length * slope_y 70 | 71 | y3 = y1 + length * slope_x 72 | y4 = y2 + length * slope_x 73 | 74 | return int(x3), int(y3), int(x4), int(y4) 75 | 76 | 77 | 78 | 79 | def make_square(img_size = (64, 64), num_points_per_cluster = 8, cluster_radius = 1): 80 | is_square = False 81 | while not is_square: 82 | point_1_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 83 | point_1_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 84 | 85 | point_2_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 86 | point_2_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 87 | 88 | point_3_x, point_3_y, point_4_x, point_4_y = get_point_square(point_1_x, point_1_y, point_2_x, point_2_y) 89 | 90 | if point_3_x + cluster_radius > img_size[0] or point_3_y + cluster_radius > img_size[1] or point_3_x - cluster_radius < 0 or point_3_y - cluster_radius < 0: 91 | continue 92 | 93 | if point_4_x + cluster_radius > img_size[0] or point_4_y + cluster_radius > img_size[1] or point_4_x - cluster_radius < 0 or point_4_y - cluster_radius < 0: 94 | continue 95 | 96 | 97 | 98 | #points = [(point_1_x, point_1_y), (point_2_x, point_2_y), (point_3_x, point_3_y), (point_4_x, point_4_y)] 99 | 100 | points = [] 101 | 102 | points = get_cluster_points(num_points_per_cluster, point_1_x, point_1_y, points, cluster_radius) 103 | points = get_cluster_points(num_points_per_cluster, point_2_x, point_2_y, points, cluster_radius) 104 | points = get_cluster_points(num_points_per_cluster, point_3_x, point_3_y, points, cluster_radius) 105 | points = get_cluster_points(num_points_per_cluster, point_4_x, point_4_y, points, cluster_radius) 106 | 107 | 108 | image = np.zeros((img_size[0], img_size[1], 1)) 109 | 110 | #print(points) 111 | 112 | for p in points: 113 | image = cv2.circle(image, p, radius=2, color=255, thickness=-1) 114 | 115 | #cv2.imshow('image', image) 116 | #cv2.waitKey(0) 117 | #cv2.destroyAllWindows() 118 | is_square = True 119 | return image 120 | 121 | def make_rectangle(img_size = (64, 64), num_points_per_cluster = 8, cluster_radius = 1): 122 | is_rectangle = False 123 | while not is_rectangle: 124 | point_1_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 125 | point_1_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 126 | 127 | point_2_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 128 | point_2_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 129 | 130 | point_3_x, point_3_y, point_4_x, point_4_y = get_point_rectangle(point_1_x, point_1_y, point_2_x, point_2_y) 131 | 132 | if point_3_x + cluster_radius > img_size[0] or point_3_y + cluster_radius > img_size[1] or point_3_x - cluster_radius < 0 or point_3_y - cluster_radius < 0: 133 | continue 134 | 135 | if point_4_x + cluster_radius > img_size[0] or point_4_y + cluster_radius > img_size[1] or point_4_x - cluster_radius < 0 or point_4_y - cluster_radius < 0: 136 | continue 137 | 138 | 139 | 140 | #points = [(point_1_x, point_1_y), (point_2_x, point_2_y), (point_3_x, point_3_y), (point_4_x, point_4_y)] 141 | 142 | points = [] 143 | 144 | points = get_cluster_points(num_points_per_cluster, point_1_x, point_1_y, points, cluster_radius) 145 | points = get_cluster_points(num_points_per_cluster, point_2_x, point_2_y, points, cluster_radius) 146 | points = get_cluster_points(num_points_per_cluster, point_3_x, point_3_y, points, cluster_radius) 147 | points = get_cluster_points(num_points_per_cluster, point_4_x, point_4_y, points, cluster_radius) 148 | 149 | 150 | image = np.zeros((img_size[0], img_size[1], 1)) 151 | 152 | #print(points) 153 | 154 | for p in points: 155 | image = cv2.circle(image, p, radius=2, color=255, thickness=-1) 156 | 157 | #cv2.imshow('image', image) 158 | #cv2.waitKey(0) 159 | #cv2.destroyAllWindows() 160 | is_rectangle = True 161 | return image 162 | 163 | def make_right_angle_triangle(img_size = (64, 64), num_points_per_cluster = 8, cluster_radius = 1): 164 | is_rectangle = False 165 | while not is_rectangle: 166 | point_1_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 167 | point_1_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 168 | 169 | point_2_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 170 | point_2_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 171 | 172 | point_3_x, point_3_y, point_4_x, point_4_y = get_point_rectangle(point_1_x, point_1_y, point_2_x, point_2_y) 173 | 174 | if point_3_x + cluster_radius > img_size[0] or point_3_y + cluster_radius > img_size[1] or point_3_x - cluster_radius < 0 or point_3_y - cluster_radius < 0: 175 | continue 176 | 177 | if point_4_x + cluster_radius > img_size[0] or point_4_y + cluster_radius > img_size[1] or point_4_x - cluster_radius < 0 or point_4_y - cluster_radius < 0: 178 | continue 179 | 180 | 181 | 182 | #points = [(point_1_x, point_1_y), (point_2_x, point_2_y), (point_3_x, point_3_y), (point_4_x, point_4_y)] 183 | 184 | points = [] 185 | 186 | points = get_cluster_points(num_points_per_cluster, point_1_x, point_1_y, points, cluster_radius) 187 | points = get_cluster_points(num_points_per_cluster, point_2_x, point_2_y, points, cluster_radius) 188 | points = get_cluster_points(num_points_per_cluster, point_3_x, point_3_y, points, cluster_radius) 189 | #points = get_cluster_points(num_points_per_cluster, point_4_x, point_4_y, points, cluster_radius) 190 | 191 | 192 | image = np.zeros((img_size[0], img_size[1], 1)) 193 | 194 | #print(points) 195 | 196 | for p in points: 197 | image = cv2.circle(image, p, radius=2, color=255, thickness=-1) 198 | 199 | #cv2.imshow('image', image) 200 | #cv2.waitKey(0) 201 | #cv2.destroyAllWindows() 202 | is_rectangle = True 203 | return image 204 | 205 | 206 | def make_equilateral_triangle(img_size = (64, 64), make_equilateral = True, num_points_per_cluster = 8, cluster_radius = 1): 207 | if make_equilateral: 208 | is_equilateral = False 209 | while not is_equilateral: 210 | point_1_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 211 | point_1_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 212 | 213 | point_2_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 214 | point_2_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 215 | 216 | point_3_x, point_3_y = get_point(point_1_x, point_1_y, point_2_x, point_2_y) 217 | 218 | if point_3_x + cluster_radius > img_size[0] or point_3_y + cluster_radius > img_size[1] or point_3_x - cluster_radius < 0 or point_3_y - cluster_radius < 0: 219 | continue 220 | 221 | points = [] 222 | for c in range(num_points_per_cluster // 2): 223 | x_range = (point_1_x - cluster_radius, point_1_x + cluster_radius) 224 | y_range = (point_1_y - cluster_radius, point_1_y + cluster_radius) 225 | 226 | c_1_x = random.randint(x_range[0], x_range[1]) 227 | c_1_y = random.randint(y_range[0], y_range[1]) 228 | 229 | c_2_x = 2 * point_1_x - c_1_x 230 | c_2_y = 2 * point_1_y - c_1_y 231 | 232 | points.append((c_1_x, c_1_y)) 233 | points.append((c_2_x, c_2_y)) 234 | 235 | for c in range(num_points_per_cluster // 2): 236 | x_range = (point_2_x - cluster_radius, point_2_x + cluster_radius) 237 | y_range = (point_2_y - cluster_radius, point_2_y + cluster_radius) 238 | 239 | c_1_x = random.randint(x_range[0], x_range[1]) 240 | c_1_y = random.randint(y_range[0], y_range[1]) 241 | 242 | c_2_x = 2 * point_2_x - c_1_x 243 | c_2_y = 2 * point_2_y - c_1_y 244 | 245 | points.append((c_1_x, c_1_y)) 246 | points.append((c_2_x, c_2_y)) 247 | 248 | 249 | for c in range(num_points_per_cluster // 2): 250 | x_range = (point_3_x - cluster_radius, point_3_x + cluster_radius) 251 | y_range = (point_3_y - cluster_radius, point_3_y + cluster_radius) 252 | 253 | c_1_x = random.randint(x_range[0], x_range[1]) 254 | c_1_y = random.randint(y_range[0], y_range[1]) 255 | 256 | c_2_x = 2 * point_3_x - c_1_x 257 | c_2_y = 2 * point_3_y - c_1_y 258 | 259 | points.append((c_1_x, c_1_y)) 260 | points.append((c_2_x, c_2_y)) 261 | 262 | 263 | image = np.zeros((img_size[0], img_size[1], 1)) 264 | 265 | #print(points) 266 | 267 | for p in points: 268 | image = cv2.circle(image, p, radius=2, color=255, thickness=-1) 269 | #print(image.max()) 270 | #cv2.imshow('image', image) 271 | #cv2.waitKey(0) 272 | #cv2.destroyAllWindows() 273 | is_equilateral = True 274 | return image 275 | else: 276 | is_equilateral = False 277 | while not is_equilateral: 278 | point_1_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 279 | point_1_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 280 | 281 | point_2_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 282 | point_2_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 283 | 284 | point_3_x = random.randint(0 + cluster_radius, img_size[0] - cluster_radius) 285 | point_3_y = random.randint(0 + cluster_radius, img_size[1] - cluster_radius) 286 | 287 | if point_3_x + cluster_radius > img_size[0] or point_3_y + cluster_radius > img_size[1] or point_3_x - cluster_radius < 0 or point_3_y - cluster_radius < 0: 288 | continue 289 | 290 | points = [] 291 | for c in range(num_points_per_cluster // 2): 292 | x_range = (point_1_x - cluster_radius, point_1_x + cluster_radius) 293 | y_range = (point_1_y - cluster_radius, point_1_y + cluster_radius) 294 | 295 | c_1_x = random.randint(x_range[0], x_range[1]) 296 | c_1_y = random.randint(y_range[0], y_range[1]) 297 | 298 | c_2_x = 2 * point_1_x - c_1_x 299 | c_2_y = 2 * point_1_y - c_1_y 300 | 301 | points.append((c_1_x, c_1_y)) 302 | points.append((c_2_x, c_2_y)) 303 | 304 | for c in range(num_points_per_cluster // 2): 305 | x_range = (point_2_x - cluster_radius, point_2_x + cluster_radius) 306 | y_range = (point_2_y - cluster_radius, point_2_y + cluster_radius) 307 | 308 | c_1_x = random.randint(x_range[0], x_range[1]) 309 | c_1_y = random.randint(y_range[0], y_range[1]) 310 | 311 | c_2_x = 2 * point_2_x - c_1_x 312 | c_2_y = 2 * point_2_y - c_1_y 313 | 314 | points.append((c_1_x, c_1_y)) 315 | points.append((c_2_x, c_2_y)) 316 | 317 | 318 | for c in range(num_points_per_cluster // 2): 319 | x_range = (point_3_x - cluster_radius, point_3_x + cluster_radius) 320 | y_range = (point_3_y - cluster_radius, point_3_y + cluster_radius) 321 | 322 | c_1_x = random.randint(x_range[0], x_range[1]) 323 | c_1_y = random.randint(y_range[0], y_range[1]) 324 | 325 | c_2_x = 2 * point_3_x - c_1_x 326 | c_2_y = 2 * point_3_y - c_1_y 327 | 328 | points.append((c_1_x, c_1_y)) 329 | points.append((c_2_x, c_2_y)) 330 | 331 | 332 | image = np.zeros((img_size[0], img_size[1], 1)) 333 | 334 | #print(points) 335 | 336 | for p in points: 337 | image = cv2.circle(image, p, radius=2, color=255, thickness=-1) 338 | #print(image.max()) 339 | #cv2.imshow('image', image) 340 | #cv2.waitKey(0) 341 | #cv2.destroyAllWindows() 342 | is_equilateral = True 343 | return image 344 | 345 | class TriangleDataset(torch.utils.data.Dataset): 346 | def __init__(self, num_examples = 60000): 347 | self.num_examples = num_examples 348 | 349 | def __len__(self): 350 | return self.num_examples 351 | 352 | def __getitem__(self, i): 353 | n = random.randint(0, 1) 354 | if n == 0: 355 | image = make_equilateral_triangle(make_equilateral = True) 356 | elif n == 1: 357 | image = make_equilateral_triangle(make_equilateral = False) 358 | #elif n == 2: 359 | # image = make_rectangle() 360 | #else: 361 | # image = make_right_angle_triangle() 362 | 363 | image = torch.from_numpy(image)#.cuda() 364 | image = image.permute(2, 0, 1) 365 | return image.float(), torch.tensor(n)#.cuda() 366 | #make_image(make_equilateral = False) 367 | -------------------------------------------------------------------------------- /Triangle/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import argparse 5 | import numpy as np 6 | 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.backends.cudnn as cudnn 14 | 15 | from transformer_utilities.set_transformer import SetTransformer 16 | from transformers import TransformerEncoder #FunctionalVisionTransformer, ViT 17 | from einops import rearrange, repeat 18 | from dataset import TriangleDataset 19 | 20 | def str2bool(v): 21 | """Method to map string to bool for argument parser""" 22 | if isinstance(v, bool): 23 | return v 24 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 25 | return True 26 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 27 | return False 28 | raise argparse.ArgumentTypeError('Boolean value expected.') 29 | 30 | parser = argparse.ArgumentParser(description='Image Classification Tasks') 31 | parser = argparse.ArgumentParser(description='Image Classification Tasks') 32 | parser.add_argument('--model', default="functional", type=str, choices=('default','functional') ,help='type of transformer to use') 33 | parser.add_argument('--data', default="cifar10", type=str, choices=('cifar10','cifar100','pathfinder', 'MNIST', 'Triangle') ,help='data to train on') 34 | parser.add_argument('--version', default=0, type=int, help='version for shared transformer-- 0 or 1') 35 | parser.add_argument('--num_layers', default=12, type=int, help='num of layers') 36 | parser.add_argument('--num_templates', default=12, type=int, help='num of templates for shared transformer') 37 | parser.add_argument('--num_heads', default=4, type=int, help='num of heads in Multi Head attention layer') 38 | parser.add_argument('--batch_size', default=64, type=int, help='batch_size to use') 39 | parser.add_argument('--patch_size', default=4, type=int, help='patch_size for transformer') 40 | parser.add_argument('--epochs', default=200, type=int, help='num of epochs to train') 41 | parser.add_argument('--lr', default=0.0001, type=float, help='learning rate') 42 | parser.add_argument('--dropout', default=0.1, type=float, help='dropout') 43 | parser.add_argument('--name', default="model", type=str, help='Model name for logs and checkpoint') 44 | parser.add_argument('--resume', '-r', action='store_true', 45 | help='resume from checkpoint') 46 | parser.add_argument('--h_dim', type = int, default = 256) 47 | parser.add_argument('--ffn_dim', type = int, default = 512) 48 | parser.add_argument('--num_gru_schemas', type = int, default = 1) 49 | parser.add_argument('--num_attention_schemas', type = int, default = 1) 50 | parser.add_argument('--schema_specific', type = str2bool, default = False) 51 | parser.add_argument('--num_eval_layers', type = int, default = 1) 52 | parser.add_argument('--share_vanilla_parameters', type = str2bool, default = False) 53 | parser.add_argument('--num_digits_for_mnist', type = int, default = 3) 54 | parser.add_argument('--use_topk', type = str2bool, default = False) 55 | parser.add_argument('--topk', type = int, default = 3) 56 | parser.add_argument('--shared_memory_attention', type = str2bool, default = False) 57 | parser.add_argument('--mem_slots', type = int, default = 4) 58 | parser.add_argument('--null_attention', type = str2bool, default = False) 59 | parser.add_argument('--seed', type = int, default = 0) 60 | args = parser.parse_args() 61 | 62 | MIN_NUM_PATCHES=16 63 | 64 | 65 | # logging config 66 | 67 | #if not os.path.isdir('logs'): 68 | # os.mkdir('logs') 69 | 70 | #logging.basicConfig(filename='./logs/%s.log' % args.name, 71 | # level=logging.DEBUG, format='%(asctime)s %(levelname)-10s %(message)s') 72 | 73 | #logging.info("Using args: {}".format(args)) 74 | 75 | def seed_everything(seed=1234): 76 | random.seed(seed) 77 | os.environ['PYTHONHASHSEED'] = str(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed(seed) 81 | torch.backends.cudnn.deterministic = True 82 | 83 | seed_everything(seed = args.seed) 84 | 85 | image_size =0 86 | num_classes=0 87 | 88 | #logging.info("Loading data: {}".format(args.data)) 89 | if args.data =="cifar10": 90 | # settings from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py 91 | 92 | transform_train = transforms.Compose([ 93 | transforms.RandomCrop(32, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 97 | ]) 98 | 99 | transform_test = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 102 | ]) 103 | 104 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 105 | download=True, transform=transform_train) 106 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 107 | shuffle=True, num_workers=2) 108 | 109 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 110 | download=True, transform=transform_test) 111 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 112 | shuffle=False, num_workers=2) 113 | 114 | classes = ('plane', 'car', 'bird', 'cat', 115 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 116 | num_classes =10 117 | image_size = 32 118 | channels=3 119 | 120 | elif args.data =="cifar100": 121 | # settings from https://github.com/weiaicunzai/pytorch-cifar100/blob/master/conf/global_settings.py 122 | 123 | transform_train = transforms.Compose([ 124 | transforms.RandomCrop(32, padding=4), 125 | transforms.RandomHorizontalFlip(), 126 | transforms.RandomRotation(15), 127 | transforms.ToTensor(), 128 | transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)), 129 | ]) 130 | 131 | transform_test = transforms.Compose([ 132 | transforms.ToTensor(), 133 | transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)), 134 | ]) 135 | 136 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, 137 | download=True, transform=transform_train) 138 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 139 | shuffle=True, num_workers=2) 140 | 141 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, 142 | download=True, transform=transform_test) 143 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 144 | shuffle=False, num_workers=2) 145 | 146 | num_classes = 100 147 | image_size = 32 148 | channels=3 149 | 150 | elif args.data == "pathfinder": 151 | trainset = np.load('./data/train.npz') 152 | trainset = torch.utils.data.TensorDataset(torch.Tensor(trainset['x']).reshape(-1,1,32,32),torch.LongTensor(trainset['y'])) 153 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 154 | shuffle=True, num_workers=2) 155 | testset = np.load('./data/test.npz') 156 | testset = torch.utils.data.TensorDataset(torch.Tensor(testset['x']).reshape(-1,1,32,32),torch.LongTensor(testset['y'])) 157 | 158 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 159 | shuffle=False, num_workers=2) 160 | num_classes=2 161 | image_size = 32 162 | channels=1 163 | elif args.data == 'MNIST': 164 | train_dataset = CountingMNISTDataset(split = "train", path = "MNIST", dig_range = [1,3], num_examples = 10000) 165 | test_dataset = CountingMNISTDataset(split = "test", path = "MNIST", dig_range = [4,5], num_examples = 2000) 166 | 167 | trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False) 168 | testloader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False) 169 | num_classes = 10 170 | image_size = 100 171 | channels = 1 172 | elif args.data == 'Triangle': 173 | train_dataset = TriangleDataset(num_examples = 50000) 174 | test_dataset = TriangleDataset(num_examples = 10000) 175 | trainloader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False) 176 | testloader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch_size, num_workers = 2, shuffle = False) 177 | num_classes = 4 178 | image_size = 64 179 | channels = 1 180 | 181 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 182 | best_acc = 0 # best test accuracy 183 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 184 | 185 | 186 | if args.model == "functional": 187 | transformer = SetTransformer(args.h_dim, dim_hidden = args.h_dim, num_inds = args.mem_slots) 188 | #net = FunctionalVisionTransformer( 189 | # image_size = image_size, 190 | # patch_size = args.patch_size, 191 | # num_classes = num_classes, 192 | # dim = 1024, 193 | # depth = args.num_layers, 194 | # heads = args.num_heads, 195 | # mlp_dim = 2048, 196 | # dropout = args.dropout, 197 | # emb_dropout = 0.1, 198 | # num_templates = args.num_templates, 199 | # version = args.version, 200 | # channels=channels 201 | 202 | # ) 203 | 204 | elif args.model == "default": 205 | transformer = TransformerEncoder( 206 | args.h_dim, 207 | args.ffn_dim, 208 | num_layers = args.num_layers, 209 | num_heads = args.num_heads, 210 | dropout = args.dropout, 211 | share_parameters = args.share_vanilla_parameters, 212 | shared_memory_attention = args.shared_memory_attention, 213 | use_topk = args.use_topk, 214 | topk = args.topk, 215 | mem_slots = args.mem_slots, 216 | null_attention = args.null_attention, 217 | num_steps = int((image_size*image_size) / (args.patch_size * args.patch_size) + 1) ) 218 | #net = ViT( 219 | # image_size = image_size, 220 | # patch_size = args.patch_size, 221 | # num_classes = num_classes, 222 | # dim = 1024, 223 | # depth = args.num_layers, 224 | # heads = args.num_heads, 225 | # mlp_dim = 2048, 226 | # dropout = args.dropout, 227 | # emb_dropout = 0.1, 228 | # channels=channels 229 | 230 | # ) 231 | #print(int((image_size*image_size) / (args.patch_size * args.patch_size))) 232 | class model(nn.Module): 233 | def __init__(self, net, image_size, patch_size, num_classes): 234 | super().__init__() 235 | #print(image_size) 236 | #print(patch_size) 237 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 238 | num_patches = (image_size // patch_size) ** 2 239 | patch_dim = channels * patch_size ** 2 240 | assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' 241 | 242 | self.net = net 243 | self.patch_size = patch_size 244 | #print(patch_dim) 245 | self.patch_to_embedding = nn.Linear(patch_dim, args.h_dim) 246 | self.cls_token = nn.Parameter(torch.randn(1, 1, args.h_dim)) 247 | 248 | self.mlp_head = nn.Linear(args.h_dim, num_classes) 249 | 250 | def forward(self, img, mask = None): 251 | p = self.patch_size 252 | #print(img.size()) 253 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 254 | #print(x.size()) 255 | #print(x.type()) 256 | x = self.patch_to_embedding(x) 257 | 258 | b, n, _ = x.shape 259 | #print(x.shape) 260 | 261 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 262 | x = torch.cat((cls_tokens, x), dim=1) 263 | #print(x.size()) 264 | 265 | x = self.net(x) 266 | 267 | x = self.mlp_head(x[:,0]) 268 | 269 | return x 270 | #print('line 234') 271 | net = model(transformer, image_size, args.patch_size, num_classes) 272 | 273 | 274 | net = net.to(device) 275 | 276 | if os.path.exists('./checkpoint/'+args.name+'_ckpt.pth'): 277 | args.resume =True 278 | 279 | if False and args.resume: 280 | # Load checkpoint. 281 | #logging.info("==> Resuming from checkpoint..") 282 | print('==> Resuming from checkpoint..') 283 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 284 | checkpoint = torch.load('./checkpoint/'+args.name+'_ckpt.pth') 285 | net.load_state_dict(checkpoint['net']) 286 | best_acc = checkpoint['acc'] 287 | start_epoch = checkpoint['epoch'] 288 | 289 | pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad) 290 | try: 291 | rmc_params = sum(p.numel() for p in net.net.enc.self_attn.relational_memory.parameters() if p.requires_grad) 292 | print(rmc_params) 293 | except: 294 | pass 295 | #logging.info("Total number of parameters:{}".format(pytorch_total_params)) 296 | print("Total number of parameters:{}".format(pytorch_total_params)) 297 | 298 | if args.data == 'MNIST': 299 | pre_loss_fn = nn.Sigmoid() 300 | else: 301 | pre_loss_fn = nn.Identity() 302 | 303 | if args.data == "MNIST": 304 | criterion = nn.BCELoss() 305 | else: 306 | criterion = nn.CrossEntropyLoss() 307 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 308 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) 309 | 310 | def mnist_acc(outputs, targets): 311 | outputs[outputs >= 0.5] = 1. 312 | outputs[outputs<0.5] = 0. 313 | 314 | #print(outputs) 315 | #print(targets) 316 | 317 | equality = torch.eq(outputs, targets) 318 | 319 | equality = equality.int() 320 | 321 | #print(equality) 322 | print('-----') 323 | equality = equality.sum(1) 324 | equality[equality < num_classes] = 0 325 | equality[equality == num_classes] = 1 326 | 327 | correct = equality.sum().item() 328 | print(correct) 329 | 330 | return correct 331 | 332 | 333 | 334 | 335 | 336 | def train(epoch): 337 | print('\nEpoch: %d' % epoch) 338 | #logging.info('Epoch: %d' % epoch) 339 | net.train() 340 | train_loss = 0 341 | correct = 0 342 | total = 0 343 | for batch_idx, (inputs, targets) in enumerate(trainloader): 344 | inputs, targets = inputs.to(device), targets.to(device) 345 | #print(inputs.shape) 346 | #print(targets) 347 | optimizer.zero_grad() 348 | outputs = net(inputs) 349 | outputs = pre_loss_fn(outputs) 350 | loss = criterion(outputs, targets) 351 | loss.backward() 352 | optimizer.step() 353 | 354 | train_loss += loss.item() 355 | _, predicted = outputs.max(1) 356 | total += targets.size(0) 357 | if args.data == "MNIST": 358 | correct += mnist_acc(outputs, targets) 359 | else: 360 | correct += predicted.eq(targets).sum().item() 361 | 362 | 363 | if batch_idx % 100 == 99: # print every 100 mini-batches 364 | #net.net.enc.self_attn.relational_memory.print_log() 365 | print('[%d, %5d] loss: %.3f accuracy:%.3f' % 366 | (epoch + 1, batch_idx + 1, train_loss / (batch_idx+1), 100.*correct/total)) 367 | #logging.info('[%d, %5d] loss: %.3f accuracy:%.3f' % 368 | # (epoch + 1, batch_idx + 1, train_loss / (batch_idx+1), 100.*correct/total)) 369 | 370 | 371 | 372 | def test(epoch): 373 | global best_acc 374 | net.eval() 375 | test_loss = 0 376 | correct = 0 377 | total = 0 378 | with torch.no_grad(): 379 | for batch_idx, (inputs, targets) in enumerate(testloader): 380 | inputs, targets = inputs.to(device), targets.to(device) 381 | outputs = net(inputs) 382 | outputs = pre_loss_fn(outputs) 383 | loss = criterion(outputs, targets) 384 | 385 | test_loss += loss.item() 386 | _, predicted = outputs.max(1) 387 | total += targets.size(0) 388 | if args.data == "MNIST": 389 | correct += mnist_acc(outputs, targets) 390 | else: 391 | correct += predicted.eq(targets).sum().item() 392 | 393 | # Save checkpoint. 394 | acc = 100.*correct/total 395 | print("test_accuracy is %.3f after epochs %d"%(acc,epoch)) 396 | #logging.info("test_accuracy is %.3f after epochs %d"%(acc,epoch)) 397 | if acc > best_acc: 398 | print('Saving..') 399 | #logging.info("==> Saving...") 400 | state = { 401 | 'net': net.state_dict(), 402 | 'acc': acc, 403 | 'epoch': epoch, 404 | } 405 | if not os.path.isdir('checkpoint'): 406 | os.mkdir('checkpoint') 407 | torch.save(state, './checkpoint/'+args.name+'_ckpt.pth') 408 | best_acc = acc 409 | 410 | #logging.info("Starting Training...") 411 | for epoch in range(start_epoch, start_epoch+200): 412 | train(epoch) 413 | test(epoch) 414 | scheduler.step() 415 | -------------------------------------------------------------------------------- /Triangle/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #run.sh default Triangle 0 4 0.1 100 4 4 8 0.0001 64 128 2 2 True 6 True False 10 False 4 False 3 4 | 5 | 6 | data='Triangle' 7 | version=0 8 | num_layers=$1 9 | num_templates=1 10 | dropout=0.1 11 | epochs=100 12 | patch_size=4 13 | num_heads=4 14 | batch_size=64 15 | lr=0.0001 16 | h_dim=$2 17 | ffn_dim=$3 18 | num_gru_schemas=2 19 | num_attention_schemas=2 20 | schema_specific=True 21 | num_eval_layers=6 22 | share_vanilla_parameters=${4} 23 | use_topk=${5} 24 | topk=${6} 25 | shared_memory_attention=${7} 26 | null_attention=False 27 | seed=${8} 28 | mem_slots=${9} 29 | model=${10} 30 | 31 | 32 | 33 | name="HXLN_LN_LSTM-"$model"-data-"$data"-version-"$version"-num_layers-"$num_layers"-num_templates-"$num_layers"-dropout-"$dropout"-epochs-"$epochs"-patch_size-"$patch_size"-num_heads-"$num_heads"-batch_size-"$batch_size"-lr-"$lr-$h_dim-$ffn_dim-$num_gru_schemas-$num_attention_schemas-$schema_specific-$num_eval_layers-$share_vanilla_parameters-$use_topk-$topk-$shared_memory_attention-$mem_slots-$null_attention-$seed 34 | 35 | echo $name 36 | 37 | python run.py --model $model --data $data --version $version --num_layers $num_layers --num_templates $num_templates --dropout $dropout --epochs $epochs --patch_size $patch_size --num_heads $num_heads --name $name --batch_size $batch_size --lr $lr \ 38 | --h_dim $h_dim --ffn_dim $ffn_dim --num_gru_schemas $num_gru_schemas \ 39 | --num_attention_schemas $num_attention_schemas --schema_specific $schema_specific \ 40 | --num_eval_layers $num_eval_layers --share_vanilla_parameters $share_vanilla_parameters --use_topk $use_topk --topk $topk --shared_memory_attention $shared_memory_attention \ 41 | --mem_slots $mem_slots --null_attention $null_attention \ 42 | --seed $seed 43 | 44 | #sh run_local.sh functional cifar10 1 12 3 0.1 200 4 4 128 0.0001 45 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/Gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def gelu_accurate(x): 17 | if not hasattr(gelu_accurate, "_a"): 18 | gelu_accurate._a = math.sqrt(2 / math.pi) 19 | return ( 20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | ) 22 | 23 | 24 | def gelu(x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/GroupGRUCell.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | 7 | class GroupGRUCell(nn.Module): 8 | """ 9 | GroupGRUCell can compute the operation of N GRU Cells at once. 10 | """ 11 | def __init__(self, input_size, hidden_size, num_grus): 12 | super(GroupGRUCell, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.x2h = GroupLinearLayer(input_size, 3 * hidden_size, num_grus) 16 | self.h2h = GroupLinearLayer(hidden_size, 3 * hidden_size, num_grus) 17 | self.reset_parameters() 18 | 19 | 20 | 21 | def reset_parameters(self): 22 | std = 1.0 / math.sqrt(self.hidden_size) 23 | for w in self.parameters(): 24 | w.data = torch.ones(w.data.size())#.uniform_(-std, std) 25 | 26 | def forward(self, x, hidden): 27 | """ 28 | input: x (batch_size, num_grus, input_size) 29 | hidden (batch_size, num_grus, hidden_size) 30 | output: hidden (batch_size, num_grus, hidden_size) 31 | """ 32 | gate_x = self.x2h(x) 33 | gate_h = self.h2h(hidden) 34 | 35 | i_r, i_i, i_n = gate_x.chunk(3, 2) 36 | h_r, h_i, h_n = gate_h.chunk(3, 2) 37 | 38 | 39 | resetgate = torch.sigmoid(i_r + h_r) 40 | inputgate = torch.sigmoid(i_i + h_i) 41 | newgate = torch.tanh(i_n + (resetgate * h_n)) 42 | 43 | hy = newgate + inputgate * (hidden - newgate) 44 | 45 | return hy 46 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/GroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | class GroupLinearLayer(nn.Module): 6 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 7 | super(GroupLinearLayer, self).__init__() 8 | self.nb = num_blocks 9 | #din = din // num_blocks 10 | #dout = dout // num_blocks 11 | self.dout = dout 12 | if a is None: 13 | a = 1. / math.sqrt(dout) 14 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 15 | self.bias = bias 16 | if bias is True: 17 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 18 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 19 | else: 20 | self.bias = None 21 | def forward(self,x): 22 | ts,bs,m = x.shape 23 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 24 | x = x.permute(1,0,2) 25 | x = torch.bmm(x,self.weight) 26 | x = x.permute(1,0,2) 27 | if not self.bias is None: 28 | x = x + self.bias 29 | #x = x.reshape((ts, bs, self.dout*self.nb)) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/attention_rim.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import random 7 | from .sparse_attn import Sparse_attention 8 | import torch.nn.functional as F 9 | from .GroupLinearLayer import GroupLinearLayer 10 | from .sparse_grad_attn import Sparse_grad_attention 11 | 12 | 13 | class Identity_2(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, input): 16 | return input * 1.0 17 | def backward(ctx, grad_output): 18 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 19 | print('+++++++++') 20 | return grad_output * 1.0 21 | 22 | class Identity(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, input): 25 | return input * 1.0 26 | def backward(ctx, grad_output): 27 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 28 | print('-----------') 29 | return grad_output * 1.0 30 | 31 | class ScaledDotProductAttention(nn.Module): 32 | ''' Scaled Dot-Product Attention ''' 33 | 34 | def __init__(self, temperature, topk = -1, grad_sparse=False, attn_dropout=0.1, flag=False): 35 | super().__init__() 36 | self.temperature = temperature 37 | #self.dropout = nn.Dropout(attn_dropout) 38 | self.softmax = nn.Softmax(dim=2) 39 | self.grad_sparse = grad_sparse 40 | self.topk = topk 41 | self.sa = Sparse_attention(top_k=topk) #k=2 42 | self.flag = flag 43 | 44 | def forward(self, q, k, v, mask=None): 45 | 46 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Forward of Scaled Dot Product Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 47 | # print("q: ", q.size()) 48 | # print("k: ", k.size()) 49 | # print("v: ", v.size()) 50 | # print("k transpose: ", k.transpose(1,2).size()) 51 | # input() 52 | 53 | attn = torch.bmm(q, k.transpose(1, 2)) 54 | attn = attn / self.temperature 55 | 56 | #print('in forward attn shape', attn.shape) 57 | 58 | if mask is not None: 59 | attn = attn.masked_fill(mask, -np.inf) 60 | 61 | if self.flag: 62 | n_b,k_1,k_2 = attn.size() 63 | attn = self.softmax(attn.permute(0,2,1)).reshape(n_b,k_1,k_2) 64 | else: 65 | attn = self.softmax(attn) 66 | 67 | extra_loss = 0.0 68 | 69 | use_sparse = False#False 70 | 71 | if use_sparse: 72 | mb, ins, outs = attn.shape[0], attn.shape[1], attn.shape[2] 73 | if self.flag: 74 | sparse_attn = attn.permute(0,2,1).reshape(mb*outs, ins) 75 | else: 76 | sparse_attn = attn.reshape((mb*ins, outs)) 77 | #print('sparse attn shape 1', sparse_attn.shape) 78 | #sga = Sparse_grad_attention(2) 79 | if self.grad_sparse: 80 | sga = Sparse_grad_attention(self.topk) 81 | sparse_attn = sga(sparse_attn) 82 | else: 83 | sparse_attn = self.sa(sparse_attn) 84 | if self.flag: 85 | sparse_attn = sparse_attn.reshape(mb, outs, ins).permute(0, 2, 1) 86 | else: 87 | sparse_attn = sparse_attn.reshape((mb,ins,outs)) 88 | attn = sparse_attn*1.0 89 | 90 | output = torch.bmm(attn, v) 91 | 92 | return output, attn, extra_loss 93 | 94 | import torch.nn.functional as F 95 | 96 | class MultiHeadAttention(nn.Module): 97 | ''' Multi-Head Attention module ''' 98 | 99 | def __init__(self, n_head, d_model_read, d_model_write, d_model_out, d_k, d_v, grad_sparse, residual=True, dropout=0.1, skip_write=False, flag=False): 100 | super().__init__() 101 | 102 | self.n_head = n_head 103 | self.d_k = d_k 104 | self.d_v = d_v 105 | 106 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Initialize Multi-Head Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 107 | # print('d model read: ', d_model_read) 108 | # print('d_model_write: ', d_model_write) 109 | # print('d_model_out: ', d_model_out) 110 | # print('n_head: ', n_head) 111 | # print('d_k: ', d_k) 112 | # print('d_v: ', d_v) 113 | # print('num_blocks_read: ', num_blocks_read) 114 | # print('num_blocks_write: ', num_blocks_write) 115 | # input() 116 | 117 | self.GLN_qs = nn.Linear(d_model_read, n_head * d_k) 118 | self.GLN_ks = nn.Linear(d_model_write, n_head * d_k) 119 | self.GLN_vs = nn.Linear(d_model_write, n_head * d_v) 120 | 121 | self.residual = residual 122 | 123 | #self.w_qs = nn.Linear(d_model_read, n_head * d_k) 124 | #self.w_ks = nn.Linear(d_model_write, n_head * d_k) 125 | #self.w_vs = nn.Linear(d_model_write, n_head * d_v) 126 | 127 | #nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 128 | #nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 129 | #nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 130 | 131 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), flag=flag) 132 | #self.layer_norm = nn.LayerNorm(d_model) 133 | 134 | self.gate_fc = nn.Linear(n_head * d_v, d_model_out) 135 | 136 | if not skip_write: 137 | self.fc = nn.Linear(n_head * d_v, d_model_out) 138 | else: 139 | self.fc = lambda a: a 140 | 141 | #nn.init.xavier_normal_(self.fc.weight) 142 | 143 | self.dropout = nn.Dropout(dropout) 144 | 145 | self.ln = nn.LayerNorm(d_model_out) 146 | 147 | def forward(self, q, k, v, mask=None): 148 | 149 | #print('attn input shape', q.shape) 150 | 151 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 152 | 153 | sz_b, len_q, _ = q.size() 154 | sz_b, len_k, _ = k.size() 155 | sz_b, len_v, _ = v.size() 156 | 157 | residual = q 158 | 159 | #print('q shape', q.shape) 160 | 161 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Forward of Multi-Head Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 162 | # print("q: ", q.size()) 163 | # print("k: ", k.size()) 164 | # print("v: ", v.size()) 165 | # input() 166 | 167 | q = self.GLN_qs(q).view(sz_b, len_q, n_head, d_k) 168 | #q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 169 | k = self.GLN_ks(k).view(sz_b, len_k, n_head, d_k) 170 | v = self.GLN_vs(v).reshape(sz_b, len_v, n_head, d_v) 171 | #v = v.view(sz_b, len_v, n_head, d_v) 172 | 173 | # print("GLN q: ", q.size()) 174 | # print("GLN k: ", k.size()) 175 | # print("GLN v: ", v.size()) 176 | 177 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 178 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 179 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 180 | 181 | # print("Permute q: ", q.size()) 182 | # print("Permute k: ", k.size()) 183 | # print("Permute v: ", v.size()) 184 | 185 | #mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 186 | output, attn, extra_loss = self.attention(q, k, v, mask=None) 187 | 188 | # print("Output: ", output.size()) 189 | # print("Attention: ", attn.size()) 190 | 191 | output = output.view(n_head, sz_b, len_q, d_v) 192 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 193 | 194 | # print("Here Output: ", output.size()) 195 | 196 | #print('output shape before fc', output.shape) 197 | 198 | #TODO: probably shouldn't just apply residual layer in the forward pass. 199 | 200 | output_init = output*1.0 201 | 202 | output = self.dropout(self.fc(output_init)) 203 | 204 | gate = torch.sigmoid(self.gate_fc(output_init)) 205 | 206 | #output = self.layer_norm(gate * output + (1 - gate) * residual) 207 | #output = gate * output + (1 - gate) * residual 208 | 209 | if self.residual: 210 | output = gate * torch.tanh(output) 211 | else: 212 | #output = self.ln(output) 213 | pass 214 | 215 | # print("Final Output: ", output.size()) 216 | 217 | #output 218 | 219 | #print('attn', attn[0]) 220 | #print('output input diff', output - residual) 221 | 222 | return output, attn, extra_loss 223 | 224 | class PositionwiseFeedForward(nn.Module): 225 | ''' A two-feed-forward-layer module ''' 226 | 227 | def __init__(self, d_in, d_hid, dropout=0.1): 228 | super().__init__() 229 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 230 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 231 | self.layer_norm = nn.LayerNorm(d_in) 232 | self.dropout = nn.Dropout(dropout) 233 | 234 | def forward(self, x): 235 | residual = x 236 | output = x.transpose(1, 2) 237 | output = self.w_2(F.relu(self.w_1(output))) 238 | output = output.transpose(1, 2) 239 | output = self.dropout(output) 240 | output = self.layer_norm(output + residual) 241 | return output 242 | 243 | 244 | class Seq2SeqAttention(nn.Module): 245 | def __init__(self, enc_hid_dim, dec_hid_dim): 246 | super().__init__() 247 | 248 | self.attn = nn.Linear(enc_hid_dim + dec_hid_dim, dec_hid_dim) 249 | self.v = nn.Linear(dec_hid_dim, 1, bias = False) 250 | 251 | def forward(self, hidden, encoder_outputs): 252 | 253 | #hidden = [batch size, dec hid dim] 254 | #encoder_outputs = [src len, batch size, enc hid dim * 2] 255 | 256 | batch_size = encoder_outputs.shape[1] 257 | src_len = encoder_outputs.shape[0] 258 | 259 | #repeat decoder hidden state src_len times 260 | hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) 261 | 262 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 263 | 264 | #hidden = [batch size, src len, dec hid dim] 265 | #encoder_outputs = [batch size, src len, enc hid dim * 2] 266 | 267 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 268 | 269 | #energy = [batch size, src len, dec hid dim] 270 | 271 | attention = self.v(energy).squeeze(2) 272 | 273 | #attention= [batch size, src len] 274 | 275 | return F.softmax(attention, dim=1) 276 | 277 | 278 | if __name__ == "__main__": 279 | 280 | x = torch.randn((64,3,100)) 281 | 282 | mha = MultiHeadAttention(n_head=8, d_model=100, d_k=64, d_v=64) 283 | 284 | out, attn = mha(x,x,x) 285 | 286 | print('out shape', out.shape) 287 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/basic_mha.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | class MemoryAttention(nn.Module): 7 | def __init__(self, n_blocks_query, n_blocks_val, dim_query, dim_val, n_heads=8): 8 | super(MemoryAttention, self).__init__() 9 | 10 | self.n_heads = n_heads 11 | self.n_blocks_val = n_blocks_val 12 | self.dim_val = dim_val 13 | self.block_dim_val = dim_val // self.n_blocks_val 14 | 15 | self.n_blocks_query = n_blocks_query 16 | self.dim_query = dim_query 17 | self.block_dim_query = dim_query // self.n_blocks_query 18 | 19 | self.head_dim = 64 20 | self.scale = self.head_dim ** -0.5 21 | 22 | #self.n_blocks_val * self.block_dim_val 23 | 24 | self.query_net = GroupLinearLayer(self.block_dim_query, self.head_dim * self.n_heads, n_blocks_query) 25 | self.key_net = GroupLinearLayer(self.block_dim_val, self.head_dim * self.n_heads, n_blocks_val) 26 | self.value_net = GroupLinearLayer(self.block_dim_val, self.head_dim * self.n_heads, n_blocks_val) 27 | self.final = GroupLinearLayer(self.head_dim * self.n_heads, self.block_dim_query, n_blocks_query) 28 | 29 | def forward(self, q, kv): 30 | 31 | #comes in as: bs, pos*emb. 32 | #positions_attend x T*bs x emb 33 | 34 | 35 | #q = q.permute(1,0,2) 36 | #kv = kv.permute(1,0,2) 37 | 38 | #print('kv shape after permute', kv.shape) 39 | 40 | seq_len_q,bsz,_ = q.shape 41 | seq_len_v,bsz,_ = kv.shape 42 | 43 | q = q.reshape((seq_len_q, bsz, self.n_blocks_query * self.block_dim_query)) 44 | 45 | kv = kv.reshape((seq_len_v, bsz, self.n_blocks_val * self.block_dim_val)) 46 | 47 | q = self.query_net(q).view(seq_len_q, bsz, self.n_blocks_query, self.n_heads, self.head_dim) 48 | k = self.key_net(kv).view(seq_len_v, bsz, self.n_blocks_val, self.n_heads, self.head_dim) 49 | v = self.value_net(kv).view(seq_len_v, bsz, self.n_blocks_val, self.n_heads, self.head_dim) 50 | 51 | q = q.transpose(2,3) * self.scale 52 | k = k.transpose(2,3) 53 | v = v.transpose(2,3) 54 | score = torch.matmul(q, k.transpose(3,4)) 55 | #print('score shape', score.shape) 56 | score = F.softmax(score, dim=-1) 57 | out = torch.matmul(score, v).transpose(2,3) 58 | #print('out shape', out.shape) 59 | score = score.mean(dim=2) 60 | 61 | out = out.reshape(seq_len_q, bsz, self.n_blocks_query * self.head_dim * self.n_heads) 62 | out = self.final(out) 63 | out = out.view(seq_len_q, bsz, self.dim_query) 64 | 65 | 66 | return out, score 67 | 68 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/fairseq_dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List, Optional 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class FairseqDropout(nn.Module): 17 | 18 | def __init__(self, p, module_name=None): 19 | super().__init__() 20 | self.p = p 21 | self.module_name = module_name 22 | self.apply_during_inference = False 23 | 24 | def forward(self, x, inplace: bool = False): 25 | if self.training or self.apply_during_inference: 26 | return F.dropout(x, p=self.p, training=True, inplace=inplace) 27 | else: 28 | return x 29 | 30 | def make_generation_fast_( 31 | self, 32 | name: str, 33 | retain_dropout: bool = False, 34 | retain_dropout_modules: Optional[List[str]] = None, 35 | **kwargs 36 | ): 37 | if retain_dropout: 38 | if retain_dropout_modules is not None and self.module_name is None: 39 | logger.warning( 40 | 'Cannot enable dropout during inference for module {} ' 41 | 'because module_name was not set'.format(name) 42 | ) 43 | elif ( 44 | retain_dropout_modules is None # if None, apply to all modules 45 | or self.module_name in retain_dropout_modules 46 | ): 47 | logger.info( 48 | 'Enabling dropout during inference for module: {}'.format(name) 49 | ) 50 | self.apply_during_inference = True 51 | else: 52 | logger.info('Disabling dropout for module: {}'.format(name)) 53 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/group_linear_layer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | class GroupLinearLayer(nn.Module): 9 | 10 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 11 | super(GroupLinearLayer, self).__init__() 12 | self.nb = num_blocks 13 | self.dout = dout 14 | 15 | if a is None: 16 | a = 1. / math.sqrt(dout * num_blocks) 17 | 18 | #gain = 1.0 / math.sqrt(2) 19 | #a = gain * math.sqrt(6.0 / (din + dout)) 20 | 21 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 22 | 23 | self.bias = bias 24 | 25 | if bias is True: 26 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 27 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 28 | else: 29 | self.bias = None 30 | 31 | def forward(self,x): 32 | 33 | #input: ts x bs x blocks*nhid 34 | #ts*bs , blocks, nhid 35 | #blocks, ts*bs, nhid 36 | ts,bs,m = x.shape 37 | 38 | x = x.reshape((ts*bs, self.nb, m//self.nb)) 39 | x = x.permute(1,0,2) 40 | x = torch.bmm(x,self.weight) 41 | x = x.permute(1,0,2) 42 | 43 | if not self.bias is None: 44 | x = x + self.bias 45 | 46 | x = x.reshape((ts, bs, self.dout*self.nb)) 47 | 48 | #if not self.bias is None: 49 | # x += self.bias 50 | 51 | return x 52 | 53 | class GroupMLP(nn.Module): 54 | """Container module with an encoder, a recurrent module, and a decoder.""" 55 | 56 | def __init__(self, din, dout, num_blocks, dropout=0.1): 57 | super(GroupMLP, self).__init__() 58 | 59 | self.w_1 = nn.Parameter(0.01 * torch.randn(num_blocks,din,dout)) 60 | self.w_2 = nn.Parameter(0.01 * torch.randn(num_blocks,dout,din)) 61 | 62 | self.layer_norm = nn.LayerNorm(din) 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | def forward(self,x): 66 | 67 | residual = x*1.0 68 | x = x.permute(1,0,2) 69 | x = torch.bmm(F.relu(torch.bmm(x,self.w_1)), self.w_2) 70 | x = x.permute(1,0,2) 71 | x = self.dropout(x) 72 | x = self.layer_norm(x + residual) 73 | 74 | return x 75 | 76 | if __name__ == "__main__": 77 | 78 | GLN = GroupLinearLayer(512, 512, 2, bias=True) 79 | 80 | print('params', sum(g.numel() for g in GLN.parameters())) 81 | 82 | #bs, blocks, nhid 83 | x = torch.randn(64,12,2*512) 84 | 85 | print(GLN(x).shape) 86 | 87 | #for p in GLN.parameters(): 88 | # print(p.shape) 89 | 90 | 91 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/isab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MAB(nn.Module): 7 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 8 | super(MAB, self).__init__() 9 | self.dim_V = dim_V 10 | self.num_heads = num_heads 11 | self.fc_q = nn.Linear(dim_Q, dim_V) 12 | self.fc_k = nn.Linear(dim_K, dim_V) 13 | self.fc_v = nn.Linear(dim_K, dim_V) 14 | if ln: 15 | self.ln0 = nn.LayerNorm(dim_V) 16 | self.ln1 = nn.LayerNorm(dim_V) 17 | self.fc_o = nn.Linear(dim_V, dim_V) 18 | 19 | def forward(self, Q, K): 20 | Q = self.fc_q(Q) 21 | K, V = self.fc_k(K), self.fc_v(K) 22 | 23 | dim_split = self.dim_V // self.num_heads 24 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 25 | K_ = torch.cat(K.split(dim_split, 2), 0) 26 | V_ = torch.cat(V.split(dim_split, 2), 0) 27 | 28 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 29 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 30 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 31 | O = O + F.relu(self.fc_o(O)) 32 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 33 | return O 34 | 35 | class SAB(nn.Module): 36 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 37 | super(SAB, self).__init__() 38 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 39 | 40 | def forward(self, X): 41 | return self.mab(X, X) 42 | 43 | class ISAB(nn.Module): 44 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 45 | super(ISAB, self).__init__() 46 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 47 | nn.init.xavier_uniform_(self.I) 48 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 49 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 50 | 51 | def forward(self, X): 52 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 53 | return self.mab1(X, H) 54 | 55 | class PMA(nn.Module): 56 | def __init__(self, dim, num_heads, num_seeds, ln=False): 57 | super(PMA, self).__init__() 58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 59 | nn.init.xavier_uniform_(self.S) 60 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 61 | 62 | def forward(self, X): 63 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 64 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | try: 12 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 13 | 14 | has_fused_layernorm = True 15 | 16 | class FusedLayerNorm(_FusedLayerNorm): 17 | @torch.jit.unused 18 | def forward(self, x): 19 | if not x.is_cuda: 20 | return super().forward(x) 21 | else: 22 | with torch.cuda.device(x.device): 23 | return super().forward(x) 24 | 25 | except ImportError: 26 | has_fused_layernorm = False 27 | 28 | 29 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 30 | if not export and torch.cuda.is_available() and has_fused_layernorm: 31 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 32 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 33 | 34 | 35 | class Fp32LayerNorm(nn.LayerNorm): 36 | def __init__(self, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | 39 | def forward(self, input): 40 | output = F.layer_norm( 41 | input.float(), 42 | self.normalized_shape, 43 | self.weight.float() if self.weight is not None else None, 44 | self.bias.float() if self.bias is not None else None, 45 | self.eps, 46 | ) 47 | return output.type_as(input) 48 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/pos_enc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import math 7 | import torch.nn.functional as F 8 | import random 9 | 10 | class PositionEncoder(nn.Module): 11 | def __init__(self, d_model, max_seq_len = 300): 12 | super().__init__() 13 | self.d_model = d_model 14 | # create constant 'pe' matrix with values dependant on 15 | # pos and i 16 | pe = torch.zeros(max_seq_len, d_model) 17 | for pos in range(max_seq_len): 18 | for i in range(0, d_model, 2): 19 | pe[pos, i] = \ 20 | math.sin(pos / (10000 ** ((2 * i)/d_model))) 21 | pe[pos, i + 1] = \ 22 | math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) 23 | 24 | pe = pe.unsqueeze(0) 25 | self.register_buffer('pe', pe) 26 | 27 | self.pos_emb_weight = nn.Parameter(torch.ones_like(pe)) 28 | 29 | def forward(self, x): 30 | # make embeddings relatively larger 31 | 32 | x = x.permute(1,0,2) 33 | 34 | #x = x * math.sqrt(self.d_model) 35 | #add constant to embedding 36 | 37 | seq_len = x.size(1) 38 | 39 | #width x channel 40 | #pe_use = F.interpolate(self.pe.permute(0,2,1), size=seq_len).permute(0,2,1) 41 | 42 | pe_use = Variable(self.pe[:,:seq_len] * F.sigmoid(self.pos_emb_weight[:,:seq_len]), requires_grad=False).cuda() 43 | 44 | #bs x pos x nhid --> bs x nhid x pos --> bs x pos x nhid 45 | 46 | x = x + pe_use 47 | #Variable(pe_use, requires_grad=False).cuda() 48 | 49 | x = x.permute(1,0,2) 50 | 51 | return x 52 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/quant_noise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def quant_noise(module, p, block_size): 11 | """ 12 | Wraps modules and applies quantization noise to the weights for 13 | subsequent quantization with Iterative Product Quantization as 14 | described in "Training with Quantization Noise for Extreme Model Compression" 15 | 16 | Args: 17 | - module: nn.Module 18 | - p: amount of Quantization Noise 19 | - block_size: size of the blocks for subsequent quantization with iPQ 20 | 21 | Remarks: 22 | - Module weights must have the right sizes wrt the block size 23 | - Only Linear, Embedding and Conv2d modules are supported for the moment 24 | - For more detail on how to quantize by blocks with convolutional weights, 25 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 26 | - We implement the simplest form of noise here as stated in the paper 27 | which consists in randomly dropping blocks 28 | """ 29 | 30 | # if no quantization noise, don't register hook 31 | if p <= 0: 32 | return module 33 | 34 | # supported modules 35 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 36 | 37 | # test whether module.weight has the right sizes wrt block_size 38 | is_conv = module.weight.ndim == 4 39 | 40 | # 2D matrix 41 | if not is_conv: 42 | assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" 43 | 44 | # 4D matrix 45 | else: 46 | # 1x1 convolutions 47 | if module.kernel_size == (1, 1): 48 | assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" 49 | # regular convolutions 50 | else: 51 | k = module.kernel_size[0] * module.kernel_size[1] 52 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 53 | 54 | def _forward_pre_hook(mod, input): 55 | # no noise for evaluation 56 | if mod.training: 57 | if not is_conv: 58 | # gather weight and sizes 59 | weight = mod.weight 60 | in_features = weight.size(1) 61 | out_features = weight.size(0) 62 | 63 | # split weight matrix into blocks and randomly drop selected blocks 64 | mask = torch.zeros(in_features // block_size * out_features, device=weight.device) 65 | mask.bernoulli_(p) 66 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 67 | 68 | else: 69 | # gather weight and sizes 70 | weight = mod.weight 71 | in_channels = mod.in_channels 72 | out_channels = mod.out_channels 73 | 74 | # split weight matrix into blocks and randomly drop selected blocks 75 | if mod.kernel_size == (1, 1): 76 | mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device) 77 | mask.bernoulli_(p) 78 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 79 | else: 80 | mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) 81 | mask.bernoulli_(p) 82 | mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 83 | 84 | # scale weights and apply mask 85 | mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript 86 | s = 1 / (1 - p) 87 | mod.weight.data = s * weight.masked_fill(mask, 0) 88 | 89 | module.register_forward_pre_hook(_forward_pre_hook) 90 | return module 91 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/set_transformer.py: -------------------------------------------------------------------------------- 1 | from .isab import * 2 | from .pos_enc import PositionEncoder 3 | 4 | class DeepSet(nn.Module): 5 | def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128): 6 | super(DeepSet, self).__init__() 7 | self.num_outputs = num_outputs 8 | self.dim_output = dim_output 9 | self.enc = nn.Sequential( 10 | nn.Linear(dim_input, dim_hidden), 11 | nn.ReLU(), 12 | nn.Linear(dim_hidden, dim_hidden), 13 | nn.ReLU(), 14 | nn.Linear(dim_hidden, dim_hidden), 15 | nn.ReLU(), 16 | nn.Linear(dim_hidden, dim_hidden)) 17 | self.dec = nn.Sequential( 18 | nn.Linear(dim_hidden, dim_hidden), 19 | nn.ReLU(), 20 | nn.Linear(dim_hidden, dim_hidden), 21 | nn.ReLU(), 22 | nn.Linear(dim_hidden, dim_hidden), 23 | nn.ReLU(), 24 | nn.Linear(dim_hidden, num_outputs*dim_output)) 25 | 26 | def forward(self, X): 27 | X = self.enc(X).mean(-2) 28 | X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output) 29 | return X 30 | 31 | class SetTransformer(nn.Module): 32 | def __init__(self, dim_input, 33 | num_inds=32, dim_hidden=128, num_heads=4, ln=True, num_layers = 4): 34 | super(SetTransformer, self).__init__() 35 | self.pe = PositionEncoder(dim_input) 36 | layers = [] 37 | layers.append(ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)) 38 | for _ in range(num_layers-1): 39 | layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 40 | self.layers = nn.ModuleList(layers) 41 | # self.enc = nn.Sequential( 42 | # ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), 43 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln), 44 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln), 45 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 46 | 47 | def forward(self, X): 48 | X=X.permute(1,0,2) #self.pe expects T,B,D 49 | X = self.pe(X) 50 | X=X.permute(1,0,2) #layer expects B,T,D 51 | for layer in self.layers: 52 | X=layer(X) 53 | return X 54 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/sparse_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy 5 | 6 | class Sparse_attention(nn.Module): 7 | def __init__(self, top_k = 5): 8 | super(Sparse_attention,self).__init__() 9 | top_k += 1 10 | self.top_k = top_k 11 | 12 | def forward(self, attn_s): 13 | 14 | # normalize the attention weights using piece-wise Linear function 15 | # only top k should 16 | attn_plot = [] 17 | # torch.max() returns both value and location 18 | #attn_s_max = torch.max(attn_s, dim = 1)[0] 19 | #attn_w = torch.clamp(attn_s_max, min = 0, max = attn_s_max) 20 | eps = 10e-8 21 | time_step = attn_s.size()[1] 22 | if time_step <= self.top_k: 23 | # just make everything greater than 0, and return it 24 | #delta = torch.min(attn_s, dim = 1)[0] 25 | return attn_s 26 | else: 27 | # get top k and return it 28 | # bottom_k = attn_s.size()[1] - self.top_k 29 | # value of the top k elements 30 | #delta = torch.kthvalue(attn_s, bottm_k, dim= 1 )[0] 31 | delta = torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 32 | #delta = attn_s_max - torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 33 | # normalize 34 | delta = delta.reshape((delta.shape[0],1)) 35 | 36 | 37 | attn_w = attn_s - delta.repeat(1, time_step) 38 | attn_w = torch.clamp(attn_w, min = 0) 39 | attn_w_sum = torch.sum(attn_w, dim = 1, keepdim=True) 40 | attn_w_sum = attn_w_sum + eps 41 | attn_w_normalize = attn_w / attn_w_sum.repeat(1, time_step) 42 | 43 | #print('attn', attn_w_normalize) 44 | 45 | return attn_w_normalize 46 | 47 | 48 | if __name__ == "__main__": 49 | k = 1 50 | print('take top k', k) 51 | sa = Sparse_attention(top_k=k) 52 | 53 | #batch x time 54 | 55 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4],[0.5,0.4,0.1,0.0,0.0]]])) 56 | 57 | x = x.reshape((2,5)) 58 | 59 | print('x shape', x.shape) 60 | print('x', x) 61 | 62 | o = sa(x) 63 | 64 | 65 | print('o', o) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/sparse_grad_attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Giving an N x M attention matrix, returns the same matrix, 3 | but performs masking to determine where to block gradients. 4 | ''' 5 | 6 | import numpy 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from .sparse_attn import Sparse_attention 11 | 12 | 13 | class blocked_grad(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward(ctx, x, mask): 17 | ctx.save_for_backward(x, mask) 18 | return x 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | x, mask = ctx.saved_tensors 23 | return grad_output * mask, mask * 0.0 24 | 25 | 26 | class Sparse_grad_attention(torch.autograd.Function): 27 | # def __init__(self, top_k): 28 | # super(Sparse_grad_attention,self).__init__() 29 | # 30 | # self.sa = Sparse_attention(top_k=top_k) 31 | 32 | @staticmethod 33 | def forward(ctx, inp, sa): 34 | sparsified = sa(inp) 35 | ctx.save_for_backward(inp, sparsified) 36 | 37 | return inp 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | inp, sparsified = ctx.saved_tensors 42 | # print('sparsified', sparsified) 43 | return (grad_output) * (sparsified > 0.0).float() 44 | 45 | 46 | if __name__ == "__main__": 47 | k = 2 48 | sga = Sparse_grad_attention(k) 49 | sa = Sparse_attention(k) 50 | 51 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4], 52 | [0.5, 0.4, 0.1, 0.0, 0.0]]])) 53 | x = x.reshape((2, 5)) 54 | 55 | x = Variable(x, requires_grad=True) 56 | 57 | print(x) 58 | print('output', sga(x)) 59 | 60 | (sga(x).sum()).backward() 61 | 62 | print('sparse grad', x.grad) 63 | 64 | x = Variable(x.data, requires_grad=True) 65 | 66 | (sa(x).sum()).backward() 67 | 68 | print('normal grad', x.grad) 69 | -------------------------------------------------------------------------------- /Triangle/transformer_utilities/transformer_helper.py: -------------------------------------------------------------------------------- 1 | 2 | ''' Define the sublayers in encoder/decoder layer ''' 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import random 8 | 9 | __author__ = "Yu-Hsiang Huang" 10 | 11 | class ScaledDotProductAttention(nn.Module): 12 | ''' Scaled Dot-Product Attention ''' 13 | 14 | def __init__(self, temperature, attn_dropout=0.1): 15 | super().__init__() 16 | self.temperature = temperature 17 | self.dropout = nn.Dropout(attn_dropout) 18 | 19 | def forward(self, q, k, v, mask=None): 20 | 21 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 22 | if mask is not None: 23 | attn = attn.masked_fill(mask == 0, -1e9) 24 | 25 | attn = self.dropout(F.softmax(attn, dim=-1)) 26 | output = torch.matmul(attn, v) 27 | 28 | return output, attn 29 | 30 | class PositionalEncoding(nn.Module): 31 | 32 | def __init__(self, d_hid, n_position=200): 33 | super(PositionalEncoding, self).__init__() 34 | 35 | # Not a parameter 36 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 37 | 38 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 39 | ''' Sinusoid position encoding table ''' 40 | # TODO: make it with torch instead of numpy 41 | 42 | def get_position_angle_vec(position): 43 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 44 | 45 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 46 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 47 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 48 | 49 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 50 | 51 | def forward(self, x): 52 | #if self.train: 53 | # ind = random.randint(0, 160) 54 | #else: 55 | ind = 0 56 | return x + self.pos_table[:, ind:ind + x.size(1)].clone().detach() 57 | 58 | class MultiHeadAttention(nn.Module): 59 | ''' Multi-Head Attention module ''' 60 | 61 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 62 | super().__init__() 63 | 64 | self.n_head = n_head 65 | self.d_k = d_k 66 | self.d_v = d_v 67 | 68 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 69 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 70 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 71 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 72 | 73 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 74 | 75 | self.dropout = nn.Dropout(dropout) 76 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 77 | 78 | 79 | def forward(self, q, k, v, mask=None): 80 | 81 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 82 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 83 | 84 | residual = q 85 | 86 | # Pass through the pre-attention projection: b x lq x (n*dv) 87 | # Separate different heads: b x lq x n x dv 88 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 89 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 90 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 91 | 92 | # Transpose for attention dot product: b x n x lq x dv 93 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 94 | 95 | if mask is not None: 96 | mask = mask.unsqueeze(1) # For head axis broadcasting. 97 | 98 | q, attn = self.attention(q, k, v, mask=mask) 99 | 100 | # Transpose to move the head dimension back: b x lq x n x dv 101 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 102 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 103 | q = self.dropout(self.fc(q)) 104 | q += residual 105 | 106 | q = self.layer_norm(q) 107 | 108 | return q, attn 109 | 110 | 111 | class PositionwiseFeedForward(nn.Module): 112 | ''' A two-feed-forward-layer module ''' 113 | 114 | def __init__(self, d_in, d_hid, dropout=0.1): 115 | super().__init__() 116 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 117 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 118 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 119 | self.dropout = nn.Dropout(dropout) 120 | 121 | def forward(self, x): 122 | 123 | residual = x 124 | 125 | 126 | x = self.w_2(F.relu(self.w_1(x))) 127 | x = self.dropout(x) 128 | x += residual 129 | 130 | x = self.layer_norm(x) 131 | 132 | return x 133 | 134 | 135 | class EncoderLayer(nn.Module): 136 | ''' Compose with two layers ''' 137 | 138 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 139 | super(EncoderLayer, self).__init__() 140 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 141 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 142 | 143 | def forward(self, enc_input, slf_attn_mask=None, seperate_queries = None): 144 | enc_output, enc_slf_attn = self.slf_attn( 145 | seperate_queries if seperate_queries is not None else enc_input, enc_input, enc_input, mask=slf_attn_mask) 146 | enc_output = self.pos_ffn(enc_output) 147 | return enc_output, enc_slf_attn -------------------------------------------------------------------------------- /Triangle/transformer_utilities/transformer_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from transformer import TransformerEncoder 4 | import types 5 | import math 6 | 7 | args = types.SimpleNamespace() 8 | args.use_module_communication = 'true' 9 | args.encoder_embed_dim = 512 10 | args.encoder_attention_heads = 8 #was 8 11 | args.attention_dropout = 0.1 12 | args.topk_ratio = 1.0 13 | args.dropout = 0.2 14 | args.encoder_normalize_before = True 15 | args.encoder_ffn_embed_dim = 2048 16 | args.use_nfm = 'false' 17 | 18 | from models.transformer_layer import TransformerEncoderLayer, TransformerEncoderLayerVanilla 19 | from models.pos_enc import PositionEncoder 20 | #from transformer_utilities.GroupLinearLayer import GroupLinearLayer 21 | import math 22 | class GroupLinearLayer(nn.Module): 23 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 24 | super(GroupLinearLayer, self).__init__() 25 | self.nb = num_blocks 26 | #din = din // num_blocks 27 | #dout = dout // num_blocks 28 | self.dout = dout 29 | if a is None: 30 | a = 1. / math.sqrt(dout) 31 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 32 | self.bias = bias 33 | if bias is True: 34 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 35 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 36 | else: 37 | self.bias = None 38 | def forward(self,x): 39 | ts,bs,m = x.shape 40 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 41 | x = x.permute(1,0,2) 42 | x = torch.bmm(x,self.weight) 43 | x = x.permute(1,0,2) 44 | if not self.bias is None: 45 | x = x + self.bias 46 | #x = x.reshape((ts, bs, self.dout*self.nb)) 47 | return x 48 | 49 | 50 | 51 | class SelectAttention(nn.Module): 52 | """docstring for SelectAttention""" 53 | def __init__(self, d_read, d_write, d_k = 16, num_read = 5, num_write = 5, share_query = False, share_key = False): 54 | super(SelectAttention, self).__init__() 55 | if not share_key: 56 | self.gll_write = GroupLinearLayer(d_write,d_k, num_write) 57 | else: 58 | self.gll_write = nn.Linear(d_write, d_k) 59 | 60 | if not share_query: 61 | self.gll_read = GroupLinearLayer(d_read,d_k, num_read) 62 | else: 63 | self.gll_read = nn.Linear(d_read, d_k) 64 | 65 | self.temperature = math.sqrt(d_k) 66 | 67 | def forward(self, q, k): 68 | read = self.gll_read(q) 69 | write = self.gll_write(k) 70 | 71 | return torch.bmm(read, write.permute(0, 2, 1)) / self.temperature 72 | 73 | class TransformerEncoder(nn.Module): 74 | 75 | def __init__(self, inp_dim, h_dim, inp_nb, nb, functional = True): 76 | super().__init__() 77 | 78 | args.encoder_embed_dim = h_dim 79 | 80 | print('transformer h_dim', h_dim) 81 | 82 | 83 | 84 | args.encoder_embed_dim = h_dim 85 | self.functional = functional 86 | print('functional? '+str(self.functional)) 87 | if not self.functional: 88 | layer_lst = [] 89 | 90 | args.encoder_embed_dim = h_dim 91 | #layer_lst.append(TransformerEncoderLayer(args=args, nb=inp_nb, blockatt=False, blockatt_memory=True, use_nfm=False, out_proj_dim=h_dim)) 92 | #for j in range(0,6): 93 | # layer_lst.append(TransformerEncoderLayer(args=args, nb=nb, blockatt=False, blockatt_memory=True, use_nfm=False)) 94 | self.enc = TransformerEncoderLayerVanilla(args) 95 | #self.layers = nn.ModuleList(layer_lst) 96 | else: 97 | #args.encoder_embed_dim = inp_dim 98 | #print('init_layer initialize') 99 | #self.init_layer = TransformerEncoderLayerVanilla(args=args, out_proj=h_dim) 100 | args.encoder_embed_dim = h_dim 101 | hidden_dim = args.encoder_embed_dim 102 | print('inp_att initialize') 103 | self.inp_att = TransformerEncoderLayerVanilla(args=args) 104 | print('gru initialize') 105 | self.gru_pool = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(1)]) 106 | self.state_att = TransformerEncoderLayerVanilla(args=args) 107 | self.select_attention = SelectAttention( hidden_dim + hidden_dim, hidden_dim, num_read = 1, num_write = 1) 108 | 109 | self.pe = PositionEncoder(inp_dim) 110 | self.pe_state = PositionEncoder(args.encoder_embed_dim) 111 | 112 | def forward(self, x, mask = None): 113 | 114 | x = x.permute(1, 0, 2) 115 | 116 | x = self.pe(x) 117 | if not self.functional: 118 | """klst = [] 119 | vlst = [] 120 | 121 | initial_state = self.layers[0].memory_layer.initial_state(batch_size=x.shape[0]*x.shape[1]).type(x.dtype).to(x.device) 122 | memory_obj = [initial_state] 123 | 124 | for layer in self.layers: 125 | layer.klst = klst 126 | layer.vlst = vlst 127 | layer.memory_obj = memory_obj 128 | 129 | """ 130 | for i in range(6): 131 | x = self.enc(x, None) 132 | return x.permute(1, 0, 2) 133 | else: 134 | """ 135 | klst = [] 136 | vlst = [] 137 | 138 | initial_state = self.init_layer.memory_layer.initial_state(batch_size=x.shape[0]*x.shape[1]).type(x.dtype).to(x.device) 139 | memory_obj = [initial_state] 140 | 141 | self.init_layer.klst = klst 142 | self.init_layer.vlst = vlst 143 | self.init_layer.memory_obj = memory_obj 144 | 145 | 146 | self.inp_att.klst = klst 147 | self.inp_att.vlst = vlst 148 | self.inp_att.memory_obj = memory_obj 149 | 150 | self.state_att.klst = klst 151 | self.state_att.vlst = vlst 152 | self.state_att.memory_obj = memory_obj 153 | """ 154 | T, B, D = x.size() 155 | 156 | #x = self.init_layer(x, None) 157 | state = self.pe_state(torch.randn(x.size()).to(x.device)) 158 | 159 | 160 | 161 | for i in range(0, 6): 162 | gru_in = self.inp_att(x, mask, state = state) 163 | gru_in = gru_in.permute(1, 0, 2) 164 | state = state.permute(1, 0, 2) 165 | 166 | gru_in = gru_in.reshape(B * T, -1) 167 | state = state.reshape(B * T, -1) 168 | 169 | gru_outs = [] 170 | 171 | for gru in self.gru_pool: 172 | gru_outs.append(gru(gru_in, state)) 173 | 174 | gru_outs = torch.stack(gru_outs, dim = 1) 175 | 176 | selector = torch.cat((gru_in, state), dim = 1).unsqueeze(1) 177 | 178 | attn_scores = self.select_attention(selector, gru_outs) 179 | 180 | attn_scores = attn_scores.squeeze(1) 181 | 182 | attn_scores = torch.nn.functional.gumbel_softmax(attn_scores, dim = 1, tau = 1.0, hard = True) 183 | attn_scores = attn_scores.unsqueeze(-1) 184 | gru_outs = (gru_outs * attn_scores).sum(dim = 1) 185 | gru_outs_hidden = gru_outs.reshape(B, T, -1) 186 | gru_outs_hidden = gru_outs_hidden.permute(1, 0, 2) 187 | gru_outs_hidden = self.state_att(gru_outs_hidden, mask) 188 | gru_in = gru_in.reshape(B, T, -1).permute(1, 0, 2) 189 | 190 | x = gru_in 191 | state = gru_outs_hidden 192 | 193 | return state.permute(1,0,2) 194 | 195 | 196 | 197 | if __name__ == "__main__": 198 | x = torch.randn(32, 64, 512) 199 | 200 | TE = TransformerEncoder() 201 | 202 | y = TE(x) 203 | 204 | print(y.shape) 205 | 206 | -------------------------------------------------------------------------------- /Triangle/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from transformer import TransformerEncoder 4 | import types 5 | import math 6 | import numpy as np 7 | 8 | args = types.SimpleNamespace() 9 | args.use_module_communication = 'true' 10 | args.encoder_embed_dim = 512 11 | args.encoder_attention_heads = 8 #was 8 12 | args.attention_dropout = 0.1 13 | args.topk_ratio = 1.0 14 | args.dropout = 0.2 15 | args.encoder_normalize_before = True 16 | args.encoder_ffn_embed_dim = 2048 17 | args.use_nfm = 'false' 18 | args.shared_memory_attention = False 19 | args.self_attention = True 20 | args.mem_slots = 4 21 | args.use_topk = False 22 | args.topk = 3 23 | args.num_steps = 5 24 | 25 | from transformer_utilities.transformer_layer import TransformerEncoderLayer, TransformerEncoderLayerVanilla 26 | from transformer_utilities.pos_enc import PositionEncoder 27 | from transformer_utilities.GroupLinearLayer import GroupLinearLayer 28 | import math 29 | 30 | 31 | class SelectAttention(nn.Module): 32 | """docstring for SelectAttention""" 33 | def __init__(self, d_read, d_write, d_k = 16, num_read = 5, num_write = 5, share_query = False, share_key = False): 34 | super(SelectAttention, self).__init__() 35 | if not share_key: 36 | self.gll_write = GroupLinearLayer(d_write,d_k, num_write) 37 | else: 38 | self.gll_write = nn.Linear(d_write, d_k) 39 | 40 | if not share_query: 41 | self.gll_read = GroupLinearLayer(d_read,d_k, num_read) 42 | else: 43 | self.gll_read = nn.Linear(d_read, d_k) 44 | 45 | self.temperature = math.sqrt(d_k) 46 | 47 | def forward(self, q, k): 48 | read = self.gll_read(q) 49 | write = self.gll_write(k) 50 | 51 | return torch.bmm(read, write.permute(0, 2, 1)) / self.temperature 52 | 53 | class TransformerEncoder(nn.Module): 54 | 55 | def __init__(self, 56 | embed_dim, 57 | ffn_dim, 58 | num_layers = 6, 59 | num_heads = 4, 60 | dropout = 0.1, 61 | functional = False, 62 | shared_memory_attention = False, 63 | shared_memory_percentage = 0.1, 64 | share_parameters = False, 65 | mem_slots = 4, 66 | num_attention_schemas = 3, 67 | num_gru_schemas = 3, 68 | schema_specific = False, 69 | use_topk = False, 70 | topk = 3, 71 | num_steps = 5, 72 | null_attention = False, 73 | regressive = False): 74 | super().__init__() 75 | 76 | if schema_specific and (num_gru_schemas != num_attention_schemas): 77 | print('Cannot use schema specific as num_gru_schemas != num_attention_schemas, continuing without') 78 | self.schema_specific = False 79 | else: 80 | self.schema_specific = schema_specific 81 | 82 | args.mem_slots = mem_slots 83 | args.encoder_embed_dim = embed_dim 84 | args.encoder_ffn_embed_dim = ffn_dim 85 | args.encoder_attention_heads = num_heads 86 | args.dropout = dropout 87 | args.shared_memory_attention = shared_memory_attention 88 | args.num_steps = num_steps 89 | args.null_attention = null_attention 90 | args.regressive = regressive 91 | 92 | 93 | self.num_layers = num_layers 94 | self.shared_memory_attention = shared_memory_attention 95 | self.shared_memory_percentage = shared_memory_percentage 96 | 97 | print('transformer embed_dim', embed_dim) 98 | self.functional = functional 99 | print('functional? '+str(self.functional)) 100 | if not self.functional: 101 | layer_lst = [] 102 | args.use_topk = use_topk 103 | args.topk = topk 104 | 105 | 106 | args.encoder_embed_dim = embed_dim 107 | self.share_parameters = share_parameters 108 | if share_parameters: 109 | self.enc = TransformerEncoderLayerVanilla(args) 110 | else: 111 | layer_lst = [] 112 | for i in range(self.num_layers): 113 | layer_lst.append(TransformerEncoderLayerVanilla(args)) 114 | print('flmklsd') 115 | self.layers = nn.ModuleList(layer_lst) 116 | else: 117 | #args.encoder_embed_dim = inp_dim 118 | #print('init_layer initialize') 119 | #self.init_layer = TransformerEncoderLayerVanilla(args=args, out_proj=h_dim) 120 | print('NUM GRU SCHEMAS:' + str(num_gru_schemas)) 121 | print('NUM Attention SCHEMAS:' + str(num_attention_schemas)) 122 | print('SCHEMA SPECIFIC:' + str(self.schema_specific)) 123 | args.use_topk = use_topk 124 | args.topk = topk 125 | print('inp_att initialize') 126 | self.num_gru_schemas = num_gru_schemas 127 | self.num_att_schemas = num_attention_schemas 128 | self.schema_stats = np.zeros(self.num_gru_schemas) 129 | args.self_attention = True 130 | self.inp_att = nn.ModuleList([TransformerEncoderLayerVanilla(args=args) for _ in range(num_attention_schemas)]) 131 | self.select_attention_inp_att = SelectAttention( args.encoder_embed_dim, args.encoder_embed_dim, num_read = 1, num_write = num_attention_schemas) 132 | print('gru initialize') 133 | hidden_dim = args.encoder_embed_dim 134 | 135 | 136 | self.gru_pool = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(num_gru_schemas)]) 137 | #args.self_attention = True 138 | #self.state_att = TransformerEncoderLayerVanilla(args=args) 139 | self.select_attention = SelectAttention( hidden_dim + hidden_dim, hidden_dim, num_read = 1, num_write = num_gru_schemas) 140 | 141 | self.pe = PositionEncoder(args.encoder_embed_dim) 142 | self.pe_state = PositionEncoder(args.encoder_embed_dim) 143 | 144 | def forward(self, x, mask = None, num_layers = None): 145 | 146 | x = x.permute(1, 0, 2) 147 | 148 | x = self.pe(x) 149 | 150 | 151 | 152 | if not self.functional: 153 | if self.shared_memory_attention: 154 | memory_size = int(self.shared_memory_percentage * x.size(0)) 155 | 156 | memory = torch.randn(memory_size, 1, x.size(2)).repeat(1 ,x.size(1), 1).to(x.device) 157 | else: 158 | memory = None 159 | if self.shared_memory_attention: 160 | if self.share_parameters: 161 | if self.enc.self_attn.memory is not None: 162 | self.enc.self_attn.init_memory(x.size(1), x.size(0), x.device)#.memory = self.enc.self_attn.memory.detach() 163 | else: 164 | for layer in self.layers: 165 | if layer.self_attn.memory is not None: 166 | layer.self_attn.init_memory(x.size(1), x.device)#.memory = layer.self_attn.memory.detach() 167 | 168 | 169 | for i in range(self.num_layers): 170 | if self.share_parameters: 171 | x, memory = self.enc(x, mask, memory = memory) 172 | else: 173 | x, memory = self.layers[i](x, mask, memory = memory) 174 | return x.permute(1, 0, 2) 175 | else: 176 | 177 | T, B, D = x.size() 178 | 179 | if num_layers is None: 180 | num_layers = self.num_layers 181 | 182 | 183 | #state = self.pe_state(torch.randn(x.size()).to(x.device)) 184 | 185 | if self.shared_memory_attention: 186 | memory_size = int(self.shared_memory_percentage * x.size(0)) 187 | memory_inp = torch.randn( memory_size, 1, x.size(2)).repeat(1, x.size(1), 1).to(x.device) 188 | memory_state = torch.randn(memory_size, 1, x.size(2)).repeat(1, x.size(1), 1).to(x.device) 189 | else: 190 | memory_inp = None 191 | memory_state = None 192 | 193 | if self.shared_memory_attention: 194 | for inp_att in self.inp_att: 195 | if inp_att.self_attn.memory is not None: 196 | inp_att.self_attn.init_memory(x.size(1), x.device)#memory = inp_att.self_attn.memory.detach() 197 | for i in range(0, num_layers): 198 | gru_ins = [] 199 | for inp_att in self.inp_att: 200 | gru_in, memory_inp = inp_att(x, mask, memory = memory_inp) 201 | gru_ins.append(gru_in.permute(1, 0, 2)) 202 | 203 | gru_ins = torch.stack(gru_ins, dim = 2) 204 | gru_ins = gru_ins.reshape(B * T, -1, D) 205 | 206 | 207 | x = x.permute(1, 0, 2) 208 | x = x.reshape(B * T, -1).unsqueeze(1) 209 | 210 | attn_scores_inp_att = self.select_attention_inp_att(x, gru_ins) 211 | 212 | attn_scores_inp_att = attn_scores_inp_att.squeeze(1) 213 | attn_scores_inp_att = torch.nn.functional.gumbel_softmax(attn_scores_inp_att, dim = 1, hard = True, tau = 0.5) 214 | 215 | attn_scores_inp_att = attn_scores_inp_att.unsqueeze(-1) 216 | 217 | gru_in = (gru_ins * attn_scores_inp_att).sum(dim = 1) 218 | 219 | gru_in = gru_in.reshape(B, T, -1) 220 | x = x.reshape(B, T, -1) 221 | 222 | gru_in = gru_in.reshape(B * T, -1) 223 | x = x.reshape(B * T, -1) 224 | 225 | gru_outs = [] 226 | 227 | for gru in self.gru_pool: 228 | gru_outs.append(gru(gru_in, x)) 229 | 230 | gru_outs = torch.stack(gru_outs, dim = 1) 231 | 232 | selector = torch.cat((gru_in, x), dim = 1).unsqueeze(1) 233 | if not self.schema_specific: 234 | attn_scores = self.select_attention(selector, gru_outs) 235 | 236 | 237 | attn_scores = attn_scores.squeeze(1) 238 | 239 | attn_scores = torch.nn.functional.gumbel_softmax(attn_scores, dim = 1, tau = 1.0, hard = True) 240 | 241 | att_argmax = torch.sum(attn_scores.clone().detach(), dim = 0).cpu().numpy() 242 | 243 | self.schema_stats += att_argmax 244 | 245 | 246 | attn_scores = attn_scores.unsqueeze(-1) 247 | else: 248 | attn_scores = attn_scores_inp_att 249 | att_argmax = torch.sum(attn_scores.squeeze(-1).clone().detach(), dim = 0).cpu().numpy() 250 | 251 | self.schema_stats += att_argmax 252 | 253 | gru_outs = (gru_outs * attn_scores).sum(dim = 1) 254 | gru_outs_hidden = gru_outs.reshape(B, T, -1) 255 | gru_outs_hidden = gru_outs_hidden.permute(1, 0, 2) 256 | #gru_outs_hidden, memory_state = self.state_att(gru_outs_hidden, mask, memory = memory_state) 257 | #gru_in = gru_in.reshape(B, T, -1).permute(1, 0, 2) 258 | #x = gru_in 259 | x = gru_outs_hidden 260 | 261 | return x.permute(1,0,2) 262 | 263 | def print_schema_stats(self): 264 | total = np.sum(self.schema_stats) 265 | for k in range(self.schema_stats.shape[0]): 266 | print('schema ' + str(k) + ' used ' + str(self.schema_stats[k]) + ' out of ' + str(total) + ' times') 267 | 268 | 269 | def reset_schema_stats(self): 270 | self.schema_stats = np.zeros(self.num_gru_schemas) 271 | 272 | 273 | if __name__ == "__main__": 274 | x = torch.randn(8, 20, 256).cuda() 275 | import time 276 | TE1 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = False, shared_memory_attention = True, mem_slots = 8, num_steps = 20).cuda() 277 | t1 = time.time() 278 | for i in range(5): 279 | 280 | x = TE1(x) 281 | print(time.time() - t1) 282 | 283 | 284 | x = torch.randn(8, 20, 256).cuda() 285 | import time 286 | TE1 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = False, shared_memory_attention = True, mem_slots = 8, num_steps = 20).cuda() 287 | t1 = time.time() 288 | for i in range(5): 289 | 290 | x = TE1(x) 291 | print(time.time() - t1) 292 | x = torch.randn(8, 20, 256).cuda() 293 | TE2 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = True, shared_memory_attention = False, mem_slots = 8, num_steps = 20).cuda() 294 | t1 = time.time() 295 | for i in range(5): 296 | x = TE2(x) 297 | print(time.time() - t1) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5 2 | cycler==0.10.0 3 | dataclasses==0.8 4 | decorator==4.4.2 5 | einops==0.3.0 6 | h5py==2.10.0 7 | imageio==2.9.0 8 | kiwisolver==1.3.1 9 | matplotlib==3.3.2 10 | networkx==2.5 11 | numpy==1.19.1 12 | opencv-python==3.4.4.19 13 | Pillow==7.2.0 14 | pyparsing==2.4.7 15 | python-dateutil==2.8.1 16 | PyWavelets==1.1.1 17 | scikit-image==0.17.2 18 | scipy==1.4.1 19 | six==1.15.0 20 | slot-attention==1.0.2 21 | tifffile==2020.9.3 22 | torch==1.7.1 23 | torchvision==0.8.2 24 | tqdm==4.56.0 25 | typing-extensions==3.7.4.3 26 | -------------------------------------------------------------------------------- /sort_of_clevr/README.md: -------------------------------------------------------------------------------- 1 | Pytorch implementation of Relational Networks - [A simple neural network module for relational reasoning](https://arxiv.org/pdf/1706.01427.pdf) 2 | 3 | Implemented & tested on Sort-of-CLEVR task. 4 | 5 | ## Sort-of-CLEVR 6 | 7 | Sort-of-CLEVR is simplified version of [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/).This is composed of 10000 images and 20 questions (10 relational questions and 10 non-relational questions) per each image. 6 colors (red, green, blue, orange, gray, yellow) are assigned to randomly chosen shape (square or circle), and placed in a image. 8 | 9 | Non-relational questions are composed of 3 subtypes: 10 | 11 | 1) Shape of certain colored object 12 | 2) Horizontal location of certain colored object : whether it is on the left side of the image or right side of the image 13 | 3) Vertical location of certain colored object : whether it is on the upside of the image or downside of the image 14 | 15 | Theses questions are "non-relational" because the agent only need to focus on certain object. 16 | 17 | Relational questions are composed of 3 subtypes: 18 | 19 | 1) Shape of the object which is closest to the certain colored object 20 | 1) Shape of the object which is furthest to the certain colored object 21 | 3) Number of objects which have the same shape with the certain colored object 22 | 23 | These questions are "relational" because the agent has to consider the relations between objects. 24 | 25 | Questions are encoded into a vector of size of 11 : 6 for one-hot vector for certain color among 6 colors, 2 for one-hot vector of relational/non-relational questions. 3 for one-hot vector of 3 subtypes. 26 | 27 | 28 | 29 | I.e., with the sample image shown, we can generate non-relational questions like: 30 | 31 | 1) What is the shape of the red object? => Circle (even though it does not really look like "circle"...) 32 | 2) Is green object placed on the left side of the image? => yes 33 | 3) Is orange object placed on the upside of the image? => no 34 | 35 | And relational questions: 36 | 37 | 1) What is the shape of the object closest to the red object? => square 38 | 2) What is the shape of the object furthest to the orange object? => circle 39 | 3) How many objects have same shape with the blue object? => 3 40 | 41 | ## Setup 42 | 43 | Create conda environment from `environment.yml` file 44 | ``` 45 | $ conda env create -f environment.yml 46 | ``` 47 | Activate environment 48 | ``` 49 | $ conda activate RN3 50 | ``` 51 | If you don't use conda install python 3 normally and use `pip install` to install remaining dependencies. The list of dependencies can be found in the `environment.yml` file. 52 | 53 | ## Usage 54 | 55 | $ ./run.sh 56 | 57 | or 58 | 59 | $ python sort_of_clevr_generator.py 60 | 61 | to generate sort-of-clevr dataset 62 | and 63 | 64 | $ python main.py 65 | 66 | to train the binary RN model. 67 | Alternatively, use 68 | 69 | $ python main.py --relation-type=ternary 70 | 71 | to train the ternary RN model. 72 | 73 | ## Modifications 74 | In the original paper, Sort-of-CLEVR task used different model from CLEVR task. However, because model used CLEVR requires much less time to compute (network is much smaller), this model is used for Sort-of-CLEVR task. 75 | 76 | ## Result 77 | 78 | | | Relational Networks (20th epoch) | CNN + MLP (without RN, 100th epoch) | 79 | | --- | --- | --- | 80 | | Non-relational question | 99% | 66% | 81 | | Relational question | 89% | 66% | 82 | 83 | CNN + MLP occured overfitting to the training data. 84 | 85 | Relational networks shows far better results in relational questions and non-relation questions. 86 | 87 | ## Contributions 88 | 89 | [@gngdb](https://github.com/gngdb) speeds up the model by 10 times. 90 | -------------------------------------------------------------------------------- /sort_of_clevr/main.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | #import cPickle as pickle 6 | import pickle 7 | import random 8 | import numpy as np 9 | import csv 10 | 11 | import torch 12 | #from torch.utils.tensorboard import SummaryWriter 13 | from torch.autograd import Variable 14 | 15 | from model import RN, CNN_MLP, Transformer 16 | 17 | def str2bool(v): 18 | """Method to map string to bool for argument parser""" 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | 28 | # Training settings 29 | parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example') 30 | parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP', 'Transformer'], default='RN', 31 | help='resume from model stored') 32 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 33 | help='input batch size for training (default: 64)') 34 | parser.add_argument('--epochs', type=int, default=20, metavar='N', 35 | help='number of epochs to train (default: 20)') 36 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 37 | help='learning rate (default: 0.0001)') 38 | parser.add_argument('--no-cuda', action='store_true', default=False, 39 | help='disables CUDA training') 40 | #parser.add_argument('--seed', type=int, default=1, metavar='S', 41 | # help='random seed (default: 1)') 42 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 43 | help='how many batches to wait before logging training status') 44 | parser.add_argument('--resume', type=str, 45 | help='resume from model stored') 46 | parser.add_argument('--relation-type', type=str, default='binary', 47 | help='what kind of relations to learn. options: binary, ternary (default: binary)') 48 | parser.add_argument('--num_layers', type = int, default = 4) 49 | parser.add_argument('--functional', type = str2bool, default = False) 50 | parser.add_argument('--use_topk', type = str2bool, default = False) 51 | parser.add_argument('--topk', type = int, default = 3) 52 | parser.add_argument('--shared_memory_attention', type= str2bool, default = False) 53 | parser.add_argument('--embed_dim', type = int, default = 256) 54 | parser.add_argument('--share_vanilla_parameters', type= str2bool, default = False) 55 | parser.add_argument('--save_dir', type = str, default = 'model') 56 | parser.add_argument('--mem_slots', type = int, default = 4) 57 | parser.add_argument('--null_attention', type = str2bool, default = False) 58 | parser.add_argument('--seed', type = int, default = 0) 59 | 60 | args = parser.parse_args() 61 | args.cuda = not args.no_cuda and torch.cuda.is_available() 62 | 63 | torch.manual_seed(args.seed) 64 | if args.cuda: 65 | torch.cuda.manual_seed(args.seed) 66 | 67 | torch.manual_seed(args.seed) 68 | torch.cuda.manual_seed_all(args.seed) 69 | torch.backends.cudnn.deterministic = True 70 | np.random.seed(args.seed) 71 | 72 | #summary_writer = SummaryWriter() 73 | 74 | args.image_size = 75 75 | args.patch_size = 15 76 | 77 | if args.model=='CNN_MLP': 78 | model = CNN_MLP(args) 79 | elif args.model == 'Transformer': 80 | model = Transformer(args) 81 | else: 82 | model = RN(args) 83 | 84 | model_dirs = args.save_dir 85 | bs = args.batch_size 86 | input_img = torch.FloatTensor(bs, 3, 75, 75) 87 | input_qst = torch.FloatTensor(bs, 18) 88 | label = torch.LongTensor(bs) 89 | 90 | if args.cuda: 91 | model.cuda() 92 | input_img = input_img.cuda() 93 | input_qst = input_qst.cuda() 94 | label = label.cuda() 95 | 96 | input_img = Variable(input_img) 97 | input_qst = Variable(input_qst) 98 | label = Variable(label) 99 | 100 | def tensor_data(data, i): 101 | img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)])) 102 | qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)])) 103 | ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)])) 104 | 105 | input_img.data.resize_(img.size()).copy_(img) 106 | input_qst.data.resize_(qst.size()).copy_(qst) 107 | label.data.resize_(ans.size()).copy_(ans) 108 | 109 | 110 | def cvt_data_axis(data): 111 | img = [e[0] for e in data] 112 | qst = [e[1] for e in data] 113 | ans = [e[2] for e in data] 114 | return (img,qst,ans) 115 | 116 | 117 | def train(epoch, ternary, rel, norel): 118 | model.train() 119 | 120 | if not len(rel[0]) == len(norel[0]): 121 | print('Not equal length for relation dataset and non-relation dataset.') 122 | return 123 | 124 | random.shuffle(ternary) 125 | random.shuffle(rel) 126 | random.shuffle(norel) 127 | 128 | ternary = cvt_data_axis(ternary) 129 | rel = cvt_data_axis(rel) 130 | norel = cvt_data_axis(norel) 131 | 132 | acc_ternary = [] 133 | acc_rels = [] 134 | acc_norels = [] 135 | 136 | l_ternary = [] 137 | l_binary = [] 138 | l_unary = [] 139 | 140 | for batch_idx in range(len(rel[0]) // bs): 141 | tensor_data(ternary, batch_idx) 142 | accuracy_ternary, loss_ternary = model.train_(input_img, input_qst, label) 143 | acc_ternary.append(accuracy_ternary.item()) 144 | l_ternary.append(loss_ternary.item()) 145 | 146 | tensor_data(rel, batch_idx) 147 | accuracy_rel, loss_binary = model.train_(input_img, input_qst, label) 148 | acc_rels.append(accuracy_rel.item()) 149 | l_binary.append(loss_binary.item()) 150 | 151 | tensor_data(norel, batch_idx) 152 | accuracy_norel, loss_unary = model.train_(input_img, input_qst, label) 153 | acc_norels.append(accuracy_norel.item()) 154 | l_unary.append(loss_unary.item()) 155 | 156 | if batch_idx % args.log_interval == 0: 157 | print('Train Epoch: {} [{}/{} ({:.0f}%)] ' 158 | 'Ternary accuracy: {:.0f}% | Relations accuracy: {:.0f}% | Non-relations accuracy: {:.0f}%'.format( 159 | epoch, 160 | batch_idx * bs * 2, 161 | len(rel[0]) * 2, 162 | 100. * batch_idx * bs / len(rel[0]), 163 | accuracy_ternary, 164 | accuracy_rel, 165 | accuracy_norel), flush=True) 166 | 167 | avg_acc_ternary = sum(acc_ternary) / len(acc_ternary) 168 | avg_acc_binary = sum(acc_rels) / len(acc_rels) 169 | avg_acc_unary = sum(acc_norels) / len(acc_norels) 170 | 171 | #summary_writer.add_scalars('Accuracy/train', { 172 | # 'ternary': avg_acc_ternary, 173 | # 'binary': avg_acc_binary, 174 | # 'unary': avg_acc_unary 175 | #}, epoch) 176 | 177 | avg_loss_ternary = sum(l_ternary) / len(l_ternary) 178 | avg_loss_binary = sum(l_binary) / len(l_binary) 179 | avg_loss_unary = sum(l_unary) / len(l_unary) 180 | 181 | #summary_writer.add_scalars('Loss/train', { 182 | # 'ternary': avg_loss_ternary, 183 | # 'binary': avg_loss_binary, 184 | # 'unary': avg_loss_unary 185 | #}, epoch) 186 | 187 | # return average accuracy 188 | return avg_acc_ternary, avg_acc_binary, avg_acc_unary 189 | 190 | def test(epoch, ternary, rel, norel): 191 | model.eval() 192 | if not len(rel[0]) == len(norel[0]): 193 | print('Not equal length for relation dataset and non-relation dataset.', flush=True) 194 | return 195 | 196 | ternary = cvt_data_axis(ternary) 197 | rel = cvt_data_axis(rel) 198 | norel = cvt_data_axis(norel) 199 | 200 | accuracy_ternary = [] 201 | accuracy_rels = [] 202 | accuracy_norels = [] 203 | 204 | loss_ternary = [] 205 | loss_binary = [] 206 | loss_unary = [] 207 | 208 | for batch_idx in range(len(rel[0]) // bs): 209 | tensor_data(ternary, batch_idx) 210 | acc_ter, l_ter = model.test_(input_img, input_qst, label) 211 | accuracy_ternary.append(acc_ter.item()) 212 | loss_ternary.append(l_ter.item()) 213 | 214 | tensor_data(rel, batch_idx) 215 | acc_bin, l_bin = model.test_(input_img, input_qst, label) 216 | accuracy_rels.append(acc_bin.item()) 217 | loss_binary.append(l_bin.item()) 218 | 219 | tensor_data(norel, batch_idx) 220 | acc_un, l_un = model.test_(input_img, input_qst, label) 221 | accuracy_norels.append(acc_un.item()) 222 | loss_unary.append(l_un.item()) 223 | 224 | accuracy_ternary = sum(accuracy_ternary) / len(accuracy_ternary) 225 | accuracy_rel = sum(accuracy_rels) / len(accuracy_rels) 226 | accuracy_norel = sum(accuracy_norels) / len(accuracy_norels) 227 | print('\n Test set: Ternary accuracy: {:.0f}% Binary accuracy: {:.0f}% | Unary accuracy: {:.0f}%\n'.format( 228 | accuracy_ternary, accuracy_rel, accuracy_norel), flush=True) 229 | 230 | #summary_writer.add_scalars('Accuracy/test', { 231 | # 'ternary': accuracy_ternary, 232 | # 'binary': accuracy_rel, 233 | # 'unary': accuracy_norel 234 | #}, epoch) 235 | 236 | loss_ternary = sum(loss_ternary) / len(loss_ternary) 237 | loss_binary = sum(loss_binary) / len(loss_binary) 238 | loss_unary = sum(loss_unary) / len(loss_unary) 239 | 240 | #summary_writer.add_scalars('Loss/test', { 241 | # 'ternary': loss_ternary, 242 | # 'binary': loss_binary, 243 | # 'unary': loss_unary 244 | #}, epoch) 245 | 246 | return accuracy_ternary, accuracy_rel, accuracy_norel 247 | 248 | 249 | def load_data(): 250 | print('loading data...') 251 | dirs = './data' 252 | filename = os.path.join(dirs,'sort-of-clevr.pickle') 253 | with open(filename, 'rb') as f: 254 | train_datasets, test_datasets = pickle.load(f) 255 | ternary_train = [] 256 | ternary_test = [] 257 | rel_train = [] 258 | rel_test = [] 259 | norel_train = [] 260 | norel_test = [] 261 | print('processing data...', flush=True) 262 | 263 | for img, ternary, relations, norelations in train_datasets: 264 | img = np.swapaxes(img, 0, 2) 265 | for qst, ans in zip(ternary[0], ternary[1]): 266 | ternary_train.append((img,qst,ans)) 267 | for qst,ans in zip(relations[0], relations[1]): 268 | rel_train.append((img,qst,ans)) 269 | for qst,ans in zip(norelations[0], norelations[1]): 270 | norel_train.append((img,qst,ans)) 271 | 272 | for img, ternary, relations, norelations in test_datasets: 273 | img = np.swapaxes(img, 0, 2) 274 | for qst, ans in zip(ternary[0], ternary[1]): 275 | ternary_test.append((img, qst, ans)) 276 | for qst,ans in zip(relations[0], relations[1]): 277 | rel_test.append((img,qst,ans)) 278 | for qst,ans in zip(norelations[0], norelations[1]): 279 | norel_test.append((img,qst,ans)) 280 | 281 | return (ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test) 282 | 283 | 284 | ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test = load_data() 285 | 286 | try: 287 | os.makedirs(model_dirs) 288 | except: 289 | print('directory {} already exists'.format(model_dirs), flush=True) 290 | 291 | if args.resume: 292 | filename = os.path.join(model_dirs, args.resume) 293 | if os.path.isfile(filename): 294 | print('==> loading checkpoint {}'.format(filename)) 295 | checkpoint = torch.load(filename) 296 | model.load_state_dict(checkpoint) 297 | print('==> loaded checkpoint {}'.format(filename), flush=True) 298 | 299 | with open(f'{args.save_dir}/{args.model}_{args.seed}_log.csv', 'w') as log_file: 300 | csv_writer = csv.writer(log_file, delimiter=',') 301 | csv_writer.writerow(['epoch', 'train_acc_ternary', 'train_acc_rel', 302 | 'train_acc_norel', 'train_acc_ternary', 'test_acc_rel', 'test_acc_norel']) 303 | 304 | print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...", flush=True) 305 | for epoch in range(1, args.epochs + 1): 306 | train_acc_ternary, train_acc_binary, train_acc_unary = train( 307 | epoch, ternary_train, rel_train, norel_train) 308 | test_acc_ternary, test_acc_binary, test_acc_unary = test( 309 | epoch, ternary_test, rel_test, norel_test) 310 | 311 | csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary, 312 | train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary]) 313 | model.save_model(epoch, args.save_dir) 314 | -------------------------------------------------------------------------------- /sort_of_clevr/main_splits.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | #import cPickle as pickle 6 | import pickle 7 | import random 8 | import numpy as np 9 | import csv 10 | 11 | import torch 12 | #from torch.utils.tensorboard import SummaryWriter 13 | from torch.autograd import Variable 14 | 15 | from model import RN, CNN_MLP, Transformer 16 | 17 | def str2bool(v): 18 | """Method to map string to bool for argument parser""" 19 | if isinstance(v, bool): 20 | return v 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | 28 | # Training settings 29 | parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example') 30 | parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP', 'Transformer'], default='RN', 31 | help='resume from model stored') 32 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 33 | help='input batch size for training (default: 64)') 34 | parser.add_argument('--epochs', type=int, default=20, metavar='N', 35 | help='number of epochs to train (default: 20)') 36 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 37 | help='learning rate (default: 0.0001)') 38 | parser.add_argument('--no-cuda', action='store_true', default=False, 39 | help='disables CUDA training') 40 | #parser.add_argument('--seed', type=int, default=1, metavar='S', 41 | # help='random seed (default: 1)') 42 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 43 | help='how many batches to wait before logging training status') 44 | parser.add_argument('--resume', type=str, 45 | help='resume from model stored') 46 | parser.add_argument('--relation-type', type=str, default='binary', 47 | help='what kind of relations to learn. options: binary, ternary (default: binary)') 48 | parser.add_argument('--num_layers', type = int, default = 4) 49 | parser.add_argument('--functional', type = str2bool, default = False) 50 | parser.add_argument('--use_topk', type = str2bool, default = False) 51 | parser.add_argument('--topk', type = int, default = 3) 52 | parser.add_argument('--shared_memory_attention', type= str2bool, default = False) 53 | parser.add_argument('--embed_dim', type = int, default = 256) 54 | parser.add_argument('--share_vanilla_parameters', type= str2bool, default = False) 55 | parser.add_argument('--save_dir', type = str, default = 'model') 56 | parser.add_argument('--mem_slots', type = int, default = 4) 57 | parser.add_argument('--null_attention', type = str2bool, default = False) 58 | parser.add_argument('--seed', type = int, default = 0) 59 | parser.add_argument('--nb_heldout_colors', type = int, default = 0) 60 | 61 | args = parser.parse_args() 62 | args.cuda = not args.no_cuda and torch.cuda.is_available() 63 | 64 | torch.manual_seed(args.seed) 65 | if args.cuda: 66 | torch.cuda.manual_seed(args.seed) 67 | 68 | torch.manual_seed(args.seed) 69 | torch.cuda.manual_seed_all(args.seed) 70 | torch.backends.cudnn.deterministic = True 71 | np.random.seed(args.seed) 72 | 73 | #summary_writer = SummaryWriter() 74 | 75 | args.image_size = 75 76 | args.patch_size = 25 77 | 78 | if args.model=='CNN_MLP': 79 | model = CNN_MLP(args) 80 | elif args.model == 'Transformer': 81 | model = Transformer(args) 82 | else: 83 | model = RN(args) 84 | 85 | model_dirs = args.save_dir 86 | bs = args.batch_size 87 | input_img = torch.FloatTensor(bs, 3, 75, 75) 88 | input_qst = torch.FloatTensor(bs, 11) 89 | label = torch.LongTensor(bs) 90 | 91 | if args.cuda: 92 | model.cuda() 93 | input_img = input_img.cuda() 94 | input_qst = input_qst.cuda() 95 | label = label.cuda() 96 | 97 | input_img = Variable(input_img) 98 | input_qst = Variable(input_qst) 99 | label = Variable(label) 100 | 101 | def tensor_data(data, i): 102 | img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)])) 103 | qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)])) 104 | ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)])) 105 | 106 | input_img.data.resize_(img.size()).copy_(img) 107 | input_qst.data.resize_(qst.size()).copy_(qst) 108 | label.data.resize_(ans.size()).copy_(ans) 109 | 110 | 111 | def cvt_data_axis(data): 112 | img = [e[0] for e in data] 113 | qst = [e[1] for e in data] 114 | ans = [e[2] for e in data] 115 | return (img,qst,ans) 116 | 117 | 118 | def train(epoch, rel, norel): 119 | model.train() 120 | 121 | if not len(rel[0]) == len(norel[0]): 122 | print('Not equal length for relation dataset and non-relation dataset.') 123 | return 124 | 125 | #random.shuffle(ternary) 126 | random.shuffle(rel) 127 | random.shuffle(norel) 128 | 129 | #ternary = cvt_data_axis(ternary) 130 | rel = cvt_data_axis(rel) 131 | norel = cvt_data_axis(norel) 132 | 133 | acc_ternary = [] 134 | acc_rels = [] 135 | acc_norels = [] 136 | 137 | l_ternary = [] 138 | l_binary = [] 139 | l_unary = [] 140 | 141 | for batch_idx in range(len(rel[0]) // bs): 142 | #tensor_data(ternary, batch_idx) 143 | #accuracy_ternary, loss_ternary = model.train_(input_img, input_qst, label) 144 | #acc_ternary.append(accuracy_ternary.item()) 145 | #l_ternary.append(loss_ternary.item()) 146 | 147 | tensor_data(rel, batch_idx) 148 | accuracy_rel, loss_binary = model.train_(input_img, input_qst, label) 149 | acc_rels.append(accuracy_rel.item()) 150 | l_binary.append(loss_binary.item()) 151 | 152 | tensor_data(norel, batch_idx) 153 | accuracy_norel, loss_unary = model.train_(input_img, input_qst, label) 154 | acc_norels.append(accuracy_norel.item()) 155 | l_unary.append(loss_unary.item()) 156 | 157 | if batch_idx % args.log_interval == 0: 158 | print('Train Epoch: {} [{}/{} ({:.0f}%)] ' 159 | ' | Relations accuracy: {:.0f}% | Non-relations accuracy: {:.0f}%'.format( 160 | epoch, 161 | batch_idx * bs * 2, 162 | len(rel[0]) * 2, 163 | 100. * batch_idx * bs / len(rel[0]), 164 | accuracy_rel, 165 | accuracy_norel)) 166 | 167 | #avg_acc_ternary = sum(acc_ternary) / len(acc_ternary) 168 | avg_acc_binary = sum(acc_rels) / len(acc_rels) 169 | avg_acc_unary = sum(acc_norels) / len(acc_norels) 170 | 171 | #summary_writer.add_scalars('Accuracy/train', { 172 | # 'ternary': avg_acc_ternary, 173 | # 'binary': avg_acc_binary, 174 | # 'unary': avg_acc_unary 175 | #}, epoch) 176 | 177 | #avg_loss_ternary = sum(l_ternary) / len(l_ternary) 178 | avg_loss_binary = sum(l_binary) / len(l_binary) 179 | avg_loss_unary = sum(l_unary) / len(l_unary) 180 | 181 | #summary_writer.add_scalars('Loss/train', { 182 | # 'ternary': avg_loss_ternary, 183 | # 'binary': avg_loss_binary, 184 | # 'unary': avg_loss_unary 185 | #}, epoch) 186 | 187 | # return average accuracy 188 | return avg_acc_binary, avg_acc_unary 189 | 190 | def test(epoch, rel, norel): 191 | model.eval() 192 | if not len(rel[0]) == len(norel[0]): 193 | print('Not equal length for relation dataset and non-relation dataset.') 194 | return 195 | 196 | #ternary = cvt_data_axis(ternary) 197 | rel = cvt_data_axis(rel) 198 | norel = cvt_data_axis(norel) 199 | 200 | accuracy_ternary = [] 201 | accuracy_rels = [] 202 | accuracy_norels = [] 203 | 204 | loss_ternary = [] 205 | loss_binary = [] 206 | loss_unary = [] 207 | 208 | for batch_idx in range(len(rel[0]) // bs): 209 | #tensor_data(ternary, batch_idx) 210 | #acc_ter, l_ter = model.test_(input_img, input_qst, label) 211 | #accuracy_ternary.append(acc_ter.item()) 212 | #loss_ternary.append(l_ter.item()) 213 | 214 | tensor_data(rel, batch_idx) 215 | acc_bin, l_bin = model.test_(input_img, input_qst, label) 216 | accuracy_rels.append(acc_bin.item()) 217 | loss_binary.append(l_bin.item()) 218 | 219 | tensor_data(norel, batch_idx) 220 | acc_un, l_un = model.test_(input_img, input_qst, label) 221 | accuracy_norels.append(acc_un.item()) 222 | loss_unary.append(l_un.item()) 223 | 224 | #accuracy_ternary = sum(accuracy_ternary) / len(accuracy_ternary) 225 | accuracy_rel = sum(accuracy_rels) / len(accuracy_rels) 226 | accuracy_norel = sum(accuracy_norels) / len(accuracy_norels) 227 | print('\n Test set: Binary accuracy: {:.0f}% | Unary accuracy: {:.0f}%\n'.format( accuracy_rel, accuracy_norel)) 228 | 229 | #summary_writer.add_scalars('Accuracy/test', { 230 | # 'ternary': accuracy_ternary, 231 | # 'binary': accuracy_rel, 232 | # 'unary': accuracy_norel 233 | #}, epoch) 234 | 235 | #loss_ternary = sum(loss_ternary) / len(loss_ternary) 236 | loss_binary = sum(loss_binary) / len(loss_binary) 237 | loss_unary = sum(loss_unary) / len(loss_unary) 238 | 239 | #summary_writer.add_scalars('Loss/test', { 240 | # 'ternary': loss_ternary, 241 | # 'binary': loss_binary, 242 | # 'unary': loss_unary 243 | #}, epoch) 244 | 245 | return accuracy_rel, accuracy_norel 246 | 247 | 248 | def load_data(): 249 | print('loading data...') 250 | dirs = './data' 251 | filename = os.path.join(dirs,'sort-of-clevr-{}.pickle'.format(args.nb_heldout_colors)) 252 | with open(filename, 'rb') as f: 253 | train_datasets, test_datasets = pickle.load(f) 254 | ternary_train = [] 255 | ternary_test = [] 256 | rel_train = [] 257 | rel_test = [] 258 | norel_train = [] 259 | norel_test = [] 260 | print('processing data...') 261 | 262 | for img, relations, norelations in train_datasets: 263 | img = np.swapaxes(img, 0, 2) 264 | #for qst, ans in zip(ternary[0], ternary[1]): 265 | # ternary_train.append((img,qst,ans)) 266 | for qst,ans in zip(relations[0], relations[1]): 267 | rel_train.append((img,qst,ans)) 268 | for qst,ans in zip(norelations[0], norelations[1]): 269 | norel_train.append((img,qst,ans)) 270 | 271 | for img, relations, norelations in test_datasets: 272 | img = np.swapaxes(img, 0, 2) 273 | #for qst, ans in zip(ternary[0], ternary[1]): 274 | # ternary_test.append((img, qst, ans)) 275 | for qst,ans in zip(relations[0], relations[1]): 276 | rel_test.append((img,qst,ans)) 277 | for qst,ans in zip(norelations[0], norelations[1]): 278 | norel_test.append((img,qst,ans)) 279 | 280 | return (rel_train, rel_test, norel_train, norel_test) 281 | 282 | 283 | rel_train, rel_test, norel_train, norel_test = load_data() 284 | 285 | try: 286 | os.makedirs(model_dirs) 287 | except: 288 | print('directory {} already exists'.format(model_dirs)) 289 | 290 | if args.resume: 291 | filename = os.path.join(model_dirs, args.resume) 292 | if os.path.isfile(filename): 293 | print('==> loading checkpoint {}'.format(filename)) 294 | checkpoint = torch.load(filename) 295 | model.load_state_dict(checkpoint) 296 | print('==> loaded checkpoint {}'.format(filename)) 297 | 298 | with open(f'{args.save_dir}/{args.model}_{args.seed}_log.csv', 'w') as log_file: 299 | csv_writer = csv.writer(log_file, delimiter=',') 300 | csv_writer.writerow(['epoch', 'train_acc_rel', 301 | 'train_acc_norel', 'test_acc_rel', 'test_acc_norel']) 302 | 303 | print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...") 304 | 305 | for epoch in range(1, args.epochs + 1): 306 | train_acc_binary, train_acc_unary = train( 307 | epoch, rel_train, norel_train) 308 | test_acc_binary, test_acc_unary = test( 309 | epoch, rel_test, norel_test) 310 | 311 | csv_writer.writerow([epoch, train_acc_binary, 312 | train_acc_unary, test_acc_binary, test_acc_unary]) 313 | model.save_model(epoch, args.save_dir) 314 | -------------------------------------------------------------------------------- /sort_of_clevr/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.autograd import Variable 7 | from transformers import TransformerEncoder 8 | from einops import rearrange, repeat 9 | from transformer_utilities.set_transformer import SetTransformer 10 | 11 | 12 | class ConvInputModel(nn.Module): 13 | def __init__(self): 14 | super(ConvInputModel, self).__init__() 15 | 16 | self.conv1 = nn.Conv2d(3, 24, 3, stride=2, padding=1) 17 | self.batchNorm1 = nn.BatchNorm2d(24) 18 | self.conv2 = nn.Conv2d(24, 24, 3, stride=2, padding=1) 19 | self.batchNorm2 = nn.BatchNorm2d(24) 20 | self.conv3 = nn.Conv2d(24, 24, 3, stride=2, padding=1) 21 | self.batchNorm3 = nn.BatchNorm2d(24) 22 | self.conv4 = nn.Conv2d(24, 24, 3, stride=2, padding=1) 23 | self.batchNorm4 = nn.BatchNorm2d(24) 24 | 25 | 26 | def forward(self, img): 27 | """convolution""" 28 | x = self.conv1(img) 29 | x = F.relu(x) 30 | x = self.batchNorm1(x) 31 | x = self.conv2(x) 32 | x = F.relu(x) 33 | x = self.batchNorm2(x) 34 | x = self.conv3(x) 35 | x = F.relu(x) 36 | x = self.batchNorm3(x) 37 | x = self.conv4(x) 38 | x = F.relu(x) 39 | x = self.batchNorm4(x) 40 | return x 41 | 42 | 43 | class FCOutputModel(nn.Module): 44 | def __init__(self): 45 | super(FCOutputModel, self).__init__() 46 | 47 | self.fc2 = nn.Linear(256, 256) 48 | self.fc3 = nn.Linear(256, 10) 49 | 50 | def forward(self, x): 51 | x = self.fc2(x) 52 | x = F.relu(x) 53 | x = F.dropout(x) 54 | x = self.fc3(x) 55 | return F.log_softmax(x, dim=1) 56 | 57 | class BasicModel(nn.Module): 58 | def __init__(self, args, name): 59 | super(BasicModel, self).__init__() 60 | self.name=name 61 | 62 | def train_(self, input_img, input_qst, label): 63 | self.optimizer.zero_grad() 64 | output = self(input_img, input_qst) 65 | loss = F.nll_loss(output, label) 66 | loss.backward() 67 | self.optimizer.step() 68 | pred = output.data.max(1)[1] 69 | correct = pred.eq(label.data).cpu().sum() 70 | accuracy = correct * 100. / len(label) 71 | return accuracy, loss 72 | 73 | def test_(self, input_img, input_qst, label): 74 | output = self(input_img, input_qst) 75 | loss = F.nll_loss(output, label) 76 | pred = output.data.max(1)[1] 77 | correct = pred.eq(label.data).cpu().sum() 78 | accuracy = correct * 100. / len(label) 79 | return accuracy, loss 80 | 81 | def save_model(self, epoch, save_dir): 82 | import os 83 | name = 'epoch_{}_{:02d}.pth'.format(self.name, epoch) 84 | path = os.path.join(save_dir, name) 85 | torch.save(self.state_dict(), path) 86 | 87 | 88 | class RN(BasicModel): 89 | def __init__(self, args): 90 | super(RN, self).__init__(args, 'RN') 91 | 92 | self.conv = ConvInputModel() 93 | 94 | self.relation_type = args.relation_type 95 | 96 | if self.relation_type == 'ternary': 97 | ##(number of filters per object+coordinate of object)*3+question vector 98 | self.g_fc1 = nn.Linear((24+2)*3+18, 256) 99 | else: 100 | ##(number of filters per object+coordinate of object)*2+question vector 101 | self.g_fc1 = nn.Linear((24+2)*2+18, 256) 102 | 103 | self.g_fc2 = nn.Linear(256, 256) 104 | self.g_fc3 = nn.Linear(256, 256) 105 | self.g_fc4 = nn.Linear(256, 256) 106 | 107 | self.f_fc1 = nn.Linear(256, 256) 108 | 109 | self.coord_oi = torch.FloatTensor(args.batch_size, 2) 110 | self.coord_oj = torch.FloatTensor(args.batch_size, 2) 111 | if args.cuda: 112 | self.coord_oi = self.coord_oi.cuda() 113 | self.coord_oj = self.coord_oj.cuda() 114 | self.coord_oi = Variable(self.coord_oi) 115 | self.coord_oj = Variable(self.coord_oj) 116 | 117 | # prepare coord tensor 118 | def cvt_coord(i): 119 | return [(i/5-2)/2., (i%5-2)/2.] 120 | 121 | self.coord_tensor = torch.FloatTensor(args.batch_size, 25, 2) 122 | if args.cuda: 123 | self.coord_tensor = self.coord_tensor.cuda() 124 | self.coord_tensor = Variable(self.coord_tensor) 125 | np_coord_tensor = np.zeros((args.batch_size, 25, 2)) 126 | for i in range(25): 127 | np_coord_tensor[:,i,:] = np.array( cvt_coord(i) ) 128 | self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor)) 129 | 130 | 131 | self.fcout = FCOutputModel() 132 | 133 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 134 | 135 | 136 | def forward(self, img, qst): 137 | x = self.conv(img) ## x = (64 x 24 x 5 x 5) 138 | 139 | """g""" 140 | mb = x.size()[0] 141 | n_channels = x.size()[1] 142 | d = x.size()[2] 143 | # x_flat = (64 x 25 x 24) 144 | x_flat = x.view(mb,n_channels,d*d).permute(0,2,1) 145 | 146 | # add coordinates 147 | x_flat = torch.cat([x_flat, self.coord_tensor],2) 148 | 149 | 150 | if self.relation_type == 'ternary': 151 | # add question everywhere 152 | qst = torch.unsqueeze(qst, 1) # (64x1x18) 153 | qst = qst.repeat(1, 25, 1) # (64x25x18) 154 | qst = torch.unsqueeze(qst, 1) # (64x1x25x18) 155 | qst = torch.unsqueeze(qst, 1) # (64x1x1x25x18) 156 | 157 | # cast all triples against each other 158 | x_i = torch.unsqueeze(x_flat, 1) # (64x1x25x26) 159 | x_i = torch.unsqueeze(x_i, 3) # (64x1x25x1x26) 160 | x_i = x_i.repeat(1, 25, 1, 25, 1) # (64x25x25x25x26) 161 | 162 | x_j = torch.unsqueeze(x_flat, 2) # (64x25x1x26) 163 | x_j = torch.unsqueeze(x_j, 2) # (64x25x1x1x26) 164 | x_j = x_j.repeat(1, 1, 25, 25, 1) # (64x25x25x25x26) 165 | 166 | x_k = torch.unsqueeze(x_flat, 1) # (64x1x25x26) 167 | x_k = torch.unsqueeze(x_k, 1) # (64x1x1x25x26) 168 | x_k = torch.cat([x_k, qst], 4) # (64x1x1x25x26+18) 169 | x_k = x_k.repeat(1, 25, 25, 1, 1) # (64x25x25x25x26+18) 170 | 171 | # concatenate all together 172 | x_full = torch.cat([x_i, x_j, x_k], 4) # (64x25x25x25x3*26+18) 173 | 174 | # reshape for passing through network 175 | x_ = x_full.view(mb * (d * d) * (d * d) * (d * d), 96) # (64*25*25*25x3*26+18) = (1.000.000, 96) 176 | else: 177 | # add question everywhere 178 | qst = torch.unsqueeze(qst, 1) 179 | qst = qst.repeat(1, 25, 1) 180 | qst = torch.unsqueeze(qst, 2) 181 | 182 | # cast all pairs against each other 183 | x_i = torch.unsqueeze(x_flat, 1) # (64x1x25x26+18) 184 | x_i = x_i.repeat(1, 25, 1, 1) # (64x25x25x26+18) 185 | x_j = torch.unsqueeze(x_flat, 2) # (64x25x1x26+18) 186 | x_j = torch.cat([x_j, qst], 3) 187 | x_j = x_j.repeat(1, 1, 25, 1) # (64x25x25x26+18) 188 | 189 | # concatenate all together 190 | x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+18) 191 | 192 | # reshape for passing through network 193 | x_ = x_full.view(mb * (d * d) * (d * d), 70) # (64*25*25x2*26*18) = (40.000, 70) 194 | 195 | x_ = self.g_fc1(x_) 196 | x_ = F.relu(x_) 197 | x_ = self.g_fc2(x_) 198 | x_ = F.relu(x_) 199 | x_ = self.g_fc3(x_) 200 | x_ = F.relu(x_) 201 | x_ = self.g_fc4(x_) 202 | x_ = F.relu(x_) 203 | 204 | # reshape again and sum 205 | if self.relation_type == 'ternary': 206 | x_g = x_.view(mb, (d * d) * (d * d) * (d * d), 256) 207 | else: 208 | x_g = x_.view(mb, (d * d) * (d * d), 256) 209 | 210 | x_g = x_g.sum(1).squeeze() 211 | 212 | """f""" 213 | x_f = self.f_fc1(x_g) 214 | x_f = F.relu(x_f) 215 | 216 | return self.fcout(x_f) 217 | 218 | 219 | class CNN_MLP(BasicModel): 220 | def __init__(self, args): 221 | super(CNN_MLP, self).__init__(args, 'CNNMLP') 222 | 223 | self.conv = ConvInputModel() 224 | self.fc1 = nn.Linear(5*5*24 + 18, 256) # question concatenated to all 225 | self.fcout = FCOutputModel() 226 | 227 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 228 | #print([ a for a in self.parameters() ] ) 229 | 230 | def forward(self, img, qst): 231 | print(qst.size()) 232 | x = self.conv(img) ## x = (64 x 24 x 5 x 5) 233 | 234 | """fully connected layers""" 235 | x = x.view(x.size(0), -1) 236 | 237 | x_ = torch.cat((x, qst), 1) # Concat question 238 | 239 | x_ = self.fc1(x_) 240 | x_ = F.relu(x_) 241 | 242 | return self.fcout(x_) 243 | 244 | class Transformer(BasicModel): 245 | def __init__(self, args): 246 | super(Transformer, self).__init__(args, 'Transformer') 247 | 248 | image_size = args.image_size 249 | patch_size = args.patch_size 250 | h_dim = args.embed_dim 251 | channels = 3 252 | num_classes = 10 253 | 254 | 255 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 256 | num_patches = (image_size // patch_size) ** 2 257 | patch_dim = channels * patch_size ** 2 258 | #assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size' 259 | if args.functional: 260 | print('USING SET TRANSFORMER') 261 | self.net = SetTransformer(h_dim, dim_hidden = 512, num_inds = args.mem_slots) 262 | else: 263 | self.net = TransformerEncoder( 264 | h_dim, 265 | 512, 266 | num_layers = args.num_layers, 267 | num_heads = 4, 268 | dropout = 0.1, 269 | share_parameters = args.share_vanilla_parameters, 270 | shared_memory_attention = args.shared_memory_attention, 271 | use_topk = args.use_topk, 272 | topk = args.topk, 273 | mem_slots = args.mem_slots, 274 | null_attention = args.null_attention, 275 | num_steps = int((image_size*image_size) / (patch_size * patch_size) + 1 + 18) ) 276 | 277 | self.patch_size = patch_size 278 | print(patch_dim) 279 | self.patch_to_embedding = nn.Linear(patch_dim, h_dim) 280 | self.question_to_embedding = nn.Linear(18, h_dim) 281 | #self.question_to_embedding = nn.Linear(1, h_dim) 282 | self.cls_token = nn.Parameter(torch.randn(1, 1, h_dim)) 283 | if args.functional: 284 | self.mlp_head = nn.Linear(512, num_classes) 285 | else: 286 | self.mlp_head = nn.Linear(h_dim, num_classes) 287 | self.optimizer = optim.Adam(self.parameters(), lr=args.lr) 288 | 289 | def forward(self, img, qst): 290 | p = self.patch_size 291 | 292 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 293 | #print(x.size()) 294 | x = self.patch_to_embedding(x) 295 | 296 | #q = self.question_to_embedding(qst.unsqueeze(-1)) 297 | q = self.question_to_embedding(qst) 298 | 299 | q= q.unsqueeze(1) 300 | #print(x.size(), flush=True) 301 | #print(q.size(), flush=True) 302 | x = torch.cat((q, x), dim = 1) 303 | #print(x.size()) 304 | b, n, _ = x.shape 305 | 306 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 307 | x = torch.cat((cls_tokens, x), dim=1) 308 | #print(x.size()) 309 | 310 | x = self.net(x) 311 | 312 | x = F.log_softmax(self.mlp_head(x[:,0]), dim = 1) 313 | 314 | return x 315 | -------------------------------------------------------------------------------- /sort_of_clevr/run_transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #source ~/.bashrc 3 | 4 | 5 | embed_dim=$1 6 | num_layers=$2 7 | share_vanilla_parameters=$3 8 | use_topk=$4 9 | topk=$5 10 | shared_memory_attention=$6 11 | mem_slots=$7 12 | null_attention=False 13 | seed=${8} 14 | set_transformer=$9 15 | 16 | 17 | 18 | save_dir=$embed_dim-$num_layers-$set_transformer-$share_vanilla_parameters-$use_topk-$topk-$shared_memory_attention-$mem_slots-$null_attention-$seed 19 | 20 | mkdir $save_dir 21 | 22 | python main.py --model Transformer --epochs 100 --embed_dim $embed_dim --num_layers $num_layers \ 23 | --functional $set_transformer --share_vanilla_parameters $share_vanilla_parameters \ 24 | --use_topk $use_topk --topk $topk --shared_memory_attention $shared_memory_attention \ 25 | --save_dir $save_dir --mem_slots $mem_slots --null_attention $null_attention \ 26 | --seed $seed 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /sort_of_clevr/run_transformer_splits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | embed_dim=$1 4 | num_layers=$2 5 | functional=$3 6 | share_vanilla_parameters=$4 7 | use_topk=$5 8 | topk=$6 9 | shared_memory_attention=$7 10 | mem_slots=$8 11 | null_attention=$9 12 | seed=${10} 13 | nb_heldout_colors=${11} 14 | 15 | save_dir=$embed_dim-$num_layers-$functional-$share_vanilla_parameters-$use_topk-$topk-$shared_memory_attention-$mem_slots-$null_attention-$seed-$nb_heldout_colors 16 | 17 | mkdir $save_dir 18 | 19 | python main_splits.py --model Transformer --epochs 100 --embed_dim $embed_dim --num_layers $num_layers \ 20 | --functional $functional --share_vanilla_parameters $share_vanilla_parameters \ 21 | --use_topk $use_topk --topk $topk --shared_memory_attention $shared_memory_attention \ 22 | --save_dir $save_dir --mem_slots $mem_slots --null_attention $null_attention \ 23 | --seed $seed --nb_heldout_colors $nb_heldout_colors 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /sort_of_clevr/sort_of_clevr_generator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import random 5 | #import cPickle as pickle 6 | import pickle 7 | import warnings 8 | import argparse 9 | from tqdm import tqdm 10 | 11 | parser = argparse.ArgumentParser(description='Sort-of-CLEVR dataset generator') 12 | parser.add_argument('--seed', type=int, default=1, metavar='S', 13 | help='random seed (default: 1)') 14 | parser.add_argument('--t-subtype', type=int, default=-1, 15 | help='Force ternary questions to be of a given type') 16 | args = parser.parse_args() 17 | 18 | random.seed(args.seed) 19 | np.random.seed(args.seed) 20 | 21 | train_size = 9800 22 | test_size = 200 23 | img_size = 75 24 | size = 5 25 | question_size = 18 ## 2 x (6 for one-hot vector of color), 3 for question type, 3 for question subtype 26 | q_type_idx = 12 27 | sub_q_type_idx = 15 28 | """Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]""" 29 | 30 | nb_questions = 10 31 | dirs = './data' 32 | 33 | colors = [ 34 | (0,0,255),##r 35 | (0,255,0),##g 36 | (255,0,0),##b 37 | (0,156,255),##o 38 | (128,128,128),##k 39 | (0,255,255)##y 40 | ] 41 | 42 | 43 | try: 44 | os.makedirs(dirs) 45 | except: 46 | print('directory {} already exists'.format(dirs)) 47 | 48 | def center_generate(objects): 49 | while True: 50 | pas = True 51 | center = np.random.randint(0+size, img_size - size, 2) 52 | if len(objects) > 0: 53 | for name,c,shape in objects: 54 | if ((center - c) ** 2).sum() < ((size * 2) ** 2): 55 | pas = False 56 | if pas: 57 | return center 58 | 59 | 60 | 61 | def build_dataset(): 62 | objects = [] 63 | img = np.ones((img_size,img_size,3)) * 255 64 | for color_id,color in enumerate(colors): 65 | center = center_generate(objects) 66 | if random.random()<0.5: 67 | start = (center[0]-size, center[1]-size) 68 | end = (center[0]+size, center[1]+size) 69 | cv2.rectangle(img, start, end, color, -1) 70 | objects.append((color_id,center,'r')) 71 | else: 72 | center_ = (center[0], center[1]) 73 | cv2.circle(img, center_, size, color, -1) 74 | objects.append((color_id,center,'c')) 75 | 76 | 77 | ternary_questions = [] 78 | binary_questions = [] 79 | norel_questions = [] 80 | ternary_answers = [] 81 | binary_answers = [] 82 | norel_answers = [] 83 | """Non-relational questions""" 84 | for _ in range(nb_questions): 85 | question = np.zeros((question_size)) 86 | color = random.randint(0,5) 87 | question[color] = 1 88 | question[q_type_idx] = 1 89 | subtype = random.randint(0,2) 90 | question[subtype+sub_q_type_idx] = 1 91 | norel_questions.append(question) 92 | """Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]""" 93 | if subtype == 0: 94 | """query shape->rectangle/circle""" 95 | if objects[color][2] == 'r': 96 | answer = 2 97 | else: 98 | answer = 3 99 | 100 | elif subtype == 1: 101 | """query horizontal position->yes/no""" 102 | if objects[color][1][0] < img_size / 2: 103 | answer = 0 104 | else: 105 | answer = 1 106 | 107 | elif subtype == 2: 108 | """query vertical position->yes/no""" 109 | if objects[color][1][1] < img_size / 2: 110 | answer = 0 111 | else: 112 | answer = 1 113 | norel_answers.append(answer) 114 | 115 | """Binary Relational questions""" 116 | for _ in range(nb_questions): 117 | question = np.zeros((question_size)) 118 | color = random.randint(0,5) 119 | question[color] = 1 120 | question[q_type_idx+1] = 1 121 | subtype = random.randint(0,2) 122 | question[subtype+sub_q_type_idx] = 1 123 | binary_questions.append(question) 124 | 125 | if subtype == 0: 126 | """closest-to->rectangle/circle""" 127 | my_obj = objects[color][1] 128 | dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects] 129 | dist_list[dist_list.index(0)] = 999 130 | closest = dist_list.index(min(dist_list)) 131 | if objects[closest][2] == 'r': 132 | answer = 2 133 | else: 134 | answer = 3 135 | 136 | elif subtype == 1: 137 | """furthest-from->rectangle/circle""" 138 | my_obj = objects[color][1] 139 | dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects] 140 | furthest = dist_list.index(max(dist_list)) 141 | if objects[furthest][2] == 'r': 142 | answer = 2 143 | else: 144 | answer = 3 145 | 146 | elif subtype == 2: 147 | """count->1~6""" 148 | my_obj = objects[color][2] 149 | count = -1 150 | for obj in objects: 151 | if obj[2] == my_obj: 152 | count +=1 153 | answer = count+4 154 | 155 | binary_answers.append(answer) 156 | 157 | """Ternary Relational questions""" 158 | for _ in range(nb_questions): 159 | question = np.zeros((question_size)) 160 | rnd_colors = np.random.permutation(np.arange(5)) 161 | # 1st object 162 | color1 = rnd_colors[0] 163 | question[color1] = 1 164 | # 2nd object 165 | color2 = rnd_colors[1] 166 | question[6 + color2] = 1 167 | 168 | question[q_type_idx + 2] = 1 169 | 170 | if args.t_subtype >= 0 and args.t_subtype < 3: 171 | subtype = args.t_subtype 172 | else: 173 | subtype = random.randint(0, 2) 174 | 175 | question[subtype+sub_q_type_idx] = 1 176 | ternary_questions.append(question) 177 | 178 | # get coordiantes of object from question 179 | A = objects[color1][1] 180 | B = objects[color2][1] 181 | 182 | if subtype == 0: 183 | """between->1~4""" 184 | 185 | between_count = 0 186 | # check is any objects lies inside the box 187 | for other_obj in objects: 188 | # skip object A and B 189 | if (other_obj[0] == color1) or (other_obj[0] == color2): 190 | continue 191 | 192 | # Get x and y coordinate of third object 193 | other_objx = other_obj[1][0] 194 | other_objy = other_obj[1][1] 195 | 196 | if (A[0] <= other_objx <= B[0] and A[1] <= other_objy <= B[1]) or \ 197 | (A[0] <= other_objx <= B[0] and B[1] <= other_objy <= A[1]) or \ 198 | (B[0] <= other_objx <= A[0] and B[1] <= other_objy <= A[1]) or \ 199 | (B[0] <= other_objx <= A[0] and A[1] <= other_objy <= B[1]): 200 | between_count += 1 201 | 202 | answer = between_count + 4 203 | elif subtype == 1: 204 | """is-on-band->yes/no""" 205 | 206 | grace_threshold = 12 # half of the size of objects 207 | epsilon = 1e-10 208 | m = (B[1]-A[1])/((B[0]-A[0]) + epsilon ) # add epsilon to prevent dividing by zero 209 | c = A[1] - (m*A[0]) 210 | 211 | answer = 1 # default answer is 'no' 212 | 213 | # check if any object lies on/close the line between object A and object B 214 | for other_obj in objects: 215 | # skip object A and B 216 | if (other_obj[0] == color1) or (other_obj[0] == color2): 217 | continue 218 | 219 | other_obj_pos = other_obj[1] 220 | 221 | # y = mx + c 222 | y = (m*other_obj_pos[0]) + c 223 | if (y - grace_threshold) <= other_obj_pos[1] <= (y + grace_threshold): 224 | answer = 0 225 | elif subtype == 2: 226 | """count-obtuse-triangles->1~6""" 227 | 228 | obtuse_count = 0 229 | 230 | # disable warnings 231 | # the angle computation may fail if the points are on a line 232 | warnings.filterwarnings("ignore") 233 | for other_obj in objects: 234 | # skip object A and B 235 | if (other_obj[0] == color1) or (other_obj[0] == color2): 236 | continue 237 | 238 | # get position of 3rd object 239 | C = other_obj[1] 240 | # edge length 241 | a = np.linalg.norm(B - C) 242 | b = np.linalg.norm(C - A) 243 | c = np.linalg.norm(A - B) 244 | # angles by law of cosine 245 | alpha = np.rad2deg(np.arccos((b ** 2 + c ** 2 - a ** 2) / (2 * b * c))) 246 | beta = np.rad2deg(np.arccos((a ** 2 + c ** 2 - b ** 2) / (2 * a * c))) 247 | gamma = np.rad2deg(np.arccos((a ** 2 + b ** 2 - c ** 2) / (2 * a * b))) 248 | max_angle = max(alpha, beta, gamma) 249 | if max_angle >= 90 and max_angle < 180: 250 | obtuse_count += 1 251 | 252 | warnings.filterwarnings("default") 253 | answer = obtuse_count + 4 254 | 255 | ternary_answers.append(answer) 256 | 257 | ternary_relations = (ternary_questions, ternary_answers) 258 | binary_relations = (binary_questions, binary_answers) 259 | norelations = (norel_questions, norel_answers) 260 | 261 | img = img/255. 262 | dataset = (img, ternary_relations, binary_relations, norelations) 263 | return dataset 264 | 265 | 266 | print('building test datasets...') 267 | test_datasets = [build_dataset() for _ in tqdm(range(test_size))] 268 | print('building train datasets...') 269 | train_datasets = [build_dataset() for _ in tqdm(range(train_size))] 270 | 271 | 272 | #img_count = 0 273 | #cv2.imwrite(os.path.join(dirs,'{}.png'.format(img_count)), cv2.resize(train_datasets[0][0]*255, (512,512))) 274 | 275 | 276 | print('saving datasets...') 277 | filename = os.path.join(dirs,'sort-of-clevr.pickle') 278 | with open(filename, 'wb') as f: 279 | pickle.dump((train_datasets, test_datasets), f) 280 | print('datasets saved at {}'.format(filename)) 281 | -------------------------------------------------------------------------------- /sort_of_clevr/sort_of_clevr_splits.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import random 5 | import pickle 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description='Generate Sort-of-ClEVR (https://arxiv.org/abs/1706.01427)') 9 | parser.add_argument('--train_size', type=int, default=9800) #9800 10 | parser.add_argument('--test_size', type=int, default=200) #200 11 | parser.add_argument('--image_size', type=int, default=75) #75 12 | parser.add_argument('--object_size', type=int, default=5) 13 | parser.add_argument('--nb_questions', type=int, default=10) # 10 relational and 10 non-rel qst per image 14 | parser.add_argument('--nb_heldout_colors', type=int, default=0) 15 | 16 | args = parser.parse_args() 17 | 18 | train_size = args.train_size 19 | test_size = args.test_size 20 | img_size = args.image_size 21 | size = args.object_size 22 | nb_questions = args.nb_questions 23 | 24 | question_size = 11 # 6 for one-hot vector of color, 2 for question type, 3 for question subtype 25 | dirs = './data' 26 | 27 | all_colors = [ 28 | (0,0,255),##r 29 | (0,255,0),##g 30 | (255,0,0),##b 31 | (0,156,255),##o 32 | (128,128,128),##k 33 | (0,255,255)##y 34 | ] 35 | 36 | try: 37 | os.makedirs(dirs) 38 | except: 39 | print('directory {} already exists'.format(dirs)) 40 | 41 | def center_generate(objects): 42 | while True: 43 | pas = True 44 | center = np.random.randint(0+size, img_size - size, 2) 45 | if len(objects) > 0: 46 | for name,c,shape in objects: 47 | if ((center - c) ** 2).sum() < ((size * 2) ** 2): 48 | pas = False 49 | if pas: 50 | return center 51 | 52 | def build_dataset(left_out_color=0): 53 | objects = [] 54 | img = np.ones((img_size,img_size,3)) * 255 55 | 56 | colors = all_colors[left_out_color:] 57 | 58 | for color_id,color in enumerate(colors): 59 | center = center_generate(objects) 60 | if random.random()<0.5: # 2 possible shapes 61 | start = (center[0]-size, center[1]-size) 62 | end = (center[0]+size, center[1]+size) 63 | cv2.rectangle(img, start, end, color, -1) 64 | objects.append((color_id,center,'r')) 65 | else: 66 | center_ = (center[0], center[1]) 67 | cv2.circle(img, center_, size, color, -1) 68 | objects.append((color_id,center,'c')) 69 | 70 | 71 | rel_questions = [] 72 | norel_questions = [] 73 | rel_answers = [] 74 | norel_answers = [] 75 | 76 | """Non-relational questions""" 77 | for _ in range(nb_questions): 78 | question = np.zeros((question_size)) 79 | color = random.randint(0,len(colors) - 1) 80 | question[color] = 1 81 | question[6] = 1 82 | subtype = random.randint(0,2) 83 | question[subtype+8] = 1 84 | norel_questions.append(question) 85 | """Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]""" 86 | if subtype == 0: 87 | """query shape->rectangle/circle""" 88 | if objects[color][2] == 'r': 89 | answer = 2 90 | else: 91 | answer = 3 92 | 93 | elif subtype == 1: 94 | """query horizontal position->yes/no""" 95 | if objects[color][1][0] < img_size / 2: 96 | answer = 0 97 | else: 98 | answer = 1 99 | 100 | elif subtype == 2: 101 | """query vertical position->yes/no""" 102 | if objects[color][1][1] < img_size / 2: 103 | answer = 0 104 | else: 105 | answer = 1 106 | norel_answers.append(answer) 107 | 108 | """Relational questions""" 109 | for i in range(nb_questions): 110 | question = np.zeros((question_size)) 111 | color = random.randint(0,len(colors) - 1) 112 | question[color] = 1 113 | question[7] = 1 114 | subtype = random.randint(0,2) 115 | question[subtype+8] = 1 116 | rel_questions.append(question) 117 | 118 | if subtype == 0: 119 | """closest-to->rectangle/circle""" 120 | my_obj = objects[color][1] 121 | dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects] 122 | dist_list[dist_list.index(0)] = 999 123 | closest = dist_list.index(min(dist_list)) 124 | if objects[closest][2] == 'r': 125 | answer = 2 126 | else: 127 | answer = 3 128 | 129 | elif subtype == 1: 130 | """furthest-from->rectangle/circle""" 131 | my_obj = objects[color][1] 132 | dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects] 133 | furthest = dist_list.index(max(dist_list)) 134 | if objects[furthest][2] == 'r': 135 | answer = 2 136 | else: 137 | answer = 3 138 | 139 | elif subtype == 2: 140 | """count->1~6""" 141 | my_obj = objects[color][2] 142 | count = -1 143 | for obj in objects: 144 | if obj[2] == my_obj: 145 | count +=1 146 | answer = count+4 147 | 148 | rel_answers.append(answer) 149 | 150 | relations = (rel_questions, rel_answers) 151 | norelations = (norel_questions, norel_answers) 152 | 153 | img = img/255. 154 | dataset = (img, relations, norelations) 155 | return dataset 156 | 157 | 158 | print('building test datasets...') 159 | test_datasets = [build_dataset() for _ in range(test_size)] 160 | print('building train datasets...') 161 | train_datasets = [build_dataset(left_out_color=args.nb_heldout_colors) for _ in range(train_size)] 162 | print('saving datasets...') 163 | filename = os.path.join(dirs,'sort-of-clevr-{}.pickle'.format(args.nb_heldout_colors)) 164 | with open(filename, 'wb') as f: 165 | pickle.dump((train_datasets, test_datasets), f) 166 | print('datasets saved at {}'.format(filename)) -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/Gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def gelu_accurate(x): 17 | if not hasattr(gelu_accurate, "_a"): 18 | gelu_accurate._a = math.sqrt(2 / math.pi) 19 | return ( 20 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 21 | ) 22 | 23 | 24 | def gelu(x: torch.Tensor) -> torch.Tensor: 25 | return torch.nn.functional.gelu(x.float()).type_as(x) 26 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/GroupGRUCell.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | 7 | class GroupGRUCell(nn.Module): 8 | """ 9 | GroupGRUCell can compute the operation of N GRU Cells at once. 10 | """ 11 | def __init__(self, input_size, hidden_size, num_grus): 12 | super(GroupGRUCell, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.x2h = GroupLinearLayer(input_size, 3 * hidden_size, num_grus) 16 | self.h2h = GroupLinearLayer(hidden_size, 3 * hidden_size, num_grus) 17 | self.reset_parameters() 18 | 19 | 20 | 21 | def reset_parameters(self): 22 | std = 1.0 / math.sqrt(self.hidden_size) 23 | for w in self.parameters(): 24 | w.data = torch.ones(w.data.size())#.uniform_(-std, std) 25 | 26 | def forward(self, x, hidden): 27 | """ 28 | input: x (batch_size, num_grus, input_size) 29 | hidden (batch_size, num_grus, hidden_size) 30 | output: hidden (batch_size, num_grus, hidden_size) 31 | """ 32 | gate_x = self.x2h(x) 33 | gate_h = self.h2h(hidden) 34 | 35 | i_r, i_i, i_n = gate_x.chunk(3, 2) 36 | h_r, h_i, h_n = gate_h.chunk(3, 2) 37 | 38 | 39 | resetgate = torch.sigmoid(i_r + h_r) 40 | inputgate = torch.sigmoid(i_i + h_i) 41 | newgate = torch.tanh(i_n + (resetgate * h_n)) 42 | 43 | hy = newgate + inputgate * (hidden - newgate) 44 | 45 | return hy 46 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/GroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | class GroupLinearLayer(nn.Module): 6 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 7 | super(GroupLinearLayer, self).__init__() 8 | self.nb = num_blocks 9 | #din = din // num_blocks 10 | #dout = dout // num_blocks 11 | self.dout = dout 12 | if a is None: 13 | a = 1. / math.sqrt(dout) 14 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 15 | self.bias = bias 16 | if bias is True: 17 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 18 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 19 | else: 20 | self.bias = None 21 | def forward(self,x): 22 | ts,bs,m = x.shape 23 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 24 | x = x.permute(1,0,2) 25 | x = torch.bmm(x,self.weight) 26 | x = x.permute(1,0,2) 27 | if not self.bias is None: 28 | x = x + self.bias 29 | #x = x.reshape((ts, bs, self.dout*self.nb)) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/attention_rim.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import random 7 | from .sparse_attn import Sparse_attention 8 | import torch.nn.functional as F 9 | from .GroupLinearLayer import GroupLinearLayer 10 | from .sparse_grad_attn import Sparse_grad_attention 11 | 12 | 13 | class Identity_2(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, input): 16 | return input * 1.0 17 | def backward(ctx, grad_output): 18 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 19 | print('+++++++++') 20 | return grad_output * 1.0 21 | 22 | class Identity(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, input): 25 | return input * 1.0 26 | def backward(ctx, grad_output): 27 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 28 | print('-----------') 29 | return grad_output * 1.0 30 | 31 | class ScaledDotProductAttention(nn.Module): 32 | ''' Scaled Dot-Product Attention ''' 33 | 34 | def __init__(self, temperature, topk = -1, grad_sparse=False, attn_dropout=0.1, flag=False): 35 | super().__init__() 36 | self.temperature = temperature 37 | #self.dropout = nn.Dropout(attn_dropout) 38 | self.softmax = nn.Softmax(dim=2) 39 | self.grad_sparse = grad_sparse 40 | self.topk = topk 41 | self.sa = Sparse_attention(top_k=topk) #k=2 42 | self.flag = flag 43 | 44 | def forward(self, q, k, v, mask=None): 45 | 46 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Forward of Scaled Dot Product Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 47 | # print("q: ", q.size()) 48 | # print("k: ", k.size()) 49 | # print("v: ", v.size()) 50 | # print("k transpose: ", k.transpose(1,2).size()) 51 | # input() 52 | 53 | attn = torch.bmm(q, k.transpose(1, 2)) 54 | attn = attn / self.temperature 55 | 56 | #print('in forward attn shape', attn.shape) 57 | 58 | if mask is not None: 59 | attn = attn.masked_fill(mask, -np.inf) 60 | 61 | if self.flag: 62 | n_b,k_1,k_2 = attn.size() 63 | attn = self.softmax(attn.permute(0,2,1)).reshape(n_b,k_1,k_2) 64 | else: 65 | attn = self.softmax(attn) 66 | 67 | extra_loss = 0.0 68 | 69 | use_sparse = False#False 70 | 71 | if use_sparse: 72 | mb, ins, outs = attn.shape[0], attn.shape[1], attn.shape[2] 73 | if self.flag: 74 | sparse_attn = attn.permute(0,2,1).reshape(mb*outs, ins) 75 | else: 76 | sparse_attn = attn.reshape((mb*ins, outs)) 77 | #print('sparse attn shape 1', sparse_attn.shape) 78 | #sga = Sparse_grad_attention(2) 79 | if self.grad_sparse: 80 | sga = Sparse_grad_attention(self.topk) 81 | sparse_attn = sga(sparse_attn) 82 | else: 83 | sparse_attn = self.sa(sparse_attn) 84 | if self.flag: 85 | sparse_attn = sparse_attn.reshape(mb, outs, ins).permute(0, 2, 1) 86 | else: 87 | sparse_attn = sparse_attn.reshape((mb,ins,outs)) 88 | attn = sparse_attn*1.0 89 | 90 | output = torch.bmm(attn, v) 91 | 92 | return output, attn, extra_loss 93 | 94 | import torch.nn.functional as F 95 | 96 | class MultiHeadAttention(nn.Module): 97 | ''' Multi-Head Attention module ''' 98 | 99 | def __init__(self, n_head, d_model_read, d_model_write, d_model_out, d_k, d_v, grad_sparse, residual=True, dropout=0.1, skip_write=False, flag=False): 100 | super().__init__() 101 | 102 | self.n_head = n_head 103 | self.d_k = d_k 104 | self.d_v = d_v 105 | 106 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Initialize Multi-Head Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 107 | # print('d model read: ', d_model_read) 108 | # print('d_model_write: ', d_model_write) 109 | # print('d_model_out: ', d_model_out) 110 | # print('n_head: ', n_head) 111 | # print('d_k: ', d_k) 112 | # print('d_v: ', d_v) 113 | # print('num_blocks_read: ', num_blocks_read) 114 | # print('num_blocks_write: ', num_blocks_write) 115 | # input() 116 | 117 | self.GLN_qs = nn.Linear(d_model_read, n_head * d_k) 118 | self.GLN_ks = nn.Linear(d_model_write, n_head * d_k) 119 | self.GLN_vs = nn.Linear(d_model_write, n_head * d_v) 120 | 121 | self.residual = residual 122 | 123 | #self.w_qs = nn.Linear(d_model_read, n_head * d_k) 124 | #self.w_ks = nn.Linear(d_model_write, n_head * d_k) 125 | #self.w_vs = nn.Linear(d_model_write, n_head * d_v) 126 | 127 | #nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 128 | #nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 129 | #nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 130 | 131 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), flag=flag) 132 | #self.layer_norm = nn.LayerNorm(d_model) 133 | 134 | self.gate_fc = nn.Linear(n_head * d_v, d_model_out) 135 | 136 | if not skip_write: 137 | self.fc = nn.Linear(n_head * d_v, d_model_out) 138 | else: 139 | self.fc = lambda a: a 140 | 141 | #nn.init.xavier_normal_(self.fc.weight) 142 | 143 | self.dropout = nn.Dropout(dropout) 144 | 145 | self.ln = nn.LayerNorm(d_model_out) 146 | 147 | def forward(self, q, k, v, mask=None): 148 | 149 | #print('attn input shape', q.shape) 150 | 151 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 152 | 153 | sz_b, len_q, _ = q.size() 154 | sz_b, len_k, _ = k.size() 155 | sz_b, len_v, _ = v.size() 156 | 157 | residual = q 158 | 159 | #print('q shape', q.shape) 160 | 161 | # print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Forward of Multi-Head Attention~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 162 | # print("q: ", q.size()) 163 | # print("k: ", k.size()) 164 | # print("v: ", v.size()) 165 | # input() 166 | 167 | q = self.GLN_qs(q).view(sz_b, len_q, n_head, d_k) 168 | #q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 169 | k = self.GLN_ks(k).view(sz_b, len_k, n_head, d_k) 170 | v = self.GLN_vs(v).reshape(sz_b, len_v, n_head, d_v) 171 | #v = v.view(sz_b, len_v, n_head, d_v) 172 | 173 | # print("GLN q: ", q.size()) 174 | # print("GLN k: ", k.size()) 175 | # print("GLN v: ", v.size()) 176 | 177 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 178 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 179 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 180 | 181 | # print("Permute q: ", q.size()) 182 | # print("Permute k: ", k.size()) 183 | # print("Permute v: ", v.size()) 184 | 185 | #mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 186 | output, attn, extra_loss = self.attention(q, k, v, mask=None) 187 | 188 | # print("Output: ", output.size()) 189 | # print("Attention: ", attn.size()) 190 | 191 | output = output.view(n_head, sz_b, len_q, d_v) 192 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 193 | 194 | # print("Here Output: ", output.size()) 195 | 196 | #print('output shape before fc', output.shape) 197 | 198 | #TODO: probably shouldn't just apply residual layer in the forward pass. 199 | 200 | output_init = output*1.0 201 | 202 | output = self.dropout(self.fc(output_init)) 203 | 204 | gate = torch.sigmoid(self.gate_fc(output_init)) 205 | 206 | #output = self.layer_norm(gate * output + (1 - gate) * residual) 207 | #output = gate * output + (1 - gate) * residual 208 | 209 | if self.residual: 210 | output = gate * torch.tanh(output) 211 | else: 212 | #output = self.ln(output) 213 | pass 214 | 215 | # print("Final Output: ", output.size()) 216 | 217 | #output 218 | 219 | #print('attn', attn[0]) 220 | #print('output input diff', output - residual) 221 | 222 | return output, attn, extra_loss 223 | 224 | class PositionwiseFeedForward(nn.Module): 225 | ''' A two-feed-forward-layer module ''' 226 | 227 | def __init__(self, d_in, d_hid, dropout=0.1): 228 | super().__init__() 229 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 230 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 231 | self.layer_norm = nn.LayerNorm(d_in) 232 | self.dropout = nn.Dropout(dropout) 233 | 234 | def forward(self, x): 235 | residual = x 236 | output = x.transpose(1, 2) 237 | output = self.w_2(F.relu(self.w_1(output))) 238 | output = output.transpose(1, 2) 239 | output = self.dropout(output) 240 | output = self.layer_norm(output + residual) 241 | return output 242 | 243 | 244 | class Seq2SeqAttention(nn.Module): 245 | def __init__(self, enc_hid_dim, dec_hid_dim): 246 | super().__init__() 247 | 248 | self.attn = nn.Linear(enc_hid_dim + dec_hid_dim, dec_hid_dim) 249 | self.v = nn.Linear(dec_hid_dim, 1, bias = False) 250 | 251 | def forward(self, hidden, encoder_outputs): 252 | 253 | #hidden = [batch size, dec hid dim] 254 | #encoder_outputs = [src len, batch size, enc hid dim * 2] 255 | 256 | batch_size = encoder_outputs.shape[1] 257 | src_len = encoder_outputs.shape[0] 258 | 259 | #repeat decoder hidden state src_len times 260 | hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) 261 | 262 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 263 | 264 | #hidden = [batch size, src len, dec hid dim] 265 | #encoder_outputs = [batch size, src len, enc hid dim * 2] 266 | 267 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 268 | 269 | #energy = [batch size, src len, dec hid dim] 270 | 271 | attention = self.v(energy).squeeze(2) 272 | 273 | #attention= [batch size, src len] 274 | 275 | return F.softmax(attention, dim=1) 276 | 277 | 278 | if __name__ == "__main__": 279 | 280 | x = torch.randn((64,3,100)) 281 | 282 | mha = MultiHeadAttention(n_head=8, d_model=100, d_k=64, d_v=64) 283 | 284 | out, attn = mha(x,x,x) 285 | 286 | print('out shape', out.shape) 287 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/basic_mha.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | class MemoryAttention(nn.Module): 7 | def __init__(self, n_blocks_query, n_blocks_val, dim_query, dim_val, n_heads=8): 8 | super(MemoryAttention, self).__init__() 9 | 10 | self.n_heads = n_heads 11 | self.n_blocks_val = n_blocks_val 12 | self.dim_val = dim_val 13 | self.block_dim_val = dim_val // self.n_blocks_val 14 | 15 | self.n_blocks_query = n_blocks_query 16 | self.dim_query = dim_query 17 | self.block_dim_query = dim_query // self.n_blocks_query 18 | 19 | self.head_dim = 64 20 | self.scale = self.head_dim ** -0.5 21 | 22 | #self.n_blocks_val * self.block_dim_val 23 | 24 | self.query_net = GroupLinearLayer(self.block_dim_query, self.head_dim * self.n_heads, n_blocks_query) 25 | self.key_net = GroupLinearLayer(self.block_dim_val, self.head_dim * self.n_heads, n_blocks_val) 26 | self.value_net = GroupLinearLayer(self.block_dim_val, self.head_dim * self.n_heads, n_blocks_val) 27 | self.final = GroupLinearLayer(self.head_dim * self.n_heads, self.block_dim_query, n_blocks_query) 28 | 29 | def forward(self, q, kv): 30 | 31 | #comes in as: bs, pos*emb. 32 | #positions_attend x T*bs x emb 33 | 34 | 35 | #q = q.permute(1,0,2) 36 | #kv = kv.permute(1,0,2) 37 | 38 | #print('kv shape after permute', kv.shape) 39 | 40 | seq_len_q,bsz,_ = q.shape 41 | seq_len_v,bsz,_ = kv.shape 42 | 43 | q = q.reshape((seq_len_q, bsz, self.n_blocks_query * self.block_dim_query)) 44 | 45 | kv = kv.reshape((seq_len_v, bsz, self.n_blocks_val * self.block_dim_val)) 46 | 47 | q = self.query_net(q).view(seq_len_q, bsz, self.n_blocks_query, self.n_heads, self.head_dim) 48 | k = self.key_net(kv).view(seq_len_v, bsz, self.n_blocks_val, self.n_heads, self.head_dim) 49 | v = self.value_net(kv).view(seq_len_v, bsz, self.n_blocks_val, self.n_heads, self.head_dim) 50 | 51 | q = q.transpose(2,3) * self.scale 52 | k = k.transpose(2,3) 53 | v = v.transpose(2,3) 54 | score = torch.matmul(q, k.transpose(3,4)) 55 | #print('score shape', score.shape) 56 | score = F.softmax(score, dim=-1) 57 | out = torch.matmul(score, v).transpose(2,3) 58 | #print('out shape', out.shape) 59 | score = score.mean(dim=2) 60 | 61 | out = out.reshape(seq_len_q, bsz, self.n_blocks_query * self.head_dim * self.n_heads) 62 | out = self.final(out) 63 | out = out.view(seq_len_q, bsz, self.dim_query) 64 | 65 | 66 | return out, score 67 | 68 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/fairseq_dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import List, Optional 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class FairseqDropout(nn.Module): 17 | 18 | def __init__(self, p, module_name=None): 19 | super().__init__() 20 | self.p = p 21 | self.module_name = module_name 22 | self.apply_during_inference = False 23 | 24 | def forward(self, x, inplace: bool = False): 25 | if self.training or self.apply_during_inference: 26 | return F.dropout(x, p=self.p, training=True, inplace=inplace) 27 | else: 28 | return x 29 | 30 | def make_generation_fast_( 31 | self, 32 | name: str, 33 | retain_dropout: bool = False, 34 | retain_dropout_modules: Optional[List[str]] = None, 35 | **kwargs 36 | ): 37 | if retain_dropout: 38 | if retain_dropout_modules is not None and self.module_name is None: 39 | logger.warning( 40 | 'Cannot enable dropout during inference for module {} ' 41 | 'because module_name was not set'.format(name) 42 | ) 43 | elif ( 44 | retain_dropout_modules is None # if None, apply to all modules 45 | or self.module_name in retain_dropout_modules 46 | ): 47 | logger.info( 48 | 'Enabling dropout during inference for module: {}'.format(name) 49 | ) 50 | self.apply_during_inference = True 51 | else: 52 | logger.info('Disabling dropout for module: {}'.format(name)) 53 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/group_linear_layer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | class GroupLinearLayer(nn.Module): 9 | 10 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 11 | super(GroupLinearLayer, self).__init__() 12 | self.nb = num_blocks 13 | self.dout = dout 14 | 15 | if a is None: 16 | a = 1. / math.sqrt(dout * num_blocks) 17 | 18 | #gain = 1.0 / math.sqrt(2) 19 | #a = gain * math.sqrt(6.0 / (din + dout)) 20 | 21 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 22 | 23 | self.bias = bias 24 | 25 | if bias is True: 26 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 27 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 28 | else: 29 | self.bias = None 30 | 31 | def forward(self,x): 32 | 33 | #input: ts x bs x blocks*nhid 34 | #ts*bs , blocks, nhid 35 | #blocks, ts*bs, nhid 36 | ts,bs,m = x.shape 37 | 38 | x = x.reshape((ts*bs, self.nb, m//self.nb)) 39 | x = x.permute(1,0,2) 40 | x = torch.bmm(x,self.weight) 41 | x = x.permute(1,0,2) 42 | 43 | if not self.bias is None: 44 | x = x + self.bias 45 | 46 | x = x.reshape((ts, bs, self.dout*self.nb)) 47 | 48 | #if not self.bias is None: 49 | # x += self.bias 50 | 51 | return x 52 | 53 | class GroupMLP(nn.Module): 54 | """Container module with an encoder, a recurrent module, and a decoder.""" 55 | 56 | def __init__(self, din, dout, num_blocks, dropout=0.1): 57 | super(GroupMLP, self).__init__() 58 | 59 | self.w_1 = nn.Parameter(0.01 * torch.randn(num_blocks,din,dout)) 60 | self.w_2 = nn.Parameter(0.01 * torch.randn(num_blocks,dout,din)) 61 | 62 | self.layer_norm = nn.LayerNorm(din) 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | def forward(self,x): 66 | 67 | residual = x*1.0 68 | x = x.permute(1,0,2) 69 | x = torch.bmm(F.relu(torch.bmm(x,self.w_1)), self.w_2) 70 | x = x.permute(1,0,2) 71 | x = self.dropout(x) 72 | x = self.layer_norm(x + residual) 73 | 74 | return x 75 | 76 | if __name__ == "__main__": 77 | 78 | GLN = GroupLinearLayer(512, 512, 2, bias=True) 79 | 80 | print('params', sum(g.numel() for g in GLN.parameters())) 81 | 82 | #bs, blocks, nhid 83 | x = torch.randn(64,12,2*512) 84 | 85 | print(GLN(x).shape) 86 | 87 | #for p in GLN.parameters(): 88 | # print(p.shape) 89 | 90 | 91 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/isab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MAB(nn.Module): 7 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 8 | super(MAB, self).__init__() 9 | self.dim_V = dim_V 10 | self.num_heads = num_heads 11 | self.fc_q = nn.Linear(dim_Q, dim_V) 12 | self.fc_k = nn.Linear(dim_K, dim_V) 13 | self.fc_v = nn.Linear(dim_K, dim_V) 14 | if ln: 15 | self.ln0 = nn.LayerNorm(dim_V) 16 | self.ln1 = nn.LayerNorm(dim_V) 17 | self.fc_o = nn.Linear(dim_V, dim_V) 18 | 19 | def forward(self, Q, K): 20 | Q = self.fc_q(Q) 21 | K, V = self.fc_k(K), self.fc_v(K) 22 | 23 | dim_split = self.dim_V // self.num_heads 24 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 25 | K_ = torch.cat(K.split(dim_split, 2), 0) 26 | V_ = torch.cat(V.split(dim_split, 2), 0) 27 | 28 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 29 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 30 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 31 | O = O + F.relu(self.fc_o(O)) 32 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 33 | return O 34 | 35 | class SAB(nn.Module): 36 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 37 | super(SAB, self).__init__() 38 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 39 | 40 | def forward(self, X): 41 | return self.mab(X, X) 42 | 43 | class ISAB(nn.Module): 44 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 45 | super(ISAB, self).__init__() 46 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 47 | nn.init.xavier_uniform_(self.I) 48 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 49 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 50 | 51 | def forward(self, X): 52 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 53 | return self.mab1(X, H) 54 | 55 | class PMA(nn.Module): 56 | def __init__(self, dim, num_heads, num_seeds, ln=False): 57 | super(PMA, self).__init__() 58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 59 | nn.init.xavier_uniform_(self.S) 60 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 61 | 62 | def forward(self, X): 63 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 64 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | try: 12 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 13 | 14 | has_fused_layernorm = True 15 | 16 | class FusedLayerNorm(_FusedLayerNorm): 17 | @torch.jit.unused 18 | def forward(self, x): 19 | if not x.is_cuda: 20 | return super().forward(x) 21 | else: 22 | with torch.cuda.device(x.device): 23 | return super().forward(x) 24 | 25 | except ImportError: 26 | has_fused_layernorm = False 27 | 28 | 29 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 30 | if not export and torch.cuda.is_available() and has_fused_layernorm: 31 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 32 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 33 | 34 | 35 | class Fp32LayerNorm(nn.LayerNorm): 36 | def __init__(self, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | 39 | def forward(self, input): 40 | output = F.layer_norm( 41 | input.float(), 42 | self.normalized_shape, 43 | self.weight.float() if self.weight is not None else None, 44 | self.bias.float() if self.bias is not None else None, 45 | self.eps, 46 | ) 47 | return output.type_as(input) 48 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/pos_enc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import math 7 | import torch.nn.functional as F 8 | import random 9 | 10 | class PositionEncoder(nn.Module): 11 | def __init__(self, d_model, max_seq_len = 300): 12 | super().__init__() 13 | self.d_model = d_model 14 | # create constant 'pe' matrix with values dependant on 15 | # pos and i 16 | pe = torch.zeros(max_seq_len, d_model) 17 | for pos in range(max_seq_len): 18 | for i in range(0, d_model, 2): 19 | pe[pos, i] = \ 20 | math.sin(pos / (10000 ** ((2 * i)/d_model))) 21 | pe[pos, i + 1] = \ 22 | math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) 23 | 24 | pe = pe.unsqueeze(0) 25 | self.register_buffer('pe', pe) 26 | 27 | self.pos_emb_weight = nn.Parameter(torch.ones_like(pe)) 28 | 29 | def forward(self, x): 30 | # make embeddings relatively larger 31 | 32 | x = x.permute(1,0,2) 33 | 34 | #x = x * math.sqrt(self.d_model) 35 | #add constant to embedding 36 | 37 | seq_len = x.size(1) 38 | 39 | #width x channel 40 | #pe_use = F.interpolate(self.pe.permute(0,2,1), size=seq_len).permute(0,2,1) 41 | 42 | pe_use = Variable(self.pe[:,:seq_len] * F.sigmoid(self.pos_emb_weight[:,:seq_len]), requires_grad=False).cuda() 43 | 44 | #bs x pos x nhid --> bs x nhid x pos --> bs x pos x nhid 45 | 46 | x = x + pe_use 47 | #Variable(pe_use, requires_grad=False).cuda() 48 | 49 | x = x.permute(1,0,2) 50 | 51 | return x 52 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/quant_noise.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def quant_noise(module, p, block_size): 11 | """ 12 | Wraps modules and applies quantization noise to the weights for 13 | subsequent quantization with Iterative Product Quantization as 14 | described in "Training with Quantization Noise for Extreme Model Compression" 15 | 16 | Args: 17 | - module: nn.Module 18 | - p: amount of Quantization Noise 19 | - block_size: size of the blocks for subsequent quantization with iPQ 20 | 21 | Remarks: 22 | - Module weights must have the right sizes wrt the block size 23 | - Only Linear, Embedding and Conv2d modules are supported for the moment 24 | - For more detail on how to quantize by blocks with convolutional weights, 25 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 26 | - We implement the simplest form of noise here as stated in the paper 27 | which consists in randomly dropping blocks 28 | """ 29 | 30 | # if no quantization noise, don't register hook 31 | if p <= 0: 32 | return module 33 | 34 | # supported modules 35 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 36 | 37 | # test whether module.weight has the right sizes wrt block_size 38 | is_conv = module.weight.ndim == 4 39 | 40 | # 2D matrix 41 | if not is_conv: 42 | assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" 43 | 44 | # 4D matrix 45 | else: 46 | # 1x1 convolutions 47 | if module.kernel_size == (1, 1): 48 | assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" 49 | # regular convolutions 50 | else: 51 | k = module.kernel_size[0] * module.kernel_size[1] 52 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 53 | 54 | def _forward_pre_hook(mod, input): 55 | # no noise for evaluation 56 | if mod.training: 57 | if not is_conv: 58 | # gather weight and sizes 59 | weight = mod.weight 60 | in_features = weight.size(1) 61 | out_features = weight.size(0) 62 | 63 | # split weight matrix into blocks and randomly drop selected blocks 64 | mask = torch.zeros(in_features // block_size * out_features, device=weight.device) 65 | mask.bernoulli_(p) 66 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 67 | 68 | else: 69 | # gather weight and sizes 70 | weight = mod.weight 71 | in_channels = mod.in_channels 72 | out_channels = mod.out_channels 73 | 74 | # split weight matrix into blocks and randomly drop selected blocks 75 | if mod.kernel_size == (1, 1): 76 | mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device) 77 | mask.bernoulli_(p) 78 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 79 | else: 80 | mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) 81 | mask.bernoulli_(p) 82 | mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 83 | 84 | # scale weights and apply mask 85 | mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript 86 | s = 1 / (1 - p) 87 | mod.weight.data = s * weight.masked_fill(mask, 0) 88 | 89 | module.register_forward_pre_hook(_forward_pre_hook) 90 | return module 91 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/set_transformer.py: -------------------------------------------------------------------------------- 1 | from .isab import * 2 | from .pos_enc import PositionEncoder 3 | 4 | class DeepSet(nn.Module): 5 | def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128): 6 | super(DeepSet, self).__init__() 7 | self.num_outputs = num_outputs 8 | self.dim_output = dim_output 9 | self.enc = nn.Sequential( 10 | nn.Linear(dim_input, dim_hidden), 11 | nn.ReLU(), 12 | nn.Linear(dim_hidden, dim_hidden), 13 | nn.ReLU(), 14 | nn.Linear(dim_hidden, dim_hidden), 15 | nn.ReLU(), 16 | nn.Linear(dim_hidden, dim_hidden)) 17 | self.dec = nn.Sequential( 18 | nn.Linear(dim_hidden, dim_hidden), 19 | nn.ReLU(), 20 | nn.Linear(dim_hidden, dim_hidden), 21 | nn.ReLU(), 22 | nn.Linear(dim_hidden, dim_hidden), 23 | nn.ReLU(), 24 | nn.Linear(dim_hidden, num_outputs*dim_output)) 25 | 26 | def forward(self, X): 27 | X = self.enc(X).mean(-2) 28 | X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output) 29 | return X 30 | 31 | class SetTransformer(nn.Module): 32 | def __init__(self, dim_input, 33 | num_inds=32, dim_hidden=128, num_heads=4, ln=True, num_layers = 4): 34 | super(SetTransformer, self).__init__() 35 | self.pe = PositionEncoder(dim_input) 36 | layers = [] 37 | layers.append(ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln)) 38 | for _ in range(num_layers-1): 39 | layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 40 | self.layers = nn.ModuleList(layers) 41 | # self.enc = nn.Sequential( 42 | # ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), 43 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln), 44 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln), 45 | # ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 46 | 47 | def forward(self, X): 48 | X=X.permute(1,0,2) #self.pe expects T,B,D 49 | X = self.pe(X) 50 | X=X.permute(1,0,2) #layer expects B,T,D 51 | for layer in self.layers: 52 | X=layer(X) 53 | return X 54 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/sparse_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy 5 | 6 | class Sparse_attention(nn.Module): 7 | def __init__(self, top_k = 5): 8 | super(Sparse_attention,self).__init__() 9 | top_k += 1 10 | self.top_k = top_k 11 | 12 | def forward(self, attn_s): 13 | 14 | # normalize the attention weights using piece-wise Linear function 15 | # only top k should 16 | attn_plot = [] 17 | # torch.max() returns both value and location 18 | #attn_s_max = torch.max(attn_s, dim = 1)[0] 19 | #attn_w = torch.clamp(attn_s_max, min = 0, max = attn_s_max) 20 | eps = 10e-8 21 | time_step = attn_s.size()[1] 22 | if time_step <= self.top_k: 23 | # just make everything greater than 0, and return it 24 | #delta = torch.min(attn_s, dim = 1)[0] 25 | return attn_s 26 | else: 27 | # get top k and return it 28 | # bottom_k = attn_s.size()[1] - self.top_k 29 | # value of the top k elements 30 | #delta = torch.kthvalue(attn_s, bottm_k, dim= 1 )[0] 31 | delta = torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 32 | #delta = attn_s_max - torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 33 | # normalize 34 | delta = delta.reshape((delta.shape[0],1)) 35 | 36 | 37 | attn_w = attn_s - delta.repeat(1, time_step) 38 | attn_w = torch.clamp(attn_w, min = 0) 39 | attn_w_sum = torch.sum(attn_w, dim = 1, keepdim=True) 40 | attn_w_sum = attn_w_sum + eps 41 | attn_w_normalize = attn_w / attn_w_sum.repeat(1, time_step) 42 | 43 | #print('attn', attn_w_normalize) 44 | 45 | return attn_w_normalize 46 | 47 | 48 | if __name__ == "__main__": 49 | k = 1 50 | print('take top k', k) 51 | sa = Sparse_attention(top_k=k) 52 | 53 | #batch x time 54 | 55 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4],[0.5,0.4,0.1,0.0,0.0]]])) 56 | 57 | x = x.reshape((2,5)) 58 | 59 | print('x shape', x.shape) 60 | print('x', x) 61 | 62 | o = sa(x) 63 | 64 | 65 | print('o', o) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/sparse_grad_attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Giving an N x M attention matrix, returns the same matrix, 3 | but performs masking to determine where to block gradients. 4 | ''' 5 | 6 | import numpy 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from .sparse_attn import Sparse_attention 11 | 12 | 13 | class blocked_grad(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward(ctx, x, mask): 17 | ctx.save_for_backward(x, mask) 18 | return x 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | x, mask = ctx.saved_tensors 23 | return grad_output * mask, mask * 0.0 24 | 25 | 26 | class Sparse_grad_attention(torch.autograd.Function): 27 | # def __init__(self, top_k): 28 | # super(Sparse_grad_attention,self).__init__() 29 | # 30 | # self.sa = Sparse_attention(top_k=top_k) 31 | 32 | @staticmethod 33 | def forward(ctx, inp, sa): 34 | sparsified = sa(inp) 35 | ctx.save_for_backward(inp, sparsified) 36 | 37 | return inp 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | inp, sparsified = ctx.saved_tensors 42 | # print('sparsified', sparsified) 43 | return (grad_output) * (sparsified > 0.0).float() 44 | 45 | 46 | if __name__ == "__main__": 47 | k = 2 48 | sga = Sparse_grad_attention(k) 49 | sa = Sparse_attention(k) 50 | 51 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4], 52 | [0.5, 0.4, 0.1, 0.0, 0.0]]])) 53 | x = x.reshape((2, 5)) 54 | 55 | x = Variable(x, requires_grad=True) 56 | 57 | print(x) 58 | print('output', sga(x)) 59 | 60 | (sga(x).sum()).backward() 61 | 62 | print('sparse grad', x.grad) 63 | 64 | x = Variable(x.data, requires_grad=True) 65 | 66 | (sa(x).sum()).backward() 67 | 68 | print('normal grad', x.grad) 69 | -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/transformer_helper.py: -------------------------------------------------------------------------------- 1 | 2 | ''' Define the sublayers in encoder/decoder layer ''' 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import random 8 | 9 | __author__ = "Yu-Hsiang Huang" 10 | 11 | class ScaledDotProductAttention(nn.Module): 12 | ''' Scaled Dot-Product Attention ''' 13 | 14 | def __init__(self, temperature, attn_dropout=0.1): 15 | super().__init__() 16 | self.temperature = temperature 17 | self.dropout = nn.Dropout(attn_dropout) 18 | 19 | def forward(self, q, k, v, mask=None): 20 | 21 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 22 | if mask is not None: 23 | attn = attn.masked_fill(mask == 0, -1e9) 24 | 25 | attn = self.dropout(F.softmax(attn, dim=-1)) 26 | output = torch.matmul(attn, v) 27 | 28 | return output, attn 29 | 30 | class PositionalEncoding(nn.Module): 31 | 32 | def __init__(self, d_hid, n_position=200): 33 | super(PositionalEncoding, self).__init__() 34 | 35 | # Not a parameter 36 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 37 | 38 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 39 | ''' Sinusoid position encoding table ''' 40 | # TODO: make it with torch instead of numpy 41 | 42 | def get_position_angle_vec(position): 43 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 44 | 45 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 46 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 47 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 48 | 49 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 50 | 51 | def forward(self, x): 52 | #if self.train: 53 | # ind = random.randint(0, 160) 54 | #else: 55 | ind = 0 56 | return x + self.pos_table[:, ind:ind + x.size(1)].clone().detach() 57 | 58 | class MultiHeadAttention(nn.Module): 59 | ''' Multi-Head Attention module ''' 60 | 61 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 62 | super().__init__() 63 | 64 | self.n_head = n_head 65 | self.d_k = d_k 66 | self.d_v = d_v 67 | 68 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 69 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 70 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 71 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 72 | 73 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 74 | 75 | self.dropout = nn.Dropout(dropout) 76 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 77 | 78 | 79 | def forward(self, q, k, v, mask=None): 80 | 81 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 82 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 83 | 84 | residual = q 85 | 86 | # Pass through the pre-attention projection: b x lq x (n*dv) 87 | # Separate different heads: b x lq x n x dv 88 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 89 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 90 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 91 | 92 | # Transpose for attention dot product: b x n x lq x dv 93 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 94 | 95 | if mask is not None: 96 | mask = mask.unsqueeze(1) # For head axis broadcasting. 97 | 98 | q, attn = self.attention(q, k, v, mask=mask) 99 | 100 | # Transpose to move the head dimension back: b x lq x n x dv 101 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 102 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 103 | q = self.dropout(self.fc(q)) 104 | q += residual 105 | 106 | q = self.layer_norm(q) 107 | 108 | return q, attn 109 | 110 | 111 | class PositionwiseFeedForward(nn.Module): 112 | ''' A two-feed-forward-layer module ''' 113 | 114 | def __init__(self, d_in, d_hid, dropout=0.1): 115 | super().__init__() 116 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 117 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 118 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 119 | self.dropout = nn.Dropout(dropout) 120 | 121 | def forward(self, x): 122 | 123 | residual = x 124 | 125 | 126 | x = self.w_2(F.relu(self.w_1(x))) 127 | x = self.dropout(x) 128 | x += residual 129 | 130 | x = self.layer_norm(x) 131 | 132 | return x 133 | 134 | 135 | class EncoderLayer(nn.Module): 136 | ''' Compose with two layers ''' 137 | 138 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 139 | super(EncoderLayer, self).__init__() 140 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 141 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 142 | 143 | def forward(self, enc_input, slf_attn_mask=None, seperate_queries = None): 144 | enc_output, enc_slf_attn = self.slf_attn( 145 | seperate_queries if seperate_queries is not None else enc_input, enc_input, enc_input, mask=slf_attn_mask) 146 | enc_output = self.pos_ffn(enc_output) 147 | return enc_output, enc_slf_attn -------------------------------------------------------------------------------- /sort_of_clevr/transformer_utilities/transformer_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from transformer import TransformerEncoder 4 | import types 5 | import math 6 | 7 | args = types.SimpleNamespace() 8 | args.use_module_communication = 'true' 9 | args.encoder_embed_dim = 512 10 | args.encoder_attention_heads = 8 #was 8 11 | args.attention_dropout = 0.1 12 | args.topk_ratio = 1.0 13 | args.dropout = 0.2 14 | args.encoder_normalize_before = True 15 | args.encoder_ffn_embed_dim = 2048 16 | args.use_nfm = 'false' 17 | 18 | from models.transformer_layer import TransformerEncoderLayer, TransformerEncoderLayerVanilla 19 | from models.pos_enc import PositionEncoder 20 | #from transformer_utilities.GroupLinearLayer import GroupLinearLayer 21 | import math 22 | class GroupLinearLayer(nn.Module): 23 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 24 | super(GroupLinearLayer, self).__init__() 25 | self.nb = num_blocks 26 | #din = din // num_blocks 27 | #dout = dout // num_blocks 28 | self.dout = dout 29 | if a is None: 30 | a = 1. / math.sqrt(dout) 31 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 32 | self.bias = bias 33 | if bias is True: 34 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 35 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 36 | else: 37 | self.bias = None 38 | def forward(self,x): 39 | ts,bs,m = x.shape 40 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 41 | x = x.permute(1,0,2) 42 | x = torch.bmm(x,self.weight) 43 | x = x.permute(1,0,2) 44 | if not self.bias is None: 45 | x = x + self.bias 46 | #x = x.reshape((ts, bs, self.dout*self.nb)) 47 | return x 48 | 49 | 50 | 51 | class SelectAttention(nn.Module): 52 | """docstring for SelectAttention""" 53 | def __init__(self, d_read, d_write, d_k = 16, num_read = 5, num_write = 5, share_query = False, share_key = False): 54 | super(SelectAttention, self).__init__() 55 | if not share_key: 56 | self.gll_write = GroupLinearLayer(d_write,d_k, num_write) 57 | else: 58 | self.gll_write = nn.Linear(d_write, d_k) 59 | 60 | if not share_query: 61 | self.gll_read = GroupLinearLayer(d_read,d_k, num_read) 62 | else: 63 | self.gll_read = nn.Linear(d_read, d_k) 64 | 65 | self.temperature = math.sqrt(d_k) 66 | 67 | def forward(self, q, k): 68 | read = self.gll_read(q) 69 | write = self.gll_write(k) 70 | 71 | return torch.bmm(read, write.permute(0, 2, 1)) / self.temperature 72 | 73 | class TransformerEncoder(nn.Module): 74 | 75 | def __init__(self, inp_dim, h_dim, inp_nb, nb, functional = True): 76 | super().__init__() 77 | 78 | args.encoder_embed_dim = h_dim 79 | 80 | print('transformer h_dim', h_dim) 81 | 82 | 83 | 84 | args.encoder_embed_dim = h_dim 85 | self.functional = functional 86 | print('functional? '+str(self.functional)) 87 | if not self.functional: 88 | layer_lst = [] 89 | 90 | args.encoder_embed_dim = h_dim 91 | #layer_lst.append(TransformerEncoderLayer(args=args, nb=inp_nb, blockatt=False, blockatt_memory=True, use_nfm=False, out_proj_dim=h_dim)) 92 | #for j in range(0,6): 93 | # layer_lst.append(TransformerEncoderLayer(args=args, nb=nb, blockatt=False, blockatt_memory=True, use_nfm=False)) 94 | self.enc = TransformerEncoderLayerVanilla(args) 95 | #self.layers = nn.ModuleList(layer_lst) 96 | else: 97 | #args.encoder_embed_dim = inp_dim 98 | #print('init_layer initialize') 99 | #self.init_layer = TransformerEncoderLayerVanilla(args=args, out_proj=h_dim) 100 | args.encoder_embed_dim = h_dim 101 | hidden_dim = args.encoder_embed_dim 102 | print('inp_att initialize') 103 | self.inp_att = TransformerEncoderLayerVanilla(args=args) 104 | print('gru initialize') 105 | self.gru_pool = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(1)]) 106 | self.state_att = TransformerEncoderLayerVanilla(args=args) 107 | self.select_attention = SelectAttention( hidden_dim + hidden_dim, hidden_dim, num_read = 1, num_write = 1) 108 | 109 | self.pe = PositionEncoder(inp_dim) 110 | self.pe_state = PositionEncoder(args.encoder_embed_dim) 111 | 112 | def forward(self, x, mask = None): 113 | 114 | x = x.permute(1, 0, 2) 115 | 116 | x = self.pe(x) 117 | if not self.functional: 118 | """klst = [] 119 | vlst = [] 120 | 121 | initial_state = self.layers[0].memory_layer.initial_state(batch_size=x.shape[0]*x.shape[1]).type(x.dtype).to(x.device) 122 | memory_obj = [initial_state] 123 | 124 | for layer in self.layers: 125 | layer.klst = klst 126 | layer.vlst = vlst 127 | layer.memory_obj = memory_obj 128 | 129 | """ 130 | for i in range(6): 131 | x = self.enc(x, None) 132 | return x.permute(1, 0, 2) 133 | else: 134 | """ 135 | klst = [] 136 | vlst = [] 137 | 138 | initial_state = self.init_layer.memory_layer.initial_state(batch_size=x.shape[0]*x.shape[1]).type(x.dtype).to(x.device) 139 | memory_obj = [initial_state] 140 | 141 | self.init_layer.klst = klst 142 | self.init_layer.vlst = vlst 143 | self.init_layer.memory_obj = memory_obj 144 | 145 | 146 | self.inp_att.klst = klst 147 | self.inp_att.vlst = vlst 148 | self.inp_att.memory_obj = memory_obj 149 | 150 | self.state_att.klst = klst 151 | self.state_att.vlst = vlst 152 | self.state_att.memory_obj = memory_obj 153 | """ 154 | T, B, D = x.size() 155 | 156 | #x = self.init_layer(x, None) 157 | state = self.pe_state(torch.randn(x.size()).to(x.device)) 158 | 159 | 160 | 161 | for i in range(0, 6): 162 | gru_in = self.inp_att(x, mask, state = state) 163 | gru_in = gru_in.permute(1, 0, 2) 164 | state = state.permute(1, 0, 2) 165 | 166 | gru_in = gru_in.reshape(B * T, -1) 167 | state = state.reshape(B * T, -1) 168 | 169 | gru_outs = [] 170 | 171 | for gru in self.gru_pool: 172 | gru_outs.append(gru(gru_in, state)) 173 | 174 | gru_outs = torch.stack(gru_outs, dim = 1) 175 | 176 | selector = torch.cat((gru_in, state), dim = 1).unsqueeze(1) 177 | 178 | attn_scores = self.select_attention(selector, gru_outs) 179 | 180 | attn_scores = attn_scores.squeeze(1) 181 | 182 | attn_scores = torch.nn.functional.gumbel_softmax(attn_scores, dim = 1, tau = 1.0, hard = True) 183 | attn_scores = attn_scores.unsqueeze(-1) 184 | gru_outs = (gru_outs * attn_scores).sum(dim = 1) 185 | gru_outs_hidden = gru_outs.reshape(B, T, -1) 186 | gru_outs_hidden = gru_outs_hidden.permute(1, 0, 2) 187 | gru_outs_hidden = self.state_att(gru_outs_hidden, mask) 188 | gru_in = gru_in.reshape(B, T, -1).permute(1, 0, 2) 189 | 190 | x = gru_in 191 | state = gru_outs_hidden 192 | 193 | return state.permute(1,0,2) 194 | 195 | 196 | 197 | if __name__ == "__main__": 198 | x = torch.randn(32, 64, 512) 199 | 200 | TE = TransformerEncoder() 201 | 202 | y = TE(x) 203 | 204 | print(y.shape) 205 | 206 | -------------------------------------------------------------------------------- /sort_of_clevr/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from transformer import TransformerEncoder 4 | import types 5 | import math 6 | import numpy as np 7 | 8 | args = types.SimpleNamespace() 9 | args.use_module_communication = 'true' 10 | args.encoder_embed_dim = 512 11 | args.encoder_attention_heads = 8 #was 8 12 | args.attention_dropout = 0.1 13 | args.topk_ratio = 1.0 14 | args.dropout = 0.2 15 | args.encoder_normalize_before = True 16 | args.encoder_ffn_embed_dim = 2048 17 | args.use_nfm = 'false' 18 | args.shared_memory_attention = False 19 | args.self_attention = True 20 | args.mem_slots = 4 21 | args.use_topk = False 22 | args.topk = 3 23 | args.num_steps = 5 24 | 25 | from transformer_utilities.transformer_layer import TransformerEncoderLayer, TransformerEncoderLayerVanilla 26 | from transformer_utilities.pos_enc import PositionEncoder 27 | from transformer_utilities.GroupLinearLayer import GroupLinearLayer 28 | import math 29 | 30 | 31 | class SelectAttention(nn.Module): 32 | """docstring for SelectAttention""" 33 | def __init__(self, d_read, d_write, d_k = 16, num_read = 5, num_write = 5, share_query = False, share_key = False): 34 | super(SelectAttention, self).__init__() 35 | if not share_key: 36 | self.gll_write = GroupLinearLayer(d_write,d_k, num_write) 37 | else: 38 | self.gll_write = nn.Linear(d_write, d_k) 39 | 40 | if not share_query: 41 | self.gll_read = GroupLinearLayer(d_read,d_k, num_read) 42 | else: 43 | self.gll_read = nn.Linear(d_read, d_k) 44 | 45 | self.temperature = math.sqrt(d_k) 46 | 47 | def forward(self, q, k): 48 | read = self.gll_read(q) 49 | write = self.gll_write(k) 50 | 51 | return torch.bmm(read, write.permute(0, 2, 1)) / self.temperature 52 | 53 | class TransformerEncoder(nn.Module): 54 | 55 | def __init__(self, 56 | embed_dim, 57 | ffn_dim, 58 | num_layers = 6, 59 | num_heads = 1, 60 | dropout = 0.1, 61 | functional = False, 62 | shared_memory_attention = False, 63 | shared_memory_percentage = 0.1, 64 | share_parameters = False, 65 | mem_slots = 4, 66 | num_attention_schemas = 3, 67 | num_gru_schemas = 3, 68 | schema_specific = False, 69 | use_topk = False, 70 | topk = 3, 71 | num_steps = 5, 72 | null_attention = False, 73 | regressive = False): 74 | super().__init__() 75 | 76 | if schema_specific and (num_gru_schemas != num_attention_schemas): 77 | print('Cannot use schema specific as num_gru_schemas != num_attention_schemas, continuing without') 78 | self.schema_specific = False 79 | else: 80 | self.schema_specific = schema_specific 81 | 82 | args.mem_slots = mem_slots 83 | args.encoder_embed_dim = embed_dim 84 | args.encoder_ffn_embed_dim = ffn_dim 85 | args.encoder_attention_heads = num_heads 86 | args.dropout = dropout 87 | args.shared_memory_attention = shared_memory_attention 88 | args.num_steps = num_steps 89 | args.null_attention = null_attention 90 | args.regressive = regressive 91 | 92 | 93 | self.num_layers = num_layers 94 | self.shared_memory_attention = shared_memory_attention 95 | self.shared_memory_percentage = shared_memory_percentage 96 | 97 | print('transformer embed_dim', embed_dim) 98 | self.functional = functional 99 | print('functional? '+str(self.functional)) 100 | if not self.functional: 101 | layer_lst = [] 102 | args.use_topk = use_topk 103 | args.topk = topk 104 | 105 | 106 | args.encoder_embed_dim = embed_dim 107 | self.share_parameters = share_parameters 108 | if share_parameters: 109 | self.enc = TransformerEncoderLayerVanilla(args) 110 | else: 111 | layer_lst = [] 112 | for i in range(self.num_layers): 113 | layer_lst.append(TransformerEncoderLayerVanilla(args)) 114 | print('flmklsd') 115 | self.layers = nn.ModuleList(layer_lst) 116 | else: 117 | #args.encoder_embed_dim = inp_dim 118 | #print('init_layer initialize') 119 | #self.init_layer = TransformerEncoderLayerVanilla(args=args, out_proj=h_dim) 120 | print('NUM GRU SCHEMAS:' + str(num_gru_schemas)) 121 | print('NUM Attention SCHEMAS:' + str(num_attention_schemas)) 122 | print('SCHEMA SPECIFIC:' + str(self.schema_specific)) 123 | args.use_topk = use_topk 124 | args.topk = topk 125 | print('inp_att initialize') 126 | self.num_gru_schemas = num_gru_schemas 127 | self.num_att_schemas = num_attention_schemas 128 | self.schema_stats = np.zeros(self.num_gru_schemas) 129 | args.self_attention = True 130 | self.inp_att = nn.ModuleList([TransformerEncoderLayerVanilla(args=args) for _ in range(num_attention_schemas)]) 131 | self.select_attention_inp_att = SelectAttention( args.encoder_embed_dim, args.encoder_embed_dim, num_read = 1, num_write = num_attention_schemas) 132 | print('gru initialize') 133 | hidden_dim = args.encoder_embed_dim 134 | 135 | 136 | self.gru_pool = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(num_gru_schemas)]) 137 | #args.self_attention = True 138 | #self.state_att = TransformerEncoderLayerVanilla(args=args) 139 | self.select_attention = SelectAttention( hidden_dim + hidden_dim, hidden_dim, num_read = 1, num_write = num_gru_schemas) 140 | 141 | self.pe = PositionEncoder(args.encoder_embed_dim) 142 | self.pe_state = PositionEncoder(args.encoder_embed_dim) 143 | 144 | def forward(self, x, mask = None, num_layers = None): 145 | 146 | x = x.permute(1, 0, 2) 147 | 148 | x = self.pe(x) 149 | 150 | 151 | 152 | if not self.functional: 153 | if self.shared_memory_attention: 154 | memory_size = int(self.shared_memory_percentage * x.size(0)) 155 | 156 | memory = torch.randn(memory_size, 1, x.size(2)).repeat(1 ,x.size(1), 1).to(x.device) 157 | else: 158 | memory = None 159 | if self.shared_memory_attention: 160 | if self.share_parameters: 161 | if self.enc.self_attn.memory is not None: 162 | self.enc.self_attn.init_memory(x.size(1), x.size(0), x.device)#.memory = self.enc.self_attn.memory.detach() 163 | else: 164 | for layer in self.layers: 165 | if layer.self_attn.memory is not None: 166 | layer.self_attn.init_memory(x.size(1), x.device)#.memory = layer.self_attn.memory.detach() 167 | 168 | 169 | for i in range(self.num_layers): 170 | if self.share_parameters: 171 | x, memory = self.enc(x, mask, memory = memory) 172 | else: 173 | x, memory = self.layers[i](x, mask, memory = memory) 174 | return x.permute(1, 0, 2) 175 | else: 176 | 177 | T, B, D = x.size() 178 | 179 | if num_layers is None: 180 | num_layers = self.num_layers 181 | 182 | 183 | #state = self.pe_state(torch.randn(x.size()).to(x.device)) 184 | 185 | if self.shared_memory_attention: 186 | memory_size = int(self.shared_memory_percentage * x.size(0)) 187 | memory_inp = torch.randn( memory_size, 1, x.size(2)).repeat(1, x.size(1), 1).to(x.device) 188 | memory_state = torch.randn(memory_size, 1, x.size(2)).repeat(1, x.size(1), 1).to(x.device) 189 | else: 190 | memory_inp = None 191 | memory_state = None 192 | 193 | if self.shared_memory_attention: 194 | for inp_att in self.inp_att: 195 | if inp_att.self_attn.memory is not None: 196 | inp_att.self_attn.init_memory(x.size(1), x.device)#memory = inp_att.self_attn.memory.detach() 197 | for i in range(0, num_layers): 198 | gru_ins = [] 199 | for inp_att in self.inp_att: 200 | gru_in, memory_inp = inp_att(x, mask, memory = memory_inp) 201 | gru_ins.append(gru_in.permute(1, 0, 2)) 202 | 203 | gru_ins = torch.stack(gru_ins, dim = 2) 204 | gru_ins = gru_ins.reshape(B * T, -1, D) 205 | 206 | 207 | x = x.permute(1, 0, 2) 208 | x = x.reshape(B * T, -1).unsqueeze(1) 209 | 210 | attn_scores_inp_att = self.select_attention_inp_att(x, gru_ins) 211 | 212 | attn_scores_inp_att = attn_scores_inp_att.squeeze(1) 213 | attn_scores_inp_att = torch.nn.functional.gumbel_softmax(attn_scores_inp_att, dim = 1, hard = True, tau = 0.5) 214 | 215 | attn_scores_inp_att = attn_scores_inp_att.unsqueeze(-1) 216 | 217 | gru_in = (gru_ins * attn_scores_inp_att).sum(dim = 1) 218 | 219 | gru_in = gru_in.reshape(B, T, -1) 220 | x = x.reshape(B, T, -1) 221 | 222 | gru_in = gru_in.reshape(B * T, -1) 223 | x = x.reshape(B * T, -1) 224 | 225 | gru_outs = [] 226 | 227 | for gru in self.gru_pool: 228 | gru_outs.append(gru(gru_in, x)) 229 | 230 | gru_outs = torch.stack(gru_outs, dim = 1) 231 | 232 | selector = torch.cat((gru_in, x), dim = 1).unsqueeze(1) 233 | if not self.schema_specific: 234 | attn_scores = self.select_attention(selector, gru_outs) 235 | 236 | 237 | attn_scores = attn_scores.squeeze(1) 238 | 239 | attn_scores = torch.nn.functional.gumbel_softmax(attn_scores, dim = 1, tau = 1.0, hard = True) 240 | 241 | att_argmax = torch.sum(attn_scores.clone().detach(), dim = 0).cpu().numpy() 242 | 243 | self.schema_stats += att_argmax 244 | 245 | 246 | attn_scores = attn_scores.unsqueeze(-1) 247 | else: 248 | attn_scores = attn_scores_inp_att 249 | att_argmax = torch.sum(attn_scores.squeeze(-1).clone().detach(), dim = 0).cpu().numpy() 250 | 251 | self.schema_stats += att_argmax 252 | 253 | gru_outs = (gru_outs * attn_scores).sum(dim = 1) 254 | gru_outs_hidden = gru_outs.reshape(B, T, -1) 255 | gru_outs_hidden = gru_outs_hidden.permute(1, 0, 2) 256 | #gru_outs_hidden, memory_state = self.state_att(gru_outs_hidden, mask, memory = memory_state) 257 | #gru_in = gru_in.reshape(B, T, -1).permute(1, 0, 2) 258 | #x = gru_in 259 | x = gru_outs_hidden 260 | 261 | return x.permute(1,0,2) 262 | 263 | def print_schema_stats(self): 264 | total = np.sum(self.schema_stats) 265 | for k in range(self.schema_stats.shape[0]): 266 | print('schema ' + str(k) + ' used ' + str(self.schema_stats[k]) + ' out of ' + str(total) + ' times') 267 | 268 | 269 | def reset_schema_stats(self): 270 | self.schema_stats = np.zeros(self.num_gru_schemas) 271 | 272 | 273 | if __name__ == "__main__": 274 | x = torch.randn(8, 20, 256).cuda() 275 | import time 276 | TE1 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = False, shared_memory_attention = True, mem_slots = 8, num_steps = 20).cuda() 277 | t1 = time.time() 278 | for i in range(5): 279 | 280 | x = TE1(x) 281 | print(time.time() - t1) 282 | 283 | 284 | x = torch.randn(8, 20, 256).cuda() 285 | import time 286 | TE1 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = False, shared_memory_attention = True, mem_slots = 8, num_steps = 20).cuda() 287 | t1 = time.time() 288 | for i in range(5): 289 | 290 | x = TE1(x) 291 | print(time.time() - t1) 292 | x = torch.randn(8, 20, 256).cuda() 293 | TE2 = TransformerEncoder(256, 512, num_layers = 1, functional = False, num_gru_schemas = 3, num_attention_schemas = 3, schema_specific = True, shared_memory_attention = False, mem_slots = 8, num_steps = 20).cuda() 294 | t1 = time.time() 295 | for i in range(5): 296 | x = TE2(x) 297 | print(time.time() - t1) 298 | -------------------------------------------------------------------------------- /sort_of_clevr/translator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | def translate(dataset): 3 | img, (rel_questions, rel_answers), (norel_questions, norel_answers) = dataset 4 | colors = ['red ', 'green ', 'blue ', 'orange ', 'gray ', 'yellow '] 5 | answer_sheet = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6'] 6 | questions = rel_questions + norel_questions 7 | answers = rel_answers + norel_answers 8 | 9 | print rel_questions 10 | print rel_answers 11 | 12 | 13 | for question,answer in zip(questions,answers): 14 | query = '' 15 | query += colors[question.tolist()[0:6].index(1)] 16 | 17 | if question[6] == 1: 18 | if question[8] == 1: 19 | query += 'shape?' 20 | if question[9] == 1: 21 | query += 'left?' 22 | if question[10] == 1: 23 | query += 'up?' 24 | if question[7] == 1: 25 | if question[8] == 1: 26 | query += 'closest shape?' 27 | if question[9] == 1: 28 | query += 'furthest shape?' 29 | if question[10] == 1: 30 | query += 'count?' 31 | 32 | ans = answer_sheet[answer] 33 | print query,'==>', ans 34 | #cv2.imwrite('sample.jpg',(img*255).astype(np.int32)) 35 | cv2.imshow('img',cv2.resize(img,(512,512))) 36 | cv2.waitKey(0) 37 | --------------------------------------------------------------------------------