├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── experiments ├── _context.py ├── attention.py ├── bias.py ├── convolution.py ├── gconv-simple.py ├── gconvolution.py ├── identity.py ├── memory.py ├── minmal.py ├── sparsity-mlp.py ├── sparsity.py └── transformer.py ├── scripts ├── README.md ├── _context.py ├── generate.mnist.py ├── plot.identity.py ├── plot.sort.py └── plot.sparsity.py ├── setup.py ├── sparse ├── __init__.py ├── layers.py ├── sort.py ├── tensors.py └── util │ ├── __init__.py │ ├── plot.py │ └── util.py ├── test.py └── tests ├── _context.py ├── test_layers.py ├── test_tensors.py ├── test_util.py └── tests.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Peter Bloem 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse, adaptive hyperlayers 2 | 3 | This is the codebase that accompanies the paper [Learning sparse transformations through brackpropagation](http://www.peterbloem.nl/publications/learning-sparse-transformations). Follow the link for the paper and an annotated slidedeck. 4 | 5 | ## Disclaimer 6 | 7 | We are still cleaning up the code, but it should now be relatively readable. Make sure 8 | you have PyTorch 1.0 installed and start by running ```experiments/identity.py```, 9 | which runs the identity experiment: 10 | ``` 11 | python experiments/identity.py -F 12 | ``` 13 | The ```-F``` flag sets all values of the matrix to 1, which makes learning a little easier. 14 | 15 | Feel free to ask me for help by making an issue, or sending [an email](mailto:sparse@peterbloem.nl). 16 | 17 | The ```archive``` branch contains a snapshot of the code at the time the preprint went up. 18 | 19 | ## Dependencies (probably incomplete) 20 | 21 | * Numpy 22 | * Matplotlib 23 | * Pytorch 0.4 24 | * torchvision 25 | * tensorboardX 26 | -------------------------------------------------------------------------------- /experiments/_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../sparse'))) 6 | 7 | import sparse -------------------------------------------------------------------------------- /experiments/bias.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | 3 | from sparse import util 4 | from util import d 5 | 6 | import torch 7 | 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torchvision 11 | from torch import nn 12 | from torch.autograd import Variable 13 | from tqdm import trange 14 | 15 | import matplotlib as mpl 16 | mpl.use('Agg') 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | import torchvision 21 | import torchvision.transforms as transforms 22 | from torchvision.transforms import ToTensor 23 | from torch.utils.data import TensorDataset, DataLoader 24 | 25 | import torch.distributions as ds 26 | 27 | from argparse import ArgumentParser 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | import random, tqdm, sys, math, os 31 | 32 | """ 33 | Experiment to test bias of gradient estimator. Simple encoder/decoder with discrete latent space. 34 | 35 | """ 36 | 37 | 38 | def sample_gumbel(shape, eps=1e-20, cuda=False): 39 | U = torch.rand(shape, device=d(cuda)) 40 | return -Variable(torch.log(-torch.log(U + eps) + eps)) 41 | 42 | def gumbelize(logits, temperature=1.0): 43 | y = logits + sample_gumbel(logits.size(), cuda=logits.is_cuda) 44 | return y / temperature 45 | 46 | def gradient(models): 47 | """ 48 | Returns the gradient of the given models as a single vector 49 | :param models: 50 | :return: 51 | """ 52 | gs = [] 53 | for model in models: 54 | for param in model.parameters(): 55 | if param.requires_grad: 56 | gs.append(param.grad.data.view(-1)) 57 | 58 | return torch.cat(gs, dim=0) 59 | 60 | 61 | def num_params(models): 62 | """ 63 | Returns the gradient of the given models as a single vector 64 | :param models: 65 | :return: 66 | """ 67 | gs = 0 68 | for model in models: 69 | for param in model.parameters(): 70 | if param.requires_grad: 71 | gs += param.view(-1).size(0) 72 | 73 | return gs 74 | 75 | def clean(axes=None): 76 | 77 | if axes is None: 78 | axes = plt.gca() 79 | 80 | [s.set_visible(False) for s in axes.spines.values()] 81 | axes.tick_params(top=False, bottom=False, left=False, right=False, labelbottom=False, labelleft=False) 82 | 83 | class Encoder(nn.Module): 84 | 85 | def __init__(self, data_size, latent_size=128, depth=3): 86 | super().__init__() 87 | 88 | c, h, w = data_size 89 | cs = [c] + [2**(d+4) for d in range(depth)] 90 | 91 | div = 2 ** depth 92 | 93 | modules = [] 94 | 95 | for d in range(depth): 96 | modules += [ 97 | nn.Conv2d(cs[d], cs[d+1], 3, padding=1), nn.ReLU(), 98 | nn.Conv2d(cs[d+1], cs[d+1], 3, padding=1), nn.ReLU(), 99 | nn.MaxPool2d((2, 2)) 100 | ] 101 | 102 | modules += [ 103 | util.Flatten(), 104 | nn.Linear(cs[-1] * (h//div) * (w//div), 1024), nn.ReLU(), 105 | nn.Linear(1024, latent_size) # encoder produces a cont. index tuple (ln -1 for the means, 1 for the sigma) 106 | ] 107 | 108 | self.encoder = nn.Sequential(*modules) 109 | 110 | def forward(self, x): 111 | 112 | return self.encoder(x) 113 | 114 | class Decoder(nn.Module): 115 | 116 | def __init__(self, data_size, latent_size=128, depth=3): 117 | super().__init__() 118 | 119 | upmode = 'bilinear' 120 | 121 | c, h, w = data_size 122 | cs = [c] + [2**(d+4) for d in range(depth)] 123 | 124 | div = 2 ** depth 125 | cl = lambda x : int(math.ceil(x)) 126 | 127 | modules = [ 128 | nn.Linear(latent_size, cs[-1] * cl(h/div) * cl(w/div)), nn.ReLU(), 129 | util.Reshape( (cs[-1], cl(h/div), cl(w/div)) ) 130 | ] 131 | 132 | for d in range(depth, 0, -1): 133 | modules += [ 134 | nn.Upsample(scale_factor=2, mode=upmode), 135 | nn.ConvTranspose2d(cs[d], cs[d], 3, padding=1), nn.ReLU(), 136 | nn.ConvTranspose2d(cs[d], cs[d-1], 3, padding=1), nn.ReLU() 137 | ] 138 | 139 | modules += [ 140 | nn.ConvTranspose2d(c, c, (3, 3), padding=1), nn.Sigmoid(), 141 | util.Lambda(lambda x : x[:, :, :h, :w]) # crop out any extra pixels due to rounding errors 142 | ] 143 | self.decoder = nn.Sequential(*modules) 144 | 145 | def forward(self, x): 146 | return self.decoder(x) 147 | 148 | def go(arg): 149 | 150 | try: 151 | arg.bins = int(arg.bins) 152 | except ValueError: 153 | pass 154 | 155 | util.makedirs('./bias/') 156 | 157 | if not os.path.exists('./bias/cached.npz'): 158 | 159 | if arg.seed < 0: 160 | seed = random.randint(0, 1000000) 161 | print('random seed: ', seed) 162 | else: 163 | torch.manual_seed(arg.seed) 164 | 165 | tbw = SummaryWriter(log_dir=arg.tb_dir) 166 | tfms = transforms.Compose([transforms.ToTensor()]) 167 | 168 | if (arg.task == 'mnist'): 169 | 170 | shape = (1, 28, 28) 171 | num_classes = 10 172 | 173 | data = arg.data + os.sep + arg.task 174 | 175 | if arg.final: 176 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=tfms) 177 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch_size, shuffle=True, num_workers=0) 178 | 179 | test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=ToTensor()) 180 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch_size, shuffle=False, num_workers=0) 181 | 182 | else: 183 | NUM_TRAIN = 45000 184 | NUM_VAL = 5000 185 | total = NUM_TRAIN + NUM_VAL 186 | 187 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=tfms) 188 | 189 | trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 190 | testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 191 | 192 | elif (arg.task == 'cifar10'): 193 | 194 | shape = (3, 32, 32) 195 | num_classes = 10 196 | 197 | data = arg.data + os.sep + arg.task 198 | 199 | if arg.final: 200 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=tfms) 201 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch, shuffle=True, num_workers=2) 202 | test = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=ToTensor()) 203 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch, shuffle=False, num_workers=2) 204 | 205 | else: 206 | NUM_TRAIN = 45000 207 | NUM_VAL = 5000 208 | total = NUM_TRAIN + NUM_VAL 209 | 210 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=tfms) 211 | 212 | trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 213 | testloader = DataLoader(train, batch_size=arg.batch, 214 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 215 | 216 | elif arg.task == 'ffhq': 217 | 218 | transform = ToTensor() 219 | shape = (3, 128, 128) 220 | 221 | trainset = torchvision.datasets.ImageFolder(root=arg.data+os.sep+'train', 222 | transform=transform) 223 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch, 224 | shuffle=True, num_workers=2) 225 | 226 | testset = torchvision.datasets.ImageFolder(root=arg.data+os.sep+'valid', 227 | transform=transform) 228 | testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch, 229 | shuffle=False, num_workers=2) 230 | 231 | else: 232 | raise Exception('Task {} not recognized'.format(arg.task)) 233 | 234 | encoder = Encoder(shape, latent_size=arg.latent_size, depth=arg.depth) 235 | decoder = Decoder(shape, latent_size=arg.latent_size, depth=arg.depth) 236 | 237 | if arg.cuda: 238 | encoder.cuda() 239 | decoder.cuda() 240 | 241 | opt = torch.optim.Adam(params=list(encoder.parameters()) + list(decoder.parameters()), lr=arg.lr) 242 | 243 | nparms = num_params([encoder]) 244 | print(f'{nparms} parameters in encoder.') 245 | 246 | seen = 0 247 | l = arg.latent_size 248 | ti = random.sample(range(nparms), arg.num_params) # random indices of parameters for which to test the gradient 249 | k = arg.k 250 | 251 | # Train for a fixed nr of instances (with the true gradient) 252 | for e in range(arg.epochs): 253 | print('epoch', e) 254 | 255 | for i, (inputs, _) in enumerate(trainloader): 256 | 257 | b, c, h, w = inputs.size() 258 | 259 | if arg.cuda: 260 | inputs = inputs.cuda() 261 | 262 | # compute actual gradient 263 | opt.zero_grad() 264 | 265 | latent = encoder(inputs) 266 | latent = F.softmax(latent, dim=1) 267 | 268 | dinp = torch.eye(l, device=d(arg.cuda))[None, :, :].expand(b, l, l).reshape(b*l, l) 269 | dout = decoder(dinp) 270 | 271 | assert dout.size() == (b*l, c, h, w) 272 | 273 | target = inputs.detach()[:, None, :, :, :].expand(b, l, c, h, w).reshape(b*l, c, h, w) 274 | 275 | loss = F.binary_cross_entropy(dout, target, reduction='none') 276 | loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b, l) 277 | 278 | loss = (loss * latent).sum(dim=1).mean() 279 | 280 | loss.backward() 281 | 282 | true_gradient = gradient([encoder, decoder]) 283 | true_gradient = true_gradient[ti] 284 | 285 | opt.step() 286 | 287 | inputs, _ = next(iter(trainloader)) 288 | if arg.cuda: 289 | inputs = inputs.cuda() 290 | 291 | b, c, h, w = inputs.size() 292 | 293 | # compute true gradient 294 | opt.zero_grad() 295 | 296 | latent = encoder(inputs) 297 | latent = F.softmax(latent, dim=1) 298 | 299 | dinp = torch.eye(l, device=d(arg.cuda))[None, :, :].expand(b, l, l).reshape(b*l, l) 300 | dout = decoder(dinp) 301 | 302 | assert dout.size() == (b*l, c, h, w) 303 | 304 | target = inputs.detach()[:, None, :, :, :].expand(b, l, c, h, w).reshape(b*l, c, h, w) 305 | 306 | loss = F.binary_cross_entropy(dout, target, reduction='none') 307 | loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b, l) 308 | 309 | loss = (loss * latent).sum(dim=1).mean() 310 | 311 | loss.backward() 312 | 313 | true_gradient = gradient([encoder]) 314 | true_gradient = true_gradient[ti] 315 | 316 | # - Estimate the bias for the uninformed sampler 317 | 318 | uste = torch.zeros((arg.samples, len(ti),), device=d(arg.cuda)) 319 | 320 | # Unbiased, uninformed STE 321 | for s in trange(arg.samples): 322 | opt.zero_grad() 323 | 324 | ks = [random.sample(range(arg.latent_size), k) for _ in range(b)] 325 | ks = torch.tensor(ks, device=d(arg.cuda)) 326 | 327 | latent = encoder(inputs) 328 | latent = torch.gather(latent, dim=1, index=ks); assert latent.size() == (b, k) 329 | latent = F.softmax(latent, dim=1) 330 | 331 | dinp = torch.zeros(size=(b*k, l), device=d(arg.cuda)) 332 | dinp.scatter_(dim=1, index=ks.view(b*k, 1), value=1) 333 | dout = decoder(dinp) 334 | 335 | assert dout.size() == (b * k, c, h, w) 336 | 337 | target = inputs.detach()[:, None, :, :, :].expand(b, k, c, h, w).reshape(b * k, c, h, w) 338 | 339 | loss = F.binary_cross_entropy(dout, target, reduction='none') 340 | loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b, k) 341 | 342 | loss = (loss * latent).sum(dim=1).mean() 343 | 344 | loss.backward() 345 | 346 | samp_gradient = gradient([encoder]) 347 | uste[s, :] = samp_gradient[ti] 348 | 349 | del loss 350 | 351 | iste = torch.zeros((arg.samples, len(ti),), device=d(arg.cuda)) 352 | 353 | # Unbiased, informed STE 354 | # This behaves like the USTE, but ensures that the argmax is always included in the sample 355 | for s in trange(arg.samples): 356 | opt.zero_grad() 357 | 358 | latent = encoder(inputs) 359 | 360 | ks = [random.sample(range(arg.latent_size-1), k-1) for _ in range(b)] 361 | ks = torch.tensor(ks, device=d(arg.cuda)) 362 | am = latent.argmax(dim=1, keepdim=True) 363 | ks[ks > am] += 1 364 | 365 | ks = torch.cat([am, ks], dim=1) 366 | 367 | latent = torch.gather(latent, dim=1, index=ks); assert latent.size() == (b, k) 368 | latent = F.softmax(latent, dim=1) 369 | 370 | dinp = torch.zeros(size=(b * k, l), device=d()) 371 | dinp.scatter_(dim=1, index=ks.view(b * k, 1), value=1) 372 | dout = decoder(dinp) 373 | 374 | assert dout.size() == (b * k, c, h, w) 375 | 376 | target = inputs.detach()[:, None, :, :, :].expand(b, k, c, h, w).reshape(b * k, c, h, w) 377 | 378 | loss = F.binary_cross_entropy(dout, target, reduction='none') 379 | loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b, k) 380 | 381 | loss = (loss * latent).sum(dim=1).mean() 382 | 383 | loss.backward() 384 | 385 | samp_gradient = gradient([encoder]) 386 | iste[s, :] = samp_gradient[ti] 387 | 388 | del loss 389 | 390 | # Biased (?) gumbel STE 391 | # STE with gumbel noise 392 | 393 | gste = torch.zeros((arg.samples, len(ti),), device=d(arg.cuda)) 394 | 395 | for s in trange(arg.samples): 396 | for _ in range(k): 397 | opt.zero_grad() 398 | 399 | latent = encoder(inputs) 400 | 401 | gumbelize(latent, temperature=arg.gumbel) 402 | latent = F.softmax(latent, dim=1) 403 | 404 | ks = latent.argmax(dim=1, keepdim=True) 405 | 406 | dinp = torch.zeros(size=(b, l), device=d()) 407 | dinp.scatter_(dim=1, index=ks, value=1) 408 | 409 | dinp = (dinp - latent).detach() + latent # straight-through trick 410 | dout = decoder(dinp) 411 | 412 | assert dout.size() == (b, c, h, w) 413 | 414 | target = inputs.detach() 415 | 416 | loss = F.binary_cross_entropy(dout, target, reduction='none') 417 | loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b) 418 | loss = loss.mean() 419 | 420 | loss.backward() 421 | 422 | samp_gradient = gradient([encoder]) 423 | gste[s, :] += samp_gradient[ti] 424 | 425 | del loss 426 | 427 | gste[s, :] /= k 428 | 429 | # Classical STE 430 | # cste = torch.zeros((arg.samples, len(ti),), device=d(arg.cuda)) 431 | # 432 | # for s in trange(arg.samples): 433 | # opt.zero_grad() 434 | # 435 | # latent = encoder(inputs) 436 | # 437 | # # gumbelize(latent, temperature=arg.gumbel) 438 | # dist = ds.Categorical(logits=latent) 439 | # ks = dist.sample()[:, None] 440 | # 441 | # dinp = torch.zeros(size=(b, l), device=d()) 442 | # dinp.scatter_(dim=1, index=ks, value=1) 443 | # 444 | # dinp = (dinp - latent).detach() + latent # straight-through trick 445 | # dout = decoder(dinp) 446 | # 447 | # assert dout.size() == (b, c, h, w) 448 | # 449 | # target = inputs.detach() 450 | # 451 | # loss = F.binary_cross_entropy(dout, target, reduction='none') 452 | # loss = loss.sum(dim=1).sum(dim=1).sum(dim=1).view(b) 453 | # loss = loss.mean() 454 | # 455 | # loss.backward() 456 | # 457 | # samp_gradient = gradient([encoder]) 458 | # cste[s, :] = samp_gradient[ti] 459 | # 460 | # del loss 461 | 462 | uste = uste.cpu().numpy() 463 | iste = iste.cpu().numpy() 464 | gste = gste.cpu().numpy() 465 | tgrd = true_gradient.cpu().numpy() 466 | 467 | np.savez_compressed('./bias/cached.npz', uste=uste, iste=iste, gste=gste, tgrd=tgrd) 468 | 469 | else: 470 | res = np.load('./bias/cached.npz') 471 | uste, iste, gste, tgrd = res['uste'], res['iste'], res['gste'], res['tgrd'] 472 | 473 | ind = tgrd != 0.0 474 | print(tgrd.shape, ind) 475 | 476 | print(f'{ind.sum()} derivatives out of {ind.shape} not equal to zero.') 477 | 478 | if not arg.skip: 479 | for nth, i in enumerate( np.arange(ind.shape[0])[ind][:5] ): 480 | 481 | plt.gcf().clear() 482 | 483 | unump = uste[:, i] 484 | inump = iste[:, i] 485 | gnump = gste[:, i] 486 | # cnump = cste[:, i].cpu().numpy() 487 | 488 | ulab = f'uninformed, var={unump.var():.4}' 489 | ilab = f'informed, var={inump.var():.4}' 490 | glab = f'Gumbel STE (t={arg.gumbel}) var={gnump.var():.4}' 491 | # clab = f'Classical STE var={cnump.var():.4}' 492 | 493 | plt.hist([unump, inump, gnump], color=['r', 'g', 'b'], label=[ulab, ilab, glab], bins=arg.bins) 494 | 495 | plt.axvline(x=tgrd[i], color='k', label='true gradient') 496 | plt.axvline(x=unump.mean(), color='r', ls='--') 497 | plt.axvline(x=inump.mean(), color='g', ls='-.') 498 | plt.axvline(x=gnump.mean(), color='b', ls=':') 499 | # plt.axvline(x=cnump.mean(), color='c') 500 | 501 | plt.title(f'estimates for parameter ... ({uste.shape[0]} samples)') 502 | 503 | plt.legend() 504 | util.basic() 505 | 506 | plt.savefig(f'./bias/histogram.{nth}.pdf') 507 | 508 | 509 | plt.gcf().clear() 510 | 511 | unump = uste[:, ind].mean(axis=0) 512 | inump = iste[:, ind].mean(axis=0) 513 | gnump = gste[:, ind].mean(axis=0) 514 | 515 | tnump = tgrd[ind] 516 | 517 | unump = np.abs(unump - tnump) 518 | inump = np.abs(inump - tnump) 519 | gnump = np.abs(gnump - tnump) 520 | 521 | ulab = f'uninformed, var={unump.var():.4}' 522 | ilab = f'informed, var={inump.var():.4}' 523 | glab = f'gumbel STE (t={arg.gumbel}) var={gnump.var():.4}' 524 | # clab = f'Classical STE var={cnump.var():.4}' 525 | 526 | plt.hist([unump, inump, gnump], color=['r', 'g', 'b'], label=[ulab, ilab, glab], bins=arg.bins) 527 | 528 | plt.axvline(x=unump.mean(), color='r', ls='--') 529 | plt.axvline(x=inump.mean(), color='g', ls='-.') 530 | plt.axvline(x=gnump.mean(), color='b', ls=':') 531 | # plt.axvline(x=cnump.mean(), color='c') 532 | 533 | plt.title(f'Absolute error between true gradient and estimate \n over {ind.sum()} parameters with nonzero gradient.') 534 | 535 | plt.legend() 536 | util.basic() 537 | 538 | if arg.range is not None: 539 | plt.xlim(*arg.range) 540 | 541 | plt.savefig(f'./bias/histogram.all.pdf') 542 | 543 | if __name__ == "__main__": 544 | 545 | parser = ArgumentParser() 546 | 547 | parser.add_argument("-e", "--epochs", 548 | dest="epochs", 549 | help="Number of epochs to train (with the true gradient) before testing the estimator biases.", 550 | default=5, type=int) 551 | 552 | parser.add_argument("-b", "--batch", 553 | dest="batch", 554 | help="Batch size", 555 | default=8, type=int) 556 | 557 | parser.add_argument("-d", "--depth", 558 | dest="depth", 559 | help="Depth of the autoencoder (number of maxpooling operations).", 560 | default=3, type=int) 561 | 562 | parser.add_argument("--num-params", 563 | dest="num_params", 564 | help="Depth", 565 | default=50000, type=int) 566 | 567 | parser.add_argument("--task", 568 | dest="task", 569 | help="Dataset to model (mnist, cifar10)", 570 | default='mnist', type=str) 571 | 572 | parser.add_argument("--latent-size", 573 | dest="latent_size", 574 | help="Size of the discrete latent space.", 575 | default=128, type=int) 576 | 577 | parser.add_argument("--samples", 578 | dest="samples", 579 | help="Number of samples to take from the estimator.", 580 | default=100, type=int) 581 | 582 | parser.add_argument("-p", "--plot-every", 583 | dest="plot_every", 584 | help="Number of epochs to wait between plotting", 585 | default=1, type=int) 586 | 587 | parser.add_argument("-k", "--set-size", 588 | dest="k", 589 | help="Size of the sample (the set S). For the gumbel softmax, we average the estimate over k separate samples", 590 | default=5, type=int) 591 | 592 | parser.add_argument("-l", "--learn-rate", 593 | dest="lr", 594 | help="Learning rate", 595 | default=0.0001, type=float) 596 | 597 | parser.add_argument("--limit", 598 | dest="limit", 599 | help="Limit.", 600 | default=None, type=int) 601 | 602 | parser.add_argument("-r", "--seed", 603 | dest="seed", 604 | help="Random seed", 605 | default=0, type=int) 606 | 607 | parser.add_argument("-c", "--cuda", dest="cuda", 608 | help="Whether to use cuda.", 609 | action="store_true") 610 | 611 | parser.add_argument("-D", "--data", dest="data", 612 | help="Data directory", 613 | default='./data') 614 | 615 | parser.add_argument("-f", "--final", dest="final", 616 | help="Whether to run on the real test set (if not included, the validation set is used).", 617 | action="store_true") 618 | 619 | parser.add_argument("-T", "--tb_dir", dest="tb_dir", 620 | help="Data directory", 621 | default=None) 622 | 623 | parser.add_argument("-G", "--gumbel", dest="gumbel", 624 | help="Gumbel temperature.", 625 | default=1.0, type=float) 626 | 627 | parser.add_argument("--range", dest="range", 628 | help="Range for the 'all' plot.", 629 | nargs=2, 630 | default=None, type=float) 631 | 632 | parser.add_argument("--bins", dest="bins", 633 | help="Nr of bins (or binning strategy).", 634 | default='sturges') 635 | 636 | parser.add_argument("--skip", dest="skip", 637 | help="Skip the per-parameter histograms.", 638 | action="store_true") 639 | 640 | 641 | args = parser.parse_args() 642 | 643 | print('OPTIONS', args) 644 | 645 | go(args) 646 | -------------------------------------------------------------------------------- /experiments/gconv-simple.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import matplotlib as mpl 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from tensorboardX import SummaryWriter 9 | from torch import nn 10 | from torch.autograd import Variable 11 | from tqdm import trange 12 | 13 | import gaussian 14 | import util 15 | from util import sparsemm 16 | 17 | mpl.use('Agg') 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | 21 | from argparse import ArgumentParser 22 | 23 | import networkx as nx 24 | 25 | import math 26 | 27 | import torch 28 | 29 | from torch.nn.parameter import Parameter 30 | from torch.nn.modules.module import Module 31 | 32 | """ 33 | Simple Graph convolution experiment. Given a set of random vectors, learn to express each as the sum of some of the 34 | others 35 | """ 36 | 37 | def clean(axes=None): 38 | 39 | if axes is None: 40 | axes = plt.gca() 41 | 42 | [s.set_visible(False) for s in axes.spines.values()] 43 | axes.tick_params(top=False, bottom=False, left=False, right=False, labelbottom=False, labelleft=False) 44 | 45 | 46 | def densities(points, means, sigmas): 47 | """ 48 | Compute the unnormalized PDFs of the points under the given MVNs 49 | 50 | (with sigma a diagonal matrix per MVN) 51 | 52 | :param means: 53 | :param sigmas: 54 | :param points: 55 | :return: 56 | """ 57 | 58 | # n: number of MVNs 59 | # d: number of points per MVN 60 | # rank: dim of points 61 | 62 | batchsize, n, rank = points.size() 63 | batchsize, k, rank = means.size() 64 | # batchsize, k, rank = sigmas.size() 65 | 66 | points = points.unsqueeze(2).expand(batchsize, n, k, rank) 67 | means = means.unsqueeze(1).expand_as(points) 68 | sigmas = sigmas.unsqueeze(1).expand_as(points) 69 | 70 | sigmas_squared = torch.sqrt(1.0/(gaussian.EPSILON + sigmas)) 71 | 72 | points = points - means 73 | points = points * sigmas_squared 74 | 75 | # Compute dot products for all points 76 | # -- unroll the batch/n dimensions 77 | points = points.view(-1, 1, rank) 78 | # -- dot prod 79 | products = torch.bmm(points, points.transpose(1,2)) 80 | # -- reconstruct shape 81 | products = products.view(batchsize, n, k) 82 | 83 | num = torch.exp(- 0.5 * products) 84 | 85 | return num 86 | 87 | class MatrixHyperlayer(nn.Module): 88 | """ 89 | Constrained version of the matrix hyperlayer. Each output get exactly k inputs 90 | """ 91 | 92 | def duplicates(self, tuples): 93 | """ 94 | Takes a list of tuples, and for each tuple that occurs mutiple times 95 | marks all but one of the occurences (in the mask that is returned). 96 | 97 | :param tuples: A size (batch, k, rank) tensor of integer tuples 98 | :return: A size (batch, k) mask indicating the duplicates 99 | """ 100 | b, k, r = tuples.size() 101 | 102 | primes = self.primes[:r] 103 | primes = primes.unsqueeze(0).unsqueeze(0).expand(b, k, r) 104 | unique = ((tuples+1) ** primes).prod(dim=2) # unique identifier for each tuple 105 | 106 | sorted, sort_idx = torch.sort(unique, dim=1) 107 | _, unsort_idx = torch.sort(sort_idx, dim=1) 108 | 109 | mask = sorted[:, 1:] == sorted[:, :-1] 110 | 111 | zs = torch.zeros(b, 1, dtype=torch.uint8, device='cuda' if self.use_cuda else 'cpu') 112 | mask = torch.cat([zs, mask], dim=1) 113 | 114 | return torch.gather(mask, 1, unsort_idx) 115 | 116 | def cuda(self, device_id=None): 117 | 118 | self.use_cuda = True 119 | super().cuda(device_id) 120 | 121 | def __init__(self, in_num, out_num, k, radditional=0, gadditional=0, region=(128,), 122 | sigma_scale=0.2, min_sigma=0.0, fix_value=False): 123 | super().__init__() 124 | 125 | self.min_sigma = min_sigma 126 | self.use_cuda = False 127 | self.in_num = in_num 128 | self.out_num = out_num 129 | self.k = k 130 | self.radditional = radditional 131 | self.region = region 132 | self.gadditional = gadditional 133 | self.sigma_scale = sigma_scale 134 | self.fix_value = fix_value 135 | 136 | self.weights_rank = 2 # implied rank of W 137 | 138 | self.params = Parameter(torch.randn(k * out_num, 3)) 139 | 140 | outs = torch.arange(out_num).unsqueeze(1).expand(out_num, k * (2 + radditional + gadditional)).contiguous().view(-1, 1) 141 | self.register_buffer('outs', outs.long()) 142 | 143 | outs_inf = torch.arange(out_num).unsqueeze(1).expand(out_num, k).contiguous().view(-1, 1) 144 | self.register_buffer('outs_inf', outs_inf.long()) 145 | 146 | self.register_buffer('primes', torch.tensor(util.PRIMES)) 147 | 148 | 149 | def size(self): 150 | return (self.out_num, self.in_num) 151 | 152 | def generate_integer_tuples(self, means,rng=None, use_cuda=False): 153 | 154 | dv = 'cuda' if use_cuda else 'cpu' 155 | 156 | c, k, rank = means.size() 157 | 158 | assert rank == 1 159 | # In the following, we cut the first dimension up into chunks of size self.k (for which the row index) 160 | # is the same. This then functions as a kind of 'batch' dimension, allowing us to use the code from 161 | # globalsampling without much adaptation 162 | 163 | """ 164 | Sample the 2 nearest points 165 | """ 166 | 167 | floor_mask = torch.tensor([1, 0], device=dv, dtype=torch.uint8) 168 | fm = floor_mask.unsqueeze(0).unsqueeze(2).expand(c, k, 2, 1) 169 | 170 | neighbor_ints = means.data.unsqueeze(2).expand(c, k, 2, 1).contiguous() 171 | 172 | neighbor_ints[fm] = neighbor_ints[fm].floor() 173 | neighbor_ints[~fm] = neighbor_ints[~fm].ceil() 174 | 175 | neighbor_ints = neighbor_ints.long() 176 | 177 | """ 178 | Sample uniformly from a small range around the given index tuple 179 | """ 180 | rr_ints = torch.cuda.FloatTensor(c, k, self.radditional, 1) if use_cuda else torch.FloatTensor(c, k, self.radditional, 1) 181 | 182 | rr_ints.uniform_() 183 | rr_ints *= (1.0 - gaussian.EPSILON) 184 | 185 | rng = torch.cuda.FloatTensor(rng) if use_cuda else torch.FloatTensor(rng) 186 | 187 | rngxp = rng.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(rr_ints) # bounds of the tensor 188 | rrng = torch.cuda.FloatTensor(self.region) if use_cuda else torch.FloatTensor(self.region) # bounds of the range from which to sample 189 | rrng = rrng.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(rr_ints) 190 | 191 | mns_expand = means.round().unsqueeze(2).expand_as(rr_ints) 192 | 193 | # upper and lower bounds 194 | lower = mns_expand - rrng * 0.5 195 | upper = mns_expand + rrng * 0.5 196 | 197 | # check for any ranges that are out of bounds 198 | idxs = lower < 0.0 199 | lower[idxs] = 0.0 200 | 201 | idxs = upper > rngxp 202 | lower[idxs] = rngxp[idxs] - rrng[idxs] 203 | 204 | rr_ints = (rr_ints * rrng + lower).long() 205 | 206 | """ 207 | Sample uniformly from all index tuples 208 | """ 209 | g_ints = torch.cuda.FloatTensor(c, k, self.gadditional, 1) if use_cuda else torch.FloatTensor(c, k, self.gadditional, 1) 210 | rngxp = rng.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(g_ints) # bounds of the tensor 211 | 212 | g_ints.uniform_() 213 | g_ints *= (1.0 - gaussian.EPSILON) * rngxp 214 | g_ints = g_ints.long() 215 | 216 | ints = torch.cat([neighbor_ints, rr_ints, g_ints], dim=2) 217 | 218 | return ints.view(c, -1, rank) 219 | 220 | def forward(self, input, train=True): 221 | 222 | ### Compute and unpack output of hypernetwork 223 | 224 | means, sigmas, values = self.hyper(input) 225 | nm = means.size(0) 226 | c = nm // self.k 227 | 228 | means = means.view(c, self.k, 1) 229 | sigmas = sigmas.view(c, self.k, 1) 230 | values = values.view(c, self.k) 231 | 232 | rng = (self.in_num, ) 233 | 234 | assert input.size(0) == self.in_num 235 | 236 | if train: 237 | indices = self.generate_integer_tuples(means, rng=rng, use_cuda=self.use_cuda) 238 | indfl = indices.float() 239 | 240 | # Mask for duplicate indices 241 | dups = self.duplicates(indices) 242 | 243 | props = densities(indfl, means, sigmas).clone() # result has size (c, indices.size(1), means.size(1)) 244 | props[dups] = 0 245 | props = props / props.sum(dim=1, keepdim=True) 246 | 247 | values = values.unsqueeze(1).expand(c, indices.size(1), means.size(1)) 248 | 249 | values = props * values 250 | values = values.sum(dim=2) 251 | 252 | # unroll the batch dimension 253 | indices = indices.view(-1, 1) 254 | values = values.view(-1) 255 | 256 | indices = torch.cat([self.outs, indices.long()], dim=1) 257 | else: 258 | indices = means.round().long().view(-1, 1) 259 | values = values.squeeze().view(-1) 260 | 261 | indices = torch.cat([self.outs_inf, indices.long()], dim=1) 262 | 263 | 264 | if self.use_cuda: 265 | indices = indices.cuda() 266 | 267 | # Kill anything on the diagonal 268 | values[indices[:, 0] == indices[:, 1]] = 0.0 269 | 270 | # if self.symmetric: 271 | # # Add reverse direction automatically 272 | # flipped_indices = torch.cat([indices[:, 1].unsqueeze(1), indices[:, 0].unsqueeze(1)], dim=1) 273 | # indices = torch.cat([indices, flipped_indices], dim=0) 274 | # values = torch.cat([values, values], dim=0) 275 | 276 | ### Create the sparse weight tensor 277 | 278 | # Prevent segfault 279 | assert not util.contains_nan(values.data) 280 | 281 | vindices = Variable(indices.t()) 282 | sz = Variable(torch.tensor((self.out_num, self.in_num))) 283 | 284 | spmm = sparsemm(self.use_cuda) 285 | output = spmm(vindices, values, sz, input) 286 | 287 | return output 288 | 289 | def hyper(self, input=None): 290 | """ 291 | Evaluates hypernetwork. 292 | """ 293 | k, width = self.params.size() 294 | 295 | means = F.sigmoid(self.params[:, 0:1]) 296 | 297 | # Limits for each of the w_rank indices 298 | # and scales for the sigmas 299 | s = torch.cuda.FloatTensor((self.in_num,)) if self.use_cuda else torch.FloatTensor((self.in_num,)) 300 | s = Variable(s.contiguous()) 301 | 302 | ss = s.unsqueeze(0) 303 | sm = s - 1 304 | sm = sm.unsqueeze(0) 305 | 306 | means = means * sm.expand_as(means) 307 | 308 | sigmas = nn.functional.softplus(self.params[:, 1:2] + gaussian.SIGMA_BOOST) + gaussian.EPSILON 309 | 310 | values = self.params[:, 2:] # * 0.0 + 1.0 311 | 312 | sigmas = sigmas.expand_as(means) 313 | sigmas = sigmas * ss.expand_as(sigmas) 314 | sigmas = sigmas * self.sigma_scale + self.min_sigma 315 | 316 | return means, sigmas, values * 0.0 + 1.0/self.k if self.fix_value else values 317 | 318 | class GraphConvolution(Module): 319 | """ 320 | Code adapted from pyGCN, see https://github.com/tkipf/pygcn 321 | 322 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 323 | """ 324 | 325 | def __init__(self, in_features, out_features, bias=True, has_weight=True): 326 | 327 | super(GraphConvolution, self).__init__() 328 | self.in_features = in_features 329 | self.out_features = out_features 330 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if has_weight else None 331 | if bias: 332 | self.bias = Parameter(torch.FloatTensor(out_features)) 333 | else: 334 | self.register_parameter('bias', None) 335 | self.reset_parameters() 336 | 337 | 338 | def reset_parameters(self): 339 | if self.weight is not None: 340 | stdv = 1. / math.sqrt(self.weight.size(1)) 341 | self.weight.data.uniform_(-stdv, stdv) 342 | 343 | if self.bias is not None: 344 | self.bias.data.zero_() # different from the default implementation 345 | 346 | def forward(self, input, adj, train=True): 347 | 348 | if input is None: # The input is the identity matrix 349 | support = self.weight 350 | elif self.weight is not None: 351 | support = torch.mm(input, self.weight) 352 | else: 353 | support = input 354 | 355 | output = adj(support, train=train) 356 | 357 | if self.bias is not None: 358 | return output + self.bias 359 | else: 360 | return output 361 | 362 | class ConvModel(nn.Module): 363 | 364 | def __init__(self, data_size, k, radd=32, gadd=32, range=128, min_sigma=0.0): 365 | super().__init__() 366 | 367 | n, e = data_size 368 | 369 | self.adj = MatrixHyperlayer(n, n, k, radditional=radd, gadditional=gadd, region=(range,), 370 | min_sigma=min_sigma, fix_value=True) 371 | def freeze(self): 372 | for param in self.encoder_conv.parameters(): 373 | param.requires_grad = False 374 | 375 | for param in self.decoder_conv.parameters(): 376 | param.requires_grad = False 377 | 378 | def forward(self, x, depth=1, train=True): 379 | 380 | n, e = x.size() 381 | 382 | results = [] 383 | 384 | for _ in range(1, depth): 385 | x = self.adj(x, train=train) 386 | 387 | results.append(x) 388 | 389 | return results 390 | 391 | def cuda(self): 392 | 393 | super().cuda() 394 | 395 | self.adj.apply(lambda t: t.cuda()) 396 | 397 | PLOT_MAX = 2000 # max number of data points for the latent space plot 398 | 399 | def go(arg): 400 | 401 | MARGIN = 0.1 402 | util.makedirs('./conv-simple/') 403 | torch.manual_seed(arg.seed) 404 | 405 | writer = SummaryWriter() 406 | 407 | data = torch.randn(arg.size, arg.width) 408 | 409 | model = ConvModel(data.size(), k=arg.k, 410 | gadd=arg.gadditional, radd=arg.radditional, range=arg.range, 411 | min_sigma=arg.min_sigma) 412 | 413 | if arg.cuda: 414 | model.cuda() 415 | data = data.cuda() 416 | 417 | data, target = Variable(data), Variable(data) 418 | 419 | optimizer = optim.Adam(list(model.parameters()), lr=arg.lr) 420 | n, e = data.size() 421 | 422 | for epoch in trange(arg.epochs): 423 | 424 | optimizer.zero_grad() 425 | 426 | outputs = model(data, depth=arg.depth) 427 | 428 | loss = 0.0 429 | for i, o in enumerate(outputs): 430 | loss += F.mse_loss(o, data) 431 | 432 | # regularize sigmas 433 | _, sigmas, _ = model.adj.hyper() 434 | 435 | reg = sigmas.norm().mean() 436 | 437 | # print(loss.item(), reg.item()) 438 | # sys.exit() 439 | 440 | tloss = loss + 0.0001 * reg 441 | 442 | tloss.backward() 443 | optimizer.step() 444 | 445 | writer.add_scalar('conv-simple/train-tloss', tloss.item(), epoch) 446 | writer.add_scalar('conv-simple/train-loss', loss.item(), epoch) 447 | writer.add_scalar('conv-simple/train-reg', reg.item(), epoch) 448 | 449 | 450 | if epoch % arg.plot_every == 0: 451 | print('data') 452 | print(data[:3, :3].data) 453 | print() 454 | 455 | for o in outputs: 456 | print(o[:3, :3].data) 457 | 458 | # Plot the results 459 | with torch.no_grad(): 460 | 461 | outputs = model(data, depth=arg.depth, train=False) 462 | 463 | plt.figure(figsize=(8, 8)) 464 | 465 | means, sigmas, values = model.adj.hyper() 466 | means, sigmas, values = means.data, sigmas.data, values.data 467 | means = torch.cat([model.adj.outs_inf.data.float(), means], dim=1) 468 | 469 | plt.cla() 470 | 471 | s = model.adj.size() 472 | util.plot1d(means, sigmas, values.squeeze(), shape=s) 473 | plt.xlim((-MARGIN * (s[0] - 1), (s[0] - 1) * (1.0 + MARGIN))) 474 | plt.ylim((-MARGIN * (s[0] - 1), (s[0] - 1) * (1.0 + MARGIN))) 475 | 476 | plt.savefig('./conv-simple/means.{:05}.pdf'.format(epoch)) 477 | 478 | print('Finished Training.') 479 | 480 | if __name__ == "__main__": 481 | 482 | parser = ArgumentParser() 483 | 484 | parser.add_argument("-e", "--epochs", 485 | dest="epochs", 486 | help="Number of epochs", 487 | default=1000, type=int) 488 | 489 | parser.add_argument("-W", "--width", 490 | dest="width", 491 | help="Width of the data.", 492 | default=16, type=int) 493 | 494 | parser.add_argument("-k", "--num-points", 495 | dest="k", 496 | help="Number of index tuples", 497 | default=3, type=int) 498 | 499 | parser.add_argument("-S", "--size", 500 | dest="size", 501 | help="Number of data points", 502 | default=128, type=int) 503 | 504 | parser.add_argument("-a", "--gadditional", 505 | dest="gadditional", 506 | help="Number of additional points sampled globally per index-tuple", 507 | default=32, type=int) 508 | 509 | parser.add_argument("-A", "--radditional", 510 | dest="radditional", 511 | help="Number of additional points sampled locally per index-tuple", 512 | default=16, type=int) 513 | 514 | parser.add_argument("-R", "--range", 515 | dest="range", 516 | help="Range in which the local points are sampled", 517 | default=128, type=int) 518 | 519 | parser.add_argument("-d", "--depth", 520 | dest="depth", 521 | help="Number of graph convolutions", 522 | default=5, type=int) 523 | 524 | parser.add_argument("-p", "--plot-every", 525 | dest="plot_every", 526 | help="Numer of epochs to wait between plotting", 527 | default=100, type=int) 528 | 529 | parser.add_argument("-l", "--learn-rate", 530 | dest="lr", 531 | help="Learning rate", 532 | default=0.01, type=float) 533 | 534 | parser.add_argument("-r", "--seed", 535 | dest="seed", 536 | help="Random seed", 537 | default=4, type=int) 538 | 539 | parser.add_argument("-c", "--cuda", dest="cuda", 540 | help="Whether to use cuda.", 541 | action="store_true") 542 | 543 | parser.add_argument("-M", "--min-sigma", 544 | dest="min_sigma", 545 | help="Minimal sigma value", 546 | default=0.0, type=float) 547 | 548 | args = parser.parse_args() 549 | 550 | print('OPTIONS', args) 551 | 552 | go(args) 553 | -------------------------------------------------------------------------------- /experiments/identity.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | 3 | from sparse import NASLayer 4 | import sparse.util as util 5 | 6 | import torch, random, sys 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | from torch import nn, optim 10 | from tqdm import trange 11 | from tensorboardX import SummaryWriter 12 | 13 | import matplotlib as mpl 14 | mpl.use('Agg') 15 | import matplotlib.pyplot as plt 16 | import logging, time, gc, math 17 | import numpy as np 18 | 19 | from scipy.stats import sem 20 | 21 | from argparse import ArgumentParser 22 | 23 | import os 24 | 25 | logging.basicConfig(filename='run.log',level=logging.INFO) 26 | LOG = logging.getLogger() 27 | 28 | """ 29 | Simple experiment: learn the identity function from one tensor to another 30 | """ 31 | 32 | class ReinforceLayer(nn.Module): 33 | """ 34 | Baseline method: use REINFORCE to sample from the continuous index tuples. 35 | """ 36 | def __init__(self, size, 37 | sigma_scale=0.2, 38 | fix_values=False, 39 | min_sigma=0.0): 40 | 41 | super().__init__() 42 | 43 | self.size = size 44 | self.sigma_scale = sigma_scale 45 | self.fix_values = fix_values 46 | self.min_sigma = min_sigma 47 | 48 | self.pmeans = nn.Parameter(torch.randn(self.size, 2)) 49 | self.psigmas = nn.Parameter(torch.randn(self.size)) 50 | 51 | if fix_values: 52 | self.register_buffer('pvalues', torch.ones(self.size)) 53 | else: 54 | self.pvalues = nn.Parameter(torch.randn(self.size)) 55 | 56 | def is_cuda(self): 57 | return next(self.parameters()).is_cuda 58 | 59 | def hyper(self, x): 60 | 61 | b = x.size(0) 62 | size = (self.size, self.size) 63 | 64 | # Expand parameters along batch dimension 65 | means = self.pmeans[None, :, :].expand(b, self.size, 2) 66 | sigmas = self.psigmas[None, :].expand(b, self.size) 67 | values = self.pvalues[None, :].expand(b, self.size) 68 | 69 | means, sigmas = sparse.transform_means(means, size), sparse.transform_sigmas(sigmas, size) 70 | 71 | return means, sigmas, values 72 | 73 | def forward(self, x): 74 | size = (self.size, self.size) 75 | 76 | means, sigmas, values = self.hyper(x) 77 | 78 | dists = torch.distributions.Normal(means, sigmas) 79 | samples = dists.sample() 80 | 81 | indices = samples.data.round().long() 82 | 83 | # if the sampling puts the indices out of bounds, we just clip to the min and max values 84 | indices[indices < 0] = 0 85 | 86 | rngt = torch.tensor(data=size, device='cuda' if self.is_cuda() else 'cpu') 87 | maxes = rngt[None, None, :].expand_as(means) - 1 88 | indices[indices > maxes] = maxes[indices > maxes] 89 | 90 | y = sparse.contract(indices, values, size, x) 91 | 92 | return y, dists, samples 93 | 94 | def go(arg): 95 | 96 | MARGIN = 0.1 97 | 98 | iterations = arg.iterations if arg.iterations is not None else arg.size * 3000 99 | additional = arg.additional if arg.additional is not None else int(np.floor(np.log2(arg.size)) * arg.size) 100 | 101 | torch.manual_seed(arg.seed) 102 | 103 | ndots = iterations // arg.dot_every 104 | 105 | results = np.zeros((arg.reps, ndots)) 106 | 107 | print('Starting size {} with {} additional samples (reinforce={})'.format(arg.size, additional, arg.reinforce)) 108 | w = None 109 | for r in range(arg.reps): 110 | print('repeat {} of {}'.format(r, arg.reps)) 111 | 112 | util.makedirs('./identity/{}'.format(r)) 113 | util.makedirs('./runs/identity/{}'.format(r)) 114 | 115 | if w is not None: 116 | w.close() 117 | w = SummaryWriter(log_dir='./runs/identity/{}/'.format(r)) 118 | 119 | SHAPE = (arg.size,) 120 | 121 | if not arg.reinforce: 122 | model = sparse.NASLayer( 123 | SHAPE, SHAPE, 124 | k=arg.size, 125 | gadditional=additional, 126 | sigma_scale=arg.sigma_scale, 127 | has_bias=False, 128 | fix_values=arg.fix_values, 129 | min_sigma=arg.min_sigma, 130 | region=(arg.rr, arg.rr), 131 | radditional=arg.ca) 132 | else: 133 | model = ReinforceLayer( 134 | arg.size, 135 | fix_values=arg.fix_values 136 | ) 137 | 138 | if arg.cuda: 139 | model.cuda() 140 | 141 | optimizer = optim.Adam(model.parameters(), lr=arg.lr) 142 | 143 | for i in trange(iterations): 144 | model.train(True) 145 | 146 | x = torch.randn((arg.batch,) + SHAPE) 147 | 148 | if arg.cuda: 149 | x = x.cuda() 150 | x = Variable(x) 151 | 152 | if not arg.reinforce: 153 | 154 | if arg.subbatch is None: 155 | optimizer.zero_grad() 156 | 157 | y = model(x) 158 | 159 | loss = F.mse_loss(y, x) 160 | 161 | loss.backward() 162 | optimizer.step() 163 | else: # compute the gradient in multiple passes, useful for large matrices 164 | 165 | optimizer.zero_grad() 166 | 167 | # multiple forward/backward passes, accumulate gradient 168 | seed = (torch.rand(1) * 100000).long().item() 169 | 170 | for fr in range(0, arg.size, arg.subbatch): 171 | to = min(fr + arg.subbatch, arg.size) 172 | 173 | y = model(x, mrange=(fr, to), seed=seed) 174 | 175 | loss = F.mse_loss(y, x) 176 | 177 | loss.backward() 178 | optimizer.step() 179 | 180 | else: 181 | 182 | optimizer.zero_grad() 183 | 184 | y, dists, actions = model(x) 185 | 186 | mloss = F.mse_loss(y, x, reduce=False).mean(dim=1) 187 | rloss = - dists.log_prob(actions) * - mloss[:, None, None].expand_as(actions) 188 | 189 | loss = rloss.mean() 190 | 191 | loss.backward() 192 | optimizer.step() 193 | 194 | w.add_scalar('identity/loss/', loss.item(), i*arg.batch) 195 | 196 | if i % arg.dot_every == 0: 197 | model.train(False) 198 | 199 | with torch.no_grad(): 200 | losses = [] 201 | for fr in range(0, 10000, arg.batch): 202 | to = min(fr + arg.batch, 10000) 203 | 204 | x = torch.randn(to - fr, arg.size) 205 | 206 | if arg.cuda: 207 | x = x.cuda() 208 | x = Variable(x) 209 | 210 | if not arg.reinforce: 211 | y = model(x) 212 | else: 213 | y, _, _ = model(x) 214 | 215 | losses.append(F.mse_loss(y, x).item()) 216 | 217 | results[r, i//arg.dot_every] = sum(losses)/len(losses) 218 | 219 | if arg.plot_every > 0 and i % arg.plot_every == 0: 220 | plt.figure(figsize=(7, 7)) 221 | 222 | means, sigmas, values = model.hyper(x) 223 | 224 | plt.cla() 225 | util.plot(means, sigmas, values, shape=(SHAPE[0], SHAPE[0])) 226 | plt.xlim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN))) 227 | plt.ylim((-MARGIN*(SHAPE[0]-1), (SHAPE[0]-1) * (1.0+MARGIN))) 228 | 229 | plt.savefig('./identity/{}/means{:06}.pdf'.format(r, i)) 230 | 231 | plt.figure(figsize=(10, 4)) 232 | 233 | # for rep in range(results.shape[0]): 234 | # plt.plot(np.arange(ndots) * arg.dot_every, results[rep]) 235 | plt.errorbar( 236 | x=np.arange(ndots) * arg.dot_every, y=np.mean(results, axis=0), yerr=np.std(results, axis=0)) 237 | 238 | ax = plt.gca() 239 | ax.set_ylim(bottom=0) 240 | ax.set_xlabel('iterations') 241 | ax.set_ylabel('mean-squared error') 242 | 243 | util.basic() 244 | 245 | plt.savefig('./identity/results.png') 246 | plt.savefig('./identity/results.pdf') 247 | 248 | np.save('results.{:03d}.{}'.format(arg.size, arg.reinforce), results) 249 | 250 | print('experiments finished') 251 | 252 | if __name__ == "__main__": 253 | 254 | ## Parse the command line options 255 | parser = ArgumentParser() 256 | 257 | parser.add_argument("-i", "--iterations", 258 | dest="iterations", 259 | help="Size (nr of dimensions) of the input.", 260 | default=10000, type=int) 261 | 262 | parser.add_argument("-s", "--size", 263 | dest="size", 264 | help="Size (nr of dimensions) of the input.", 265 | default=16, type=int) 266 | 267 | parser.add_argument("-b", "--batch-size", 268 | dest="batch", 269 | help="The batch size.", 270 | default=64, type=int) 271 | 272 | parser.add_argument("-a", "--gadditional", 273 | dest="additional", 274 | help="Number of global additional points sampled ", 275 | default=4, type=int) 276 | 277 | parser.add_argument("-R", "--rrange", 278 | dest="rr", 279 | help="Size of the sampling region around the index tuple.", 280 | default=4, type=int) 281 | 282 | parser.add_argument("-A", "--radditional", 283 | dest="ca", 284 | help="Number of points to sample from the sampling region.", 285 | default=4, type=int) 286 | 287 | parser.add_argument("-C", "--sub-batch", 288 | dest="subbatch", 289 | help="Size for updating in multiple forward/backward passes.", 290 | default=None, type=int) 291 | 292 | parser.add_argument("-c", "--cuda", dest="cuda", 293 | help="Whether to use cuda.", 294 | action="store_true") 295 | 296 | parser.add_argument("-F", "--fix_values", dest="fix_values", 297 | help="Whether to fix the values to 1.", 298 | action="store_true") 299 | 300 | parser.add_argument("-l", "--learn-rate", 301 | dest="lr", 302 | help="Learning rate", 303 | default=0.005, type=float) 304 | 305 | parser.add_argument("-S", "--sigma-scale", 306 | dest="sigma_scale", 307 | help="Sigma scale", 308 | default=0.1, type=float) 309 | 310 | parser.add_argument("-M", "--min_sigma", 311 | dest="min_sigma", 312 | help="Minimum variance for the components.", 313 | default=0.0, type=float) 314 | 315 | parser.add_argument("-p", "--plot-every", 316 | dest="plot_every", 317 | help="Plot every x iterations", 318 | default=1000, type=int) 319 | 320 | parser.add_argument("-d", "--dot-every", 321 | dest="dot_every", 322 | help="A dot in the graph for every x iterations", 323 | default=1000, type=int) 324 | 325 | parser.add_argument("--repeats", 326 | dest="reps", 327 | help="Number of repeats.", 328 | default=1, type=int) 329 | 330 | parser.add_argument("--seed", 331 | dest="seed", 332 | help="Random seed.", 333 | default=32, type=int) 334 | 335 | parser.add_argument("-B", "--use-reinforce", dest="reinforce", 336 | help="Use the reinforce baseline instead of the backprop approach.", 337 | action="store_true") 338 | 339 | options = parser.parse_args() 340 | 341 | print('OPTIONS ', options) 342 | LOG.info('OPTIONS ' + str(options)) 343 | 344 | go(options) 345 | -------------------------------------------------------------------------------- /experiments/memory.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | 3 | from sparse import util 4 | from util import d 5 | 6 | import torch 7 | 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torchvision 11 | from torch import nn 12 | from torch.autograd import Variable 13 | from tqdm import trange 14 | 15 | import matplotlib as mpl 16 | mpl.use('Agg') 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | import torchvision 21 | import torchvision.transforms as transforms 22 | from torchvision.transforms import ToTensor 23 | from torch.utils.data import TensorDataset, DataLoader 24 | 25 | from argparse import ArgumentParser 26 | from torch.utils.tensorboard import SummaryWriter 27 | 28 | import random, tqdm, sys, math, os 29 | 30 | """ 31 | Memory layer experiment. Autoencoder with a small set of (learned) codes, arranged in 32 | an nD grid. 33 | 34 | The encoder picks a single code in a sparse manner. 35 | 36 | """ 37 | 38 | def clean(axes=None): 39 | 40 | if axes is None: 41 | axes = plt.gca() 42 | 43 | [s.set_visible(False) for s in axes.spines.values()] 44 | axes.tick_params(top=False, bottom=False, left=False, right=False, labelbottom=False, labelleft=False) 45 | 46 | class Model(nn.Module): 47 | 48 | def __init__(self, data_size, latent_size=(5, 5, 128), depth=3, gadditional=2, radditional=4, region=0.2, 49 | method='clamp', sigma_scale=1.0, min_sigma=0.01): 50 | super().__init__() 51 | 52 | self.method, self.gadditional, self.radditional = method, gadditional, radditional 53 | self.sigma_scale, self.min_sigma = sigma_scale, min_sigma 54 | 55 | # latent space 56 | self.latent = nn.Parameter(torch.randn(size=latent_size)) 57 | self.region = [int(r*region) for r in latent_size[:-1]] 58 | 59 | ln = len(latent_size) 60 | emb_size = latent_size[-1] 61 | 62 | c, h, w = data_size 63 | 64 | cs = [c] + [2**(d+4) for d in range(depth)] 65 | 66 | div = 2 ** depth 67 | 68 | modules = [] 69 | 70 | for d in range(depth): 71 | modules += [ 72 | nn.Conv2d(cs[d], cs[d+1], 3, padding=1), nn.ReLU(), 73 | nn.Conv2d(cs[d+1], cs[d+1], 3, padding=1), nn.ReLU(), 74 | nn.MaxPool2d((2, 2)) 75 | ] 76 | 77 | modules += [ 78 | util.Flatten(), 79 | nn.Linear(cs[-1] * (h//div) * (w//div), 1024), nn.ReLU(), 80 | nn.Linear(1024, len(latent_size)) # encoder produces a cont. index tuple (ln -1 for the means, 1 for the sigma) 81 | ] 82 | 83 | self.encoder = nn.Sequential(*modules) 84 | 85 | upmode = 'bilinear' 86 | cl = lambda x : int(math.ceil(x)) 87 | 88 | 89 | 90 | modules = [ 91 | nn.Linear(emb_size, cs[-1] * cl(h/div) * cl(w/div)), nn.ReLU(), 92 | util.Reshape( (cs[-1], cl(h/div), cl(w/div)) ) 93 | ] 94 | 95 | for d in range(depth, 0, -1): 96 | modules += [ 97 | nn.Upsample(scale_factor=2, mode=upmode), 98 | nn.ConvTranspose2d(cs[d], cs[d], 3, padding=1), nn.ReLU(), 99 | nn.ConvTranspose2d(cs[d], cs[d-1], 3, padding=1), nn.ReLU() 100 | ] 101 | 102 | modules += [ 103 | nn.ConvTranspose2d(c, c, (3, 3), padding=1), nn.Sigmoid(), 104 | util.Lambda(lambda x : x[:, :, :h, :w]) # crop out any extra pixels due to rounding errors 105 | ] 106 | self.decoder = nn.Sequential(*modules) 107 | 108 | self.smp = True 109 | 110 | def sample(self, smp): 111 | self.smp = smp 112 | 113 | def forward(self, x): 114 | 115 | b, c, h, w = x.size() 116 | 117 | params = self.encoder(x) 118 | ls = self.latent.size() 119 | s, e = ls[:-1], ls[-1] 120 | 121 | assert params.size() == (b, len(ls)) 122 | 123 | means = sparse.transform_means(params[:, None, None, :-1], s, method=self.method) 124 | sigmas = sparse.transform_sigmas(params[:, None, None, -1], s, min_sigma=self.min_sigma) * self.sigma_scale 125 | 126 | if self.smp: 127 | 128 | indices = sparse.ngenerate(means, self.gadditional, self.radditional, rng=s, relative_range=self.region, cuda=x.is_cuda) 129 | vs = (2**len(s) + self.radditional + self.gadditional) 130 | 131 | assert indices.size() == (b, 1, vs, len(s)), f'{indices.size()}, {(b, 1, vs, len(s))}' 132 | indfl = indices.float() 133 | 134 | # Mask for duplicate indices 135 | dups = util.nduplicates(indices).to(torch.bool) 136 | 137 | # compute (unnormalized) densities under the given MVNs (proportions) 138 | props = sparse.densities(indfl, means, sigmas).clone() 139 | assert props.size() == (b, 1, vs, 1) #? 140 | 141 | props[dups, :] = 0 142 | props = props / props.sum(dim=2, keepdim=True) # normalize over all points of a given index tuple 143 | 144 | weights = props.sum(dim=-1) # - sum out the MVNs 145 | 146 | assert indices.size() == (b, 1, vs, len(s)) 147 | assert weights.size() == (b, 1, vs) 148 | 149 | indices, weights = indices.squeeze(1), weights.squeeze(1) 150 | 151 | else: 152 | vs = 1 153 | indices = means.floor().to(torch.long).detach().squeeze(1) 154 | 155 | # Select a single code from the latent space (per instance in batch). 156 | # When sampling, this is a weighted sum, when not sampling, just one. 157 | indices = indices.view(b*vs, len(s)) 158 | 159 | # checks to prevent segfaults 160 | if util.contains_nan(indices): 161 | 162 | print(params) 163 | raise Exception('Indices contain NaN') 164 | 165 | if indices[:, 0].max() >= s[0] or indices[:, 1].max() >= s[1]: 166 | 167 | print(indices.max()) 168 | print(params) 169 | raise Exception('Indices out of bounds') 170 | 171 | if len(s) == 1: 172 | code = self.latent[indices[:, 0], :] 173 | elif len(s) == 2: 174 | code = self.latent[indices[:, 0], indices[:, 1], :] 175 | elif len(s) == 3: 176 | code = self.latent[indices[:, 0], indices[:, 1], indices[:, 2], :] 177 | else: 178 | raise Exception(f'Dimensionality above 3 not supported.') 179 | # - ugly hack, until I figure out how to do this for n dimensions 180 | 181 | assert code.size() == (b*vs, e), f'{code.size()} --- {(b*vs, e)}' 182 | 183 | if self.smp: 184 | code = code.view(b, vs, e) 185 | code = code * weights[:, :, None] 186 | code = code.sum(dim=1) 187 | else: 188 | code = code.view(b, e) 189 | 190 | assert code.size() == (b, e) 191 | 192 | # Decode 193 | result = self.decoder(code) 194 | 195 | assert result.size() == (b, c, h, w), f'{result.size()} --- {(b, c, h, w)}' 196 | 197 | return result 198 | 199 | def go(arg): 200 | 201 | util.makedirs('./memory/') 202 | 203 | if arg.seed < 0: 204 | seed = random.randint(0, 1000000) 205 | print('random seed: ', seed) 206 | else: 207 | torch.manual_seed(arg.seed) 208 | 209 | tbw = SummaryWriter(log_dir=arg.tb_dir) 210 | tfms = transforms.Compose([transforms.ToTensor()]) 211 | 212 | if (arg.task == 'mnist'): 213 | 214 | shape = (1, 28, 28) 215 | num_classes = 10 216 | 217 | data = arg.data + os.sep + arg.task 218 | 219 | if arg.final: 220 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=tfms) 221 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch_size, shuffle=True, num_workers=0) 222 | 223 | test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=ToTensor()) 224 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch_size, shuffle=False, num_workers=0) 225 | 226 | else: 227 | NUM_TRAIN = 45000 228 | NUM_VAL = 5000 229 | total = NUM_TRAIN + NUM_VAL 230 | 231 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=tfms) 232 | 233 | trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 234 | testloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 235 | 236 | elif (arg.task == 'cifar10'): 237 | 238 | shape = (3, 32, 32) 239 | num_classes = 10 240 | 241 | data = arg.data + os.sep + arg.task 242 | 243 | if arg.final: 244 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=tfms) 245 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch, shuffle=True, num_workers=2) 246 | test = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=ToTensor()) 247 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch, shuffle=False, num_workers=2) 248 | 249 | else: 250 | NUM_TRAIN = 45000 251 | NUM_VAL = 5000 252 | total = NUM_TRAIN + NUM_VAL 253 | 254 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=tfms) 255 | 256 | trainloader = DataLoader(train, batch_size=arg.batch, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 257 | testloader = DataLoader(train, batch_size=arg.batch, 258 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 259 | 260 | elif arg.task == 'ffhq': 261 | 262 | transform = ToTensor() 263 | shape = (3, 128, 128) 264 | 265 | trainset = torchvision.datasets.ImageFolder(root=arg.data+os.sep+'train', 266 | transform=transform) 267 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=arg.batch, 268 | shuffle=True, num_workers=2) 269 | 270 | testset = torchvision.datasets.ImageFolder(root=arg.data+os.sep+'valid', 271 | transform=transform) 272 | testloader = torch.utils.data.DataLoader(testset, batch_size=arg.batch, 273 | shuffle=False, num_workers=2) 274 | 275 | else: 276 | raise Exception('Task {} not recognized'.format(arg.task)) 277 | 278 | model = Model( 279 | data_size=shape, latent_size=arg.lgrid, gadditional=arg.gadditional, 280 | radditional=arg.radditional, region=arg.region, method=arg.edges, 281 | sigma_scale=arg.sigma_scale, min_sigma=arg.min_sigma) 282 | 283 | if arg.cuda: 284 | model.cuda() 285 | 286 | opt = torch.optim.Adam(params=model.parameters(), lr=arg.lr) 287 | 288 | seen = 0 289 | for e in range(arg.epochs): 290 | print('epoch', e) 291 | 292 | model.train(True) 293 | 294 | for i, (inputs, _) in enumerate(tqdm.tqdm(trainloader, 0)): 295 | 296 | if arg.limit is not None and i > arg.limit: 297 | break 298 | 299 | b, c, h, w = inputs.size() 300 | seen += b 301 | 302 | model.sample(random.random() < arg.sample_prob) # use sampling only on some proportion of batches 303 | 304 | if arg.cuda: 305 | inputs = inputs.cuda() 306 | 307 | inputs = Variable(inputs) 308 | 309 | opt.zero_grad() 310 | 311 | outputs = model(inputs) 312 | 313 | loss = F.binary_cross_entropy(outputs, inputs.detach()) 314 | 315 | loss.backward() 316 | 317 | opt.step() 318 | 319 | tbw.add_scalar('memory/loss', loss.item()/b, seen) 320 | 321 | if e % arg.plot_every == 0 and len(arg.lgrid) == 3: 322 | with torch.no_grad(): 323 | 324 | codes = model.latent.data.view(-1, arg.lgrid[-1]) 325 | images = model.decoder(codes) 326 | 327 | h, w = arg.lgrid[:2] 328 | 329 | plt.figure(figsize=(w, h)) 330 | 331 | s = 1 332 | for i in range(h): 333 | for j in range(w): 334 | 335 | ax = plt.subplot(h, w, s) 336 | ax.imshow(images[s-1].permute(1, 2, 0).squeeze().cpu(), cmap='Greys_r') 337 | 338 | clean(ax) 339 | 340 | s += 1 341 | 342 | plt.savefig(f'memory/latent.{e:03}.pdf') 343 | 344 | 345 | if __name__ == "__main__": 346 | 347 | parser = ArgumentParser() 348 | 349 | parser.add_argument("-e", "--epochs", 350 | dest="epochs", 351 | help="Number of epochs", 352 | default=250, type=int) 353 | 354 | parser.add_argument("-b", "--batch", 355 | dest="batch", 356 | help="Batch size", 357 | default=64, type=int) 358 | 359 | parser.add_argument("-d", "--depth", 360 | dest="depth", 361 | help="Depth", 362 | default=3, type=int) 363 | 364 | parser.add_argument("--task", 365 | dest="task", 366 | help="Dataset to model (mnist, cifar10)", 367 | default='mnist', type=str) 368 | 369 | parser.add_argument("--latent-grid", 370 | dest="lgrid", 371 | help="Dimensionality of the latent codes. The last dimension represents the latent vector dimension.", 372 | nargs='+', 373 | default=[25, 25, 64], type=int) 374 | 375 | parser.add_argument("-a", "--gadditional", 376 | dest="gadditional", 377 | help="Number of additional points sampled globally per index-tuple", 378 | default=2, type=int) 379 | 380 | parser.add_argument("-A", "--radditional", 381 | dest="radditional", 382 | help="Number of additional points sampled locally per index-tuple", 383 | default=4, type=int) 384 | 385 | parser.add_argument("-R", "--range", 386 | dest="region", 387 | help="Range in which the local points are sampled (as a proportion of the whole space)", 388 | default=0.2, type=float) 389 | 390 | parser.add_argument("-p", "--plot-every", 391 | dest="plot_every", 392 | help="Numer of epochs to wait between plotting", 393 | default=1, type=int) 394 | 395 | parser.add_argument("-l", "--learn-rate", 396 | dest="lr", 397 | help="Learning rate", 398 | default=0.0001, type=float) 399 | 400 | parser.add_argument("--limit", 401 | dest="limit", 402 | help="Limit.", 403 | default=None, type=int) 404 | 405 | parser.add_argument("-r", "--seed", 406 | dest="seed", 407 | help="Random seed", 408 | default=0, type=int) 409 | 410 | parser.add_argument("-c", "--cuda", dest="cuda", 411 | help="Whether to use cuda.", 412 | action="store_true") 413 | 414 | parser.add_argument("-D", "--data", dest="data", 415 | help="Data directory", 416 | default='./data') 417 | 418 | parser.add_argument("--sample-prob", 419 | dest="sample_prob", 420 | help="Sample probability (with this probability we sample index tuples).", 421 | default=0.5, type=float) 422 | 423 | parser.add_argument("--edges", dest="edges", 424 | help="Which operator to use to fit continuous index tuples to the required range.", 425 | default='clamp', type=str) 426 | 427 | parser.add_argument("-f", "--final", dest="final", 428 | help="Whether to run on the real test set (if not included, the validation set is used).", 429 | action="store_true") 430 | 431 | parser.add_argument("-T", "--tb_dir", dest="tb_dir", 432 | help="Data directory", 433 | default=None) 434 | 435 | parser.add_argument("-M", "--min-sigma", 436 | dest="min_sigma", 437 | help="Minimum value of sigma.", 438 | default=0.01, type=float) 439 | 440 | parser.add_argument("-S", "--sigma-scale", 441 | dest="sigma_scale", 442 | help="Scalar applied to sigmas.", 443 | default=0.5, type=float) 444 | 445 | args = parser.parse_args() 446 | 447 | print('OPTIONS', args) 448 | 449 | go(args) 450 | -------------------------------------------------------------------------------- /experiments/minmal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | import math 5 | 6 | class Flatten(nn.Module): 7 | 8 | def forward(self, input): 9 | return input.view(input.size(0), -1) 10 | 11 | class ResBlock(nn.Module): 12 | """ 13 | Inverted residual block (as in mobilenetv2) 14 | """ 15 | def __init__(self, c, wide=None, kernel=3, grouped=True): 16 | super().__init__() 17 | 18 | wide = 6 * c if wide is None else wide 19 | padding = int(math.floor(kernel/2)) 20 | 21 | self.convs = nn.Sequential( 22 | nn.Conv2d(c, wide, kernel_size=1), 23 | nn.Conv2d(wide, wide, kernel_size=kernel, padding=padding, groups=wide if grouped else 1), 24 | nn.Conv2d(wide, c, kernel_size=1), 25 | nn.BatchNorm2d(c), nn.ReLU() 26 | ) 27 | 28 | def forward(self, x): 29 | 30 | return self.convs(x) + x # wo the skip, the sefgfault happens immediately 31 | 32 | 33 | model = nn.Sequential( 34 | nn.Conv2d(3, 32, kernel_size=3, padding=1), 35 | nn.MaxPool2d(kernel_size=2), 36 | ResBlock(32, grouped=True), nn.Conv2d(32, 16, kernel_size=1), 37 | ResBlock(16, grouped=False), nn.Conv2d(16, 16, kernel_size=1), 38 | nn.MaxPool2d(kernel_size=16), 39 | Flatten(), 40 | nn.Linear(16, 10), 41 | nn.Softmax(dim=-1) 42 | ) 43 | 44 | 45 | opt = torch.optim.SGD(lr=0.000001, params=model.parameters()) # SGD and Adam both segfault 46 | 47 | torch.manual_seed(0) 48 | 49 | for i in range(1000): 50 | 51 | print(i) 52 | 53 | opt.zero_grad() 54 | 55 | x = Variable(torch.randn(64, 3, 32, 32)) 56 | x = model(x) 57 | loss = x.sum() 58 | 59 | loss.backward() # segfault here 60 | 61 | opt.step() 62 | -------------------------------------------------------------------------------- /experiments/sparsity-mlp.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | 3 | import torch, torchvision 4 | import numpy as np 5 | 6 | from torchvision.transforms import ToTensor 7 | from torch.utils.data import DataLoader 8 | 9 | from torch.autograd import Variable 10 | 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | from argparse import ArgumentParser 15 | 16 | import os, sys, math 17 | 18 | from sparse import util, NASLayer, Convolution 19 | 20 | from tqdm import trange, tqdm 21 | 22 | from tensorboardX import SummaryWriter 23 | 24 | """ 25 | This experiment trains a simple, fully connected three-layer MLP, following the baseline from Louizos 2018. 26 | We aim to show that in the very low density regime, the sparse layer is a competitive approach. 27 | 28 | The tasks are simple classification on mnist, cifar10 and cifar100. 29 | 30 | TODO: Test temp version. 31 | 32 | """ 33 | 34 | BATCH_SIZES = [128] 35 | LEARNING_RATES = [0.00005, 0.0001, 0.0005, 0.001, 0.005] 36 | 37 | def getrng(p, size): 38 | return [max(1, int(math.floor(p * s))) for s in size] 39 | 40 | def getmodel(arg, insize, numcls): 41 | 42 | h1, h2 = arg.hidden 43 | 44 | # if arg.method == 'l1' or arg.method == 'lp': 45 | # 46 | # one = nn.Linear(util.prod(insize), h1) 47 | # two = nn.Linear(h1, h2) 48 | # three = nn.Linear(h2, numcls) 49 | # 50 | # model = nn.Sequential( 51 | # util.Flatten(), 52 | # one, nn.Sigmoid(), 53 | # two, nn.Sigmoid(), 54 | # three, nn.Softmax() 55 | # ) 56 | 57 | if arg.method == 'nas': 58 | """ 59 | Non-templated NAS model 60 | """ 61 | 62 | rng = getrng(arg.range[0], (h1, ) + insize) 63 | 64 | c = arg.k[0] 65 | 66 | one = NASLayer( 67 | in_size=insize, out_size=(h1,), k=h1*c, 68 | gadditional=arg.gadditional[0], radditional=arg.radditional[0], region=rng, has_bias=True, 69 | fix_values=arg.fix_values, 70 | min_sigma=arg.min_sigma, 71 | template=None, 72 | learn_cols=None, 73 | chunk_size=c 74 | ) 75 | 76 | rng = getrng(arg.range[1], (h2, h1)) 77 | c = arg.k[1] 78 | 79 | two = NASLayer( 80 | in_size=(h1,), out_size=(h2,), k=h2*c, 81 | gadditional=arg.gadditional[1], radditional=arg.radditional[1], region=rng, has_bias=True, 82 | fix_values=arg.fix_values, 83 | min_sigma=arg.min_sigma, 84 | template=None, 85 | learn_cols=None, 86 | chunk_size=c 87 | ) 88 | 89 | rng = getrng(arg.range[2], (numcls, h2)) 90 | c = arg.k[2] 91 | 92 | three = NASLayer( 93 | in_size=(h2,), out_size=(numcls,), k=numcls*c, 94 | gadditional=arg.gadditional[2], radditional=arg.radditional[2], region=rng, has_bias=True, 95 | fix_values=arg.fix_values, 96 | min_sigma=arg.min_sigma, 97 | template=None, 98 | learn_cols=None, 99 | chunk_size=c 100 | ) 101 | 102 | model = nn.Sequential( 103 | one, nn.Sigmoid(), 104 | two, nn.Sigmoid(), 105 | three, nn.Softmax(), 106 | ) 107 | 108 | elif arg.method == 'nas-temp': 109 | """ 110 | Templated NAS model. Fixed output dimensions. 111 | """ 112 | 113 | rng = getrng(arg.range[0], (insize[1], insize[2])) 114 | c = arg.k[0] 115 | 116 | template = torch.arange(h1, dtype=torch.long)[:, None].expand(h1, c).contiguous().view(h1*c, 1) 117 | template = torch.cat([template, torch.zeros(h1*c, 3, dtype=torch.long)], dim=1) 118 | 119 | one = NASLayer( 120 | in_size=insize, out_size=(h1,), k=h1*c, 121 | gadditional=arg.gadditional[0], radditional=arg.radditional[0], region=rng, has_bias=True, 122 | fix_values=arg.fix_values, 123 | min_sigma=arg.min_sigma, 124 | template=template, 125 | learn_cols=(1, 2, 3) if insize[0] > 1 else (2, 3), 126 | chunk_size=c 127 | ) 128 | 129 | rng = getrng(arg.range[1], (h1, )) 130 | c = arg.k[1] 131 | 132 | template = torch.arange(h2, dtype=torch.long)[:, None].expand(h2, c).contiguous().view(h2 * c, 1) 133 | template = torch.cat([template, torch.zeros(h2*c, 1, dtype=torch.long)], dim=1) 134 | 135 | two = NASLayer( 136 | in_size=(h1,), out_size=(h2,), k=h2*c, 137 | gadditional=arg.gadditional[1], radditional=arg.radditional[1], region=rng, has_bias=True, 138 | fix_values=arg.fix_values, 139 | min_sigma=arg.min_sigma, 140 | template=template, 141 | learn_cols=(1,), 142 | chunk_size=c 143 | ) 144 | 145 | rng = getrng(arg.range[2], (h2, )) 146 | c = arg.k[2] 147 | 148 | template = torch.arange(numcls, dtype=torch.long)[:, None].expand(numcls, c).contiguous().view(numcls * c, 1) 149 | template = torch.cat([template, torch.zeros(numcls*c, 1, dtype=torch.long)], dim=1) 150 | 151 | three = NASLayer( 152 | in_size=(h2,), out_size=(numcls,), k=numcls*c, 153 | gadditional=arg.gadditional[2], radditional=arg.radditional[2], region=rng, has_bias=True, 154 | fix_values=arg.fix_values, 155 | min_sigma=arg.min_sigma, 156 | template=template, 157 | learn_cols=(1,), 158 | chunk_size=c 159 | ) 160 | 161 | model = nn.Sequential( 162 | one, nn.Sigmoid(), 163 | two, nn.Sigmoid(), 164 | three, nn.Softmax(), 165 | ) 166 | elif arg.method == 'nas-conv': 167 | """ 168 | Convolutional NAS model. 169 | """ 170 | c1, c2 = h1, h2 171 | 172 | one = Convolution(in_size=(1, 28, 28), out_channels=c1, k=arg.k[0], kernel_size=7, 173 | gadditional=arg.gadditional[0], radditional=arg.radditional[1], rprop=arg.range[0], 174 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 175 | fix_values=arg.fix_values, has_bias=True) 176 | 177 | two = Convolution(in_size=(c1, 14, 14), out_channels=c2, k=arg.k[1], kernel_size=7, 178 | gadditional=arg.gadditional[1], radditional=arg.radditional[1], rprop=arg.range[1], 179 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 180 | fix_values=arg.fix_values, has_bias=True) 181 | 182 | three = Convolution(in_size=(c2, 7, 7), out_channels=numcls, k=arg.k[2], kernel_size=7, 183 | gadditional=arg.gadditional[2], radditional=arg.radditional[2], rprop=arg.range[2], 184 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 185 | fix_values=arg.fix_values, has_bias=True) 186 | 187 | model = nn.Sequential( 188 | one, nn.Sigmoid(), nn.MaxPool2d(2), 189 | two, nn.Sigmoid(), nn.MaxPool2d(2), 190 | three, nn.Sigmoid(), nn.MaxPool2d(2), 191 | util.Lambda(lambda x : x.mean(dim=-1).mean(dim=-1)), # global average pool 192 | nn.Softmax() 193 | ) 194 | 195 | elif arg.method == 'conv': 196 | c1, c2 = h1, h2 197 | 198 | one = nn.Conv2d(insize[0], c1, kernel_size=3, padding=1) 199 | two = nn.Conv2d(c1, c2, kernel_size=3, padding=1) 200 | three = nn.Conv2d(c2, numcls, kernel_size=3, padding=1) 201 | 202 | model = nn.Sequential( 203 | one, nn.Sigmoid(), nn.MaxPool2d(2), 204 | two, nn.Sigmoid(), nn.MaxPool2d(2), 205 | three, nn.Sigmoid(), nn.MaxPool2d(2), 206 | util.Lambda(lambda x : x.mean(dim=-1).mean(dim=-1)), # global average pool 207 | nn.Softmax() 208 | ) 209 | 210 | elif arg.method == 'one': 211 | """ 212 | Convolutional NAS model. 213 | """ 214 | c1, c2 = h1, h2 215 | 216 | one = Convolution(in_size=(1, 28, 28), out_channels=c1, k=arg.k[0], kernel_size=7, 217 | gadditional=arg.gadditional[0], radditional=arg.radditional[1], rprop=arg.range[0], 218 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 219 | fix_values=arg.fix_values, has_bias=True) 220 | 221 | two = nn.Conv2d(c1, c2, kernel_size=3, padding=1) 222 | three = nn.Conv2d(c2, numcls, kernel_size=3, padding=1) 223 | 224 | model = nn.Sequential( 225 | one, nn.Sigmoid(), nn.MaxPool2d(2), 226 | two, nn.Sigmoid(), nn.MaxPool2d(2), 227 | three, nn.Sigmoid(), nn.MaxPool2d(2), 228 | util.Lambda(lambda x : x.mean(dim=-1).mean(dim=-1)), # global average pool 229 | nn.Softmax() 230 | ) 231 | elif arg.method == 'two': 232 | """ 233 | Convolutional NAS model. 234 | """ 235 | c1, c2 = h1, h2 236 | 237 | one = Convolution(in_size=(1, 28, 28), out_channels=c1, k=arg.k[0], kernel_size=7, 238 | gadditional=arg.gadditional[0], radditional=arg.radditional[1], rprop=arg.range[0], 239 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 240 | fix_values=arg.fix_values, has_bias=True) 241 | 242 | two = Convolution(in_size=(c1, 14, 14), out_channels=c2, k=arg.k[1], kernel_size=7, 243 | gadditional=arg.gadditional[1], radditional=arg.radditional[1], rprop=arg.range[1], 244 | min_sigma=arg.min_sigma, sigma_scale=arg.sigma_scale, 245 | fix_values=arg.fix_values, has_bias=True) 246 | 247 | three = nn.Conv2d(c2, numcls, kernel_size=3, padding=1) 248 | 249 | model = nn.Sequential( 250 | one, nn.Sigmoid(), nn.MaxPool2d(2), 251 | two, nn.Sigmoid(), nn.MaxPool2d(2), 252 | three, nn.Sigmoid(), nn.MaxPool2d(2), 253 | util.Lambda(lambda x : x.mean(dim=-1).mean(dim=-1)), # global average pool 254 | nn.Softmax() 255 | ) 256 | 257 | else: 258 | raise Exception('Method {} not recognized'.format(arg.method)) 259 | 260 | if arg.cuda: 261 | model.cuda() 262 | 263 | return model, one, two, three 264 | 265 | def single(arg): 266 | 267 | tbw = SummaryWriter() 268 | # 269 | # lambd = torch.logspace(arg.rfrom, arg.rto, arg.rnum)[arg.control].item() 270 | # 271 | # print('lambda ', lambd) 272 | 273 | # Grid search over batch size/learning rate 274 | # -- Set up model 275 | 276 | insize = (1, 28, 28) if arg.task == 'mnist' else (3, 32, 32) 277 | numcls = 100 if arg.task == 'cifar100' else 10 278 | 279 | # Repeat runs with chosen hyperparameters 280 | accuracies = [] 281 | densities = [] 282 | 283 | for _ in trange(arg.repeats): 284 | 285 | if arg.task == 'mnist': 286 | if arg.final: 287 | data = arg.data + os.sep + arg.task 288 | 289 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=ToTensor()) 290 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.bs, shuffle=True, num_workers=2) 291 | 292 | test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=ToTensor()) 293 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.bs, shuffle=False, num_workers=2) 294 | else: 295 | 296 | NUM_TRAIN = 45000 297 | NUM_VAL = 5000 298 | total = NUM_TRAIN + NUM_VAL 299 | 300 | train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=ToTensor()) 301 | 302 | trainloader = DataLoader(train, batch_size=arg.bs, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 303 | testloader = DataLoader(train, batch_size=arg.bs, 304 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 305 | 306 | elif (arg.task == 'cifar10'): 307 | 308 | data = arg.data + os.sep + arg.task 309 | 310 | if arg.final: 311 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=ToTensor()) 312 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.bs, shuffle=True, num_workers=2) 313 | test = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=ToTensor()) 314 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.bs, shuffle=False, num_workers=2) 315 | 316 | else: 317 | NUM_TRAIN = 45000 318 | NUM_VAL = 5000 319 | total = NUM_TRAIN + NUM_VAL 320 | 321 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=ToTensor()) 322 | 323 | trainloader = DataLoader(train, batch_size=arg.bs, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 324 | testloader = DataLoader(train, batch_size=arg.bs, 325 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 326 | 327 | else: 328 | raise Exception('Task {} not recognized'.format(arg.task)) 329 | 330 | model, one, two, three = getmodel(arg, insize, numcls) # new model 331 | opt = torch.optim.Adam(model.parameters(), lr=arg.lr) 332 | 333 | # Train for fixed number of epochs 334 | i = 0 335 | for e in range(arg.epochs): 336 | 337 | model.train(True) 338 | for input, labels in tqdm(trainloader): 339 | opt.zero_grad() 340 | 341 | if arg.cuda: 342 | input, labels = input.cuda(), labels.cuda() 343 | input, labels = Variable(input), Variable(labels) 344 | 345 | output = model(input) 346 | 347 | loss = F.cross_entropy(output, labels) 348 | loss.backward() 349 | 350 | tbw.add_scalar('sparsity/loss', loss.data.item(), i * arg.bs) 351 | i += 1 352 | 353 | opt.step() 354 | 355 | # Compute accuracy on test set 356 | with torch.no_grad(): 357 | model.train(False) 358 | 359 | total, correct = 0.0, 0.0 360 | for input, labels in testloader: 361 | opt.zero_grad() 362 | 363 | if arg.cuda: 364 | input, labels = input.cuda(), labels.cuda() 365 | input, labels = Variable(input), Variable(labels) 366 | 367 | output = model(input) 368 | 369 | outcls = output.argmax(dim=1) 370 | 371 | total += outcls.size(0) 372 | correct += (outcls == labels).sum().item() 373 | 374 | acc = correct / float(total) 375 | 376 | print('\nepoch {}: {}\n'.format(e, acc)) 377 | tbw.add_scalar('sparsity/test acc', acc, e) 378 | 379 | # # Compute density 380 | # total = util.prod(insize) * arg.hidden 381 | # 382 | # kt = arg.control[] * (arg.hidden[0] + + numcls) 383 | # 384 | # if arg.method == 'l1' or arg.method == 'lp': 385 | # density = (one.weight > 0.0001).sum().item() / float(total) 386 | # elif arg.method == 'nas' or arg.method == 'nas-temp': 387 | # density = kt / total 388 | # else: 389 | # raise Exception('Method {} not recognized'.format(arg.method)) 390 | # 391 | # accuracies.append(acc) 392 | # densities.append(density) 393 | # 394 | # print('accuracies: ', accuracies) 395 | # print('densities: ', densities) 396 | # 397 | # if arg.method == 'lp': 398 | # if arg.p == 0.2: 399 | # name = 'l5' 400 | # elif arg.p == 0.5: 401 | # name = 'l2' 402 | # elif arg.p == 1.0: 403 | # name = 'l1' 404 | # else: 405 | # name = 'l' + arg.p 406 | # else: 407 | # name = arg.method 408 | # 409 | # # Save to CSV 410 | # np.savetxt( 411 | # 'results.{}.{}.csv'.format(name, arg.control), 412 | # torch.cat([ 413 | # torch.tensor(accuracies, dtype=torch.float)[:, None], 414 | # torch.tensor(densities, dtype=torch.float)[:, None] 415 | # ], dim=1).numpy(), 416 | # ) 417 | 418 | print('Finished') 419 | 420 | if __name__ == "__main__": 421 | 422 | parser = ArgumentParser() 423 | 424 | parser.add_argument("-H", "--hidden", 425 | dest="hidden", 426 | nargs=2, 427 | help="Sizes of the two hidden layers", 428 | default=[300, 100], 429 | type=int) 430 | 431 | parser.add_argument("-k", "--points-per-out", 432 | dest="k", 433 | nargs=3, 434 | help="Number of sparse points for each output node.", 435 | default=[1, 1, 1], type=int) 436 | 437 | parser.add_argument("-l", "--lr", 438 | dest="lr", 439 | help="Learning rate (ignored in sweep)", 440 | default=0.001, type=float) 441 | 442 | parser.add_argument("-b", "--batch ", 443 | dest="bs", 444 | help="Batch size (ignored in sweep)", 445 | default=64, type=int) 446 | 447 | parser.add_argument("-e", "--epochs", 448 | dest="epochs", 449 | help="Number of epochs", 450 | default=50, type=int) 451 | 452 | parser.add_argument("-m", "--method", 453 | dest="method", 454 | help="Method to use (lp, nas) ", 455 | default='nas-temp', type=str) 456 | 457 | parser.add_argument("-P", "--lp-p", 458 | dest="p", 459 | help="Exponent in lp reg", 460 | default=2.0, type=float) 461 | 462 | parser.add_argument("-t", "--task", 463 | dest="task", 464 | help="Task to use (mnist, cifar10, cifar100) ", 465 | default='mnist', type=str) 466 | 467 | parser.add_argument("-a", "--gadditional", 468 | dest="gadditional", 469 | nargs=3, 470 | help="Number of additional points sampled globally per index-tuple (NAS)", 471 | default=[32, 6, 2], type=int) 472 | 473 | parser.add_argument("-A", "--radditional", 474 | dest="radditional", 475 | nargs=3, 476 | help="Number of additional points sampled locally per index-tuple (NAS)", 477 | default=[32, 6, 2], type=int) 478 | 479 | parser.add_argument("-R", "--range", 480 | dest="range", 481 | nargs=3, 482 | help="Range in which the local points are sampled (NAS)", 483 | default=[0.3, 0.2, 0.2], type=float) 484 | 485 | parser.add_argument("-r", "--repeats", 486 | dest="repeats", 487 | help="Number of times to repeat the final experiment (once the hyperparameters are chosen).", 488 | default=10, type=int) 489 | 490 | parser.add_argument("--seed", 491 | dest="seed", 492 | help="Random seed", 493 | default=4, type=int) 494 | 495 | parser.add_argument("-c", "--cuda", dest="cuda", 496 | help="Whether to use cuda.", 497 | action="store_true") 498 | 499 | parser.add_argument("-D", "--data", dest="data", 500 | help="Data directory", 501 | default='./data') 502 | 503 | parser.add_argument("-M", "--min-sigma", 504 | dest="min_sigma", 505 | help="Minimal sigma value", 506 | default=0.01, type=float) 507 | 508 | parser.add_argument("-S", "--sigma-scale", 509 | dest="sigma_scale", 510 | help="Sigma scale", 511 | default=0.1, type=float) 512 | 513 | parser.add_argument("--rfrom", 514 | dest="rfrom", 515 | help="Minimal control value (for lp baselines)", 516 | default=0.00001, type=float) 517 | 518 | parser.add_argument("--rto", 519 | dest="rto", 520 | help="Maximal control value (for lp baselines)", 521 | default=1.0, type=float) 522 | 523 | parser.add_argument("--rnum", 524 | dest="rnum", 525 | help="Number of control parameters (for lp baseline)", 526 | default=10, type=int) 527 | 528 | parser.add_argument("-f", "--final", dest="final", 529 | help="Whether to run on the real test set.", 530 | action="store_true") 531 | 532 | parser.add_argument("-F", "--fix-values", dest="fix_values", 533 | help="Whether to fix all values to 1 in the NAS model.", 534 | action="store_true") 535 | 536 | args = parser.parse_args() 537 | 538 | print('OPTIONS', args) 539 | 540 | single(args) 541 | -------------------------------------------------------------------------------- /experiments/sparsity.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | 3 | import torch, torchvision 4 | import numpy as np 5 | 6 | from torchvision.transforms import ToTensor 7 | from torch.utils.data import DataLoader 8 | 9 | from torch.autograd import Variable 10 | 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | from argparse import ArgumentParser 15 | 16 | import os, sys 17 | 18 | from sparse import util, NASLayer 19 | 20 | from tqdm import trange, tqdm 21 | 22 | from tensorboardX import SummaryWriter 23 | 24 | """ 25 | This experiment trains a simple, fully connected two-layer MLP, using different methods of inducing sparsity, and 26 | measures the density of the resulting weight matrices (the number of non-zero weights, divided by the total). 27 | 28 | We aim to show that in the very low density regime, the sparse layer is a competitive approach. 29 | 30 | The tasks are simple classification on mnist, cifar10 and cifar100. 31 | 32 | TODO: Test temp version. 33 | 34 | """ 35 | 36 | BATCH_SIZES = [128] 37 | LEARNING_RATES = [0.00005, 0.0001, 0.0005, 0.001, 0.005] 38 | 39 | def getmodel(arg, insize, numcls, points): 40 | 41 | if arg.method == 'l1' or arg.method == 'lp': 42 | 43 | one = nn.Linear(util.prod(insize), arg.hidden) 44 | two = nn.Linear(arg.hidden, numcls) 45 | 46 | model = nn.Sequential( 47 | util.Flatten(), 48 | one, nn.Sigmoid(), 49 | two, nn.Softmax() 50 | ) 51 | 52 | elif arg.method == 'nas': 53 | 54 | rng = ( 55 | min(arg.hidden, arg.range), 1, 56 | arg.range, arg.range) 57 | 58 | one = NASLayer( 59 | in_size=insize, out_size=(arg.hidden,), k=points, 60 | fix_values=arg.fix_values, 61 | gadditional=arg.gadditional, radditional=arg.radditional, region=rng, has_bias=True, 62 | min_sigma=arg.min_sigma 63 | ) 64 | 65 | two = nn.Linear(arg.hidden, numcls) 66 | 67 | model = nn.Sequential( 68 | one, nn.Sigmoid(), 69 | two, nn.Softmax() 70 | ) 71 | elif arg.method == 'nas-temp': 72 | """ 73 | Templated NAS model. Fixed in one dimension 74 | """ 75 | 76 | rng = (arg.range, arg.range) 77 | 78 | h, c = arg.hidden, arg.control+1 79 | 80 | template = torch.arange(h, dtype=torch.long)[:, None].expand(h, c).contiguous().view(h*c, 1) 81 | template = torch.cat([template, torch.zeros(h*c, 3, dtype=torch.long)], dim=1) 82 | 83 | one = NASLayer( 84 | in_size=insize, out_size=(arg.hidden,), k=points, 85 | gadditional=arg.gadditional, radditional=arg.radditional, region=rng, has_bias=True, 86 | fix_values=arg.fix_values, 87 | min_sigma=arg.min_sigma, 88 | template=template, 89 | learn_cols=(2, 3), 90 | chunk_size=c 91 | ) 92 | 93 | two = nn.Linear(arg.hidden, numcls) 94 | 95 | model = nn.Sequential( 96 | one, nn.Sigmoid(), 97 | two, nn.Softmax() 98 | ) 99 | 100 | else: 101 | raise Exception('Method {} not recognized'.format(arg.method)) 102 | 103 | if arg.cuda: 104 | model.cuda() 105 | 106 | return model, one, two 107 | 108 | def sweep(arg): 109 | 110 | lambd = torch.logspace(arg.rfrom, arg.rto, arg.rnum)[arg.control].item() 111 | points = arg.hidden * (arg.control + 1) # NAS control variable 112 | 113 | print('lambda ', lambd) 114 | 115 | # Grid search over batch size/learning rate 116 | # -- Set up model 117 | 118 | insize = (1, 28, 28) if arg.task == 'mnist' else (3, 32, 32) 119 | numcls = 100 if arg.task == 'cifar100' else 10 120 | 121 | ## Perform a grid search over batch size and learning rate 122 | 123 | print('Starting hyperparameter selection') 124 | bestacc = -1.0 125 | bestbs, bstlr = -1, -1.0 126 | 127 | for batch_size in BATCH_SIZES: 128 | 129 | # Load data with validation set 130 | if arg.task == 'mnist': 131 | 132 | NUM_TRAIN = 45000 133 | NUM_VAL = 5000 134 | total = NUM_TRAIN + NUM_VAL 135 | 136 | train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=ToTensor()) 137 | 138 | trainloader = DataLoader(train, batch_size=batch_size, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 139 | testloader = DataLoader(train, batch_size=batch_size, sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 140 | else: 141 | raise Exception('Task {} not recognized'.format(arg.task)) 142 | 143 | for lr in LEARNING_RATES: 144 | print('lr {}, bs {}'.format(lr, batch_size)) 145 | 146 | model, one, two = getmodel(arg, insize, numcls, points) 147 | opt = torch.optim.Adam(model.parameters(), lr=lr) 148 | 149 | # Train for fixed number of epochs 150 | for e in trange(arg.epochs): 151 | for input, labels in trainloader: 152 | opt.zero_grad() 153 | 154 | if arg.cuda: 155 | input, labels = input.cuda(), labels.cuda() 156 | input, labels = Variable(input), Variable(labels) 157 | 158 | output = model(input) 159 | 160 | loss = F.cross_entropy(output, labels) 161 | 162 | if arg.method == 'l1': 163 | l1 = one.weight.norm(p=1) 164 | loss = loss + lambd * l1 165 | if arg.method == 'lp': 166 | lp = one.weight.norm(p=arg.p) 167 | loss = loss + lambd * lp 168 | 169 | loss.backward() 170 | 171 | opt.step() 172 | 173 | # Compute accuracy on validation set 174 | with torch.no_grad(): 175 | model.train(False) 176 | 177 | total, correct = 0.0, 0.0 178 | for input, labels in testloader: 179 | opt.zero_grad() 180 | 181 | if arg.cuda: 182 | input, labels = input.cuda(), labels.cuda() 183 | input, labels = Variable(input), Variable(labels) 184 | 185 | output = model(input) 186 | outcls = output.argmax(dim=1) 187 | 188 | total += outcls.size(0) 189 | correct += (outcls == labels).sum().item() 190 | 191 | print(correct, total) 192 | acc = correct / float(total) 193 | 194 | print('lr {}, bs {}: {} acc'.format(lr, batch_size, acc)) 195 | 196 | if acc > bestacc: 197 | bestbs, bestlr = batch_size, lr 198 | 199 | print('Hyperparameter selection finished. Best learning rate: {}. Best batch size: {}.'.format(bestlr, bestbs)) 200 | 201 | # Repeat runs with chosen hyperparameters 202 | accuracies = [] 203 | densities = [] 204 | 205 | for r in trange(arg.repeats): 206 | 207 | if (arg.task == 'mnist'): 208 | data = arg.data + os.sep + arg.task 209 | 210 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=ToTensor()) 211 | trainloader = torch.utils.data.DataLoader(train, batch_size=bestbs, shuffle=True, num_workers=2) 212 | 213 | test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=ToTensor()) 214 | testloader = torch.utils.data.DataLoader(test, batch_size=bestbs, shuffle=False, num_workers=2) 215 | else: 216 | raise Exception('Task {} not recognized'.format(arg.task)) 217 | 218 | model, one, two = getmodel(arg, insize, numcls, points) # new model 219 | opt = torch.optim.Adam(model.parameters(), lr=lr) 220 | 221 | # Train for fixed number of epochs 222 | for e in range(arg.epochs): 223 | for input, labels in trainloader: 224 | opt.zero_grad() 225 | 226 | if arg.cuda: 227 | input, labels = input.cuda(), labels.cuda() 228 | input, labels = Variable(input), Variable(labels) 229 | 230 | output = model(input) 231 | 232 | loss = F.cross_entropy(output, labels) 233 | 234 | if arg.method == 'l1': 235 | l1 = one.weight.norm(p=1) 236 | loss = loss + lambd * l1 237 | if arg.method == 'lp': 238 | lp = one.weight.norm(p=arg.p) 239 | loss = loss + lambd * lp 240 | 241 | loss.backward() 242 | 243 | opt.step() 244 | 245 | # Compute accuracy on test set 246 | with torch.no_grad(): 247 | model.train(False) 248 | 249 | total, correct = 0.0, 0.0 250 | for input, labels in testloader: 251 | opt.zero_grad() 252 | 253 | if arg.cuda: 254 | input, labels = input.cuda(), labels.cuda() 255 | input, labels = Variable(input), Variable(labels) 256 | 257 | output = model(input) 258 | outcls = output.argmax(dim=1) 259 | 260 | total += outcls.size(0) 261 | correct += (outcls == labels).sum().item() 262 | 263 | acc = correct / float(total) 264 | 265 | # Compute density 266 | total = util.prod(insize) * arg.hidden 267 | 268 | if arg.method == 'l1' or arg.method== 'lp': 269 | density = (one.weight.data.abs() > 0.0001).sum().item() / float(total) 270 | elif arg.method == 'nas': 271 | density = (points)/float(total) 272 | else: 273 | raise Exception('Method {} not recognized'.format(arg.task)) 274 | 275 | accuracies.append(acc) 276 | densities.append(density) 277 | 278 | print('accuracies: ', accuracies) 279 | print('densities: ', densities) 280 | 281 | # Save to CSV 282 | np.savetxt( 283 | 'out.{}.{}.csv'.format(arg.method, arg.control), 284 | torch.cat([ 285 | torch.tensor(accuracies, dtype=torch.float)[:, None], 286 | torch.tensor(densities, dtype=torch.float)[:, None] 287 | ], dim=1).numpy(), 288 | ) 289 | 290 | print('Finished') 291 | 292 | def single(arg): 293 | 294 | tbw = SummaryWriter() 295 | 296 | lambd = torch.logspace(arg.rfrom, arg.rto, arg.rnum)[arg.control].item() 297 | points = arg.hidden * (arg.control + 1) # NAS control variable 298 | 299 | print('lambda ', lambd) 300 | 301 | # Grid search over batch size/learning rate 302 | # -- Set up model 303 | 304 | insize = (1, 28, 28) if arg.task == 'mnist' else (3, 32, 32) 305 | numcls = 100 if arg.task == 'cifar100' else 10 306 | 307 | # Repeat runs with chosen hyperparameters 308 | accuracies = [] 309 | densities = [] 310 | 311 | for r in trange(arg.repeats): 312 | 313 | if arg.task == 'mnist': 314 | if arg.final: 315 | data = arg.data + os.sep + arg.task 316 | 317 | train = torchvision.datasets.MNIST(root=data, train=True, download=True, transform=ToTensor()) 318 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.bs, shuffle=True, num_workers=2) 319 | 320 | test = torchvision.datasets.MNIST(root=data, train=False, download=True, transform=ToTensor()) 321 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.bs, shuffle=False, num_workers=2) 322 | else: 323 | 324 | NUM_TRAIN = 45000 325 | NUM_VAL = 5000 326 | total = NUM_TRAIN + NUM_VAL 327 | 328 | train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=ToTensor()) 329 | 330 | trainloader = DataLoader(train, batch_size=arg.bs, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 331 | testloader = DataLoader(train, batch_size=arg.bs, 332 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 333 | 334 | elif (arg.task == 'cifar10'): 335 | 336 | data = arg.data + os.sep + arg.task 337 | 338 | if arg.final: 339 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=ToTensor()) 340 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.bs, shuffle=True, num_workers=2) 341 | test = torchvision.datasets.CIFAR10(root=data, train=False, download=True, transform=ToTensor()) 342 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.bs, shuffle=False, num_workers=2) 343 | 344 | else: 345 | NUM_TRAIN = 45000 346 | NUM_VAL = 5000 347 | total = NUM_TRAIN + NUM_VAL 348 | 349 | train = torchvision.datasets.CIFAR10(root=data, train=True, download=True, transform=ToTensor()) 350 | 351 | trainloader = DataLoader(train, batch_size=arg.bs, sampler=util.ChunkSampler(0, NUM_TRAIN, total)) 352 | testloader = DataLoader(train, batch_size=arg.bs, 353 | sampler=util.ChunkSampler(NUM_TRAIN, NUM_VAL, total)) 354 | 355 | shape = (3, 32, 32) 356 | num_classes = 10 357 | 358 | else: 359 | raise Exception('Task {} not recognized'.format(arg.task)) 360 | 361 | model, one, two = getmodel(arg, insize, numcls, points) # new model 362 | opt = torch.optim.Adam(model.parameters(), lr=arg.lr) 363 | 364 | # Train for fixed number of epochs 365 | i = 0 366 | for e in range(arg.epochs): 367 | 368 | model.train(True) 369 | for input, labels in tqdm(trainloader): 370 | opt.zero_grad() 371 | 372 | if arg.cuda: 373 | input, labels = input.cuda(), labels.cuda() 374 | input, labels = Variable(input), Variable(labels) 375 | 376 | output = model(input) 377 | 378 | loss = F.cross_entropy(output, labels) 379 | 380 | if arg.method == 'l1': 381 | l1 = one.weight.norm(p=1) 382 | loss = loss + lambd * l1 383 | if arg.method == 'lp': 384 | lp = one.weight.norm(p=arg.p) 385 | loss = loss + lambd * lp 386 | 387 | loss.backward() 388 | 389 | tbw.add_scalar('sparsity/loss', loss.data.item(), i * arg.bs) 390 | i += 1 391 | 392 | opt.step() 393 | 394 | # Compute accuracy on test set 395 | with torch.no_grad(): 396 | model.train(False) 397 | 398 | total, correct = 0.0, 0.0 399 | for input, labels in testloader: 400 | opt.zero_grad() 401 | 402 | if arg.cuda: 403 | input, labels = input.cuda(), labels.cuda() 404 | input, labels = Variable(input), Variable(labels) 405 | 406 | output = model(input) 407 | 408 | outcls = output.argmax(dim=1) 409 | 410 | total += outcls.size(0) 411 | correct += (outcls == labels).sum().item() 412 | 413 | acc = correct / float(total) 414 | 415 | print('\nepoch {}: {}\n'.format(e, acc)) 416 | tbw.add_scalar('sparsity/test acc', acc, e) 417 | 418 | # Compute density 419 | total = util.prod(insize) * arg.hidden 420 | 421 | 422 | if arg.method == 'l1' or arg.method == 'lp': 423 | density = (one.weight > 0.0001).sum().item() / float(total) 424 | elif arg.method == 'nas' or arg.method == 'nas-temp': 425 | density = (points) / total 426 | else: 427 | raise Exception('Method {} not recognized'.format(arg.method)) 428 | 429 | accuracies.append(acc) 430 | densities.append(density) 431 | 432 | print('accuracies: ', accuracies) 433 | print('densities: ', densities) 434 | 435 | if arg.method == 'lp': 436 | if arg.p == 0.2: 437 | name = 'l5' 438 | elif arg.p == 0.5: 439 | name = 'l2' 440 | elif arg.p == 1.0: 441 | name = 'l1' 442 | else: 443 | name = 'l' + arg.p 444 | else: 445 | name = arg.method 446 | 447 | # Save to CSV 448 | np.savetxt( 449 | 'results.{}.{}.csv'.format(name, arg.control), 450 | torch.cat([ 451 | torch.tensor(accuracies, dtype=torch.float)[:, None], 452 | torch.tensor(densities, dtype=torch.float)[:, None] 453 | ], dim=1).numpy(), 454 | ) 455 | 456 | print('Finished') 457 | 458 | if __name__ == "__main__": 459 | 460 | parser = ArgumentParser() 461 | 462 | parser.add_argument("-C", "--control", 463 | dest="control", 464 | help="Control parameter. For l1, lambda = 10^(-5+c). For NAS, k=hidden*(c+1)", 465 | default=0, type=int) 466 | 467 | parser.add_argument("-l", "--lr", 468 | dest="lr", 469 | help="Learning rate (ignored in sweep)", 470 | default=None, type=float) 471 | 472 | parser.add_argument("-b", "--batch ", 473 | dest="bs", 474 | help="Batch size (ignored in sweep)", 475 | default=None, type=int) 476 | 477 | parser.add_argument("-e", "--epochs", 478 | dest="epochs", 479 | help="Number of epochs", 480 | default=50, type=int) 481 | 482 | parser.add_argument("-m", "--method", 483 | dest="method", 484 | help="Method to use (lp, nas) ", 485 | default='l1', type=str) 486 | 487 | parser.add_argument("-P", "--lp-p", 488 | dest="p", 489 | help="Exponent in lp reg", 490 | default=2.0, type=float) 491 | 492 | parser.add_argument("-t", "--task", 493 | dest="task", 494 | help="Task to use (mnist, cifar10, cifar100) ", 495 | default='mnist', type=str) 496 | 497 | parser.add_argument("-H", "--hidden-size", 498 | dest="hidden", 499 | help="size of the hidden layers", 500 | default=64, type=int) 501 | 502 | parser.add_argument("-a", "--gadditional", 503 | dest="gadditional", 504 | help="Number of additional points sampled globally per index-tuple (NAS)", 505 | default=2, type=int) 506 | 507 | parser.add_argument("-A", "--radditional", 508 | dest="radditional", 509 | help="Number of additional points sampled locally per index-tuple (NAS)", 510 | default=2, type=int) 511 | 512 | parser.add_argument("-R", "--range", 513 | dest="range", 514 | help="Range in which the local points are sampled (NAS)", 515 | default=4, type=int) 516 | 517 | parser.add_argument("-r", "--repeats", 518 | dest="repeats", 519 | help="Number of times to repeat the final experiment (once the hyperparameters are chosen).", 520 | default=10, type=int) 521 | 522 | parser.add_argument("--seed", 523 | dest="seed", 524 | help="Random seed", 525 | default=4, type=int) 526 | 527 | parser.add_argument("-c", "--cuda", dest="cuda", 528 | help="Whether to use cuda.", 529 | action="store_true") 530 | 531 | parser.add_argument("-D", "--data", dest="data", 532 | help="Data directory", 533 | default='./data') 534 | 535 | parser.add_argument("-M", "--min-sigma", 536 | dest="min_sigma", 537 | help="Minimal sigma value", 538 | default=0.01, type=float) 539 | 540 | parser.add_argument("--rfrom", 541 | dest="rfrom", 542 | help="Minimal control value (for lp baselines)", 543 | default=0.00001, type=float) 544 | 545 | parser.add_argument("--rto", 546 | dest="rto", 547 | help="Maximal control value (for lp baselines)", 548 | default=1.0, type=float) 549 | 550 | parser.add_argument("--rnum", 551 | dest="rnum", 552 | help="Number of control parameters (for lp baseline)", 553 | default=10, type=int) 554 | 555 | parser.add_argument("-f", "--final", dest="final", 556 | help="Whether to run on the real test set.", 557 | action="store_true") 558 | 559 | parser.add_argument("-F", "--fix-values", dest="fix_values", 560 | help="Whether to fix all values to 1 in the NAS model.", 561 | action="store_true") 562 | 563 | parser.add_argument("--sweep", dest="sweep", 564 | help="Whether to run a rull parameter sweep over batch size/learn rate.", 565 | action="store_true") 566 | 567 | args = parser.parse_args() 568 | 569 | print('OPTIONS', args) 570 | 571 | if args.sweep: 572 | sweep(args) 573 | else: 574 | single(args) 575 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | This directory holds the code used to create the plots in the paper. Most likely, this code will not work out-of-the 2 | box. If you want to use it, you'll have to copy it and adapt it for your use case. -------------------------------------------------------------------------------- /scripts/_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../sparse'))) 6 | 7 | import sparse -------------------------------------------------------------------------------- /scripts/generate.mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision.transforms import ToTensor 4 | 5 | from _context import sparse 6 | from sparse import util 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | import random, sys 12 | 13 | from PIL import Image 14 | 15 | from argparse import ArgumentParser 16 | 17 | from math import ceil 18 | 19 | from collections import Counter 20 | 21 | """ 22 | Generate rotated, and scaled version of MNIST 23 | """ 24 | 25 | def paste(background, foreground, scale=[1.5, 2.0]): 26 | 27 | rh, rw = background.size 28 | 29 | sch, scw = scale 30 | new_size = (int(ceil(foreground.size[0] * sch)), int(ceil(foreground.size[1] * scw))) 31 | try: 32 | foreground = foreground.resize(new_size, resample=Image.BICUBIC) 33 | except: 34 | print(f'pasting with size {new_size} failed. ({(sch, scw)})') 35 | sys.exit() 36 | 37 | # Rotate the foreground 38 | angle_degrees = random.randint(0, 359) 39 | foreground = foreground.rotate(angle_degrees, resample=Image.BICUBIC, expand=True) 40 | 41 | h, w = foreground.size 42 | h, w = rh - h, rw - w 43 | h, w = random.randint(0, h), random.randint(0, w) 44 | 45 | background.paste(foreground, box=(h, w), mask=foreground) 46 | 47 | def make_image(b, images, res=100, noise=10, scale=[1.0, 2.0], aspect=2.0): 48 | """ 49 | 50 | Extract the b-th image from the batch of images, and place it into a 100x100 image, rotated and scaled 51 | with noise extracted from other images. 52 | 53 | :param b: 54 | :param images: 55 | :param res: 56 | :return: 57 | """ 58 | 59 | 60 | # Sample a uniform scale for the noise and target 61 | sr = scale[1] - scale[0] 62 | ar = aspect - 1.0 63 | sc = random.random() * sr + scale[0] 64 | ap = random.random() * ar + 1.0 65 | 66 | sch, scw = sc, sc 67 | if random.choice([True, False]): 68 | sch *= ap 69 | else: 70 | scw *= ap 71 | 72 | background = Image.new(mode='RGB', size=(res, res)) 73 | 74 | # generate random patch size 75 | nm = 14 76 | nh, nw = random.randint(4, nm), random.randint(4, nm) 77 | 78 | # Paste noise 79 | for i in range(noise): 80 | 81 | # select another image 82 | ind = random.randint(0, images.size(0)-2) 83 | if ind == b: 84 | ind += 1 85 | 86 | # clip out a random nh x nw patch 87 | h, w = random.randint(0, 28-nh), random.randint(0, 28-nw) 88 | nump = (images[ind, 0, h:h+nh, h:h+nw].numpy() * 255).astype('uint8').squeeze() 89 | patch = Image.fromarray(nump) 90 | 91 | paste(background, patch, scale=(sch, scw)) 92 | 93 | # Paste image 94 | 95 | nump = (images[b, 0, :, :].numpy() * 255).astype('uint8').squeeze() 96 | 97 | foreground = Image.fromarray(nump) 98 | 99 | paste(background, foreground, scale=(sch, scw)) 100 | 101 | return background 102 | 103 | def go(arg): 104 | 105 | # make directories 106 | for i in range(10): 107 | util.makedirs('./mnist-rsc/train/{}/'.format(i)) 108 | util.makedirs('./mnist-rsc/test/{}/'.format(i)) 109 | 110 | train = torchvision.datasets.MNIST(root=arg.data, train=True, download=True, transform=ToTensor()) 111 | trainloader = torch.utils.data.DataLoader(train, batch_size=arg.batch, shuffle=True, num_workers=2) 112 | 113 | test = torchvision.datasets.MNIST(root=arg.data, train=False, download=True, transform=ToTensor()) 114 | testloader = torch.utils.data.DataLoader(test, batch_size=arg.batch, shuffle=True, num_workers=2) 115 | 116 | indices = Counter() 117 | 118 | for images, labels in tqdm(trainloader): 119 | 120 | batch_size = labels.size(0) 121 | 122 | for b in range(batch_size): 123 | image = make_image(b, images, res=arg.res, noise=arg.noise, scale=arg.scale) 124 | label = int(labels[b].item()) 125 | 126 | image.save('./mnist-rsc/train/{}/{:06}.png'.format(label, indices[label])) 127 | 128 | indices[label] += 1 129 | 130 | indices = Counter() 131 | 132 | for images, labels in tqdm(testloader): 133 | 134 | batch_size = labels.size(0) 135 | 136 | for b in range(batch_size): 137 | image = make_image(b, images, res=arg.res, noise=arg.noise, scale=arg.scale) 138 | label = int(labels[b].item()) 139 | 140 | image.save('./mnist-rsc/test/{}/{:06}.png'.format(label, indices[label])) 141 | 142 | indices[label] += 1 143 | 144 | if __name__ == "__main__": 145 | 146 | ## Parse the command line options 147 | parser = ArgumentParser() 148 | 149 | parser.add_argument("-D", "--data", dest="data", 150 | help="Data directory", 151 | default='./data/') 152 | 153 | parser.add_argument("-b", "--batch-size", 154 | dest="batch", 155 | help="The batch size.", 156 | default=256, type=int) 157 | 158 | parser.add_argument("-r", "--resolution", 159 | dest="res", 160 | help="Resolution (one side, images are always square).", 161 | default=100, type=int) 162 | 163 | parser.add_argument("-n", "--noise", 164 | dest="noise", 165 | help="Number of noise patches to add.", 166 | default=10, type=int) 167 | 168 | 169 | parser.add_argument("-s", "--scale", 170 | dest="scale", 171 | help="Min/max scale multiplier.", 172 | nargs=2, 173 | default=[1.0, 2.0], type=float) 174 | 175 | parser.add_argument("-a", "--aspect", 176 | dest="aspect", 177 | help="Min/max aspect multiplier.", 178 | default=2.0, type=float) 179 | 180 | options = parser.parse_args() 181 | 182 | print('OPTIONS ', options) 183 | 184 | go(options) -------------------------------------------------------------------------------- /scripts/plot.identity.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | from sparse import util 3 | 4 | import matplotlib as mpl 5 | mpl.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import logging, time, gc 8 | import numpy as np 9 | 10 | import pickle 11 | 12 | """ 13 | Used to create Figure 2 in the paper 14 | """ 15 | 16 | plt.figure(figsize=(10.5, 3.5)) 17 | plt.clf() 18 | 19 | files = [ 20 | 'results.004.True.npy', 21 | 'results.008.True.npy', 22 | 'results.016.True.npy', 23 | 'results.008.False.npy', 24 | 'results.016.False.npy', 25 | 'results.032.False.npy', 26 | 'results.064.False.npy', 27 | 'results.128.False.npy', 28 | ] 29 | sizes = [4, 8, 16, 8, 16, 32, 64, 128] 30 | itss = [120_000, 120_000, 120_000, 5_000, 5_000, 5_000, 10_000, 20_000] 31 | des = [1000, 1000, 1000, 100, 100, 100, 100, 500] 32 | reinfs = [True, True, True, False, False, False, False, False] 33 | lrs = [0.01, 0.001, 0.001, 0.005, 0.005, 0.005, 0.005, 0.005] 34 | 35 | norm = mpl.colors.Normalize(vmin=2, vmax=7) 36 | cmap = plt.get_cmap('Set1') 37 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 38 | 39 | handles, labels = [], [] 40 | 41 | for si, (file, size, iterations, dot_every, reinf, lr) in enumerate(zip(files, sizes, itss, des, reinfs, lrs)): 42 | res = np.load(file) 43 | print('size ', size, 'reinf', reinf, res.shape) 44 | 45 | color = map.to_rgba(np.log2(size)) 46 | ndots = iterations // dot_every 47 | 48 | print(ndots) 49 | 50 | lbl = '{0}x{0}, r={1}'.format(size, res.shape[0]) 51 | if res.shape[0] > 1: 52 | h = plt.errorbar( 53 | x=np.arange(ndots) * dot_every, y=np.mean(res, axis=0), yerr=np.std(res, axis=0), 54 | label=lbl, color=color, linestyle='--' if reinf else '-', alpha=0.2 if reinf else 1.0) 55 | handles.append(h) 56 | 57 | else: 58 | h = plt.plot( 59 | np.arange(ndots) * dot_every, np.mean(res, axis=0), 60 | label=lbl, color=color, linestyle='--' if reinf else '-') 61 | handles.append(h[0]) 62 | 63 | labels.append(lbl) 64 | 65 | ax = plt.gca() 66 | ax.set_ylim(bottom=0) 67 | ax.set_xlabel('iterations') 68 | ax.set_ylabel('mean-squared error') 69 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor= (0.96, 1.0), ncol=1, 70 | borderaxespad=0, frameon=False) 71 | 72 | ax.set_xlim(100, 120_000) 73 | ax.set_xscale('log') 74 | 75 | util.basic() 76 | 77 | ax.spines["bottom"].set_visible(False) 78 | 79 | plt.tight_layout() 80 | 81 | plt.savefig('identity.pdf', dpi=600) -------------------------------------------------------------------------------- /scripts/plot.sort.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import util, logging, time, gc 5 | import numpy as np 6 | 7 | import pickle 8 | 9 | """ 10 | Used to create Figure 6 in the paper 11 | 12 | """ 13 | 14 | plt.figure(figsize=(10.5, 3.5)) 15 | plt.clf() 16 | 17 | files = [ 18 | 'results.4.np.npy', 19 | 'results.8.np.npy', 20 | 'results.16.np.npy' 21 | ] 22 | sizes = [4, 8, 16] 23 | itss = [30_000, 60_000, 120_000] 24 | des = [500, 500, 1000] 25 | 26 | norm = mpl.colors.Normalize(vmin=2, vmax=6) 27 | cmap = plt.get_cmap('Set1') 28 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 29 | 30 | labels = [] 31 | handles = [] 32 | 33 | for si, (file, size, iterations, dot_every) in enumerate(zip(files, sizes, itss, des)): 34 | res = np.load('./paper/sort/' + file) 35 | res = 1.0 - res # acc to error 36 | print('size ', size, 'reinf', res.shape) 37 | 38 | color = map.to_rgba(np.log2(size)) 39 | ndots = iterations // dot_every 40 | 41 | # print(size, np.mean(res, axis=0)) 42 | 43 | labels.append('{0}x{0}, r={1}'.format(size, res.shape[0])) 44 | if res.shape[0] > 1: 45 | h = plt.errorbar( 46 | x=np.arange(ndots) * dot_every, y=np.mean(res, axis=0), yerr=np.std(res, axis=0), 47 | color=color) 48 | handles.append(h) 49 | 50 | else: 51 | h = plt.plot( 52 | np.arange(ndots) * dot_every, np.mean(res, axis=0), 53 | color=color) 54 | 55 | handles.append(h[0]) 56 | 57 | ax = plt.gca() 58 | ax.set_ylim((0, 1)) 59 | ax.set_xlabel('iterations') 60 | ax.set_ylabel('error') 61 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor= (0.96, 1.0), ncol=1, 62 | borderaxespad=0, frameon=False) 63 | 64 | util.basic() 65 | ax.spines["bottom"].set_visible(False) 66 | 67 | plt.tight_layout() 68 | 69 | plt.savefig('./paper/sort/sort.pdf', dpi=600) -------------------------------------------------------------------------------- /scripts/plot.sparsity.py: -------------------------------------------------------------------------------- 1 | from _context import sparse 2 | from sparse import util 3 | 4 | import matplotlib as mpl 5 | mpl.use('Agg') 6 | import matplotlib.pyplot as plt 7 | import logging, time, gc 8 | import numpy as np 9 | 10 | import pickle 11 | 12 | """ 13 | Plot the results of the sparsity experiment. 14 | """ 15 | 16 | 17 | tasks = ['mnist', 'cifar'] 18 | 19 | models = ['nas-temp', 'l5', 'l2', 'l1'] 20 | name = {'nas-temp':'sparse layer', 'l1':'$l^1$', 'l2':'$l^\\frac{1}{2}$', 'l5':'$l^\\frac{1}{5}$'} 21 | controls = 5 22 | 23 | norm = mpl.colors.Normalize(vmin=0, vmax=5) 24 | cmap = plt.get_cmap('Set1') 25 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 26 | 27 | handles, labels = [], [] 28 | 29 | plt.figure(figsize=(10.5, 3.5)) 30 | plt.clf() 31 | 32 | 33 | 34 | for ti, task in enumerate(tasks): 35 | 36 | ax = plt.subplot(1, 2, ti + 1) 37 | 38 | 39 | for mi, model in enumerate(models): 40 | 41 | densities = [] 42 | accuracies = [] 43 | 44 | dstds = [] 45 | astds = [] 46 | 47 | for c in range(controls): 48 | try: 49 | res = np.genfromtxt('./{}/results.{}.{}.csv'.format(task, model, c)) 50 | if len(res.shape) == 1: 51 | res = res[None, :] 52 | 53 | assert len(res.shape) == 2 and res.shape[1] == 2 54 | 55 | repeats = res.shape[0] 56 | 57 | if repeats == 1: 58 | accuracies.append( res[0, 0]) 59 | densities.append(res[0, 1]) 60 | else: 61 | accuracies.append( res[:, 0].mean()) 62 | densities.append( res[:, 1].mean()) 63 | 64 | astds.append( res[:, 0].std()) 65 | dstds.append( res[:, 1].std()) 66 | except: 67 | print('could not load file ./{}/results.{}.{}.csv'.format(task, model, c)) 68 | 69 | color = map.to_rgba(mi) 70 | 71 | lbl = '{}, r={}'.format(name[model], repeats) 72 | if ti == 0: 73 | labels.append(lbl) 74 | 75 | if len(dstds) == 0: 76 | h = ax.plot( 77 | densities, accuracies, 78 | label=lbl, linestyle='-' if model == 'nas-temp' else ':', marker='s', markersize=2) 79 | if ti == 0: 80 | handles.append(h[0]) 81 | else: 82 | h = ax.errorbar( 83 | x=densities, y=accuracies, xerr=dstds, yerr=astds, 84 | label=lbl, linestyle='-' if model == 'nas-temp' else ':', marker='s', markersize=2) 85 | 86 | if ti == 0: 87 | handles.append(h) 88 | 89 | 90 | ax.set_xlabel('density') 91 | ax.set_xscale('log') 92 | 93 | ax.set_ylim(0, 1) 94 | # if task == 'mnist': 95 | # ax.set_ylim(0.6, 1.0) 96 | # if task == 'cifar': 97 | # ax.set_ylim(0.2, 0.5) 98 | 99 | ax.set_title(task) 100 | 101 | if ti ==0: 102 | ax.set_ylabel('accuracy') 103 | 104 | if ti ==1 : 105 | ax.legend(handles, labels, loc='upper left', bbox_to_anchor= (0.96, 1.0), ncol=1, 106 | borderaxespad=0, frameon=False) 107 | 108 | util.basic() 109 | 110 | 111 | plt.tight_layout() 112 | 113 | plt.savefig('sparsity.pdf'.format(task), dpi=600) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='sparse', 4 | version='0.1', 5 | description='Sparse layer', 6 | url='http://github.com/MaestroGraph/sparse-hyper', 7 | author='Peter Bloem', 8 | author_email='sparse@peterbloem.nlß', 9 | license='MIT', 10 | packages=['sparse'], 11 | install_requires=[ 12 | 'matplotlib', 13 | 'torch', 14 | 'tqdm' 15 | ], 16 | zip_safe=False) -------------------------------------------------------------------------------- /sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .sort import Split, SortLayer 2 | 3 | from .layers import SparseLayer, NASLayer, Convolution, transform_means, transform_sigmas 4 | from .layers import ngenerate, transform_means, densities 5 | 6 | from .tensors import contract, logsoftmax, batchmm, simple_normalize 7 | 8 | # from .tensors import flatten_indices_mat 9 | -------------------------------------------------------------------------------- /sparse/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import Parameter 5 | from torch import FloatTensor, LongTensor 6 | 7 | import abc, itertools, math, types 8 | from numpy import prod 9 | 10 | import torch.nn.functional as F 11 | 12 | import tensors 13 | 14 | from sparse import util 15 | from sparse.util import Bias, sparsemult, contains_nan, bmult, nduplicates, d 16 | 17 | import sys 18 | import random 19 | 20 | import numpy as np 21 | 22 | from enum import Enum 23 | 24 | # added to the sigmas to prevent NaN 25 | EPSILON = 10e-7 26 | SIGMA_BOOST = 2.0 27 | 28 | 29 | """ 30 | Core implementation of the sparse (hyper)layer as an abstract class (SparseLayer). 31 | 32 | """ 33 | 34 | def densities(points, means, sigmas): 35 | """ 36 | Compute the unnormalized probability densities of a given set of points for a 37 | given set of multivariate normal distrbutions (MVNs) 38 | 39 | :param means: (b, n, c, r) tensor of n vectors of dimension r (in a batch of size b) 40 | representing the means of n MVNs 41 | :param sigmas: (b, k, l, r) tensor of n vectors of dimension r (in a batch of size b) 42 | representing the diagonal covariance matrix of n MVNs 43 | :param points: The points for which to compute the probabilioty densities 44 | :return: (b, k, n) tensor containing the density of every point under every MVN 45 | """ 46 | 47 | # n: number of MVNs 48 | # rank: dim of points 49 | 50 | # i: number of integer index tuples sampled per chunk 51 | # k: number of continuous index tuples per chunk 52 | # c: number of chunks 53 | 54 | c, i, rank = points.size()[-3:] 55 | c, k, rank = means.size()[-3:] 56 | 57 | pref = points.size()[:-3] 58 | assert pref == means.size()[:-3] 59 | 60 | points = points.unsqueeze(-2).expand( *(pref + (c, i, k, rank)) ) 61 | means = means.unsqueeze(-3).expand_as(points) 62 | sigmas = sigmas.unsqueeze(-3).expand_as(points) 63 | 64 | sigmas_squared = torch.sqrt(1.0/(EPSILON+sigmas)) 65 | 66 | points = points - means 67 | points = points * sigmas_squared 68 | 69 | # Compute dot products for all points 70 | # -- unroll the pref/c/k/l dimensions 71 | points = points.view(-1, 1, rank) 72 | # -- dot prod 73 | 74 | # print(points) 75 | products = torch.bmm(points, points.transpose(1, 2)) 76 | # -- reconstruct shape 77 | products = products.view( *(pref + (c, i, k)) ) 78 | 79 | num = torch.exp(- 0.5 * products) # the numerator of the Gaussian density 80 | 81 | return num 82 | 83 | def transform_means(means, size, method='sigmoid'): 84 | """ 85 | Transforms raw parameters for the index tuples (with values in (-inf, inf)) into parameters within the bound of the 86 | dimensions of the tensor. 87 | 88 | In the case of a templated sparse layer, these parameters and the corresponding size tuple deascribe only the learned 89 | subtensor. 90 | 91 | :param means: (..., rank) tensor of raw parameter values 92 | :param size: Tuple describing the tensor dimensions. 93 | :return: (..., rank) 94 | """ 95 | 96 | # Compute upper bounds 97 | s = torch.tensor(list(size), dtype=torch.float, device=d(means)) - 1 98 | s = util.unsqueezen(s, len(means.size()) - 1) 99 | s = s.expand_as(means) 100 | 101 | # Scale to [0, 1] 102 | if method == 'modulo': 103 | means = means.remainder(s) 104 | 105 | return means 106 | 107 | if method == 'clamp': 108 | means = torch.max(means, torch.zeros(means.size(), device=d(means))) 109 | means = torch.min(means, s) 110 | 111 | return means 112 | 113 | means = torch.sigmoid(means) 114 | 115 | return means * s 116 | 117 | def transform_sigmas(sigmas, size, min_sigma=EPSILON): 118 | """ 119 | Transforms raw parameters for the conv matrices (with values in (-inf, inf)) into positive values, scaled proportional 120 | to the dimensions of the tensor. Note: each sigma is parametrized by a single value, which is expanded to a vector to 121 | fit the diagonal of the covariance matrix. 122 | 123 | In the case of a templated sparse layer, these parameters and the corresponing size tuple deascribe only the learned 124 | subtensor. 125 | 126 | :param sigmas: (..., ) matrix of raw sigma values 127 | :param size: Tuple describing the tensor dimensions. 128 | :param min_sigma: Minimal sigma value. 129 | :return:(..., rank) sigma values 130 | """ 131 | ssize = sigmas.size() 132 | r = len(size) 133 | 134 | # Scale to [0, 1] 135 | sigmas = F.softplus(sigmas + SIGMA_BOOST) + min_sigma 136 | # sigmas = sigmas[:, :, None].expand(b, k, r) 137 | sigmas = sigmas.unsqueeze(-1).expand(*(ssize + (r, ))) 138 | 139 | # Compute upper bounds 140 | s = torch.tensor(list(size), dtype=torch.float, device='cuda' if sigmas.is_cuda else 'cpu') 141 | s = util.unsqueezen(s, len(sigmas.size()) - 1) 142 | s = s.expand_as(sigmas) 143 | 144 | return sigmas * s 145 | 146 | class SparseLayer(nn.Module): 147 | """ 148 | Abstract class for the (templated) hyperlayer. Implement by defining a hypernetwork, and returning it from the 149 | hyper() method. See NASLayer for an implementation without hypernetwork. 150 | 151 | The templated hyperlayer takes certain columns of its index-tuple matrix as fixed (the template), and others as 152 | learnable. Imagine a neural network layer where the connections to the output nodes are fixed, but the connections to 153 | the input nodes can be learned. 154 | 155 | For a non-templated hypernetwork (all columns learnable), just leave the template parameters None. 156 | """ 157 | @abc.abstractmethod 158 | def hyper(self, input): 159 | """ 160 | Applies the hypernetwork, and returns the continuous index tuples, with their associated sigmas and values. 161 | 162 | :param input: The input to the hyperlayer. 163 | :return: A triple: (means, sigmas, values) 164 | """ 165 | raise NotImplementedError 166 | 167 | def __init__(self, in_rank, out_size, 168 | temp_indices=None, 169 | learn_cols=None, 170 | chunk_size=None, 171 | gadditional=0, radditional=0, region=None, 172 | bias_type=Bias.DENSE): 173 | """ 174 | :param in_rank: Nr of dimensions in the input. The specific size may vary between inputs. 175 | :param out_size: Tuple describing the size of the output. 176 | :param temp_indices: The template describing the fixed part of the tuple index-tuple matrix. None for a 177 | non-templated hyperlayer. 178 | :param learn_cols: Which columns of the template are 'free' (to be learned). The rest are fixed. None for a 179 | non-templated hyperlayer. 180 | :param chunk_size: Size of the "context" of generating integer index tuples. Duplicates are removed withing the 181 | same context. The list of continuous index tuples is chunked into contexts of this size. If none, the whole 182 | list counts as a single context. This is mostly useful in combination with templating. 183 | :param gadditional: Number of points to sample globally per index tuple 184 | :param radditional: Number of points to sample locally per index tuple 185 | :param region: Tuple describing the size of the region over which the local additional points are sampled (must 186 | be smaller than the size of the tensor). 187 | :param bias_type: The type of bias of the sparse layer (none, dense or sparse). 188 | :param subsample: 189 | """ 190 | 191 | super().__init__() 192 | rank = in_rank + len(out_size) 193 | 194 | assert learn_cols is None or len(region) == len(learn_cols), "Region should span as many dimensions as there are learnable columns" 195 | 196 | self.in_rank = in_rank 197 | self.out_size = out_size # without batch dimension 198 | self.gadditional = gadditional 199 | self.radditional = radditional 200 | self.region = region 201 | self.chunk_size = chunk_size 202 | 203 | self.bias_type = bias_type 204 | self.learn_cols = learn_cols if learn_cols is not None else range(rank) 205 | 206 | self.templated = temp_indices is not None 207 | 208 | # create a tensor with all binary sequences of length 'out_rank' as rows 209 | # (this will be used to compute the nearby integer-indices of a float-index). 210 | self.register_buffer('floor_mask', floor_mask(len(self.learn_cols))) 211 | 212 | if self.templated: 213 | # template for the index matrix containing the hardwired connections 214 | # The learned parts can be set to zero; they will be overriden. 215 | assert temp_indices.size(1) == in_rank + len(out_size) 216 | 217 | self.register_buffer('temp_indices', temp_indices) 218 | 219 | def is_cuda(self): 220 | return next(self.parameters()).is_cuda 221 | 222 | def forward(self, input, mrange=None, seed=None, **kwargs): 223 | """ 224 | 225 | :param input: 226 | :param mrange: Specifies a subrange of index tuples to compute the gradient over. This is helpful for gradient 227 | accumulation methods. This doesn;t work together with templating. 228 | :param seed: 229 | :param kwargs: 230 | :return: 231 | """ 232 | 233 | assert mrange is None or not self.templated, "Templating and gradient accumulation do not work together" 234 | 235 | ### Compute and unpack output of hypernetwork 236 | 237 | bias = None 238 | 239 | if self.bias_type == Bias.NONE: 240 | means, sigmas, values = self.hyper(input, **kwargs) 241 | elif self.bias_type == Bias.DENSE: 242 | means, sigmas, values, bias = self.hyper(input, **kwargs) 243 | elif self.bias_type == Bias.SPARSE: 244 | raise Exception('Sparse bias not supported yet.') 245 | else: 246 | raise Exception('bias type {} not recognized.'.format(self.bias_type)) 247 | 248 | b, n, r = means.size() 249 | dv = 'cuda' if self.is_cuda() else 'cpu' 250 | 251 | # We divide the list of index tuples into 'chunks'. Each chunk represents a kind of context: 252 | # - duplicate integer index tuples within the chunk are removed 253 | # - proportions are normalized over all index tuples within the chunk 254 | # This is useful in the templated setting. If no chunk size is requested, we just add a singleton dimension. 255 | k = self.chunk_size if self.chunk_size is not None else n # chunk size 256 | c = n // k # number of chunks 257 | 258 | means, sigmas, values = means.view(b, c, k, r), sigmas.view(b, c, k, r), values.view(b, c, k) 259 | 260 | assert b == input.size(0), 'input batch size ({}) should match parameter batch size ({}).'.format(input.size(0), b) 261 | 262 | # max values allowed for each column in the index matrix 263 | fullrange = self.out_size + input.size()[1:] 264 | subrange = [fullrange[r] for r in self.learn_cols] # submatrix for the learnable dimensions 265 | 266 | if not self.training: 267 | indices = means.view(b, c*k, r).round().long() 268 | 269 | else: 270 | if mrange is not None: # only compute the gradient for a subset of index tuples 271 | fr, to = mrange 272 | 273 | # sample = random.sample(range(nm), self.subsample) # the means we will learn for 274 | ids = torch.zeros((k,), dtype=torch.uint8, device=dv) 275 | ids[fr:to] = 1 276 | 277 | means, means_out = means[:, :, ids, :], means[:, :, ~ids, :] 278 | sigmas, sigmas_out = sigmas[:, :, ids, :], sigmas[:, :, ~ids, :] 279 | values, values_out = values[:, :, ids], values[:, :, ~ids] 280 | 281 | # These should not get a gradient, since their means aren't being sampled for 282 | # (their gradient will be computed in other passes) 283 | means_out, sigmas_out, values_out = means_out.detach(), sigmas_out.detach(), values_out.detach() 284 | 285 | indices = generate_integer_tuples(means, self.gadditional, self.radditional, rng=subrange, relative_range=self.region, seed=seed, cuda=self.is_cuda()) 286 | indfl = indices.float() 287 | 288 | # Mask for duplicate indices 289 | dups = nduplicates(indices) 290 | 291 | # compute (unnormalized) densities under the given MVNs (proportions) 292 | props = densities(indfl, means, sigmas).clone() # result has size (b, c, i, k), i = indices[2] 293 | props[dups, :] = 0 294 | props = props / props.sum(dim=2, keepdim=True) # normalize over all points of a given index tuple 295 | 296 | # Weight the values by the proportions 297 | values = values[:, :, None, :].expand_as(props) 298 | 299 | values = props * values 300 | values = values.sum(dim=3) 301 | 302 | if mrange is not None: 303 | indices_out = means_out.data.round().long() 304 | # 305 | # print(indices.size(), indices_out.size()) 306 | # print(values.size(), values_out.size()) 307 | # sys.exit() 308 | 309 | indices = torch.cat([indices, indices_out], dim=2) 310 | values = torch.cat([values, values_out], dim=2) 311 | 312 | # remove the chunk dimensions 313 | indices, values = indices.view(b, -1 , r), values.view(b, -1) 314 | 315 | if self.templated: 316 | # stitch the generated indices into the template 317 | b, l, r = indices.size() 318 | h, w = self.temp_indices.size() 319 | template = self.temp_indices[None, :, None, :].expand(b, h, l//h, w) 320 | template = template.contiguous().view(b, l, w) 321 | 322 | template[:, :, self.learn_cols] = indices 323 | indices = template 324 | 325 | # if self.is_cuda(): 326 | # indices = indices.cuda() 327 | 328 | size = self.out_size + input.size()[1:] 329 | 330 | output = tensors.contract(indices, values, size, input) 331 | 332 | if self.bias_type == Bias.DENSE: 333 | return output + bias 334 | return output 335 | 336 | class NASLayer(SparseLayer): 337 | """ 338 | Sparse layer with free sparse parameters, no hypernetwork, no template. 339 | """ 340 | 341 | def __init__(self, in_size, out_size, k, 342 | sigma_scale=0.2, 343 | fix_values=False, has_bias=False, 344 | min_sigma=0.0, 345 | gadditional=0, 346 | region=None, 347 | radditional=None, 348 | template=None, 349 | learn_cols=None, 350 | chunk_size=None): 351 | """ 352 | 353 | :param in_size: 354 | :param out_size: 355 | :param k: 356 | :param sigma_scale: 357 | :param fix_values: 358 | :param has_bias: 359 | :param min_sigma: 360 | :param gadditional: 361 | :param region: 362 | :param radditional: 363 | :param clamp: 364 | :param template: LongTensor Template for the matrix of index tuples. Learnable columns are updated through backprop 365 | other values are taken from the template. 366 | :param learn_cols: tuple of integers. Learnable columns of the template. 367 | 368 | """ 369 | 370 | super().__init__(in_rank=len(in_size), 371 | out_size=out_size, 372 | bias_type=Bias.DENSE if has_bias else Bias.NONE, 373 | gadditional=gadditional, 374 | radditional=radditional, 375 | region=region, 376 | temp_indices=template, 377 | learn_cols=learn_cols, 378 | chunk_size=chunk_size) 379 | 380 | self.k = k 381 | self.in_size = in_size 382 | self.out_size = out_size 383 | self.sigma_scale = sigma_scale 384 | self.fix_values = fix_values 385 | self.has_bias = has_bias 386 | self.min_sigma = min_sigma 387 | 388 | self.rank = len(in_size) + len(out_size) 389 | 390 | imeans = torch.randn(k, self.rank if template is None else len(learn_cols)) 391 | isigmas = torch.randn(k) 392 | 393 | self.pmeans = Parameter(imeans) 394 | self.psigmas = Parameter(isigmas) 395 | 396 | if fix_values: 397 | self.register_buffer('pvalues', torch.ones(k)) 398 | else: 399 | self.pvalues = Parameter(torch.randn(k)) 400 | 401 | if self.has_bias: 402 | self.bias = Parameter(torch.zeros(*out_size)) 403 | 404 | def hyper(self, input, **kwargs): 405 | """ 406 | Evaluates hypernetwork. 407 | """ 408 | 409 | b = input.size(0) 410 | size = self.out_size + input.size()[1:] # total dimensions of the weight tensor 411 | 412 | if self.learn_cols is not None: 413 | size = [size[l] for l in self.learn_cols] 414 | 415 | k, r = self.pmeans.size() 416 | 417 | # Expand parameters along batch dimension 418 | means = self.pmeans[None, :, :].expand(b, k, r) 419 | sigmas = self.psigmas[None, :].expand(b, k) 420 | values = self.pvalues[None, :].expand(b, k) 421 | 422 | means, sigmas = transform_means(means, size), transform_sigmas(sigmas, size, min_sigma=self.min_sigma) * self.sigma_scale 423 | 424 | if self.has_bias: 425 | return means, sigmas, values, self.bias 426 | return means, sigmas, values 427 | 428 | class Convolution(nn.Module): 429 | """ 430 | A non-adaptive hyperlayer that mimics a convolution. That is, the basic structure of the layer is a convolution, but 431 | instead of connecting every input in the patch to every output channel, we connect them sparsely, with parameters 432 | learned by the hyperlayer. 433 | 434 | The parameters are the same for each instance of the convolution kernel, but they are sampled separately for each. 435 | 436 | The hyperlayer is _templated_ that is, each connection is fixed to one output node. There are k connections per 437 | output node. 438 | 439 | The stride is always 1, padding is always added to ensure that the output resolution is the same as the input 440 | resolution. 441 | 442 | """ 443 | 444 | def __init__(self, in_size, out_channels, k, kernel_size=3, 445 | gadditional=2, radditional=2, rprop=0.2, 446 | min_sigma=0.0, 447 | sigma_scale=0.1, 448 | fix_values=False, 449 | has_bias=True): 450 | """ 451 | :param in_size: Channels and resolution of the input 452 | :param out_size: Tuple describing the size of the output. 453 | :param k: Number of points sampled per instance of the kernel. 454 | :param kernel_size: Size of the (square) kernel., 455 | 456 | :param gadditional: Number of points to sample globally per index tuple 457 | :param radditional: Number of points to sample locally per index tuple 458 | :param rprop: Describes the region over which the local samples are taken, as a proportion of the channels 459 | :param bias_type: The type of bias of the sparse layer (none, dense or sparse). 460 | :param subsample: 461 | """ 462 | 463 | super().__init__() 464 | 465 | rank = 6 466 | 467 | self.in_size = in_size 468 | self.out_size = (out_channels,) + in_size[1:] 469 | self.kernel_size = kernel_size 470 | self.gadditional = gadditional 471 | self.radditional = radditional 472 | self.region = (max(1, math.floor(rprop * in_size[0])), kernel_size-1, kernel_size-1) 473 | 474 | self.min_sigma = min_sigma 475 | self.sigma_scale = sigma_scale 476 | 477 | self.has_bias = has_bias 478 | 479 | self.pad = nn.ZeroPad2d(kernel_size // 2) 480 | 481 | 482 | self.means = nn.Parameter(torch.randn(out_channels, k, 3)) 483 | self.sigmas = nn.Parameter(torch.randn(out_channels, k)) 484 | self.values = None if fix_values else nn.Parameter(torch.randn(out_channels, k)) 485 | 486 | # out_indices = torch.LongTensor(list(np.ndindex( (in_size[1:]) ))) 487 | # self.register_buffer('out_indices', out_indices) 488 | 489 | template = torch.LongTensor(list(np.ndindex( (out_channels, in_size[1], in_size[2]) ))) 490 | assert template.size() == (prod((out_channels, in_size[1], in_size[2])), 3) 491 | template = F.pad(template, (0, 3)) 492 | self.register_buffer('template', template) 493 | 494 | if self.has_bias: 495 | self.bias = Parameter(torch.randn(*self.out_size)) 496 | 497 | def hyper(self, x): 498 | """ 499 | Returns the means, sigmas and values for a _single_ kernel. The same kernel is applied at every position (but 500 | with fresh samples). 501 | 502 | :param x: 503 | :return: 504 | """ 505 | b = x.size(0) 506 | 507 | size = (self.in_size[0], self.kernel_size, self.kernel_size) 508 | 509 | o, k, r = self.means.size() 510 | 511 | # Expand parameters along batch dimension 512 | means = self.means[None, :, :].expand(b, o, k, r) 513 | sigmas = self.sigmas[None, :].expand(b, o, k) 514 | values = self.values[None, :].expand(b, o, k) 515 | 516 | means, sigmas = transform_means(means, size), transform_sigmas(sigmas, size, min_sigma=self.min_sigma) * self.sigma_scale 517 | 518 | return means, sigmas, values 519 | 520 | def forward(self, x): 521 | dv = 'cuda' if self.template.is_cuda else 'cpu' 522 | 523 | # get continuous parameters 524 | means, sigmas, values = self.hyper(x) 525 | 526 | # zero pad 527 | x = self.pad(x) 528 | 529 | b, o, k, r = means.size() 530 | assert sigmas.size() == (b, o, k, r) 531 | assert values.size() == (b, o, k) 532 | 533 | # number of instances of the convolution kernel 534 | nk = self.in_size[1] * self.in_size[2] 535 | 536 | # expand for all kernels 537 | means = means [:, :, None, :, :].expand(b, o, nk, k, r) 538 | sigmas = sigmas[:, :, None, :, :].expand(b, o, nk, k, r) 539 | values = values[:, :, None, :] .expand(b, o, nk, k) 540 | 541 | if not self.training: 542 | indices = means.round().long() 543 | 544 | l = k 545 | 546 | else: 547 | # sample integer index tuples 548 | # print(means.size()) 549 | indices = ngenerate(means, 550 | self.gadditional, self.radditional, 551 | relative_range=self.region, 552 | rng=(self.in_size[0], self.kernel_size, self.kernel_size), 553 | cuda=means.is_cuda) 554 | 555 | # for i in range(indices.contiguous().view(-1, 3).size(0)): 556 | # print(indices.contiguous().view(-1, 3)[i, :]) 557 | # sys.exit() 558 | 559 | # print('indices', indices.size()) 560 | indfl = indices.float() 561 | 562 | b, o, nk, l, r = indices.size() 563 | assert l == k * (2 ** r + self.gadditional + self.radditional) 564 | assert nk == self.in_size[1] * self.in_size[2] 565 | 566 | # mask for duplicate indices 567 | dups = nduplicates(indices) 568 | 569 | # compute unnormalized densities (proportions) under the given MVNs 570 | props = densities(indfl, means, sigmas).clone() # result has size (..., c, i, k), i = indices[2] 571 | # print('densities', props.size()) 572 | # print(util.contains_nan(props)) 573 | 574 | props[dups, :] = 0 575 | # print('... ', props.size()) 576 | 577 | props = props / props.sum(dim=-2, keepdim=True) # normalize over all points of a given index tuple 578 | 579 | # print(util.contains_nan(props)) 580 | # sys.exit() 581 | 582 | # Weight the values by the proportions 583 | values = values[:, :, :, None, :].expand_as(props) 584 | 585 | values = props * values 586 | values = values.sum(dim=4) 587 | 588 | template = self.template[None, :, None, :].expand(b, self.out_size[0]*self.in_size[1]*self.in_size[2], l, 6) 589 | template = template.view(b, self.out_size[0], nk, l, 6) 590 | 591 | template[:, :, :, :, 3:] = indices 592 | 593 | indices = template.contiguous().view(b, self.out_size[0] * nk * l, 6) 594 | offsets = indices[:, :, 1:3] 595 | 596 | # for i in range(indices.view(-1, 6).size(0)): 597 | # print(indices.view(-1, 6)[i, :], values.view(-1)[i].data) 598 | # sys.exit() 599 | 600 | indices[:, :, 4:] = indices[:, :, 4:] + offsets 601 | 602 | values = values.contiguous().view(b, self.out_size[0] * nk * l) 603 | 604 | # apply tensor 605 | size = self.out_size + x.size()[1:] 606 | 607 | assert (indices.view(-1, 6).max(dim=0)[0] >= torch.tensor(size, device=dv)).sum() == 0, "Max values of indices ({}) out of bounds ({})".format(indices.view(-1, 6).max(dim=0)[0], size) 608 | 609 | output = tensors.contract(indices, values, size, x) 610 | 611 | if self.has_bias: 612 | return output + self.bias 613 | return output 614 | 615 | FLOOR_MASKS = {} 616 | def floor_mask(num_cols, cuda=False): 617 | if num_cols not in FLOOR_MASKS: 618 | lsts = [[int(b) for b in bools] for bools in itertools.product([True, False], repeat=num_cols)] 619 | FLOOR_MASKS[num_cols] = torch.BoolTensor(lsts, device='cpu') 620 | 621 | if cuda: 622 | return FLOOR_MASKS[num_cols].cuda() 623 | return FLOOR_MASKS[num_cols] 624 | 625 | def generate_integer_tuples(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None): 626 | """ 627 | Takes continuous-valued index tuples, and generates integer-valued index tuples. 628 | 629 | The returned matrix of ints is not a Variable (just a plain LongTensor). Autograd of the real valued indices passes 630 | through the values alone, not the integer indices used to instantiate the sparse matrix. 631 | 632 | :param ind: A Variable containing a matrix of N by K, where K is the number of indices. 633 | :param val: A Variable containing a vector of length N containing the values corresponding to the given indices 634 | :return: a triple (ints, props, vals). ints is an N*2^K by K matrix representing the N*2^K integer index-tuples that can 635 | be made by flooring or ceiling the indices in 'ind'. 'props' is a vector of length N*2^K, which indicates how 636 | much of the original value each integer index-tuple receives (based on the distance to the real-valued 637 | index-tuple). vals is vector of length N*2^K, containing the value of the corresponding real-valued index-tuple 638 | (ie. vals just repeats each value in the input 'val' 2^K times). 639 | """ 640 | 641 | b, k, c, rank = means.size() 642 | FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor 643 | 644 | if seed is not None: 645 | torch.manual_seed(seed) 646 | 647 | """ 648 | Generate neighbor tuples 649 | """ 650 | if fm is None: 651 | fm = floor_mask(rank, cuda) 652 | fm = fm[None, None, None, :].expand(b, k, c, 2 ** rank, rank) 653 | 654 | neighbor_ints = means.data[:, :, :, None, :].expand(b, k, c, 2 ** rank, rank).contiguous() 655 | 656 | neighbor_ints[fm] = neighbor_ints[fm].floor() 657 | neighbor_ints[~fm] = neighbor_ints[~fm].ceil() 658 | 659 | neighbor_ints = neighbor_ints.long() 660 | 661 | """ 662 | Sample uniformly from all integer tuples 663 | """ 664 | 665 | global_ints = FT(b, k, c, gadditional, rank) 666 | 667 | global_ints.uniform_() 668 | global_ints *= (1.0 - EPSILON) 669 | 670 | rng = FT(rng) 671 | rngxp = rng.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(global_ints) 672 | 673 | global_ints = torch.floor(global_ints * rngxp).long() 674 | 675 | """ 676 | Sample uniformly from a small range around the given index tuple 677 | """ 678 | local_ints = FT(b, k, c, ladditional, rank) 679 | 680 | local_ints.uniform_() 681 | local_ints *= (1.0 - EPSILON) 682 | 683 | rngxp = rng[None, None, None, :].expand_as(local_ints) # bounds of the tensor 684 | 685 | rrng = FT(relative_range) # bounds of the range from which to sample 686 | rrng = rrng[None, None, None, :].expand_as(local_ints) 687 | 688 | # print(means.size()) 689 | mns_expand = means.round().unsqueeze(3).expand_as(local_ints) 690 | 691 | # upper and lower bounds 692 | lower = mns_expand - rrng * 0.5 693 | upper = mns_expand + rrng * 0.5 694 | 695 | # check for any ranges that are out of bounds 696 | idxs = lower < 0.0 697 | lower[idxs] = 0.0 698 | 699 | idxs = upper > rngxp 700 | lower[idxs] = rngxp[idxs] - rrng[idxs] 701 | 702 | local_ints = (local_ints * rrng + lower).long() 703 | 704 | all = torch.cat([neighbor_ints, global_ints, local_ints] , dim=3) 705 | 706 | return all.view(b, k, -1, rank) # combine all indices sampled within a chunk 707 | 708 | 709 | def ngenerate(means, gadditional, ladditional, rng=None, relative_range=None, seed=None, cuda=False, fm=None, epsilon=EPSILON): 710 | """ 711 | 712 | Generates random integer index tuples based on continuous parameters. 713 | 714 | :param epsilon: The random bumbers are based on uniform samples in (0, 1-epsilon). Note that 715 | in some cases epsilon needs to be relatively big (e.g. 10-5) 716 | 717 | """ 718 | 719 | b = means.size(0) 720 | k, c, rank = means.size()[-3:] 721 | pref = means.size()[:-1] 722 | 723 | FT = torch.cuda.FloatTensor if cuda else torch.FloatTensor 724 | 725 | rng = FT(tuple(rng)) 726 | # - the tuple() is there in case a torch.Size() object is passed (which causes torch to 727 | # interpret the argument as the size of the tensor rather than its content). 728 | 729 | bounds = util.unsqueezen(rng, len(pref) + 1).long() # index bound with unsqueezed dims for broadcasting 730 | 731 | if seed is not None: 732 | torch.manual_seed(seed) 733 | 734 | """ 735 | Generate neighbor tuples 736 | """ 737 | if fm is None: 738 | fm = floor_mask(rank, cuda) 739 | 740 | size = pref + (2**rank, rank) 741 | fm = util.unsqueezen(fm, len(size) - 2).expand(size) 742 | 743 | neighbor_ints = means.data.unsqueeze(-2).expand(*size).contiguous() 744 | 745 | neighbor_ints[fm] = neighbor_ints[fm].floor() 746 | neighbor_ints[~fm] = neighbor_ints[~fm].ceil() 747 | 748 | neighbor_ints = neighbor_ints.long() 749 | 750 | assert (neighbor_ints >= bounds).sum() == 0, 'One of the neighbor indices is outside the tensor bounds' 751 | 752 | """ 753 | Sample uniformly from all integer tuples 754 | """ 755 | gsize = pref + (gadditional, rank) 756 | global_ints = FT(*gsize) 757 | 758 | global_ints.uniform_() 759 | global_ints *= (1.0 - epsilon) 760 | 761 | rngxp = util.unsqueezen(rng, len(gsize) - 1).expand_as(global_ints) 762 | 763 | global_ints = torch.floor(global_ints * rngxp).long() 764 | 765 | assert (global_ints >= bounds).sum() == 0, 'One of the global sampled indices is outside the tensor bounds' 766 | 767 | """ 768 | Sample uniformly from a small range around the given index tuple 769 | """ 770 | lsize = pref + (ladditional, rank) 771 | local_ints = FT(*lsize) 772 | 773 | local_ints.uniform_() 774 | local_ints *= (1.0 - epsilon) 775 | 776 | rngxp = util.unsqueezen(rng, len(lsize) - 1).expand_as(local_ints) # bounds of the tensor 777 | 778 | rrng = FT(relative_range) # bounds of the range from which to sample 779 | rrng = util.unsqueezen(rrng, len(lsize) - 1).expand_as(local_ints) 780 | 781 | # print(means.size()) 782 | mns_expand = means.round().unsqueeze(-2).expand_as(local_ints) 783 | 784 | # upper and lower bounds 785 | lower = mns_expand - rrng * 0.5 786 | upper = mns_expand + rrng * 0.5 787 | 788 | # check for any ranges that are out of bounds 789 | idxs = lower < 0.0 790 | lower[idxs] = 0.0 791 | 792 | idxs = upper > rngxp 793 | lower[idxs] = rngxp[idxs] - rrng[idxs] 794 | 795 | cached = local_ints.clone() 796 | local_ints = (local_ints * rrng + lower).long() 797 | 798 | assert (local_ints >= bounds).sum() == 0, f'One of the local sampled indices is outside the tensor bounds (this may mean the epsilon is too small)' \ 799 | f'\n max sampled {(cached * rrng).max().item()}, rounded {(cached * rrng).max().long().item()} max lower limit {lower.max().item()}' \ 800 | f'\n sum {((cached * rrng).max() + lower.max()).item()}' \ 801 | f'\n rounds to {((cached * rrng).max() + lower.max()).long().item()}' 802 | #f'\n {means}\n {local_ints}\n {cached * rrng}' 803 | 804 | all = torch.cat([neighbor_ints, global_ints, local_ints] , dim=-2) 805 | 806 | fsize = pref[:-1] + (-1, rank) 807 | 808 | return all.view(*fsize) # combine all indices sampled within a chunk -------------------------------------------------------------------------------- /sparse/sort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | import util 6 | import numpy as np 7 | 8 | """ 9 | Modules to implement differentiable quicksort. 10 | 11 | ```Split``` implements the half-permutation. 12 | 13 | ```SortLayer``` chains these into quicksort. 14 | """ 15 | 16 | class Split(nn.Module): 17 | """ 18 | A split matrix moves the elements of the input to either the top or the bottom 19 | half of a subsection of the output, but keeps the ordering intact otherwise. 20 | 21 | For depth 0, each element is moved to the top or bottom half of the output. For 22 | depth 1 each element is moved to the top or bottom half of its current half of 23 | the matrix and so on. 24 | 25 | """ 26 | def __init__(self, size, depth, additional=1, sigma_scale=0.1, sigma_floor=0.0): 27 | super().__init__() 28 | 29 | template = torch.LongTensor(range(size)).unsqueeze(1).expand(size, 2) 30 | self.register_buffer('template', template) 31 | 32 | self.size = size 33 | self.depth = depth 34 | self.sigma_scale = sigma_scale 35 | self.sigma_floor = sigma_floor 36 | self.additional = additional 37 | 38 | def duplicates(self, tuples): 39 | """ 40 | Takes a list of tuples, and for each tuple that occurs mutiple times 41 | marks all but one of the occurences (in the mask that is returned). 42 | 43 | :param tuples: A size (batch, k, rank) tensor of integer tuples 44 | :return: A size (batch, k) mask indicating the duplicates 45 | """ 46 | b, k, r = tuples.size() 47 | 48 | # unique = ((tuples.float() + 1) ** primes).prod(dim=2) # unique identifier for each tuple 49 | unique = util.unique(tuples.view(b*k, r)).squeeze().view(b, k) 50 | 51 | sorted, sort_idx = torch.sort(unique, dim=1) 52 | _, unsort_idx = torch.sort(sort_idx, dim=1) 53 | 54 | mask = sorted[:, 1:] == sorted[:, :-1] 55 | # mask = mask.view(b, k - 1) 56 | 57 | zs = torch.zeros(b, 1, dtype=torch.uint8, device='cuda' if tuples.is_cuda else 'cpu') 58 | mask = torch.cat([zs, mask], dim=1) 59 | 60 | return torch.gather(mask, 1, unsort_idx) 61 | 62 | def generate_integer_tuples(self, offset, additional=16): 63 | 64 | b, s = offset.size() 65 | 66 | choices = offset.round().byte()[:, None, :] 67 | 68 | if additional > 0: 69 | sampled = util.sample_offsets(b, additional, s, self.depth, cuda=offset.is_cuda) 70 | # sampled = ~ choices 71 | 72 | choices = torch.cat([choices, sampled], dim=1).byte() 73 | 74 | return self.generate(choices, offset) 75 | 76 | def generate(self, choices, offset): 77 | 78 | b, n, s = choices.size() 79 | 80 | offset = offset[:, None, :].expand(b, n, s) 81 | 82 | probs = offset.clone() 83 | probs[~ choices] = 1.0 - probs[~ choices] 84 | # prob now contains the probability (under offset) of the choices made 85 | probs = probs.prod(dim=2, keepdim=True).expand(b, n, s).contiguous() 86 | 87 | # Generate indices from the chosen offset 88 | indices = util.split(choices, self.depth) 89 | 90 | if n > 1: 91 | dups = self.duplicates(indices) 92 | 93 | probs = probs.clone() 94 | probs[dups] = 0.0 95 | 96 | probs = probs / probs.sum(dim=1, keepdim=True) 97 | 98 | return indices, probs 99 | 100 | def forward(self, input, keys, offset, train=True, reverse=False, verbose=False): 101 | 102 | if train: 103 | indices, probs = self.generate_integer_tuples(offset, self.additional) 104 | else: 105 | indices, probs = self.generate_integer_tuples(offset, 0) 106 | 107 | if verbose: 108 | print(indices[0, 0]) 109 | 110 | indices = indices.detach() 111 | b, n, s = indices.size() 112 | 113 | template = self.template[None, None, :, :].expand(b, n, s, 2).contiguous() 114 | if not reverse: # normal half-permutation 115 | template[:, :, :, 0] = indices 116 | else: # reverse the permutation 117 | template[:, :, :, 1] = indices 118 | indices = template 119 | 120 | indices = indices.contiguous().view(b, -1, 2) 121 | probs = probs.contiguous().view(b, -1) 122 | 123 | output = util.batchmm(indices, probs, (s, s), input) 124 | 125 | keys_out = util.batchmm(indices, probs, (s, s), keys[:, :, None]).squeeze() 126 | 127 | return output, keys_out 128 | 129 | class SortLayer(nn.Module): 130 | """ 131 | 132 | """ 133 | def __init__(self, size, additional=0, sigma_scale=0.1, sigma_floor=0.0, certainty=10.0): 134 | super().__init__() 135 | 136 | mdepth = int(np.log2(size)) 137 | 138 | self.layers = nn.ModuleList() 139 | for d in range(mdepth): 140 | self.layers.append(Split(size, d, additional, sigma_scale, sigma_floor)) 141 | 142 | # self.certainty = nn.Parameter(torch.tensor([certainty])) 143 | self.register_buffer('certainty', torch.tensor([certainty])) 144 | 145 | # self.offset = nn.Sequential( 146 | # util.Lambda(lambda x : x[:, 0] - x[:, 1]), 147 | # util.Lambda(lambda x : x * self.certainty), 148 | # nn.Sigmoid() 149 | # ) 150 | 151 | def forward(self, x, keys, target=None, train=True, verbose=False): 152 | 153 | xs = [x] 154 | targets = [target] 155 | offsets = [] 156 | 157 | b, s, z = x.size() 158 | b, s = keys.size() 159 | 160 | t = target 161 | 162 | for d, split in enumerate(self.layers): 163 | 164 | buckets = keys[:, :, None].view(b, 2**d, -1) 165 | 166 | # compute pivots 167 | pivots = buckets.view(b*2**d, -1) 168 | pivots = median(pivots, keepdim=True) 169 | pivots = pivots.view(b, 2 ** d, -1).expand_as(buckets) 170 | 171 | pivots = pivots.contiguous().view(b, -1).expand_as(keys) 172 | 173 | # compute offsets by comparing values to pivots 174 | if train: 175 | offset = keys - pivots 176 | offset = F.sigmoid(offset * self.certainty) 177 | else: 178 | offset = (keys > pivots).float() 179 | 180 | # offset = offset.round() # DEBUG 181 | offsets.append(offset) 182 | 183 | 184 | x, keys = split(x, keys, offset, train=train, verbose=verbose) 185 | xs.append(x) 186 | 187 | if verbose: 188 | print('o', offset[0]) 189 | print('k', keys[0]) 190 | 191 | 192 | if target is not None: 193 | for split, offset in zip(self.layers[::-1], offsets[::-1]): 194 | t, _ = split(t, keys, offset, train=train, reverse=True) 195 | targets.insert(0, t) 196 | 197 | if target is None: 198 | return x, keys 199 | 200 | return xs, targets, keys 201 | 202 | def median(x, keepdim=False): 203 | b, s = x.size() 204 | 205 | y = x.sort(dim=1)[0][:, s//2-1:s//2+1].mean(dim=1, keepdim=keepdim) 206 | 207 | return y 208 | 209 | if __name__ == '__main__': 210 | 211 | x = torch.randn(3, 4) 212 | print(x) 213 | print(median(x)) -------------------------------------------------------------------------------- /sparse/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import FloatTensor, LongTensor 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | from sparse.util import prod 7 | import util, sys 8 | 9 | from util import d 10 | 11 | """ 12 | Utility functions for manipulation tensors 13 | """ 14 | 15 | def flatten_indices_mat(indices, in_shape, out_shape): 16 | """ 17 | Turns a n NxK matrix of N index-tuples for a tensor T of rank K into an Nx2 matrix M of index-tuples for a _matrix_ 18 | that is created by flattening the first 'in_shape' dimensions into the vertical dimension of M and the remaining 19 | dimensions in the the horizontal dimension of M. 20 | :param indices: Long tensor 21 | :param in_rank: 22 | :return: (1) A matrix of size N by 2, (2) the dimensions of M 23 | """ 24 | 25 | batchsize, n, rank = indices.size() 26 | 27 | inrank = len(in_shape) 28 | outrank = len(out_shape) 29 | 30 | result = torch.cuda.LongTensor(batchsize, n, 2) if indices.is_cuda else LongTensor(batchsize, n, 2) 31 | 32 | left = fi_matrix(indices[:, :, 0:outrank], out_shape) # i index of the weight matrix 33 | right = fi_matrix(indices[:, :, outrank:rank], in_shape) # j index 34 | 35 | result = torch.cat([left.unsqueeze(2), right.unsqueeze(2)], dim=2) 36 | 37 | return result, LongTensor((prod(out_shape), prod(in_shape))) 38 | 39 | def fi_matrix(indices, shape): 40 | batchsize, rows, rank = indices.size() 41 | 42 | prod = torch.LongTensor(rank).fill_(1) 43 | 44 | if indices.is_cuda: 45 | prod = prod.cuda() 46 | 47 | for i in range(rank): 48 | prod[i] = 1 49 | for j in range(i + 1, len(shape)): 50 | prod[i] *= shape[j] 51 | 52 | indices = indices * prod.unsqueeze(0).unsqueeze(0).expand_as(indices) 53 | 54 | return indices.sum(dim=2) 55 | 56 | def contract(indices, values, size, x, cuda=None): 57 | """ 58 | Performs a contraction (generalized matrix multiplication) of a sparse tensor with and input x. 59 | 60 | The contraction is defined so that every element of the output is the sum of every element of the input multiplied 61 | once by a unique element from the tensor (that is, like a fully connected neural network layer). See the paper for 62 | details. 63 | 64 | :param indices: (b, k, r)-tensor describing indices of b sparse tensors of rank r 65 | :param values: (b, k)-tes=nsor with the corresponding values 66 | :param size: 67 | :param x: 68 | :return: 69 | """ 70 | # translate tensor indices to matrix indices 71 | if cuda is None: 72 | cuda = indices.is_cuda 73 | 74 | b, k, r = indices.size() 75 | 76 | # size is equal to out_size + x.size() 77 | in_size = x.size()[1:] 78 | out_size = size[:-len(in_size)] 79 | 80 | assert len(out_size) + len(in_size) == r 81 | 82 | # Flatten into a matrix multiplication 83 | mindices, flat_size = flatten_indices_mat(indices, x.size()[1:], out_size) 84 | x_flat = x.view(b, -1, 1) 85 | 86 | # Prevent segfault 87 | assert mindices.min() >= 0, 'negative index in flattened indices: {} \n {} \n Original indices {} \n {}'.format(mindices.size(), mindices, indices.size(), indices) 88 | assert not util.contains_nan(values.data), 'NaN in values:\n {}'.format(values) 89 | 90 | y_flat = batchmm(mindices, values, flat_size, x_flat, cuda) 91 | 92 | return y_flat.view(b, *out_size) # reshape y into a tensor 93 | 94 | 95 | def sparsemm(use_cuda): 96 | """ 97 | :param use_cuda: 98 | :return: 99 | """ 100 | return SparseMMGPU.apply if use_cuda else SparseMMCPU.apply 101 | 102 | 103 | class SparseMMCPU(torch.autograd.Function): 104 | 105 | """ 106 | Sparse matrix multiplication with gradients over the value-vector 107 | 108 | Does not work with batch dim. 109 | """ 110 | 111 | @staticmethod 112 | def forward(ctx, indices, values, size, xmatrix): 113 | 114 | # print(type(size), size, list(size), intlist(size)) 115 | # print(indices.size(), values.size(), torch.Size(intlist(size))) 116 | 117 | matrix = torch.sparse.FloatTensor(indices, values, torch.Size(intlist(size))) 118 | 119 | ctx.indices, ctx.matrix, ctx.xmatrix = indices, matrix, xmatrix 120 | 121 | return torch.mm(matrix, xmatrix) 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | grad_output = grad_output.data 126 | 127 | # -- this will break recursive autograd, but it's the only way to get grad over sparse matrices 128 | 129 | i_ixs = ctx.indices[0,:] 130 | j_ixs = ctx.indices[1,:] 131 | output_select = grad_output[i_ixs, :] 132 | xmatrix_select = ctx.xmatrix[j_ixs, :] 133 | 134 | grad_values = (output_select * xmatrix_select).sum(dim=1) 135 | 136 | grad_xmatrix = torch.mm(ctx.matrix.t(), grad_output) 137 | return None, Variable(grad_values), None, Variable(grad_xmatrix) 138 | 139 | class SparseMMGPU(torch.autograd.Function): 140 | 141 | """ 142 | Sparse matrix multiplication with gradients over the value-vector 143 | 144 | Does not work with batch dim. 145 | """ 146 | 147 | @staticmethod 148 | def forward(ctx, indices, values, size, xmatrix): 149 | 150 | # print(type(size), size, list(size), intlist(size)) 151 | 152 | matrix = torch.cuda.sparse.FloatTensor(indices, values, torch.Size(intlist(size))) 153 | 154 | ctx.indices, ctx.matrix, ctx.xmatrix = indices, matrix, xmatrix 155 | 156 | return torch.mm(matrix, xmatrix) 157 | 158 | @staticmethod 159 | def backward(ctx, grad_output): 160 | grad_output = grad_output.data 161 | 162 | # -- this will break recursive autograd, but it's the only way to get grad over sparse matrices 163 | 164 | i_ixs = ctx.indices[0,:] 165 | j_ixs = ctx.indices[1,:] 166 | output_select = grad_output[i_ixs] 167 | xmatrix_select = ctx.xmatrix[j_ixs] 168 | 169 | grad_values = (output_select * xmatrix_select).sum(dim=1) 170 | 171 | grad_xmatrix = torch.mm(ctx.matrix.t(), grad_output) 172 | return None, Variable(grad_values), None, Variable(grad_xmatrix) 173 | 174 | def batchmm(indices, values, size, xmatrix, cuda=None): 175 | """ 176 | Multiply a batch of sparse matrices (indices, values, size) with a batch of dense matrices (xmatrix) 177 | 178 | :param indices: 179 | :param values: 180 | :param size: 181 | :param xmatrix: 182 | :return: 183 | """ 184 | 185 | if cuda is None: 186 | cuda = indices.is_cuda 187 | 188 | b, n, r = indices.size() 189 | dv = 'cuda' if cuda else 'cpu' 190 | 191 | height, width = size 192 | 193 | size = torch.tensor(size, device=dv, dtype=torch.long) 194 | 195 | bmult = size[None, None, :].expand(b, n, 2) 196 | m = torch.arange(b, device=dv, dtype=torch.long)[:, None, None].expand(b, n, 2) 197 | 198 | bindices = (m * bmult).view(b*n, r) + indices.view(b*n, r) 199 | 200 | bfsize = Variable(size * b) 201 | bvalues = values.contiguous().view(-1) 202 | 203 | b, w, z = xmatrix.size() 204 | bxmatrix = xmatrix.view(-1, z) 205 | 206 | sm = sparsemm(cuda) 207 | 208 | result = sm(bindices.t(), bvalues, bfsize, bxmatrix) 209 | 210 | return result.view(b, height, -1) 211 | 212 | def intlist(tensor): 213 | """ 214 | A slow and stupid way to turn a tensor into an iterable over ints 215 | :param tensor: 216 | :return: 217 | """ 218 | if type(tensor) is list: 219 | return tensor 220 | 221 | tensor = tensor.squeeze() 222 | 223 | assert len(tensor.size()) == 1 224 | 225 | s = tensor.size()[0] 226 | 227 | l = [None] * s 228 | for i in range(s): 229 | l[i] = int(tensor[i]) 230 | 231 | return l 232 | 233 | def accuracy(output, labels): 234 | preds = output.max(1)[1].type_as(labels) 235 | correct = preds.eq(labels).double() 236 | correct = correct.sum() 237 | return correct / len(labels) 238 | 239 | def simple_normalize(indices, values, size, row=True, method='softplus', cuda=torch.cuda.is_available()): 240 | """ 241 | Simple softmax-style normalization with 242 | 243 | :param indices: 244 | :param values: 245 | :param size: 246 | :param row: 247 | :return: 248 | """ 249 | epsilon = 1e-7 250 | 251 | if method == 'softplus': 252 | values = F.softplus(values) 253 | elif method == 'abs': 254 | values = values.abs() 255 | elif method == 'relu': 256 | values = F.relu(values) 257 | else: 258 | raise Exception(f'Method {method} not recognized') 259 | 260 | sums = sum(indices, values, size, row=row) 261 | 262 | return (values/(sums + epsilon)) 263 | 264 | # -- stable(ish) softmax 265 | def logsoftmax(indices, values, size, its=10, p=2, method='iteration', row=True, cuda=torch.cuda.is_available()): 266 | """ 267 | Row or column log-softmaxes a sparse matrix (using logsumexp trick) 268 | :param indices: 269 | :param values: 270 | :param size: 271 | :param row: 272 | :return: 273 | """ 274 | epsilon = 1e-7 275 | 276 | if method == 'naive': 277 | values = values.exp() 278 | sums = sum(indices, values, size, row=row) 279 | 280 | return (values/(sums + epsilon)).log() 281 | 282 | if method == 'pnorm': 283 | maxes = rowpnorm(indices, values, size, p=p) 284 | elif method == 'iteration': 285 | maxes = itmax(indices, values, size,its=its, p=p) 286 | else: 287 | raise Exception('Max method {} not recognized'.format(method)) 288 | 289 | mvalues = torch.exp(values - maxes) 290 | 291 | sums = sum(indices, mvalues, size, row=row) # row/column sums] 292 | 293 | return mvalues.log() - sums.log() 294 | 295 | def rowpnorm(indices, values, size, p, row=True): 296 | """ 297 | Row or column p-norms a sparse matrix 298 | :param indices: 299 | :param values: 300 | :param size: 301 | :param row: 302 | :return: 303 | """ 304 | pvalues = torch.pow(values, p) 305 | sums = sum(indices, pvalues, size, row=row) 306 | 307 | return torch.pow(sums, 1.0/p) 308 | 309 | def itmax(indices, values, size, its=10, p=2, row=True): 310 | """ 311 | Iterative computation of row max 312 | 313 | :param indices: 314 | :param values: 315 | :param size: 316 | :param p: 317 | :param row: 318 | :param cuda: 319 | :return: 320 | """ 321 | 322 | epsilon = 0.00000001 323 | 324 | # create an initial vector with all values made positive 325 | # weights = values - values.min() 326 | weights = F.softplus(values) 327 | weights = weights / (sum(indices, weights, size) + epsilon) 328 | 329 | # iterate, weights converges to a one-hot vector 330 | for i in range(its): 331 | weights = weights.pow(p) 332 | 333 | sums = sum(indices, weights, size, row=row) # row/column sums 334 | weights = weights/sums 335 | 336 | return sum(indices, values * weights, size, row=row) 337 | 338 | def sum(indices, values, size, row=True): 339 | """ 340 | Sum the rows or columns of a sparse matrix, and redistribute the 341 | results back to the non-sparse row/column entries 342 | 343 | Arguments are interpreted as defining sparse matrix. Any extra dimensions 344 | as treated as batch. 345 | 346 | :return: 347 | """ 348 | 349 | assert len(indices.size()) == len(values.size()) + 1 350 | 351 | if len(indices.size()) == 2: 352 | # add batch dim 353 | indices = indices[None, :, :] 354 | values = values[None, :] 355 | bdims = None 356 | else: 357 | # fold up batch dim 358 | bdims = indices.size()[:-2] 359 | k, r = indices.size()[-2:] 360 | assert bdims == values.size()[:-1] 361 | assert values.size()[-1] == k 362 | 363 | indices = indices.view(-1, k, r) 364 | values = values.view(-1, k) 365 | 366 | b, k, r = indices.size() 367 | 368 | if row: 369 | ones = torch.ones((size[1], 1), device=d(indices)) 370 | else: 371 | ones = torch.ones((size[0], 1), device=d(indices)) 372 | # transpose the matrix 373 | indices = torch.cat([indices[:, :, 1:2], indices[:, :, 0:1]], dim=1) 374 | 375 | s, _ = ones.size() 376 | ones = ones[None, :, :].expand(b, s, 1).contiguous() 377 | 378 | sums = batchmm(indices, values, size, ones) # row/column sums 379 | bindex = torch.arange(b, device=d(indices))[:, None].expand(b, indices.size(1)) 380 | sums = sums[bindex, indices[:, :, 0], 0] 381 | 382 | if bdims is None: 383 | return sums.view(k) 384 | 385 | return sums.view(*bdims + (k,)) 386 | -------------------------------------------------------------------------------- /sparse/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import \ 2 | makedirs, prod, contains_nan, contains_inf, bmult, duplicates, nduplicates, \ 3 | sparsemult, \ 4 | xent, unique, \ 5 | Bias, ChunkSampler, Flatten, Reshape, Debug, Lambda, \ 6 | od, prod, inv, logit, \ 7 | wrapmod, interpolation_grid, unsqueezen, \ 8 | sample_offsets, split, \ 9 | CConv2d, \ 10 | tic, toc, d, here, flip, coordinates, schedule 11 | 12 | from .plot import plot, plot1d, basic, clean -------------------------------------------------------------------------------- /sparse/util/plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | import matplotlib.pyplot as plt 6 | from matplotlib.patches import Circle, Wedge, Polygon, Ellipse, Rectangle 7 | from matplotlib.collections import PatchCollection 8 | from matplotlib.axes import Axes 9 | 10 | import numpy as np 11 | 12 | from torch import nn 13 | 14 | import sys 15 | 16 | 17 | def clean(axes=None): 18 | 19 | if axes is None: 20 | axes = plt.gca() 21 | 22 | axes.spines["right"].set_visible(False) 23 | axes.spines["top"].set_visible(False) 24 | axes.spines["bottom"].set_visible(False) 25 | axes.spines["left"].set_visible(False) 26 | 27 | # axes.get_xaxis().set_tick_params(which='both', top='off', bottom='off', labelbottom='off') 28 | # axes.get_yaxis().set_tick_params(which='both', left='off', right='off') 29 | 30 | 31 | def basic(axes=None): 32 | 33 | if axes is None: 34 | axes = plt.gca() 35 | 36 | axes.spines["right"].set_visible(False) 37 | axes.spines["top"].set_visible(False) 38 | axes.spines["bottom"].set_visible(True) 39 | axes.spines["left"].set_visible(True) 40 | 41 | axes.get_xaxis().set_tick_params(which='both', top='off', bottom='on', labelbottom='on') 42 | axes.get_yaxis().set_tick_params(which='both', left='on', right='off') 43 | 44 | def plot(means, sigmas, values, shape=None, axes=None, flip_y=None, alpha_global=1.0, tanh=True): 45 | """ 46 | :param means: 47 | :param sigmas: 48 | :param values: 49 | :param shape: 50 | :param axes: 51 | :param flip_y: If not None, interpreted as the max y value. y values in the scatterplot are 52 | flipped so that the max is equal to zero and vice versa. 53 | :return: 54 | """ 55 | 56 | b, n, d = means.size() 57 | 58 | means = means.data[0, :, :].cpu().numpy() 59 | sigmas = sigmas.data[0, :].cpu().numpy() 60 | 61 | tcolor = not isinstance(values, float) 62 | 63 | if tcolor: 64 | values = values.tanh().data[0, :].cpu().numpy() if tanh else values.data[0, :].cpu().numpy() 65 | 66 | if flip_y is not None: 67 | means[:, 0] = flip_y - means[:, 0] 68 | 69 | norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0) 70 | cmap = mpl.cm.RdYlBu 71 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 72 | 73 | if axes is None: 74 | axes = plt.gca() 75 | 76 | colors = [] 77 | for i in range(n): 78 | color = map.to_rgba(values[i] if tcolor else values) 79 | 80 | alpha = min(0.8, max(0.05, ((sigmas[i, 0] * sigmas[i, 0])+1.0)**-2)) * alpha_global 81 | axes.add_patch(Ellipse((means[i, 1], means[i, 0]), width=sigmas[i,1], height=sigmas[i,0], color=color, alpha=alpha, linewidth=0)) 82 | colors.append(color) 83 | 84 | axes.scatter(means[:, 1], means[:, 0], s=5, c=colors, zorder=100, linewidth=0, edgecolor='k', alpha=alpha_global) 85 | 86 | if shape is not None: 87 | 88 | m = max(shape) 89 | step = 1 if m < 100 else m//25 90 | 91 | # gray points for the integer index tuples 92 | x, y = np.mgrid[0:shape[0]:step, 0:shape[1]:step] 93 | axes.scatter(x.ravel(), y.ravel(), c='k', s=5, marker='D', zorder=-100, linewidth=0, alpha=0.1* alpha_global) 94 | 95 | axes.spines['right'].set_visible(False) 96 | axes.spines['top'].set_visible(False) 97 | axes.spines['bottom'].set_visible(False) 98 | axes.spines['left'].set_visible(False) 99 | # 100 | # 101 | # def plot1d(means, sigmas, values, shape=None, axes=None): 102 | # 103 | # h = 0.1 104 | # 105 | # n, d = means.size() 106 | # 107 | # means = means.cpu().numpy() 108 | # sigmas = sigmas.cpu().numpy() 109 | # values = nn.functional.tanh(values).data.cpu().numpy() 110 | # 111 | # norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0) 112 | # cmap = mpl.cm.RdYlBu 113 | # map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 114 | # 115 | # if axes is None: 116 | # axes = plt.gca() 117 | # 118 | # colors = [] 119 | # for i in range(n): 120 | # color = map.to_rgba(values[i]) 121 | # alpha = 0.7 # max(0.05, (sigmas[i, 0]+1.0)**-1) 122 | # axes.add_patch(Rectangle(xy=(means[i, 1] - sigmas[i, 0]*0.5, means[i, 0] - h*0.5), width=sigmas[i,0] , height=h, color=color, alpha=alpha, linewidth=0)) 123 | # colors.append(color) 124 | # 125 | # axes.scatter(means[:, 1], means[:, 0], c=colors, zorder=100, linewidth=0, s=5) 126 | # 127 | # if shape is not None: 128 | # 129 | # m = max(shape) 130 | # step = 1 if m < 100 else m//25 131 | # 132 | # # gray points for the integer index tuples 133 | # x, y = np.mgrid[0:shape[0]:step, 0:shape[1]:step] 134 | # axes.scatter(x.ravel(), y.ravel(), c='k', s=5, marker='D', zorder=-100, linewidth=0, alpha=0.1) 135 | # 136 | # axes.spines['right'].set_visible(False) 137 | # axes.spines['top'].set_visible(False) 138 | # axes.spines['bottom'].set_visible(False) 139 | # axes.spines['left'].set_visible(False) 140 | # 141 | 142 | def plot1d(means, sigmas, values, shape=None, axes=None): 143 | 144 | h = 0.1 145 | 146 | n, d = means.size() 147 | 148 | means = means.cpu().numpy() 149 | sigmas = sigmas.cpu().numpy() 150 | values = torch.tanh(values).data.cpu().numpy() 151 | 152 | norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0) 153 | cmap = mpl.cm.RdYlBu 154 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 155 | 156 | if axes is None: 157 | axes = plt.gca() 158 | 159 | colors = [] 160 | for i in range(n): 161 | 162 | color = map.to_rgba(values[i]) 163 | alpha = 0.7 # max(0.05, (sigmas[i, 0]+1.0)**-1) 164 | axes.add_patch(Rectangle(xy=(means[i, 1] - sigmas[i, 0]*0.5, means[i, 0] - h*0.5), width=sigmas[i,0] , height=h, color=color, alpha=alpha, linewidth=0)) 165 | colors.append(color) 166 | 167 | axes.scatter(means[:, 1], means[:, 0], c=colors, zorder=100, linewidth=0, s=3) 168 | 169 | if shape is not None: 170 | 171 | m = max(shape) 172 | step = 1 if m < 100 else m//25 173 | 174 | # gray points for the integer index tuples 175 | x, y = np.mgrid[0:shape[0]:step, 0:shape[1]:step] 176 | axes.scatter(x.ravel(), y.ravel(), c='k', s=5, marker='D', zorder=-100, linewidth=0, alpha=0.1) 177 | 178 | axes.spines['right'].set_visible(False) 179 | axes.spines['top'].set_visible(False) 180 | axes.spines['bottom'].set_visible(False) 181 | axes.spines['left'].set_visible(False) 182 | 183 | def plot1dvert(means, sigmas, values, shape=None, axes=None): 184 | 185 | h = 0.1 186 | 187 | n, d = means.size() 188 | 189 | means = means.cpu().numpy() 190 | sigmas = sigmas.cpu().numpy() 191 | values = nn.functional.tanh(values).data.cpu().numpy() 192 | 193 | norm = mpl.colors.Normalize(vmin=-1.0, vmax=1.0) 194 | cmap = mpl.cm.RdYlBu 195 | map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 196 | 197 | if axes is None: 198 | axes = plt.gca() 199 | 200 | colors = [] 201 | for i in range(n): 202 | color = map.to_rgba(values[i]) 203 | alpha = 0.7 # max(0.05, (sigmas[i, 0]+1.0)**-1) 204 | axes.add_patch(Rectangle(xy=(means[i, 1] - h*0.5, means[i, 0] - sigmas[i, 0]*0.5), width=h , height=sigmas[i,0], color=color, alpha=alpha, linewidth=0)) 205 | colors.append(color) 206 | 207 | axes.scatter(means[:, 1], means[:, 0], c=colors, zorder=100, linewidth=0, s=3) 208 | 209 | if shape is not None: 210 | 211 | m = max(shape) 212 | step = 1 if m < 100 else m//25 213 | 214 | # gray points for the integer index tuples 215 | x, y = np.mgrid[0:shape[0]:step, 0:shape[1]:step] 216 | axes.scatter(x.ravel(), y.ravel(), c='k', s=5, marker='D', zorder=-100, linewidth=0, alpha=0.1) 217 | 218 | axes.spines['right'].set_visible(False) 219 | axes.spines['top'].set_visible(False) 220 | axes.spines['bottom'].set_visible(False) 221 | axes.spines['left'].set_visible(False) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sparse 2 | 3 | print(dir(sparse)) -------------------------------------------------------------------------------- /tests/_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../sparse'))) 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../sparse/util'))) 7 | 8 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import _context 2 | 3 | import unittest 4 | import torch 5 | #import sparse.layers.densities 6 | 7 | import layers 8 | 9 | class TestLayers(unittest.TestCase): 10 | 11 | def test_densities(self): 12 | 13 | means = torch.tensor([[[0.0]]]) 14 | sigmas = torch.tensor([[[1.0]]]) 15 | points = torch.tensor([[[0.0]]]) 16 | 17 | density = layers.densities(points, means, sigmas) 18 | self.assertAlmostEqual(1.0, density, places=7) 19 | 20 | means = torch.tensor([[[0.0, 0.0], [1.0, 1.0]],[[2.0, 2.0], [4.0, 4.0]]]) 21 | sigmas = torch.tensor([[[1.0, 1.0], [1.0, 1.0]],[[1.0, 1.0], [1.0, 1.0]]]) 22 | points = torch.tensor([[[0.0, 0.0], [1.0, 1.0]],[[2.0, 2.0], [4.0, 4.0]]]) 23 | 24 | density = layers.densities(points, means, sigmas) 25 | 26 | self.assertEquals((2, 2, 2), density.size()) 27 | self.assertAlmostEqual(1.0, density[0, 0, 0], places=7) 28 | 29 | means = torch.randn(3, 5, 16) 30 | sigmas = torch.randn(3, 5, 16).abs() 31 | points = torch.randn(3, 7, 16) 32 | 33 | density = layers.densities(points, means, sigmas) 34 | 35 | self.assertEquals((3, 7, 5), density.size()) 36 | 37 | def test_ngenerate(self): 38 | 39 | means = torch.randn(6, 2, 3) 40 | sigmas = torch.randn(6, 2) 41 | values = torch.randn(6, 2) 42 | 43 | b = 5 44 | size = (64, 128, 32) 45 | 46 | ms = means.size() 47 | 48 | xp = (5, ) + means.size() 49 | # Expand parameters along batch dimension 50 | means = means.expand(*xp) 51 | sigmas = sigmas.expand(*xp[:-1]) 52 | values = values.expand(*xp[:-1]) 53 | 54 | means, sigmas = layers.transform_means(means, size), layers.transform_sigmas(sigmas, size) 55 | 56 | indices_old = layers.generate_integer_tuples(means, 57 | 2, 2, 58 | relative_range=(4, 4, 4), 59 | rng=size, 60 | cuda=means.is_cuda) 61 | 62 | 63 | indices_new = layers.ngenerate(means, 64 | 2, 2, 65 | relative_range=(4, 4, 4), 66 | rng=size, 67 | cuda=means.is_cuda) 68 | 69 | assert indices_old.size() == indices_new.size() 70 | 71 | for i in range(indices_new.view(-1, 3).size(0)): 72 | print(indices_new.view(-1, 3)[i]) 73 | 74 | def test_conv(self): 75 | 76 | x = torch.ones(1, 4, 3, 3) 77 | 78 | c = layers.Convolution((4, 3, 3), 4, k=2, rprop=.5, gadditional=2, radditional=2) 79 | 80 | print(c(x)) 81 | 82 | if __name__ == '__main__': 83 | unittest.main() -------------------------------------------------------------------------------- /tests/test_tensors.py: -------------------------------------------------------------------------------- 1 | import _context 2 | 3 | import unittest 4 | import torch 5 | from torch.autograd import Variable 6 | #import sparse.layers.densities 7 | 8 | import tensors, time 9 | 10 | 11 | def sample(nindices=2*256+2*8, size=(256, 256), var=1.0): 12 | assert len(size) == 2 13 | 14 | indices = (torch.rand(nindices, 2) * torch.tensor(size)[None, :].float()).long() 15 | values = torch.randn(nindices) * var 16 | 17 | return indices, values 18 | 19 | class TestTensors(unittest.TestCase): 20 | 21 | def test_sum(self): 22 | size = (5, 5) 23 | 24 | # create a batch of sparse matrices 25 | samples = [sample(nindices=3, size=size) for _ in range(3)] 26 | indices, values = [s[0][None, :, :] for s in samples], [s[1][None, :] for s in samples] 27 | 28 | indices, values = torch.cat(indices, dim=0), torch.cat(values, dim=0) 29 | 30 | print(indices) 31 | print(values) 32 | print('res', tensors.sum(indices, values, size)) 33 | 34 | def test_log_softmax(self): 35 | 36 | size = (5, 5) 37 | 38 | # create a batch of sparse matrices 39 | samples = [sample(nindices=3, size=size) for _ in range(3)] 40 | indices, values = [s[0][None, :, :] for s in samples], [s[1][None, :] for s in samples] 41 | 42 | indices, values = torch.cat(indices, dim=0), torch.cat(values, dim=0) 43 | 44 | print('res', tensors.logsoftmax(indices, values, size, method='naive').exp()) 45 | print('res', tensors.logsoftmax(indices, values, size, method='iteration').exp()) 46 | 47 | def test(self): 48 | 49 | a = Variable(torch.randn(1), requires_grad=True) 50 | x = Variable(torch.randn(15000, 15000)) 51 | 52 | x = x * a 53 | x = x / 2 54 | 55 | loss = x.sum() 56 | 57 | loss.backward() 58 | time.sleep(600) 59 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import _context 2 | 3 | import unittest 4 | import torch, sys 5 | from torch.autograd import Variable 6 | from torch import nn 7 | import torch.nn.functional as F 8 | #import sparse.layers.densities 9 | 10 | import util, sys 11 | 12 | class TestLayers(unittest.TestCase): 13 | 14 | def test_unique(self): 15 | r = util.unique(torch.tensor( [[1,2,3,4],[4,3,2,1],[1,2,3,4]] )) 16 | 17 | self.assertEqual((3, 1), r.size()) 18 | self.assertEqual(r[0], r[2]) 19 | self.assertNotEqual(r[0], r[1]) 20 | self.assertNotEqual(r[1], r[2]) 21 | 22 | r = util.nunique(torch.tensor( [[[1,2,3,4],[4,3,2,1],[1,2,3,4]]] )) 23 | 24 | self.assertEqual((1, 3), r.size()) 25 | self.assertEqual(r[0, 0], r[0, 2]) 26 | self.assertNotEqual(r[0, 0], r[0, 1]) 27 | self.assertNotEqual(r[0, 1], r[0, 2]) 28 | 29 | def test_duplicates(self): 30 | 31 | tuples = torch.tensor([ 32 | [[5, 5], [1, 1], [2, 3], [1, 1]], 33 | [[3, 2], [3, 2], [5, 5], [5, 5]] 34 | ]) 35 | dedup = torch.tensor([ 36 | [[5, 5], [1, 1], [2, 3], [0, 0]], 37 | [[3, 2], [0, 0], [5, 5], [0, 0]] 38 | ]) 39 | 40 | 41 | dup = util.duplicates(tuples) 42 | tuples[dup, :] = tuples[dup, :] * 0 43 | self.assertEqual( (tuples != dedup).sum(), 0) 44 | 45 | tuples = torch.tensor([[ 46 | [3, 1], 47 | [3, 2], 48 | [3, 1], 49 | [0, 3], 50 | [0, 2], 51 | [3, 0], 52 | [0, 3], 53 | [0, 0]]]) 54 | 55 | self.assertEqual([0, 0, 1, 0, 0, 0, 1, 0], list(util.duplicates(tuples).view(-1))) 56 | 57 | def test_nduplicates(self): 58 | 59 | # some tuples 60 | tuples = torch.tensor([ 61 | [[5, 5], [1, 1], [2, 3], [1, 1]], 62 | [[3, 2], [3, 2], [5, 5], [5, 5]] 63 | ]) 64 | 65 | # what they should look like after masking out the duplicates 66 | dedup = torch.tensor([ 67 | [[5, 5], [1, 1], [2, 3], [0, 0]], 68 | [[3, 2], [0, 0], [5, 5], [0, 0]] 69 | ]) 70 | 71 | # add a load of dimensions 72 | tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() 73 | dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() 74 | 75 | # find the duplicates 76 | dup = util.nduplicates(tuples) 77 | 78 | # mask them out 79 | tuples[dup, :] = tuples[dup, :] * 0 80 | self.assertEqual((tuples != dedup).sum(), 0) # assert equal to expected 81 | 82 | # second test: explicitly test the bitmask returned by nduplicates 83 | tuples = torch.tensor([[ 84 | [3, 1], 85 | [3, 2], 86 | [3, 1], 87 | [0, 3], 88 | [0, 2], 89 | [3, 0], 90 | [0, 3], 91 | [0, 0]]]) 92 | 93 | tuples = tuples[None, None, None, :, :, :].expand(8, 1, 7, 1, 8, 2).contiguous() 94 | 95 | self.assertEqual([0, 0, 1, 0, 0, 0, 1, 0], list(util.nduplicates(tuples)[0, 0, 0, :, :].view(-1))) 96 | 97 | # third test: single element tuples 98 | tuples = torch.tensor([ 99 | [[5], [1], [2], [1]], 100 | [[3], [3], [5], [5]] 101 | ]) 102 | dedup = torch.tensor([ 103 | [[5], [1], [2], [0]], 104 | [[3], [0], [5], [0]] 105 | ]) 106 | 107 | tuples = tuples[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() 108 | dedup = dedup[None, None, None, :, :, :].expand(3, 5, 7, 2, 4, 2).contiguous() 109 | 110 | dup = util.nduplicates(tuples) 111 | 112 | tuples[dup, :] = tuples[dup, :] * 0 113 | self.assertEqual((tuples != dedup).sum(), 0) 114 | 115 | def test_nduplicates_recursion(self): 116 | """ 117 | Reproducing observed recursion error 118 | :return: 119 | """ 120 | 121 | # tensor of 6 1-tuples 122 | tuples = torch.tensor( 123 | [[[[74], 124 | [75], 125 | [175], 126 | [246], 127 | [72], 128 | [72]]]]) 129 | 130 | dedup = torch.tensor( 131 | [[[[74], 132 | [75], 133 | [175], 134 | [246], 135 | [72], 136 | [0]]]]) 137 | 138 | dup = util.nduplicates(tuples) 139 | 140 | tuples[dup, :] = tuples[dup, :] * 0 141 | self.assertEqual((tuples != dedup).sum(), 0) 142 | 143 | def test_unique_recursion(self): 144 | """ 145 | Reproducing observed recursion error 146 | :return: 147 | """ 148 | 149 | # tensor of 6 1-tuples 150 | tuples = torch.tensor( 151 | [[74], 152 | [75], 153 | [175], 154 | [246], 155 | [72], 156 | [72]]) 157 | dup = util.unique(tuples) 158 | 159 | def test_wrapmod(self): 160 | 161 | self.assertAlmostEqual(util.wrapmod(torch.tensor([9.1]), 9).item(), 0.1, places=5) 162 | 163 | self.assertAlmostEqual(util.wrapmod(torch.tensor([-9.1]), 9).item(), 8.9, places=5) 164 | 165 | self.assertAlmostEqual(util.wrapmod(torch.tensor([-0.1]), 9).item(), 8.9, places=5) 166 | 167 | self.assertAlmostEqual(util.wrapmod(torch.tensor([10.0, -9.1]), 9)[1].item(), 8.9, places=5) 168 | 169 | def test_interpolation_grid(self): 170 | 171 | g = util.interpolation_grid() 172 | self.assertEqual( (torch.abs(g.sum(dim=2) - 1.0) > 0.0001).sum(), 0) 173 | 174 | g = util.interpolation_grid((3, 3)) 175 | self.assertAlmostEqual(g[0, 0, 0], 1.0, 5) 176 | self.assertAlmostEqual(g[0, 2, 1], 1.0, 5) 177 | self.assertAlmostEqual(g[2, 2, 2], 1.0, 5) 178 | self.assertAlmostEqual(g[2, 0, 3], 1.0, 5) 179 | 180 | def test_coordconv(self): 181 | 182 | """ 183 | Performs the regression experiment from https://arxiv.org/abs/1807.03247 184 | 185 | Tests whether the coord_conv layer behaves as expected. Baseline false should not converge to zero, baseline true 186 | should. 187 | 188 | :return: 189 | """ 190 | 191 | RES = 64 192 | B, H, W = 16, RES, RES 193 | BATCHES = 16000 194 | 195 | for baseline in (False, True): 196 | print('baseline', baseline) 197 | 198 | C = nn.Conv2d if baseline else util.CConv2d 199 | 200 | model = nn.Sequential( 201 | C(1, 8, stride=1, kernel_size=1, padding=0), 202 | #nn.ReLU(), 203 | nn.Conv2d(8, 8, stride=1, kernel_size=1, padding=0), 204 | #nn.ReLU(), 205 | nn.Conv2d(8, 8, stride=1, kernel_size=1, padding=0), 206 | #nn.ReLU(), 207 | nn.Conv2d(8, 8, stride=1, kernel_size=3, padding=1), 208 | #nn.ReLU(), 209 | nn.Conv2d(8, 2, stride=1, kernel_size=3, padding=1), 210 | nn.MaxPool2d(kernel_size=(RES, RES)), 211 | # util.Lambda(lambda x : x.squeeze()) 212 | ) 213 | 214 | opt = torch.optim.Adam(lr=0.005, params=model.parameters()) 215 | 216 | for it in range(BATCHES): 217 | 218 | # generate data 219 | 220 | target = (torch.rand(size=(B, 2)) * RES).to(torch.long) 221 | 222 | hf = RES // 2 223 | 224 | test_indices = (target[:, 0] < hf) & (target[:, 1] < hf) 225 | 226 | target_train = target[~ test_indices, :] 227 | target_test = target[ test_indices, :] 228 | 229 | input_train = torch.zeros(target_train.size(0), H, W) 230 | input_test = torch.zeros( target_test.size(0), H, W) 231 | 232 | input_train[torch.arange(input_train.size(0), dtype=torch.long), target_train[:, 0], target_train[:, 1]] = 1 233 | input_test [torch.arange(input_test.size(0), dtype=torch.long), target_test [:, 0], target_test [:, 1]] = 1 234 | 235 | input_train = input_train[:, None, :, :] 236 | input_test = input_test[:, None, :, :] 237 | 238 | input_train, input_test, target_train, target_test = Variable(input_train), Variable(input_test), Variable(target_train), Variable(target_test) 239 | 240 | if input_train.size(0) > 0: 241 | opt.zero_grad() 242 | # list(model.modules())[1].master.weight.retain_grad() 243 | 244 | out = model(input_train) 245 | loss = F.mse_loss(out, target_train.to(torch.float)) 246 | 247 | loss.backward() 248 | opt.step() 249 | else: 250 | print('size zero batch') 251 | 252 | if it % 2000 == 0 and input_test.size(0) > 0: 253 | with torch.no_grad(): 254 | out = model(input_test) 255 | tloss = F.mse_loss(out, target_test.to(torch.float)) 256 | 257 | print(f'train loss at {it}: { loss.item():.04}') 258 | print(f' test loss at {it}: {tloss.item():.04}') 259 | 260 | # print(list(model.modules())[1].master.weight.grad) 261 | 262 | def test_flip(self): 263 | x = torch.rand(6, 3, 9, 3, 2) 264 | 265 | x = util.flip(x) 266 | f = x.view(-1, 2) 267 | self.assertEqual( (f[:, 0] < f[:, 1]).sum(), 0) 268 | self.assertEqual(len(x.size()), 5) 269 | 270 | x = torch.randint(10, size=(3, 128, 2)) 271 | x = util.flip(x) 272 | f = x.view(-1, 2) 273 | self.assertEqual( (f[:, 0] < f[:, 1]).sum(), 0) 274 | self.assertEqual(len(x.size()), 3) 275 | 276 | def test_schedule(self): 277 | 278 | sched = {10: 0.1, 90:0.99} 279 | 280 | for e in range(100): 281 | print(util.schedule(e, sched)) 282 | 283 | 284 | if __name__ == '__main__': 285 | unittest.main() -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # 3 | # def test_fi(): 4 | # input = torch.LongTensor([[0, 0], [0, 1], [1, 0], [1, 1]]) 5 | # expected = torch.LongTensor([0, 1, 2, 3]) 6 | # 7 | # actual = gaussian.fi(input, (2,2)) 8 | # 9 | # print(actual) 10 | # 11 | # 12 | # def test_fi_mat(): 13 | # input = torch.LongTensor([[[0, 0], [0, 1], [1, 0], [1, 1]]]) 14 | # expected = torch.LongTensor([0, 1, 2, 3]) 15 | # 16 | # actual = gaussian.fi_matrix(input, torch.LongTensor((2, 2))) 17 | # 18 | # print(actual) 19 | # 20 | # def test_sort(): 21 | # indices = torch.LongTensor([[[6, 3], [1, 2]], [[5, 8], [1, 3]]]) 22 | # vals = torch.FloatTensor([[0.1, 0.2], [0.3, 0.4]]) 23 | # 24 | # hyper.sort(indices, vals) 25 | # 26 | # print(indices) 27 | # print(vals) 28 | # 29 | # 30 | # 31 | # if __name__ == '__main__': 32 | # # unittest.main() 33 | # 34 | # test_fi() 35 | # test_fi_mat() --------------------------------------------------------------------------------