├── LICENSE ├── README.md ├── RKNet.py ├── modules ├── ClippedModule.py ├── ConnectingLayer.py ├── DoubleLayer.py ├── DoubleSymLayer.py ├── PreactDoubleLayer.py ├── TvNorm.py ├── rk1.py ├── rk4.py ├── runModuleTests.py ├── testConnectingLayer.py ├── testDoubleLayer.py ├── testDoubleSymLayer.py ├── testPreactDoubleLayer.py ├── testRK1Block.py └── testRK4Block.py ├── naming.md ├── paper.bib ├── paper.md ├── paper.pdf ├── requirements.txt ├── results └── default_RK4_DoubleSym_cifar10_2019_06_11_00_36_36.txt ├── setup.md ├── startup.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 EmoryMLIP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DynamicBlocks 3 | PyTorch implementations of a generalized ResNet-inspired network architecture which allows for broader experimentation. 4 | 5 | ### Getting Started 6 | 7 | Start up your python environment and install the python packages stored in requirements.txt: 8 | 9 | ``` 10 | pip3 install -r requirements.txt 11 | ``` 12 | 13 | Run the default network (RK4 scheme using a doubleSymLayer on the CIFAR-10 dataset): 14 | ``` 15 | python3 RKNet.py 16 | ``` 17 | 18 | [further setup details](setup.md) 19 | 20 | [naming convention](naming.md) 21 | 22 | ### References 23 | 24 | The concepts behind the networks implemented by this toolbox are detailed in: 25 | 26 | Lars Ruthotto and Eldad Haber (2018). [Deep Neural Networks Motivated by Partial Differential Equations](https://arxiv.org/abs/1804.04272). arXiv.org. 27 | 28 | Eldad Haber and Lars Ruthotto (2017). [Stable architectures for deep neural networks](https://doi.org/10.1088/1361-6420/aa9a90). Inverse Problems, 34(1). 29 | 30 | Bo Chang, Lili Meng, Eldad Haber, Lars Ruthotto, David Begert, and Elliot Holtham (2018). [Reversible architectures for arbitrarily deep residual neural networks](https://arxiv.org/abs/1709.03698). Presented at the Thirty-Second AAAI Conference on Artificial Intelligence. 31 | 32 | 33 | 34 | ### Acknowledgements 35 | 36 | This material is in part based upon work supported by the US Israel National Science Foundation Grant Number 2018209 and the US National Science Foundation under Grant Number DMS-1751636. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. 37 | -------------------------------------------------------------------------------- /RKNet.py: -------------------------------------------------------------------------------- 1 | # RKNet.py 2 | 3 | from utils import * 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | import time 13 | import datetime 14 | import math 15 | import copy 16 | 17 | import os 18 | import sys 19 | 20 | from modules.ConnectingLayer import ConnectingLayer 21 | from modules.rk4 import rk4 22 | from modules.rk1 import rk1 23 | from modules.TvNorm import TvNorm 24 | from modules.DoubleSymLayer import DoubleSymLayer 25 | from modules.DoubleLayer import DoubleLayer 26 | 27 | from modules.ClippedModule import * 28 | 29 | class RKNet(ClippedModule): 30 | def __init__(self, tY, tTheta, nChanIn, nClasses, vFeat, dynamicScheme=rk4, layer=DoubleSymLayer, 31 | openParams=None, dynamicParams=None, connParams=None, linear=None): 32 | """ 33 | The main workhorse network. Consists of an opening layer (a single layer), multiple dynamicBlocks each 34 | of which is connected by a ConnectingLayer (a single layer), then 35 | 36 | :param tY: list, time discretization for the states , [0,1,2,3,4] 37 | :param tTheta: list, time sdiscretization for the controls, [0,1,2,3,4] 38 | :param nChanIn: int, number of input channels (RGB = 3) 39 | :param nClasses: int, number of output classes 40 | :param vFeat: list, number of channels per dynamic block...last value for final connecting layer [16,32,64,64] 41 | :param dynamicScheme: module, the time integrator / ODE solver (ex: rk1, rk4) 42 | :param layer: module, the primary layer in the dynamic blocks (ex: DoubleLayer, DoubleSymLayer) 43 | :param openParams: dict, the params dict to pass to the opening layer 44 | :param dynamicParams: dict, the params dict to pass to the dynamic scheme 45 | also, can pass strings "TvNorm", "NoNorm", or "Batch" for the corresponding norm with the default 46 | :param connParams: dict, the params dict to pass to the connecting layers 47 | :param linear: module, a nn.Linear module already initialized to be used as final fully connected layer 48 | """ 49 | super().__init__() 50 | self.nBlocks = len(vFeat) - 1 51 | self.vFeat = vFeat 52 | self.tY = tY 53 | self.tTheta = tTheta 54 | 55 | if dynamicParams is None or dynamicParams=="TvNorm" or dynamicParams=="Batch" or dynamicParams=="NoNorm": 56 | dynamicParams = [dynamicParams] * self.nBlocks 57 | if connParams is None: 58 | connParams = [None] * self.nBlocks 59 | 60 | 61 | self.dynamicBlocks = nn.ModuleList([]) 62 | self.connectors = nn.ModuleList([]) 63 | 64 | act = nn.ReLU() 65 | 66 | if openParams is None: 67 | openParams = { 'act': act, 68 | 'normLayer': nn.BatchNorm2d(num_features=vFeat[0], eps=1e-4), 69 | 'conv': nn.Conv2d(nChanIn, vFeat[0], kernel_size=3, padding=1, stride=1)} 70 | self.open = ConnectingLayer([nChanIn, vFeat[0]], params=openParams) 71 | 72 | 73 | self.dynamicBlocks = nn.ModuleList([]) 74 | self.connectors = nn.ModuleList([]) 75 | 76 | for blk in range(self.nBlocks): 77 | 78 | if dynamicParams[blk] is None or dynamicParams[blk]=="TvNorm": 79 | # default setup for the dynamic block 80 | dynamicParams[blk] = {'act': act, 81 | 'vFeat' : vFeat[blk], 82 | 'normLayer': TvNorm(vFeat[blk], eps=1e-4)} 83 | elif dynamicParams[blk]=="Batch": 84 | dynamicParams[blk] = {'act': act, 85 | 'vFeat': vFeat[blk], 86 | 'normLayer': nn.BatchNorm2d(vFeat[blk])} 87 | elif dynamicParams[blk]=="NoNorm": 88 | dynamicParams[blk] = {'act': act, 89 | 'vFeat': vFeat[blk]} 90 | 91 | 92 | self.dynamicBlocks.append(dynamicScheme(tTheta, tY, layer, layerParams = dynamicParams[blk] )) 93 | 94 | if connParams[blk] is None: 95 | connParams[blk] = {'act': act, 96 | 'normLayer': nn.BatchNorm2d(num_features=vFeat[blk+1], eps=1e-4), 97 | 'conv': nn.Conv2d(vFeat[blk], vFeat[blk+1], kernel_size=1, padding=0, stride=1)} 98 | 99 | connLayer = ConnectingLayer( [vFeat[blk],vFeat[blk+1]] , params = connParams[blk]) 100 | self.connectors.append(connLayer) 101 | 102 | if linear is None: 103 | self.linear = nn.Linear(vFeat[-1], nClasses) 104 | self.linear.weight.data = normalInit(self.linear.weight.shape) 105 | self.linear.bias.data = self.linear.bias.data*0 106 | else: 107 | self.linear = linear 108 | 109 | 110 | def forward(self, x): 111 | 112 | x = self.open(x) 113 | 114 | for blk in range(self.nBlocks): 115 | 116 | x = self.dynamicBlocks[blk](x) 117 | x = self.connectors[blk](x) 118 | 119 | if blk < self.nBlocks - 1: 120 | # x = self.connectors[blk](x) # move connectors here ?????????????? Kind of a big change???????? 121 | x = F.avg_pool2d(x, 2) # TODO: make the pooling operator abstract for max vs avg pooling etc. 122 | else: 123 | # average each channel 124 | x = F.avg_pool2d(x, x.shape[2:4]) 125 | 126 | x = x.view(x.size(0), -1) 127 | x = self.linear(x) 128 | 129 | return x 130 | 131 | 132 | def setup_checkpointing(self,optimizer): 133 | 134 | self.optimizers=[] 135 | 136 | opt1=copy.deepcopy(optimizer) 137 | opt1.defaults=optimizer.defaults 138 | opt1.param_groups=[] 139 | 140 | opt1.add_param_group({"params": self.open.parameters()}) 141 | self.optimizers.append(opt1) 142 | for blk in range(self.nBlocks): 143 | opt = copy.deepcopy(optimizer) 144 | opt.defaults = optimizer.defaults 145 | opt.param_groups = [] 146 | 147 | opt.add_param_group({"params":self.dynamicBlocks[blk].parameters()}) 148 | opt.add_param_group({"params":self.connectors[blk].parameters()}) 149 | self.optimizers.append(opt) 150 | opt = copy.deepcopy(optimizer) 151 | opt.defaults = optimizer.defaults 152 | opt.param_groups = [] 153 | 154 | opt.add_param_group({'params': self.linear.weight}) 155 | self.optimizers.append(opt) 156 | 157 | def checkpoint_train(self,data,optimizer,device,regularization=0): 158 | self.setup_checkpointing(optimizer) 159 | for optimizer in self.optimizers: 160 | optimizer.zero_grad() 161 | y=[] 162 | inputs, labels = data 163 | self.train() # set model to train mode 164 | with torch.no_grad(): 165 | inputs, labels = inputs.to(device), labels.to(device) 166 | # forward + backward + optimize 167 | 168 | x = self.open(inputs) 169 | y.append(x) # checkpointing 170 | for blk in range(self.nBlocks): 171 | x = self.dynamicBlocks[blk](x) 172 | x = self.connectors[blk](x) 173 | if blk < self.nBlocks - 1: 174 | x = F.avg_pool2d(x, 2) 175 | else: 176 | x = F.avg_pool2d(x, x.shape[2:4]) 177 | 178 | # x = x.view(x.size(0), -1) 179 | # x = self.linear(x) 180 | 181 | y.append(x) # checkpointing 182 | 183 | with torch.enable_grad(): 184 | y_tag=Variable(y[-1],requires_grad=True) 185 | 186 | # net.linear.weight.data = torch.transpose(net.W[0:-1, :], 0, 1) 187 | # net.linear.bias.data = net.W[-1, :] 188 | 189 | 190 | loss, Si = misfitW(y_tag, torch.transpose(self.linear.weight,0,1), labels, device) 191 | final_loss=loss.item() 192 | 193 | 194 | loss.backward() 195 | prev_grad = y_tag.grad 196 | 197 | ## check what we have on hand 198 | self.optimizers[-1].step() 199 | self.optimizers[-1].zero_grad() 200 | for blk in reversed(range(self.nBlocks)): 201 | y_tag = Variable(y[-2], requires_grad=True) 202 | x = self.dynamicBlocks[blk](y_tag) 203 | x = self.connectors[blk](x) 204 | if blk < self.nBlocks - 1: 205 | x = F.avg_pool2d(x, 2) # TODO: make the pooling operator abstract for max vs avg pooling etc. 206 | else: 207 | # average each channel 208 | x = F.avg_pool2d(x, x.shape[2:4]) 209 | loss = torch.dot(x.view(-1),prev_grad.view(-1)) 210 | loss=loss+regularization*sum(self.dynamicBlocks[blk].weight_variance()) 211 | y.pop() 212 | loss.backward() 213 | prev_grad = y_tag.grad 214 | 215 | self.optimizers[1+blk].step() 216 | self.optimizers[1+blk].zero_grad() 217 | 218 | 219 | x = self.open(inputs) 220 | loss = torch.dot(x.view(-1),prev_grad.view(-1)) 221 | y.pop() 222 | loss.backward() 223 | self.optimizers[0].step() 224 | self.optimizers[0].zero_grad() 225 | 226 | _ , numCorrect, numTotal = getAccuracy(Si,labels) 227 | torch.cuda.empty_cache() 228 | return final_loss , numCorrect, numTotal 229 | 230 | 231 | 232 | def checkpoint_train_debug(self,data,optimizer,device,net_other): 233 | net_other.train() # set model to train mode 234 | 235 | inputs, labels = data 236 | optimizer.zero_grad() 237 | 238 | inputs, labels = inputs.to(device), labels.to(device) 239 | z = net_other.forward(inputs) 240 | aa=[] 241 | aa.append(inputs) 242 | x = self.open(inputs) 243 | aa.append(x) 244 | for blk in range(self.nBlocks): 245 | 246 | x = self.dynamicBlocks[blk](x) 247 | x = self.connectors[blk](x) 248 | 249 | if blk < self.nBlocks - 1: 250 | x = F.avg_pool2d(x, 2) # TODO: make the pooling operator abstract for max vs avg pooling etc. 251 | else: 252 | # average each channel 253 | x = F.avg_pool2d(x, x.shape[2:4]) 254 | aa.append(x) 255 | 256 | 257 | loss_other, Si = misfit(z, net_other.linear.weight, labels, device) 258 | 259 | loss_other.backward() 260 | 261 | 262 | self.setup_checkpointing(optimizer) 263 | for optimizer in self.optimizers: 264 | optimizer.zero_grad() 265 | y=[] 266 | inputs, labels = data 267 | self.train() # set model to train mode 268 | with torch.no_grad(): 269 | inputs, labels = inputs.to(device), labels.to(device) 270 | # forward + backward + optimize 271 | y.append(inputs) # checkpointing 272 | x = self.open(inputs) 273 | y.append(x) # checkpointing 274 | for blk in range(self.nBlocks): 275 | x = self.dynamicBlocks[blk](x) 276 | x = self.connectors[blk](x) 277 | if blk < self.nBlocks - 1: 278 | x = F.avg_pool2d(x, 2) 279 | else: 280 | x = F.avg_pool2d(x, x.shape[2:4]) 281 | 282 | y.append(x) # checkpointing 283 | 284 | with torch.enable_grad(): 285 | y[-1]=Variable(y[-1],requires_grad=True) 286 | 287 | 288 | for reg,chkpt in zip(aa,y): 289 | print("forward vs checkpoint:") 290 | print(torch.norm(reg - chkpt)) 291 | print(reg.shape) 292 | 293 | 294 | loss, Si = misfit(y[-1], self.linear.weight, labels, device) 295 | final_loss=loss.item() 296 | 297 | 298 | print("loss is:") 299 | print(torch.norm(loss-loss_other)) 300 | 301 | 302 | loss.backward() 303 | 304 | 305 | print("linear.weight:") 306 | print(torch.norm(self.linear.weight-net_other.linear.weight)) 307 | print(torch.norm(self.linear.weight.grad-net_other.linear.weight.grad)) 308 | 309 | 310 | 311 | 312 | ## check what we have on hand 313 | self.optimizers[-1].step() 314 | self.optimizers[-1].zero_grad() 315 | for blk in reversed(range(self.nBlocks)): 316 | y[-2] = Variable(y[-2], requires_grad=True) 317 | x = self.dynamicBlocks[blk](y[-2]) 318 | x = self.connectors[blk](x) 319 | if blk < self.nBlocks - 1: 320 | x = F.avg_pool2d(x, 2) # TODO: make the pooling operator abstract for max vs avg pooling etc. 321 | else: 322 | # average each channel 323 | x = F.avg_pool2d(x, x.shape[2:4]) 324 | loss = torch.dot(x.view(-1),y[-1].grad.view(-1)) 325 | loss.backward() 326 | y.pop() 327 | 328 | print("connectors " + str(blk)+ ":") 329 | P1=self.connectors[blk].parameters() 330 | P2=net_other.connectors[blk].parameters() 331 | for p1,p2 in zip(P1,P2): 332 | print(torch.norm(p1.grad-p2.grad)/torch.norm(p1.grad)) 333 | print(p1.shape) 334 | print("dynamicBlocks " + str(blk)+ ":") 335 | P1=self.dynamicBlocks[blk].parameters() 336 | P2=net_other.dynamicBlocks[blk].parameters() 337 | for p1,p2 in zip(P1,P2): 338 | print(torch.norm(p1.grad-p2.grad)/torch.norm(p1.grad)) 339 | print(p1.shape) 340 | 341 | self.optimizers[1+blk].step() 342 | self.optimizers[1+blk].zero_grad() 343 | 344 | 345 | x = self.open(y[0]) 346 | loss = torch.dot(x.view(-1),y[1].grad.view(-1)) 347 | y.pop() 348 | loss.backward() 349 | print("Opening :") 350 | P1 = self.open.parameters() 351 | P2 = net_other.open.parameters() 352 | for p1, p2 in zip(P1, P2): 353 | print(torch.norm(p1.grad - p2.grad)/torch.norm(p1.grad)) 354 | print("finished") 355 | self.optimizers[0].step() 356 | self.optimizers[0].zero_grad() 357 | y.pop() 358 | _ , numCorrect, numTotal = getAccuracy(Si,labels) 359 | torch.cuda.empty_cache() 360 | return final_loss , numCorrect, numTotal 361 | 362 | 363 | def regularization(self): 364 | reg=0 365 | for block in self.dynamicBlocks: 366 | reg = reg + sum(block.weight_variance()) 367 | #print(reg) 368 | return reg 369 | 370 | 371 | def runRKNet(sDataset, sTitle, net=None, 372 | tTheta=None, tY=None, vFeat = None, 373 | dynamicScheme=rk4, dynamicParams=None, layer=DoubleSymLayer, 374 | writeFile=True, gpu=0, fTikh=4e-4 , fMomentum=0.9, bNesterov=False, 375 | transTrain=None, transTest=None, lr=None, batchSize=None, 376 | regularization_param=0.0,checkpointing=False,reparametrization_epochs=None, 377 | nTrain=None, nVal=None, percentVal=0.2, 378 | loaderTrain=None, loaderVal=None, loaderTest=None, nChanIn=3, nClasses=None): 379 | """ 380 | 381 | :param sDataset: string, name of dataset ; 'cifar10' , 'stl10' , 'cifar100' 382 | :param sTitle: string, prefix to name output files 383 | :param net: Module, the network if want different from the default network 384 | :param tTheta: list , time discretization for the controls (where the weights are) 385 | :param tY: list , time discretization for the states (where the layers are) 386 | :param vFeat: list , the width/ number of channels for each dynamic block (last entry for last connecting block) 387 | :param dynamicScheme: Module name, time integrator/ ODE solver (ex: rk1, rk4) 388 | :param layer: Module name, layer used for dynamic block (ex: DoubleLayer, DoubleSymLayer) 389 | :param writeFile: boolean, True means write the log to a text file, False means print to terminal 390 | :param gpu: int , which gpu to use to operate (must be < number of gpus accessible) 391 | :param fTikh: float, weight decay / tikhonov regularization scalar for optimizer 392 | :param fMomentum: float, momentum value for optimizer 393 | :param bNesterov: boolean, True means use Nesterov, False means no Nesterov for optimizer 394 | :param lr: numpy array, learning rate per epoch (length=number of epochs) 395 | :param batchSize: int, number of input examples (images) per batch 396 | :param transTrain: pytorch transform, defines the augmentation and conversion to Tensor for training set 397 | :param transTest: pytorch transform, defines the augmentation and conversion to Tensor for testing/validation set 398 | :param regularization_param: float, for the regularization in time within each dynamic block, the loss 399 | becomes: Loss + regularization_param * regMetric(prevLayer,currLayer) (regMetric defined in utils) 400 | :param checkpointing: boolean, True means implement checkpointing, False runs normally 401 | :param reparametrization_epochs: 402 | 403 | To use entire dataset, leave nTrain and nVal as None. nTrain and nVal values override percentVal. 404 | :param nTrain: int, number of training examples (must be <= total examples in training set) 405 | :param nVal: int, number of validation examples ( must be <= total examples in validation set) 406 | :param percentVal: float, percentage of training set that will be split out as validation 407 | 408 | if passing one of the following, must pass all 5: 409 | loaderTrain, loaderVal, loaderTest, nChanIn, and nClasses 410 | :param loaderTrain: dataloader for training data 411 | :param loaderVal: dataloader for validation data 412 | :param loaderTest: dataloader for testing data 413 | :param nChanIn: int, number of input channels 414 | :param nClasses: int, number of output classes 415 | :return: net, sBasePath, device 416 | net: the model 417 | sBasePath: string, path used to name files; (ex. sBasePath_model.pt, sBasePath.txt) 418 | device: torch device used to train the model 419 | """ 420 | 421 | 422 | start = time.time() # timing 423 | 424 | sStartTime = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 425 | print('start time: ', sStartTime) 426 | sBasePath = 'results/' + sTitle + '_' + sDataset + '_' + sStartTime 427 | print('file: ', sBasePath) 428 | 429 | 430 | # get defaults ----------------------------------------------------------------------------- 431 | vFeatDef, transTrainDef, transTestDef, batchSizeDef, lrDef, tThetaDef, nTrainDef, nValDef, device = \ 432 | getRunDefaults(sDataset, gpu=gpu) 433 | 434 | if writeFile: # write output to a text file instead of stdout 435 | sys.stdout = open(sBasePath + '.txt' , 'wt' ) 436 | 437 | 438 | # For the option to reparametrize during training: 439 | if reparametrization_epochs is not None: 440 | tY_array=tY 441 | tTheta_array=tTheta 442 | tY=tY[0] 443 | tTheta=tTheta[0] 444 | else: 445 | tY_array = None 446 | tTheta_array = None 447 | 448 | if net is None: 449 | if vFeat is None: 450 | vFeat = vFeatDef 451 | if tTheta is None: 452 | tTheta = tThetaDef 453 | else: 454 | vFeat = net.vFeat 455 | tTheta = net.tTheta 456 | tY = net.tY 457 | 458 | if transTrain is None: 459 | transTrain = transTrainDef 460 | if transTest is None: 461 | transTest = transTestDef 462 | if batchSize is None: 463 | batchSize = batchSizeDef 464 | if lr is None: 465 | lr = lrDef 466 | if nTrain is None: 467 | nTrain = nTrainDef 468 | if nVal is None: 469 | nVal = nValDef 470 | if tY is None: 471 | tY = tTheta 472 | 473 | 474 | # percentVal = 0.20 # use 20% of the training set as validation 475 | # fTikh = 4e-4 # tikhonov regularization / weight decay value 476 | # fMomentum = 0.9 477 | # bNesterov = False 478 | 479 | #------------------------------------------------------------------------------------------- 480 | 481 | nEpoch = lr.size 482 | print("device", device) 483 | 484 | print('time steps Y: ' , tY) 485 | print('time steps theta: ' , tTheta) 486 | print('no. of channels: ' , vFeat) 487 | print('no. of epochs: ' , nEpoch) 488 | print('ODE reg param: ' , regularization_param) 489 | print('Use checkpointing: ' , checkpointing) 490 | 491 | 492 | # assume that if trainLoader is passed, then all loaders are passed 493 | if loaderTrain is None: 494 | if nTrain is not None and nVal is not None: 495 | loaderTrain, loaderVal, loaderTest, nChanIn, nClasses = \ 496 | getDataLoaders(sDataset, batchSize, device, 497 | batchSizeTest=batchSize, percentVal=percentVal, 498 | transformTrain=transTrain, transformTest=transTest, 499 | nTrain=nTrain, nVal=nVal) 500 | else: 501 | loaderTrain, loaderVal, loaderTest, nChanIn, nClasses = \ 502 | getDataLoaders( sDataset , batchSize, device, 503 | batchSizeTest=batchSize, percentVal=percentVal, 504 | transformTrain=transTrain, transformTest=transTest, 505 | nTrain=None, nVal=None ) 506 | 507 | 508 | 509 | 510 | 511 | # print the matrix presenting the channels and conv sizes for every layer 512 | # printFeatureMatrix(nt, nBlocks, vFeat) 513 | 514 | if net is None: 515 | net = RKNet(tY, tTheta, nChanIn, nClasses, vFeat, dynamicScheme=dynamicScheme, dynamicParams=dynamicParams, layer=layer ) 516 | 517 | net.to(device) 518 | 519 | # count number of weights and layers 520 | nWeights = sum(p.numel() for p in net.parameters()) 521 | nLayers = len(list(net.parameters())) 522 | print('Training ',nWeights,' weights in ', nLayers, ' layers',flush=True) 523 | 524 | # separate out the norms 525 | normParams = [] 526 | convParams = [] 527 | for name, param in net.named_parameters(): 528 | if 'normLayer' in name: 529 | normParams.append(param) 530 | else: 531 | convParams.append(param) 532 | 533 | allParams = [ {'params': normParams, 'weight_decay': 0 }, 534 | {'params': convParams} ] 535 | 536 | optimizer = optim.SGD( allParams, lr=lr.item(0), momentum=fMomentum, weight_decay=fTikh, nesterov=bNesterov) 537 | print('optimizer = SGD with momentum=%.1e, weight_decay=%.1e, nesterov=%d, batchSize=%d' 538 | % (fMomentum, fTikh, bNesterov, batchSize)) 539 | 540 | # train the network 541 | sPathOpt = trainNet(net, optimizer, lr, nEpoch, device, sBasePath, loaderTrain, loaderVal, nMini=1, verbose=True, 542 | regularization_param=regularization_param,checkpointing=checkpointing) 543 | 544 | end = time.time() 545 | print('Time elapsed: ' , end-start) 546 | print('Training complete. Now testing...') 547 | 548 | #--------------TESTING----------------- 549 | # evaluate performance on test set 550 | 551 | # load best model weights based on validation accuracy 552 | net = torch.load(sPathOpt + '.pt') 553 | testLoss, testAcc = evalNet(net, device, loaderTest) 554 | 555 | print('\ntesting loss: %-9.2e testing accuracy: %-9.2f' % 556 | ( testLoss, testAcc * 100 ) ) 557 | 558 | print(net) # just so we can see what does what 559 | 560 | return net, sBasePath, device 561 | 562 | print('done') 563 | 564 | 565 | 566 | 567 | def runDefault(): 568 | runRKNet( sDataset='cifar10', sTitle = 'default_RK4_DoubleSym', writeFile=False) 569 | 570 | def runDefaultRegularized(): 571 | runRKNet( sDataset='stl10', sTitle = 'regularizedBN_RK1_Double', dynamicScheme=rk1, layer = DoubleLayer, 572 | writeFile=True,regularization_param=0.01, checkpointing=False, gpu=1, dynamicParams="Batch") 573 | 574 | 575 | # For Testing!! 576 | if __name__ == '__main__': 577 | runDefault() 578 | # runDefaultRegularized() 579 | 580 | 581 | -------------------------------------------------------------------------------- /modules/ClippedModule.py: -------------------------------------------------------------------------------- 1 | # ClippedModule.py 2 | # 5/2/19 3 | 4 | from utils import * 5 | import torch.nn as nn 6 | 7 | class ClippedModule(nn.Module): 8 | """ 9 | Extend nn.Module to include max and min values for bound constraints / clipping 10 | """ 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.minConv = -0.5 15 | self.maxConv = 0.5 16 | self.minDef = -1.5 17 | self.maxDef = 1.5 18 | 19 | 20 | def setClipValues(self, minConv = -0.5, maxConv=0.5, minDef=-1.5, maxDef=1.5): 21 | """ 22 | set box constraints 23 | :param minConv: float, lower bound for convolutions 24 | :param maxConv: float, upper bound for convolutions 25 | :param minDef: float, lower bound for all other parameters 26 | :param maxDef: float, upper bound for all other parameters 27 | """ 28 | self.minConv = minConv 29 | self.maxConv = maxConv 30 | self.minDef = minDef 31 | self.maxDef = maxDef 32 | 33 | def calcClipValues(self,h,nPixels,nChan): 34 | """ calculation for setting bound constraints....not tuned yet""" 35 | # devlop some multiplier to adjust the defaults 36 | mult =1/h # larger time step, reduce the constraints 37 | mult = mult / (math.sqrt(nPixels)) 38 | mult = mult * (500/ nChan**2) 39 | 40 | minConv = -1 41 | maxConv = 1 42 | self.setClipValues( minConv=mult*minConv, maxConv=mult*maxConv, minDef=-1.5, maxDef=1.5) 43 | 44 | 45 | 46 | 47 | def clip(self): 48 | """project values onto box constraints""" 49 | # assume conv range is subset of default range 50 | 51 | if hasattr(self,'conv'): 52 | self.conv.weight.data.clamp_(min=self.minConv, max=self.maxConv) 53 | if self.conv.bias is not None: 54 | self.conv.bias.data.clamp_(min=self.minConv, max=self.maxConv) 55 | 56 | if hasattr(self, 'conv1'): 57 | self.conv1.weight.data.clamp_(min=self.minConv, max=self.maxConv) 58 | if self.conv1.bias is not None: 59 | self.conv1.bias.data.clamp_(min=self.minConv, max=self.maxConv) 60 | 61 | if hasattr(self, 'conv2'): 62 | self.conv2.weight.data.clamp_(min=self.minConv, max=self.maxConv) 63 | if self.conv2.bias is not None: 64 | self.conv2.bias.data.clamp_(min=self.minConv, max=self.maxConv) 65 | 66 | if hasattr(self, 'weight'): 67 | w = self.weight.data 68 | w.clamp_(min=self.minDef, max=self.maxDef) 69 | # if hasattr(module, 'bias'): # bias can be quite large( 3.0+ after removal of batch norm) 70 | # if module.bias is not None: 71 | # w = module.bias.data 72 | # w.clamp_(min=self.min, max=self.max) 73 | 74 | for module in self.children(): # just immediate children 75 | if hasattr(module, 'clip'): 76 | module.clip() 77 | else: # this may not be robust...only goes two levels down...trying to get to the normLayers 78 | if hasattr(module, 'weight'): 79 | w = module.weight.data 80 | w.clamp_(min=self.minDef, max=self.maxDef) 81 | # if hasattr(module, 'bias'): # bias can be quite large( 3.0+ after removal of batch norm) 82 | # if module.bias is not None: 83 | # w = module.bias.data 84 | # w.clamp_(min=self.min, max=self.max) 85 | for child in module.children(): 86 | if hasattr(child, 'clip'): 87 | child.clip() 88 | 89 | 90 | return self 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | from modules.DoubleLayer import * 96 | from RKNet import * 97 | 98 | # mini-test 99 | 100 | l = DoubleLayer([4,4], {'normLayer':nn.BatchNorm2d(4) } ) 101 | 102 | l.setClipValues( minConv=-0.1, maxConv=0.1, minDef=-1.5, maxDef=1.5) 103 | l.conv1.weight.data = l.conv1.weight.data * 0 + 4 104 | l.conv1.bias.data = l.conv1.bias.data * 0 + 4 105 | l.conv2.weight.data = l.conv2.weight.data * 0 + 4 106 | l.conv2.bias.data = l.conv2.bias.data * 0 + 7 107 | l.normLayer1.weight.data = l.normLayer1.weight.data * 0 + 5 108 | l.normLayer2.weight.data = l.normLayer2.weight.data * 0 + 2 109 | l.normLayer1.bias.data = l.normLayer1.bias.data * 0 + 3 110 | l.normLayer2.bias.data = l.normLayer2.bias.data * 0 111 | 112 | l.clip() 113 | 114 | for n, p in l.named_parameters(): 115 | print(n, p.data) 116 | # assert(torch.max(p.data) == ?????) maxConv or maxDef 117 | 118 | 119 | l.conv1.weight.data = l.conv1.weight.data * 0 - 4 120 | l.conv1.bias.data = l.conv1.bias.data * 0 - 4 121 | l.conv2.weight.data = l.conv2.weight.data * 0 - 4 122 | l.conv2.bias.data = l.conv2.bias.data * 0 - 4 123 | l.normLayer1.weight.data = l.normLayer1.weight.data * 0 - 4 124 | l.normLayer2.weight.data = l.normLayer2.weight.data * 0 - 4 125 | l.normLayer1.bias.data = l.normLayer1.bias.data * 0 - 4 126 | l.normLayer1.bias.data = l.normLayer1.bias.data * 0 - 4 127 | 128 | l.clip() 129 | 130 | for n, p in l.named_parameters(): 131 | print(n, p.data) 132 | # assert(torch.max(p.data) == ?????) maxConv or maxDef 133 | 134 | 135 | # ----------------------------------------- 136 | net = RKNet(tY=[0,1,2,3,4], tTheta=[0,1,2,3,4], nChanIn=3, 137 | nClasses=10, vFeat=[4,7,7], dynamicScheme = rk1, layer = DoubleSymLayer) 138 | 139 | for n, p in net.named_parameters(): 140 | p.data = p.data*0 + 5 141 | 142 | 143 | net.clip() 144 | 145 | for n, p in net.named_parameters(): 146 | if torch.max(p.data) > 4: 147 | print(n, p.data) 148 | 149 | print('code finished running') 150 | 151 | -------------------------------------------------------------------------------- /modules/ConnectingLayer.py: -------------------------------------------------------------------------------- 1 | # ConnectingLayer.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | from utils import normalInit, misfitW, conv3x3 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import copy 9 | 10 | from modules.ClippedModule import * 11 | 12 | 13 | class ConnectingLayer(ClippedModule): 14 | """ 15 | Implementation of the resizing connecting layer (changing the number of channels) 16 | 17 | Attributes: 18 | conv (sub-module): convolution class, default is 3x3 2Dconvolution 19 | act (sub-module): activation function, defualt is ReLU() 20 | normLayer (sub-module): normalization with affine bias and weight, default is no normalization 21 | 22 | Typical attributes for the children: 23 | conv.weight (Parameter): dims (nChanOut,nChanIn,3,3) for default 2DConvolution from nChanIn -> nChanOut channels 24 | conv.bias (Parameter): vector, dims (nChanIn) 25 | normLayer.weight (Parameter): vector, dims (nChanOut) affine scaling 26 | normLayer.bias (Parameter): vector, dims (nChanOut) affine scaling bias 27 | 28 | """ 29 | def __init__(self, vFeat, params={}): 30 | """ 31 | :param vFeat: 2-item list of number of expected channels and number of channels to return, [nChanIn,nChanOut] 32 | :param params: dict of possible parameters ( 'conv' , 'act', 'szKernel' , 'normLayer' ) 33 | """ 34 | super().__init__() 35 | nChanIn = vFeat[0] 36 | nChanOut = vFeat[1] 37 | 38 | # defaults 39 | szKernel = 3 40 | stride = 1 41 | padding = 1 42 | self.conv = nn.Conv2d(in_channels=nChanIn, kernel_size=szKernel, 43 | out_channels=nChanOut, stride=stride, padding=padding) 44 | self.act = nn.ReLU() 45 | 46 | # overwrite from params where necessary 47 | if 'conv' in params.keys(): 48 | self.conv = copy.deepcopy(params.get('conv')) 49 | szKernel = self.conv.kernel_size[0] # doesn't allow for rectangular kernels 50 | stride = self.conv.stride 51 | padding = self.conv.padding 52 | elif 'szKernel' in params.keys(): 53 | szKernel = params.get('szKernel') 54 | self.conv = nn.Conv2d(in_channels=nChanIn, kernel_size=szKernel, 55 | out_channels=nChanOut, stride=stride, padding=padding) 56 | 57 | if 'act' in params.keys(): 58 | self.act = params.get('act') 59 | 60 | if 'normLayer' in params.keys(): 61 | self.normLayer = copy.deepcopy(params.get('normLayer')) 62 | self.normLayer.weight.data = torch.ones(nChanOut) 63 | # self.normLayer.bias.data = torch.zeros(nChanOut) # this may be redundant 64 | 65 | 66 | self.conv.weight.data = normalInit(self.conv.weight.data.shape) 67 | # make sure the biases are 0 68 | if self.conv.bias is not None: 69 | self.conv.bias.data *= 0 70 | 71 | 72 | 73 | def forward(self, x): 74 | z = self.conv(x) 75 | if hasattr(self, 'normLayer'): 76 | z = self.normLayer(z) 77 | z = self.act(z) 78 | return z 79 | 80 | 81 | 82 | if __name__ == "__main__": 83 | print('ConnectingLayer test\n') 84 | import testConnectingLayer # this should run the test 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /modules/DoubleLayer.py: -------------------------------------------------------------------------------- 1 | # DoubleLayer.py 2 | # 2/5/19 3 | 4 | from utils import * 5 | import torch.nn as nn 6 | import copy 7 | from modules.ClippedModule import * 8 | 9 | 10 | class DoubleLayer(ClippedModule): 11 | """ 12 | Implementation of the double layer, also referred to as a Basic ResNet Block. 13 | 14 | act2( N2( K2( act1( N1( K1(Y) ))))) 15 | 16 | Attributes: 17 | conv1 (sub-module): convolution class, default is 3x3 2Dconvolution 18 | conv2 (sub-module): '' 19 | act1 (sub-module): activation function, default is ReLU() 20 | act2 (sub-module): '' 21 | normLayer1 (sub-module): normalization with affine bias and weight, default is no normalization 22 | normLayer2 (sub-module): '' 23 | 24 | Typical attributes for the children: 25 | conv#.weight (Parameter): dims (nChanOut,nChanIn,3,3) for default 2DConvolution from nChanIn -> nChanOut channels 26 | conv#.bias (Parameter): vector, dims (nChanIn) 27 | normLayer#.weight (Parameter): vector, dims (nChanOut) affine scaling 28 | normLayer#.bias (Parameter): vector, dims (nChanOut) affine scaling bias 29 | """ 30 | 31 | def __init__(self, vFeat, params={}): 32 | """ 33 | :param vFeat: 2-item list of number of expected input channels and number of channels to return, [nChanIn,nChanOut] 34 | :param params: dict of possible parameters ( 'conv1' , 'conv2', 'act1' , 'act2' , 'normLayer1' , 'normLayer2' ) 35 | """ 36 | super().__init__() 37 | if type(vFeat) is not list: # assume its one number 38 | vFeat = [vFeat, vFeat] 39 | nChanIn = vFeat[0] 40 | nChanOut = vFeat[1] 41 | 42 | # defaults 43 | szKernel = 3 44 | stride = 1 45 | padding = 1 46 | # be cognisant of where you initialize...make sure each is initialized or deep-copied 47 | self.conv1 = nn.Conv2d(in_channels=nChanIn, kernel_size=szKernel, 48 | out_channels=nChanOut, stride=stride, padding=padding) 49 | self.conv2 = nn.Conv2d(in_channels=nChanIn, kernel_size=szKernel, 50 | out_channels=nChanOut, stride=stride, padding=padding) 51 | self.act1 = nn.ReLU() 52 | self.act2 = nn.ReLU() 53 | 54 | # overwrite from params where necessary 55 | if 'conv1' in params.keys(): 56 | self.conv1 = copy.deepcopy(params.get('conv1')) 57 | if 'conv2' in params.keys(): 58 | self.conv2 = copy.deepcopy(params.get('conv2')) 59 | 60 | if 'act1' in params.keys(): 61 | self.act1 = params.get('act1') 62 | if 'act2' in params.keys(): 63 | self.act2 = params.get('act2') 64 | 65 | if 'normLayer1' in params.keys(): 66 | self.normLayer1 = copy.deepcopy(params.get('normLayer1')) 67 | self.normLayer1.weight.data = torch.ones(nChanOut) 68 | self.normLayer1.bias.data = torch.zeros(nChanOut) # this may be redundant 69 | if 'normLayer2' in params.keys(): 70 | self.normLayer2 = copy.deepcopy(params.get('normLayer2')) 71 | self.normLayer2.weight.data = torch.ones(nChanOut) 72 | self.normLayer2.bias.data = torch.zeros(nChanOut) # this may be redundant 73 | 74 | # assume if blah is passed instead of blah1 and blah2, then the user wants them the same 75 | if 'conv' in params.keys(): 76 | self.conv1 = copy.deepcopy(params.get('conv')) 77 | self.conv2 = copy.deepcopy(self.conv1) 78 | if 'act' in params.keys(): 79 | self.act1 = params.get('act') 80 | self.act2 = copy.deepcopy(self.act1) 81 | if 'normLayer' in params.keys(): 82 | self.normLayer1 = copy.deepcopy(params.get('normLayer')) 83 | self.normLayer1.weight.data = torch.ones(nChanOut) 84 | #self.normLayer1.bias.data = torch.zeros(nChanOut) 85 | self.normLayer2 = copy.deepcopy(self.normLayer1) 86 | 87 | self.conv1.weight.data = normalInit(self.conv1.weight.data.shape) 88 | self.conv2.weight.data = normalInit(self.conv2.weight.data.shape) 89 | # for this demonstration, make sure the biases are 0 90 | if self.conv1.bias is not None: 91 | self.conv1.bias.data *= 0 92 | if self.conv2.bias is not None: 93 | self.conv2.bias.data *= 0 94 | 95 | 96 | def forward(self,x): 97 | z = self.conv1(x) 98 | if hasattr(self, 'normLayer1'): 99 | z = self.normLayer1(z) 100 | z = self.act1(z) 101 | 102 | z = self.conv2(z) 103 | if hasattr(self, 'normLayer2'): 104 | z = self.normLayer2(z) 105 | z = self.act2(z) 106 | 107 | return z 108 | 109 | 110 | def weight_variance(self,other): 111 | """apply regularization in time""" 112 | value = 0 113 | value += regMetric( nn.utils.convert_parameters.parameters_to_vector(self.parameters()) , 114 | nn.utils.convert_parameters.parameters_to_vector(other.parameters()) ) 115 | return value 116 | 117 | if __name__ == "__main__": 118 | print('DoubleLayer test\n') 119 | import testDoubleLayer # this should run the test 120 | -------------------------------------------------------------------------------- /modules/DoubleSymLayer.py: -------------------------------------------------------------------------------- 1 | # DoubleSymLayer.py 2 | 3 | from utils import * 4 | import torch.nn as nn 5 | import copy 6 | from modules.ClippedModule import * 7 | 8 | class DoubleSymLayer(ClippedModule): 9 | """ 10 | Implementation of the double symmetric layer, also referred to as a Parabolic Layer. 11 | 12 | - K^T ( act( N( K(Y)))) 13 | 14 | Attributes: 15 | conv (sub-module): convolution class, default is 3x3 2Dconvolution 16 | act (sub-module): activation function, default is ReLU() 17 | normLayer (sub-module): normalization with affine bias and weight, default is no normalization 18 | 19 | Typical attributes for the children: 20 | conv.weight (Parameter): dims (nChanOut,nChanIn,3,3) for default 2DConvolution from nChanIn -> nChanOut channels 21 | conv.bias (Parameter): vector, dims (nChanIn) 22 | normLayer.weight (Parameter): vector, dims (nChanOut) affine scaling 23 | normLayer.bias (Parameter): vector, dims (nChanOut) affine scaling bias 24 | """ 25 | 26 | def __init__(self, vFeat, params={}): 27 | super().__init__() 28 | if type(vFeat) is not list: # assume its one number 29 | vFeat = [vFeat, vFeat] 30 | nChanIn = vFeat[0] 31 | nChanOut = vFeat[1] 32 | 33 | # defaults 34 | szKernel = 3 35 | stride = 1 36 | padding = 1 37 | self.conv = nn.Conv2d(in_channels=nChanIn, kernel_size=szKernel, 38 | out_channels=nChanOut, stride=stride, padding=padding) 39 | self.act = nn.ReLU() 40 | 41 | # overwrite from params where necessary 42 | 43 | if 'conv' in params.keys(): 44 | self.conv = copy.deepcopy(params.get('conv')) 45 | szKernel = self.conv.kernel_size[0] # doesn't allow for rectangular kernels 46 | stride = self.conv.stride 47 | padding = self.conv.padding 48 | 49 | if 'szKernel' in params.keys(): 50 | szKernel = params.get('szKernel') 51 | 52 | if 'act' in params.keys(): 53 | self.act = params.get('act') 54 | 55 | if 'normLayer' in params.keys(): 56 | self.normLayer = copy.deepcopy(params.get('normLayer')) 57 | self.normLayer.weight.data = torch.ones(nChanOut) 58 | #self.normLayer.bias.data = torch.zeros(nChanOut) # this may be redundant 59 | 60 | self.convt = nn.ConvTranspose2d(in_channels=nChanOut, kernel_size=szKernel, 61 | out_channels=nChanIn, stride=stride, padding=padding) 62 | 63 | # conv and convt need to share weights 64 | self.weight = nn.Parameter( normalInit([vFeat[1], vFeat[0], szKernel, szKernel]) , requires_grad=True) 65 | self.conv.weight = self.weight 66 | self.convt.weight = self.weight 67 | 68 | # initiate the biases to 0 69 | if self.conv.bias is not None: 70 | self.conv.bias.data *= 0 71 | if self.convt.bias is not None: 72 | self.convt.bias.data *= 0 73 | 74 | def forward(self,x): 75 | z = self.conv(x) 76 | if hasattr(self, 'normLayer'): 77 | z = self.normLayer(z) 78 | z = self.act(z) 79 | z = - self.convt(z) 80 | return z 81 | 82 | def calcClipValues(self,h,nPixels,nChan): 83 | """DoubleSym should have bound constraints half of those in DoubleLayer""" 84 | super().calcClipValues(h,nPixels, nChan) 85 | # self.minDef = 0.5*self.minDef 86 | # self.maxDef = 0.5*self.maxDef 87 | self.minConv = 0.5*self.minConv 88 | self.maxConv = 0.5*self.maxConv 89 | 90 | 91 | def weight_variance(self,other): 92 | """apply regularization in time""" 93 | value = 0 94 | value += regMetric( nn.utils.convert_parameters.parameters_to_vector(self.parameters()) , 95 | nn.utils.convert_parameters.parameters_to_vector(other.parameters()) ) 96 | return value 97 | 98 | if __name__ == "__main__": 99 | print('DoubleSymLayer test\n') 100 | import testDoubleSymLayer # this should run the test 101 | -------------------------------------------------------------------------------- /modules/PreactDoubleLayer.py: -------------------------------------------------------------------------------- 1 | # PreactDoubleLayer.py 2 | # 2/5/19 3 | 4 | from utils import * 5 | import torch.nn as nn 6 | import copy 7 | 8 | from modules.DoubleLayer import DoubleLayer 9 | 10 | 11 | class PreactDoubleLayer(DoubleLayer): 12 | """ pre-activated version of the DoubleLayer 13 | 14 | N2( act2( K2( N1( act1( K1(Y) ))))) 15 | """ 16 | 17 | def __init__(self, vFeat, params={}): 18 | super().__init__(vFeat, params=params) 19 | 20 | def forward(self,x): 21 | z = self.act1(x) 22 | z = self.conv1(z) 23 | if hasattr(self, 'normLayer1'): 24 | z = self.normLayer1(z) 25 | 26 | z = self.act2(z) 27 | z = self.conv2(z) 28 | if hasattr(self, 'normLayer2'): 29 | z = self.normLayer2(z) 30 | 31 | return z 32 | 33 | 34 | if __name__ == "__main__": 35 | print('PreactDoubleLayer test\n') 36 | import testPreactDoubleLayer # this should run the test 37 | -------------------------------------------------------------------------------- /modules/TvNorm.py: -------------------------------------------------------------------------------- 1 | # TvNorm.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class TvNorm(nn.Module): 7 | """ 8 | normalization using the total variation; idea is to normalize pixel-wise by the length of the feature vector, i.e., 9 | MATLAB notation: 10 | z = diag( 1/ sqrt( sum(x.^2,3)+eps)) x 11 | 12 | Attributes: 13 | eps: small float so no division by 0 14 | weight: scaling weight for the affine transformation 15 | bias: bias for the affine transformation 16 | 17 | """ 18 | def __init__(self, nChan, eps=1e-4): 19 | """ 20 | :param nChan: number of channels for the data you expect to normalize 21 | :param eps: small float so no division by 0 22 | """ 23 | super().__init__() 24 | 25 | self.eps = eps 26 | 27 | # Tv Norm has no tuning of the scaling weights 28 | # self.weight = nn.Parameter(torch.ones(nChan)) 29 | self.register_buffer('weight', torch.ones(nChan)) 30 | 31 | self.bias = nn.Parameter(torch.zeros(nChan)) 32 | 33 | 34 | def forward(self,x): 35 | """ 36 | :param x: inputs tensor, second dim is channels 37 | example dims: (num images in the batch , num channels, height , width) 38 | :return: normalized version with same dimensions as x 39 | """ 40 | z = torch.pow(x, 2) 41 | z = torch.div(x, torch.sqrt(torch.sum(z, dim=1, keepdim=True) + self.eps)) 42 | # assumes that Tensor is formatted ( something , no. of channels, something, something, etc.) 43 | 44 | if self.weight is not None: 45 | w = self.weight.unsqueeze(0) # add first dimension 46 | w = w.unsqueeze(-1) # add last dimension 47 | w = w.unsqueeze(-1) # add last dimension 48 | z = z * w 49 | if self.bias is not None: 50 | b = self.bias.unsqueeze(0) # add first dimension 51 | b = b.unsqueeze(-1) # add last dimension 52 | b = b.unsqueeze(-1) # add last dimension 53 | z = z + b 54 | 55 | return z 56 | 57 | -------------------------------------------------------------------------------- /modules/rk1.py: -------------------------------------------------------------------------------- 1 | # rk1.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | from utils import interpLayer1Dpreallocated 6 | import copy 7 | from modules.ClippedModule import * 8 | 9 | 10 | class rk1(ClippedModule): 11 | """ 12 | Implementation of the Runge-Kutta 1 Dynamic Block 13 | 14 | Attributes: 15 | interpolate: function, the type of interpolation to perform (default is 1D) 16 | tTheta: list, time discretization of the controls 17 | tY: list, time discretization of the states 18 | nTheta: int, number of control layers 19 | hY: list, the step sizes between consecutive time discretization points of tY 20 | midY: list, the midpoints between subsequent tY 21 | controlLayers (sub-module list): all elements are instances of layer representing the control layers, with associated weights 22 | stateLayers (sub-module list): all elements are instances of layer representing the state layers, with no Parameters 23 | """ 24 | def __init__(self, tTheta, tY, layer, layerParams={}): 25 | 26 | super().__init__() 27 | 28 | self.interpolate = interpLayer1Dpreallocated 29 | 30 | if 'vFeat' in layerParams.keys(): 31 | vFeat = layerParams.get('vFeat') 32 | del layerParams['vFeat'] 33 | else: 34 | print('rk4: Error: layerParams must have \'vFeat\' declared') 35 | exit(1) 36 | 37 | 38 | self.tTheta = tTheta 39 | self.tY = tY 40 | self.nTheta = len(tTheta) 41 | 42 | # set up control layers....these need to contain the nn.Parameters and will be used for 43 | # the interpolation of the state layers 44 | self.controlLayers = nn.ModuleList([layer(vFeat, copy.deepcopy(layerParams)) for i in range(self.nTheta)]) 45 | # nn.ModuleList works like a python list of Modules 46 | 47 | # the hi and intermediate Ys for interpolation 48 | self.hY = [next - curr for curr, next in zip(self.tY, self.tY[1:])] 49 | self.midY = [curr + h / 2 for curr, h in zip(self.tY, self.hY)] 50 | # self.stateLayers = [] 51 | self.stateLayers = self.interpolate(self.controlLayers, self.tTheta, self.tY[0: -1]) 52 | 53 | 54 | def setTimeSteps(self, tY, tTheta, blankLayer=None, lastUnused=False, interpolate=None): 55 | """ 56 | have the capability to set tY and tTheta between forward passes 57 | 58 | :param tY: new state layer discretization 59 | :param tTheta: new control layer discretization 60 | :param blankLayer: supply a blank control layer for when new tTheta differs from prev tTheta 61 | :param lastUnused: pass True when the last layer of tTheta is unused (ex. when prev tTheta= prev tY=[0,1,2,3,4]) 62 | in these cases, we want to exclude this untuned layer bc it could negatively influence the interpolation 63 | """ 64 | 65 | # TODO: give ability to vary boundary conditions 66 | # hard-coded boundary conditions are to maintain no change from nearest knot 67 | 68 | if interpolate is None: 69 | if hasattr(self, 'interpolate'): 70 | interpolate = self.interpolate 71 | else: # patch for old networks 72 | self.interpolate = interpLayer1Dpreallocated 73 | 74 | 75 | if self.tTheta != tTheta: 76 | if blankLayer is None: 77 | print('must provide blankLayer if changing tTheta') # TODO: maybe can make more robust: 78 | # check if lengths are the same 79 | print('OR wait for torch to fix: \'Only Tensors created explicitly by the user support deepcopy\'') 80 | 81 | if lastUnused: # last control Layer is not used in RK1 82 | self.controlLayers = interpolate(self.controlLayers[0:-1], self.tTheta[0:-1], tTheta, 83 | blankLayer=blankLayer, bKeepParam=True) 84 | else: 85 | self.controlLayers = interpolate(self.controlLayers, self.tTheta, tTheta, blankLayer=blankLayer, 86 | bKeepParam = True) 87 | 88 | self.tY = tY 89 | self.tTheta = tTheta 90 | # the hi and intermediate Ys for interpolation 91 | self.hY = [next - curr for curr, next in zip(self.tY, self.tY[1:])] 92 | self.midY = [curr + h / 2 for curr, h in zip(self.tY, self.hY)] 93 | 94 | 95 | if lastUnused: 96 | self.stateLayers = interpolate(self.controlLayers[0:-1], self.tTheta[0:-1], self.tY[0: -1],blankLayer=blankLayer) 97 | 98 | else: # for rk1 don't use the last point 99 | self.stateLayers = interpolate(self.controlLayers, self.tTheta, self.tY[0: -1],blankLayer=blankLayer) 100 | 101 | def forward(self, x): 102 | if not hasattr(self, 'interpolate'): # patch for old networks 103 | self.interpolate = interpLayer1Dpreallocated 104 | 105 | nY = len(self.tY) 106 | 107 | self.stateLayers = self.interpolate(self.controlLayers, self.tTheta, self.tY[0: -1], self.stateLayers) 108 | 109 | for k in range(nY-1): 110 | # interpolate the fixed control layers (the Parameters) at tTheta locations to get the 111 | # interpolated state layers at the tY locations 112 | 113 | hi = self.hY[k] 114 | 115 | z1 = self.stateLayers[k](x) 116 | 117 | x = x + hi * z1 118 | 119 | return x 120 | 121 | def train(self, mode=True): 122 | """set train/eval mode for state layers b/c they're not in a ModuleList""" 123 | self.training = mode 124 | nY = len(self.tY) 125 | 126 | for k in range(nY - 1): 127 | for module in self.stateLayers: 128 | module.train(mode) 129 | 130 | for module in self.controlLayers: 131 | module.train(mode) 132 | 133 | return self 134 | 135 | 136 | def weight_variance(self): 137 | """apply regularization in time""" 138 | weight_diffs=[] 139 | nTheta = len(self.tTheta) 140 | for k in range(nTheta - 1): 141 | fReg = self.controlLayers[k].weight_variance(self.controlLayers[k+1]).unsqueeze(0) 142 | weight_diffs.append( self.hY[k] * fReg ) # h * ||x-y||_p 143 | return weight_diffs 144 | 145 | 146 | 147 | 148 | if __name__ == "__main__": 149 | print('rk1 block test\n') 150 | import testRK1Block # this should run the test 151 | 152 | -------------------------------------------------------------------------------- /modules/rk4.py: -------------------------------------------------------------------------------- 1 | # rk4.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | from utils import interpLayer1Dpreallocated 6 | from modules.ClippedModule import * 7 | 8 | class rk4(ClippedModule): 9 | """ 10 | Implementation of the Runge-Kutta 4 Dynamic Block 11 | 12 | Attributes: 13 | interpolate: function, the type of interpolation to perform (default is 1D) 14 | tTheta: list, time discretization of the controls 15 | tY: list, time discretization of the states 16 | nTheta: int, number of control layers 17 | hY: list, the step sizes between consecutive time discretization points of tY 18 | midY: list, the midpoints between subsequent tY 19 | controlLayers (sub-module list): all elements are instances of layer representing the control layers, with associated weights 20 | stateLayers (sub-module list): all elements are instances of layer representing the state layers, with no Parameters 21 | """ 22 | def __init__(self, tTheta, tY, layer, layerParams={}): 23 | super().__init__() 24 | 25 | self.interpolate = interpLayer1Dpreallocated 26 | 27 | if 'vFeat' in layerParams.keys(): 28 | vFeat = layerParams.get('vFeat') 29 | del layerParams['vFeat'] 30 | else: 31 | print('rk4: Error: layerParams must have \'vFeat\' declared') 32 | exit(1) 33 | 34 | 35 | self.tTheta = tTheta 36 | self.tY = tY 37 | self.nTheta = len(tTheta) 38 | 39 | # set up control layers....these need to contain the nn.Parameters and will be used for 40 | # the interpolation of the state layers 41 | self.controlLayers = nn.ModuleList([layer(vFeat, layerParams) for i in range(self.nTheta)]) 42 | # nn.ModuleList works like a python list of Modules 43 | 44 | # the hi and intermediate Ys for interpolation 45 | self.hY = [next - curr for curr, next in zip(self.tY, self.tY[1:])] 46 | self.midY = [curr + h / 2 for curr, h in zip(self.tY, self.hY)] 47 | self.stateLayers=[] 48 | for k in range(len(self.tY) - 1): 49 | self.stateLayers.append( self.interpolate(self.controlLayers, self.tTheta, 50 | [self.tY[k],self.midY[k],self.midY[k],self.tY[k+1]])) 51 | 52 | def setTimeSteps(self, tY, tTheta, blankLayer=None, interpolate=None): 53 | """ 54 | have the capability to set tY and tTheta between forward passes 55 | 56 | :param tY: list, time discretizations for the states 57 | :param tTheta: list, time discretizations for the controls 58 | :param blankLayer: module, blueprint control layer needed to be passed when adding control layers 59 | :param interpolate: function, interpolation method (default is 1D interpolation) 60 | """ 61 | 62 | if interpolate is None: 63 | if hasattr(self, 'interpolate'): 64 | interpolate = self.interpolate 65 | else: # patch for old networks 66 | self.interpolate = interpLayer1Dpreallocated 67 | 68 | self.controlLayers =interpolate(self.controlLayers, self.tTheta,tTheta, blankLayer=blankLayer, bKeepParam=True) 69 | 70 | self.tY = tY 71 | self.tTheta = tTheta 72 | # the hi and intermediate Ys for interpolation 73 | self.hY = [next - curr for curr, next in zip(self.tY, self.tY[1:])] 74 | self.midY = [curr + h / 2 for curr, h in zip(self.tY, self.hY)] 75 | 76 | # if rediscretizing with a batch norm, the interpolation will set running_mean and variance to 0 and 1 77 | self.stateLayers = [] 78 | for k in range(len(self.tY) - 1): 79 | self.stateLayers.append(interpolate(self.controlLayers, self.tTheta, 80 | [self.tY[k], self.midY[k], self.midY[k], self.tY[k + 1]],blankLayer=blankLayer)) 81 | 82 | def forward(self, x): 83 | if not hasattr(self, 'interpolate'): # patch for old networks 84 | self.interpolate = interpLayer1Dpreallocated 85 | 86 | nY = len(self.tY) 87 | # interpolate the fixed control layers (the Parameters) at tTheta locations to get the 88 | # interpolated state layers at the tY locations 89 | 90 | for k in range(nY - 1): 91 | # adjust the weights that will be used for the layers....curr, mid, next 92 | self.interpolate(self.controlLayers, self.tTheta, 93 | [self.tY[k], self.midY[k], self.midY[k], self.tY[k + 1]],self.stateLayers[k]) 94 | 95 | # INSIGHT FROM PLOTTING REGULARIZED MODEL: the runMean and runVar, of both midY are similar, 96 | # as are those for k+1 and then the next k 97 | 98 | 99 | hi = self.hY[k] 100 | 101 | # first intermediate step 102 | z1 = self.stateLayers[k][0](x) 103 | 104 | # second intermediate step 105 | z2 = self.stateLayers[k][1](x + z1 * (hi / 2)) 106 | 107 | # third intermediate step 108 | z3 = self.stateLayers[k][2](x + z2 * (hi / 2)) 109 | 110 | # fourth intermediate 111 | z4 = self.stateLayers[k][3](x + z3 * hi) 112 | 113 | x = x + (hi/6) * (z1 + 2 * z2 + 2 * z3 + z4) 114 | 115 | return x 116 | 117 | def train(self, mode=True): 118 | """set train/eval mode for state layers b/c they're not in a ModuleList""" 119 | self.training = mode 120 | nY = len(self.tY) 121 | 122 | for k in range(nY - 1): 123 | for module in self.stateLayers[k]: 124 | module.train(mode) 125 | 126 | for module in self.controlLayers: 127 | module.train(mode) 128 | 129 | return self 130 | 131 | def weight_variance(self): 132 | """apply regularization in time""" 133 | weight_diffs=[] 134 | nTheta = len(self.tTheta) 135 | for k in range(nTheta - 1): 136 | fReg = self.controlLayers[k].weight_variance(self.controlLayers[k+1]).unsqueeze(0) 137 | weight_diffs.append( self.hY[k] * fReg ) # h * ||x-y||_p 138 | return torch.cat(weight_diffs) 139 | 140 | 141 | 142 | 143 | 144 | if __name__ == "__main__": 145 | print('rk4 block test\n') 146 | import testRK4Block # this should run the test 147 | 148 | -------------------------------------------------------------------------------- /modules/runModuleTests.py: -------------------------------------------------------------------------------- 1 | import modules.testConnectingLayer 2 | import modules.testDoubleLayer 3 | import modules.testDoubleSymLayer 4 | import modules.testPreactDoubleLayer 5 | import modules.testRK1Block 6 | import modules.testRK4Block -------------------------------------------------------------------------------- /modules/testConnectingLayer.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.ConnectingLayer import ConnectingLayer 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from utils import * 8 | import torch.optim as optim 9 | 10 | 11 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 12 | vFeat = [3, 4] 13 | nChan = vFeat[0] 14 | nClasses = 5 15 | 16 | # random batch 17 | x = normalInit([10, 3, 32, 32]).to(device) # (numImages, numChans, image height, image width) 18 | W = normalInit([nChan + 1, nClasses]).to(device) # plus one for the bias 19 | labels = torch.LongTensor([1, 2, 3, 4, 3, 2, 1, 0, 2, 3]).to(device) 20 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 21 | 22 | # ---------------------------------------------------------------------- 23 | # new approach 24 | paramsStruct = {'normLayer': nn.BatchNorm2d(num_features=vFeat[1]), 25 | 'conv': nn.Conv2d(in_channels=vFeat[0], out_channels=vFeat[1], kernel_size=3, padding=1, stride=1)} 26 | 27 | net = ConnectingLayer(vFeat, params=paramsStruct) 28 | net.to(device) 29 | origK = net.conv.weight.data.clone().to(device) # for old method 30 | K = nn.Parameter(origK.clone()).to(device) 31 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 32 | 33 | optimizer.zero_grad() 34 | y1 = net.forward(x) 35 | y1 = F.avg_pool2d(y1, x.shape[2:4]) 36 | 37 | loss1, _ = misfitW(y1, W, labels, device) 38 | loss1.backward() 39 | optimizer.step() 40 | 41 | # ---------------------------------------------------------------------- 42 | def compareFunc(x, Kopen, normFlag='batch', eps=1e-5, device=torch.device('cpu')): 43 | x = conv3x3(x, Kopen) 44 | if normFlag is 'batch': 45 | x = F.batch_norm(x, 46 | running_mean = torch.zeros(Kopen.size(0)).to(device), 47 | running_var = torch.ones(Kopen.size(0)).to(device), 48 | weight = torch.ones(Kopen.size(0)).to(device), 49 | bias = torch.zeros(Kopen.size(0)).to(device), 50 | training = True, eps=eps) 51 | elif normFlag is 'instance': 52 | x = F.instance_norm(x, weight=weightBias[0], bias=weightBias[1]) 53 | x = F.relu(x) 54 | return x 55 | # old method 56 | optimParams = [{'params': K}] 57 | nWeights = 0 58 | optimizer = optim.SGD(optimParams, lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 59 | 60 | optimizer.zero_grad() 61 | 62 | y2 = compareFunc(x, K, device=device) 63 | 64 | y2 = F.avg_pool2d(y2, x.shape[2:4]) 65 | 66 | loss2, _ = misfitW(y2, W, labels, device) 67 | loss2.backward() 68 | optimizer.step() 69 | 70 | # ---------------------------------------------------------------------- 71 | 72 | # print('layer 2-norm difference:', torch.norm(y2 - y1, p=2).data) 73 | # print('loss 2-norm difference: ', torch.norm(loss2 - loss1, p=2).data) 74 | # print('K 2-norm difference: ', torch.norm(net.conv.weight.data - K.data, p=2).data) 75 | # print('K update: ', torch.norm(origK - K.data, p=2).data) 76 | 77 | 78 | tol = 1e-7 79 | assert(torch.norm(y2 - y1, p=2).data < tol) 80 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 81 | assert(torch.norm(net.conv.weight.data - K.data, p=2).data < tol) 82 | 83 | assert( torch.norm(origK - K.data, p=2).data > 1e-4) # want > 0 84 | print('tests passed') 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /modules/testDoubleLayer.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.DoubleLayer import DoubleLayer 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from utils import * 8 | import torch.optim as optim 9 | 10 | 11 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 12 | vFeat = [4, 4] 13 | nChan = vFeat[0] 14 | nClasses = 5 15 | 16 | # random batch 17 | x = normalInit([10, nChan, 32, 32]).to(device) # (numImages, numChans, image height, image width) 18 | W = normalInit([nChan + 1, nClasses]).to(device) # plus one for the bias 19 | labels = torch.LongTensor([1, 2, 3, 4, 3, 2, 1, 0, 2, 3]).to(device) 20 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 21 | 22 | # ---------------------------------------------------------------------- 23 | # new approach 24 | # paramsStruct = {'normLayer1': nn.BatchNorm2d(num_features=nChan), 25 | # 'normLayer2': nn.BatchNorm2d(num_features=nChan), 26 | # 'conv1': nn.Conv2d(in_channels=nChan, out_channels=nChan, kernel_size=3, padding=1, stride=1), 27 | # 'conv2': nn.Conv2d(in_channels=nChan, out_channels=nChan, kernel_size=3, padding=1, stride=1)} 28 | paramsStruct = {'normLayer': nn.BatchNorm2d(num_features=nChan), 29 | 'conv': nn.Conv2d(in_channels=nChan, out_channels=nChan, kernel_size=3, padding=1, stride=1)} 30 | 31 | net = DoubleLayer(vFeat, params=paramsStruct) 32 | net.to(device) 33 | origK1 = net.conv1.weight.data.clone() # for old method 34 | origK2 = net.conv2.weight.data.clone() # for old method 35 | K1 = nn.Parameter(origK1.clone()) 36 | K2 = nn.Parameter(origK2.clone()) 37 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 38 | 39 | optimizer.zero_grad() 40 | y1 = net.forward(x) 41 | y1 = F.avg_pool2d(y1, x.shape[2:4]) 42 | 43 | loss1, _ = misfitW(y1, W, labels, device) 44 | loss1.backward() 45 | optimizer.step() 46 | 47 | # ---------------------------------------------------------------------- 48 | def compareFunc(x, K1, K2 ,device): # functional DoubleLayer 49 | z = conv3x3(x, K1) 50 | z = F.batch_norm(z, running_mean=torch.zeros(K1.size(0) ,device=device), 51 | running_var=torch.ones(K1.size(0) ,device=device), training=True) 52 | z = F.relu(z) 53 | z = conv3x3(z, K2) 54 | z = F.batch_norm(z, running_mean=torch.zeros(K2.size(0), device=device), 55 | running_var=torch.ones(K2.size(0), device=device), training=True) 56 | z = F.relu(z) 57 | return z 58 | # old method 59 | optimParams = [{'params': K1}, {'params':K2}] 60 | nWeights = 0 61 | optimizer = optim.SGD(optimParams, lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 62 | 63 | optimizer.zero_grad() 64 | 65 | y2 = compareFunc(x, K1,K2, device) 66 | 67 | y2 = F.avg_pool2d(y2, x.shape[2:4]) 68 | 69 | loss2, _ = misfitW(y2, W, labels, device) 70 | loss2.backward() 71 | optimizer.step() 72 | 73 | # ---------------------------------------------------------------------- 74 | 75 | # print('layer 2-norm difference:', torch.norm(y2 - y1, p=2).data) # want = 0 76 | # print('loss 2-norm difference: ', torch.norm(loss2 - loss1, p=2).data) # want = 0 77 | # print('K1 2-norm difference:', torch.norm(net.conv1.weight.data - K1.data, p=2).data) # want = 0 78 | # print('K2 2-norm difference:', torch.norm(net.conv2.weight.data - K2.data, p=2).data) # want = 0 79 | # print('K1 update: ', torch.norm(origK1 - K1.data, p=2).data) # want > 0 80 | # print('K2 update: ', torch.norm(origK2 - K2.data, p=2).data) # want > 0 81 | 82 | tol = 1e-5 83 | assert(torch.norm(y2 - y1, p=2).data < tol) 84 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 85 | assert( torch.norm(net.conv1.weight.data - K1.data, p=2).data < tol ) 86 | assert( torch.norm(net.conv2.weight.data - K2.data, p=2).data < tol ) 87 | 88 | assert( torch.norm(origK1 - K1.data, p=2).data > 1e-4) 89 | assert( torch.norm(origK2 - K2.data, p=2).data > 1e-4) 90 | print('tests passed') 91 | 92 | -------------------------------------------------------------------------------- /modules/testDoubleSymLayer.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.DoubleSymLayer import DoubleSymLayer 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from utils import * 8 | import torch.optim as optim 9 | 10 | 11 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 12 | vFeat = [3, 4] 13 | nChan = vFeat[0] 14 | nClasses = 5 15 | 16 | # random batch 17 | x = normalInit([10, 3, 32, 32]).to(device) # (numImages, numChans, image height, image width) 18 | W = normalInit([nChan + 1, nClasses]).to(device) # plus one for the bias 19 | labels = torch.LongTensor([1, 2, 3, 4, 3, 2, 1, 0, 2, 3]).to(device) 20 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 21 | 22 | # ---------------------------------------------------------------------- 23 | # new approach 24 | paramsStruct = {'normLayer': nn.BatchNorm2d(num_features=4), 25 | 'conv': nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, padding=1, stride=1)} 26 | 27 | net = DoubleSymLayer(vFeat, params=paramsStruct) 28 | net.to(device) 29 | origK = net.weight.data.clone() # for old method 30 | K = nn.Parameter(origK.clone()) 31 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 32 | 33 | optimizer.zero_grad() 34 | y1 = net.forward(x) 35 | y1 = F.avg_pool2d(y1, x.shape[2:4]) 36 | 37 | loss1, _ = misfitW(y1, W, labels, device) 38 | loss1.backward() 39 | optimizer.step() 40 | 41 | 42 | # check that both convolutions point to the same parameter 43 | # print(net.convt.weight is net.conv.weight) # true when pointing to the same object 44 | # print(net.convt.weight is net.conv.weight.clone()) # false because different objects 45 | # print(net.convt.weight==net.conv.weight.clone()) # all true because each element is the same value 46 | 47 | assert(net.convt.weight is net.conv.weight) 48 | assert(not net.convt.weight is net.conv.weight.clone() ) 49 | 50 | 51 | # ---------------------------------------------------------------------- 52 | def compareFunc(x, K ,device): # functional DoubleSymLayer 53 | z = conv3x3(x, K) 54 | z = F.batch_norm(z, running_mean=torch.zeros(K.size(0) ,device=device), 55 | running_var=torch.ones(K.size(0) ,device=device), training=True) 56 | z = F.relu(z) 57 | z = - convt3x3(z, K) 58 | return z 59 | # old method 60 | optimParams = [{'params': K}] 61 | nWeights = 0 62 | optimizer = optim.SGD(optimParams, lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 63 | 64 | optimizer.zero_grad() 65 | 66 | y2 = compareFunc(x, K, device) 67 | 68 | y2 = F.avg_pool2d(y2, x.shape[2:4]) 69 | 70 | loss2, _ = misfitW(y2, W, labels, device) 71 | loss2.backward() 72 | optimizer.step() 73 | 74 | # ---------------------------------------------------------------------- 75 | 76 | # print('layer 2-norm difference:', torch.norm(y2 - y1, p=2).data) # want = 0 77 | # print('loss 2-norm difference: ', torch.norm(loss2 - loss1, p=2).data) # want = 0 78 | # print('K 2-norm difference: ', torch.norm(net.weight.data - K.data, p=2).data) # want = 0 79 | # print('K update: ', torch.norm(origK - K.data, p=2).data) # want > 0 80 | 81 | tol = 1e-5 82 | assert(torch.norm(y2 - y1, p=2).data < tol) 83 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 84 | assert(torch.norm(net.weight.data - K.data, p=2).data < tol) 85 | 86 | assert(torch.norm(origK - K.data, p=2).data > 1e-4) # want > 0 87 | print('tests passed') 88 | 89 | 90 | -------------------------------------------------------------------------------- /modules/testPreactDoubleLayer.py: -------------------------------------------------------------------------------- 1 | 2 | from modules.PreactDoubleLayer import PreactDoubleLayer 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from utils import * 8 | import torch.optim as optim 9 | 10 | 11 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 12 | vFeat = [4, 4] 13 | nChan = vFeat[0] 14 | nClasses = 5 15 | 16 | # random batch 17 | x = normalInit([10, nChan, 32, 32]).to(device) # (numImages, numChans, image height, image width) 18 | W = normalInit([nChan + 1, nClasses]).to(device) # plus one for the bias 19 | labels = torch.LongTensor([1, 2, 3, 4, 3, 2, 1, 0, 2, 3]).to(device) 20 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 21 | 22 | # ---------------------------------------------------------------------- 23 | # new approach 24 | paramsStruct = {'normLayer1': nn.BatchNorm2d(num_features=nChan), 25 | 'normLayer2': nn.BatchNorm2d(num_features=nChan), 26 | 'conv1': nn.Conv2d(in_channels=nChan, out_channels=nChan, kernel_size=3, padding=1, stride=1), 27 | 'conv2': nn.Conv2d(in_channels=nChan, out_channels=nChan, kernel_size=3, padding=1, stride=1)} 28 | 29 | net = PreactDoubleLayer(vFeat, params=paramsStruct) 30 | net.to(device) 31 | origK1 = net.conv1.weight.data.clone() # for old method 32 | origK2 = net.conv2.weight.data.clone() # for old method 33 | K1 = nn.Parameter(origK1.clone()) 34 | K2 = nn.Parameter(origK2.clone()) 35 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 36 | 37 | optimizer.zero_grad() 38 | y1 = net.forward(x) 39 | y1 = F.avg_pool2d(y1, x.shape[2:4]) 40 | 41 | loss1, _ = misfitW(y1, W, labels, device) 42 | loss1.backward() 43 | optimizer.step() 44 | 45 | # ---------------------------------------------------------------------- 46 | def compareFunc(x, K1, K2 ,device): # functional Preactivated DoubleLayer 47 | z = F.relu(x) 48 | z = conv3x3(z, K1) 49 | z = F.batch_norm(z, running_mean=torch.zeros(K1.size(0) ,device=device), 50 | running_var=torch.ones(K1.size(0) ,device=device), training=True) 51 | z = F.relu(z) 52 | z = conv3x3(z, K2) 53 | z = F.batch_norm(z, running_mean=torch.zeros(K2.size(0), device=device), 54 | running_var=torch.ones(K2.size(0), device=device), training=True) 55 | return z 56 | # old method 57 | optimParams = [{'params': K1}, {'params':K2}] 58 | nWeights = 0 59 | optimizer = optim.SGD(optimParams, lr=1e-1, momentum=0.9, weight_decay=0, nesterov=False) 60 | 61 | optimizer.zero_grad() 62 | 63 | y2 = compareFunc(x, K1,K2, device) 64 | 65 | y2 = F.avg_pool2d(y2, x.shape[2:4]) 66 | 67 | loss2, _ = misfitW(y2, W, labels, device) 68 | loss2.backward() 69 | optimizer.step() 70 | 71 | # ---------------------------------------------------------------------- 72 | 73 | # print('layer 2-norm difference:', torch.norm(y2 - y1, p=2).data) # want = 0 74 | # print('loss 2-norm difference: ', torch.norm(loss2 - loss1, p=2).data) # want = 0 75 | # print('K1 2-norm difference:', torch.norm(net.conv1.weight.data - K1.data, p=2).data) # want = 0 76 | # print('K2 2-norm difference:', torch.norm(net.conv2.weight.data - K2.data, p=2).data) # want = 0 77 | # print('K1 update: ',torch.norm(origK1 - K1.data, p=2).data) # want > 0 78 | # print('K2 update: ',torch.norm(origK2 - K2.data, p=2).data) # want > 0 79 | 80 | tol = 1e-5 81 | assert(torch.norm(y2 - y1, p=2).data < tol) 82 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 83 | assert( torch.norm(net.conv1.weight.data - K1.data, p=2).data < tol ) 84 | assert( torch.norm(net.conv2.weight.data - K2.data, p=2).data < tol ) 85 | 86 | assert( torch.norm(origK1 - K1.data, p=2).data > 1e-4) 87 | assert( torch.norm(origK2 - K2.data, p=2).data > 1e-4) 88 | print('tests passed') 89 | -------------------------------------------------------------------------------- /modules/testRK1Block.py: -------------------------------------------------------------------------------- 1 | # testRK1Block.py 2 | 3 | from modules.rk1 import * 4 | from utils import * 5 | from modules.DoubleSymLayer import DoubleSymLayer 6 | 7 | 8 | torch.set_printoptions(precision=10) 9 | 10 | 11 | def DoubleSymLayerBatchNorm(x, K): 12 | z = conv3x3(x, K) 13 | z = F.batch_norm(z, running_mean=torch.zeros(K.size(0)), 14 | running_var=torch.ones(K.size(0)), training=True) 15 | z = F.relu(z) 16 | z = - convt3x3(z, K) 17 | return z 18 | 19 | 20 | # OLD one 21 | def RK1DoubleSymBlock(x,Kresnet,tTheta, tY): # weightBias = [None]*2*nt 22 | """ 23 | Implementation of a Runge-Kutta 4 block that uses the DoubleSym Layer at each of the steps 24 | """ 25 | nt = len(tY)-1 # should also be the len(Kresnet)-1) 26 | 27 | for k in range(nt): 28 | tYk = tY[k] 29 | hi = tY[k + 1] - tYk 30 | 31 | interpK = inter1D(Kresnet, tTheta, [tYk]) 32 | 33 | # first intermediate step 34 | z1 = DoubleSymLayerBatchNorm(x, interpK[0]) 35 | 36 | x = x + hi*z1 37 | 38 | return x 39 | 40 | 41 | # if __name__ == "__main__": 42 | 43 | tTheta = [0 , 2] 44 | tY = [0,1,2] 45 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 46 | device = torch.device('cpu') 47 | nChan = 2 48 | nClasses = 5 49 | x = normalInit([10 , nChan , 32 , 32]).to(device) # (numImages, numChans, image height, image width) 50 | 51 | 52 | labels = torch.LongTensor([1 ,2 ,3 ,4 ,3 ,2 ,1 ,0 ,2 ,3]).to(device) 53 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 54 | 55 | 56 | 57 | # dsLayer = DoubleSymLayer( [nChan,7] , params={'norm': nn.BatchNorm2d(nChan)}) 58 | dsLayer = DoubleSymLayer 59 | layerParams = { 'vFeat': [nChan ,nChan+1] , 60 | 'act' : nn.ReLU(), 61 | 'normLayer': nn.BatchNorm2d(nChan+1)} 62 | # layerParams = {'vFeat': [ nChan,3], 'act': nn.ReLU()} 63 | net = rk1( tTheta, tY, dsLayer, layerParams=layerParams) 64 | # for old method 65 | origWeights =[] 66 | for i in range(len(net.controlLayers)): 67 | origWeights.append(net.controlLayers[i].weight.data.clone()) 68 | 69 | 70 | net.to(device) 71 | 72 | stdv = 0.1 73 | W = stdv * torch.randn(nChan +1, nClasses).to(device) # plus one for the bias 74 | optimizer = optim.SGD(net.parameters(), lr=1, momentum=0.9, weight_decay=0, nesterov=False) 75 | 76 | optimizer.zero_grad() 77 | y1 = net.forward(x) 78 | y3 = F.avg_pool2d(y1, x.shape[2:4]) 79 | 80 | loss1, _ = misfitW(y3, W, labels, device) 81 | loss1.backward() 82 | optimizer.step() 83 | 84 | 85 | # -------------- 86 | 87 | # use same initializations 88 | K = [] 89 | for i in range(len(origWeights)): 90 | K.append( nn.Parameter(origWeights[i].clone()) ) 91 | 92 | params = [] 93 | nWeights = 0 94 | params, nWeights = appendTensorList(params, K, nWeights) 95 | optimizer2 = optim.SGD(params, lr=1, momentum=0.9, weight_decay=0, nesterov=False) 96 | 97 | # for i in range(len(K)): 98 | # K[i] = K[i].to(device) 99 | 100 | optimizer2.zero_grad() 101 | 102 | class compareNet(nn.Module): 103 | def __init__(self): 104 | super().__init__() 105 | def forward(self,x, K, tTheta, tY): 106 | y2 = RK1DoubleSymBlock(x, K, tTheta, tY) 107 | 108 | y4 = F.avg_pool2d(y2, x.shape[2:4]) 109 | 110 | return y2,y4 111 | 112 | net2 = compareNet() 113 | y2,y4 = net2.forward(x,K,tTheta,tY) 114 | 115 | loss2, _ = misfitW(y4, W, labels, device) 116 | loss2.backward() 117 | optimizer2.step() 118 | 119 | 120 | # print('block 2-norm difference:', torch.norm(y2 - y1, p=2).data) # want = 0 121 | # print('net 2-norm difference:', torch.norm(y4 - y3, p=2).data) # want = 0 122 | # print('loss 2-norm difference:', torch.norm(loss2 - loss1, p=2).data) # want = 0 123 | # print('K 2-norm difference:', listNorm([net.controlLayers[i].weight for i in range(len(net.controlLayers))] ,K)) 124 | # # want = 0 125 | # print('K update: ', listNorm(origWeights ,K)) # want > 0 126 | 127 | 128 | tol = 5e-6 129 | assert(torch.norm(y2 - y1, p=2).data < tol) 130 | assert(torch.norm(y4 - y3, p=2).data < tol) 131 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 132 | assert( listNorm([net.controlLayers[i].weight for i in range(len(net.controlLayers))] ,K) < tol) 133 | 134 | assert( listNorm(origWeights ,K) > 1e-4) # want > 0 135 | print('tests passed') -------------------------------------------------------------------------------- /modules/testRK4Block.py: -------------------------------------------------------------------------------- 1 | from modules.rk4 import * 2 | from utils import * 3 | from modules.DoubleSymLayer import DoubleSymLayer 4 | 5 | 6 | 7 | 8 | torch.set_printoptions(precision=10) 9 | 10 | 11 | def DoubleSymLayerBatchNorm(x, K): 12 | z = conv3x3(x, K) 13 | z = F.batch_norm(z, running_mean=torch.zeros(K.size(0)), 14 | running_var=torch.ones(K.size(0)), training=True) 15 | z = F.relu(z) 16 | z = - convt3x3(z, K) 17 | return z 18 | 19 | 20 | # OLD one 21 | def rk4DoubleSymBlock(x,Kresnet, tTheta, tY, act=F.relu, weightBias = [None] , tvnorm=True): 22 | """ 23 | Implementation of a Runge-Kutta 4 block that uses the DoubleSym Layer at each of the steps 24 | """ 25 | nt = len(tY)-1 26 | 27 | for k in range(nt): 28 | tYk = tY[k] 29 | hi = tY[k + 1] - tYk 30 | 31 | if weightBias[0] is not None: # really should check if all are None 32 | interpWeight = inter1D(weightBias[0::2], tTheta, [tYk, tYk + hi / 2, tYk + hi]) # scaling weight 33 | interpBias = inter1D(weightBias[1::2], tTheta, [tYk, tYk + hi / 2, tYk + hi]) # bias 34 | else: 35 | interpWeight = [None,None,None] 36 | interpBias = [None,None,None] 37 | 38 | interpK = inter1D(Kresnet, tTheta, [tYk , tYk+hi/2 , tYk+hi ]) 39 | 40 | # first intermediate step 41 | z1 = DoubleSymLayerBatchNorm( x , interpK[0] ) 42 | 43 | # second intermediate step 44 | z2 = DoubleSymLayerBatchNorm( x + z1*(hi/2) , interpK[1]) 45 | 46 | # third intermediate step 47 | z3 = DoubleSymLayerBatchNorm( x + z2*(hi/2) , interpK[1] ) 48 | 49 | # fourth intermediate 50 | z4 = DoubleSymLayerBatchNorm( x + z3*hi , interpK[2] ) 51 | 52 | x = x + hi*(z1 + 2*z2 + 2*z3 + z4)/6 53 | 54 | return x 55 | 56 | 57 | # if __name__ == "__main__": 58 | 59 | tTheta = [0 , 2] 60 | tY = [0,1,2] 61 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 62 | device = torch.device('cpu') 63 | nChan = 2 64 | nClasses = 5 65 | x = normalInit([10 , nChan , 32 , 32]).to(device) # (numImages, numChans, image height, image width) 66 | 67 | 68 | labels = torch.LongTensor([1 ,2 ,3 ,4 ,3 ,2 ,1 ,0 ,2 ,3]).to(device) 69 | Kconnect = normalInit([nChan, nChan, 1, 1]).to(device) 70 | 71 | 72 | 73 | # dsLayer = DoubleSymLayer( [nChan,7] , params={'norm': nn.BatchNorm2d(nChan)}) 74 | dsLayer = DoubleSymLayer 75 | layerParams = { 'vFeat': [nChan ,nChan+1] , 76 | 'act' : nn.ReLU(), 77 | 'normLayer': nn.BatchNorm2d(nChan+1)} 78 | # layerParams = {'vFeat': [ nChan,3], 'act': nn.ReLU()} 79 | net = rk4( tTheta, tY, dsLayer, layerParams=layerParams) 80 | # for old method 81 | origWeights =[] 82 | for i in range(len(net.controlLayers)): 83 | origWeights.append(net.controlLayers[i].weight.data.clone()) 84 | 85 | 86 | net.to(device) 87 | 88 | stdv = 0.1 89 | W = stdv * torch.randn(nChan +1, nClasses).to(device) # plus one for the bias 90 | optimizer = optim.SGD(net.parameters(), lr=1, momentum=0.9, weight_decay=0, nesterov=False) 91 | 92 | optimizer.zero_grad() 93 | y1 = net.forward(x) 94 | y3 = F.avg_pool2d(y1, x.shape[2:4]) 95 | 96 | loss1, _ = misfitW(y3, W, labels, device) 97 | loss1.backward() 98 | optimizer.step() 99 | 100 | 101 | # -------------- 102 | 103 | # use same initializations 104 | K = [] 105 | for i in range(len(origWeights)): 106 | K.append( nn.Parameter(origWeights[i].clone()) ) 107 | 108 | params = [] 109 | nWeights = 0 110 | params, nWeights = appendTensorList(params, K, nWeights) 111 | optimizer2 = optim.SGD(params, lr=1, momentum=0.9, weight_decay=0, nesterov=False) 112 | 113 | # for i in range(len(K)): 114 | # K[i] = K[i].to(device) 115 | 116 | optimizer2.zero_grad() 117 | 118 | class compareNet(nn.Module): 119 | def __init__(self): 120 | super().__init__() 121 | def forward(self,x, K, tTheta, tY): 122 | y2 = rk4DoubleSymBlock(x, K, tTheta, tY, act=F.relu, weightBias=[None], tvnorm=False) 123 | 124 | y4 = F.avg_pool2d(y2, x.shape[2:4]) 125 | 126 | return y2,y4 127 | 128 | net2 = compareNet() 129 | y2,y4 = net2.forward(x,K,tTheta,tY) 130 | 131 | loss2, _ = misfitW(y4, W, labels, device) 132 | loss2.backward() 133 | optimizer2.step() 134 | 135 | 136 | # print('block 2-norm difference:', torch.norm(y2 - y1, p=2).data) # want = 0 137 | # print('net 2-norm difference:', torch.norm(y4 - y3, p=2).data) # want = 0 138 | # print('loss 2-norm difference:', torch.norm(loss2 - loss1, p=2).data) # want = 0 139 | # print('K 2-norm difference:', listNorm([net.controlLayers[i].weight for i in range(len(net.controlLayers))] ,K)) 140 | # # want = 0 141 | # print('K update: ', listNorm(origWeights ,K)) # want > 0 142 | 143 | tol = 1e-4 144 | assert(torch.norm(y2 - y1, p=2).data < tol) 145 | assert(torch.norm(y4 - y3, p=2).data < tol) 146 | assert(torch.norm(loss2 - loss1, p=2).data < tol) 147 | assert( listNorm([net.controlLayers[i].weight for i in range(len(net.controlLayers))] ,K) < tol) 148 | 149 | assert( listNorm(origWeights ,K) > 1e-4) # want > 0 150 | print('tests passed') 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | # # make auto-differentiation graph 161 | # from graphModel import make_dot 162 | # 163 | # g = make_dot(y3, net.state_dict()) 164 | # g.view() 165 | 166 | 167 | # # see all the parameters.....net.parameters() is a generator 168 | # print('\nparameters of the model:\n') 169 | # for i in enumerate(net.parameters()): 170 | # print(i) 171 | -------------------------------------------------------------------------------- /naming.md: -------------------------------------------------------------------------------- 1 | #### Naming convention 2 | 3 | nBlah - integer or 'number of' Blah 4 | 5 | fBlah - float 6 | 7 | sBlah - string 8 | 9 | vBlah - vector 10 | 11 | mBlah - matrix 12 | 13 | idxBlah - indices 14 | 15 | npBlah - numpy item 16 | 17 | Our default is to represent most variables as tensors. 18 | 19 | 20 | ### Common variables 21 | 22 | tY - time points in Y (these correspond to the state layers) 23 | 24 | tTheta - time points in theta (these correspond to the control layers) 25 | 26 | 27 | -------------------------------------------------------------------------------- /paper.bib: -------------------------------------------------------------------------------- 1 | 2 | 3 | @inproceedings{chang2018reversible, 4 | title={Reversible Architectures for Arbitrarily Deep Residual Neural Networks}, 5 | author={Chang, Bo and Meng, Lili and Haber, Eldad and Ruthotto, Lars and Begert, David and Holtham, Elliot}, 6 | booktitle = {AAAI Conference on {AI}}, 7 | url={https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewPaper/16517}, 8 | year={2018} 9 | } 10 | 11 | 12 | @incollection{chen2018neural, 13 | title = {Neural Ordinary Differential Equations}, 14 | author = {Chen, Tian Qi and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David K}, 15 | booktitle = {Advances in Neural Information Processing Systems 31}, 16 | pages = {6571--6583}, 17 | year = {2018}, 18 | url = {http://papers.nips.cc/paper/7892-neural-ordinary-differential-equations.pdf} 19 | } 20 | 21 | @article{gholami2019anode, 22 | author={Gholami, Amir and Keutzer, Kurt and Biros, George}, 23 | title = {{ANODE:} Unconditionally Accurate Memory-Efficient Gradients for Neural 24 | ODEs}, 25 | journal = {arXiv:1902.10298}, 26 | year = {2019}, 27 | url = {http://arxiv.org/abs/1902.10298}, 28 | archivePrefix = {arXiv}, 29 | eprint = {1902.10298} 30 | } 31 | 32 | 33 | @InProceedings{haber2019imexnet, 34 | title = {{IMEX}net A Forward Stable Deep Neural Network}, 35 | author = {Haber, Eldad and Lensink, Keegan and Treister, Eran and Ruthotto, Lars}, 36 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 37 | pages = {2525--2534}, 38 | year = {2019}, 39 | volume = {97}, 40 | month = {09--15 Jun}, 41 | pdf = {http://proceedings.mlr.press/v97/haber19a/haber19a.pdf}, 42 | url = {http://proceedings.mlr.press/v97/haber19a.html}, 43 | } 44 | 45 | @article{haber2017stable, 46 | doi = {10.1088/1361-6420/aa9a90}, 47 | url = {https://doi.org/10.1088%2F1361-6420%2Faa9a90}, 48 | year = 2017, 49 | month = {dec}, 50 | publisher = {{IOP} Publishing}, 51 | volume = {34}, 52 | number = {1}, 53 | pages = {014004}, 54 | author = {Eldad Haber and Lars Ruthotto}, 55 | title = {Stable architectures for deep neural networks}, 56 | journal = {Inverse Problems}, 57 | annote = {Supported by NSF DMS 1522599} 58 | } 59 | 60 | @inproceedings{he2016deep, 61 | title={Deep residual learning for image recognition}, 62 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 63 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 64 | pages={770--778}, 65 | year={2016}, 66 | url={https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf} 67 | } 68 | 69 | @InProceedings{lu2018, 70 | title = {Beyond Finite Layer Neural Networks: Bridging Deep Architectures and Numerical Differential Equations}, 71 | author = {Lu, Yiping and Zhong, Aoxiao and Li, Quanzheng and Dong, Bin}, 72 | booktitle = {Proceedings of the 35th International Conference on Machine Learning}, 73 | pages = {3276--3285}, 74 | year = {2018}, 75 | volume = {80}, 76 | month = {10--15 Jul}, 77 | pdf = {http://proceedings.mlr.press/v80/lu18d/lu18d.pdf}, 78 | url = {http://proceedings.mlr.press/v80/lu18d.html}, 79 | } 80 | 81 | @article{ruthotto2018deep, 82 | title={Deep Neural Networks motivated by Partial Differential Equations}, 83 | author={Ruthotto, Lars and Haber, Eldad}, 84 | journal={arXiv:1804.04272}, 85 | year={2018}, 86 | url={https://arxiv.org/abs/1804.04272}, 87 | annote ={Supported by NSF DMS 1522599} 88 | } 89 | 90 | 91 | @article{weinan2017, 92 | author = {Weinan, E.}, 93 | title = {{A Proposal on Machine Learning via Dynamical Systems}}, 94 | journal = {Comm. Math. Statist.}, 95 | year = {2017}, 96 | volume = {5}, 97 | number = {1}, 98 | pages = {1--11}, 99 | publisher = {Springer Berlin Heidelberg}, 100 | doi = {10.1007/s40304-017-0103-z}, 101 | url = {http://link.springer.com/10.1007/s40304-017-0103-z} 102 | } 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'DynamicBlocks: A Generalized ResNet Architecture' 3 | tags: 4 | - Pytorch 5 | - PDE-based Machine Learning 6 | - Discretize-Optimize 7 | authors: 8 | - name: Derek Onken 9 | orcid: 0000-0002-4640-767X 10 | affiliation: 1 # (Multiple affiliations must be quoted) 11 | - name: Simion Novikov 12 | affiliation: 2 13 | - name: Eran Treister 14 | affiliation: 2 15 | - name: Eldad Haber 16 | affiliation: 3 17 | - name: Lars Ruthotto 18 | affiliation: "4,1" 19 | affiliations: 20 | - name: Department of Computer Science, Emory University 21 | index: 1 22 | - name: Ben-Gurion University of the Negev 23 | index: 2 24 | - name: University of British-Columbia 25 | index: 3 26 | - name: Department of Mathematics, Emory University 27 | index: 4 28 | date: 9 June 2019 29 | bibliography: paper.bib 30 | --- 31 | 32 | # Summary 33 | 34 | Deep Residual Neural Networks (ResNets) demonstrate impressive performance on several image classification tasks [@he2016deep]. ResNets feature a skip connection, which observably increases a model's robustness to vanishing and exploding gradients. ResNets also possess interesting theoretical properties; for example, the forward propagation through a ResNet can be interpreted as an explicit Euler method applied to a nonlinear ordinary differential equation (ODE) [@weinan2017; @haber2017stable]. Similarly, @ruthotto2018deep introduces a similar interpretation of convolutional ResNets as partial differential equations (PDE). These insights provide more theoretical narrative and insight while motivating new network architectures from different types of differential equations, e.g., reversible hyperbolic network [@chang2018reversible] or parabolic networks [@ruthotto2018deep]. Recent attention focuses on improving the time integrators used in forward propagation, e.g., higher-order single and multistep methods [@lu2018], black-box time integrators [@chen2018neural], and semi-implicit discretizations [@haber2019imexnet]. 35 | 36 | This toolbox exists primarily to facilitate further research and development in convolutional residual neural networks motivated by PDEs. To this end, we generalize the notation of ResNets, referring to each of its parts that amend a PDE interpretation (i.e., each set of several consecutive convolutional layers of fixed number of channels) as a ``dynamic block``. A dynamic block can then be compactly described by a layer function and parameters of the time integrator (e.g., time discretization points). We provide the flexibility to model several state-of-the-art networks by combining several dynamic blocks (acting on different image resolutions and number of channels) through connective units (i.e., a convolutional layer to change the number of channels followed by a pooling layer). 37 | 38 | Our ``DynamicBlocks`` toolbox provides a general framework to experiment with different time discretizations and PDE solvers. In its first version, we include capabilities to obtain a more general version of ResNet based on a forward Euler (Runge-Kutta 1) module that can handle arbitrary, non-uniform time steps. By allowing different time discretizations for the features and parameters, we provide the ability to decouple the states (features) and controls (parameters). Through this decoupling, we can determine the separate relationships of weights and layers with model performance. Furthermore, we include a fourth-order accurate Runge-Kuttta 4 block to demonstrate the generalizability of the architecture to accept other PDE solvers. 39 | 40 | In the training, we primarily focus on discretize-optimize learning methods, which are popular in optimal control and whose favorable properties for ResNets are shown in @gholami2019anode. However, dynamic blocks can also employ optimize-then-discretize approaches as in Neural ODEs [@chen2018neural]. 41 | 42 | # Acknowledgements 43 | 44 | This material is in part based upon work supported by the National Science Foundation under 45 | Grant Number DMS-1751636. Any opinions, findings, and conclusions or recommendations expressed 46 | in this material are those of the authors and do not necessarily reflect the views of the 47 | National Science Foundation. 48 | 49 | # References -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmoryMLIP/DynamicBlocks/52acc9fbc1a2640c6ac8922fa18105279ccaea97/paper.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter==1.0.0 2 | matplotlib==3.0.2 3 | numpy==1.15.2 4 | scipy==1.1.0 5 | torch==1.0.1 6 | torchvision==0.2.1 7 | -------------------------------------------------------------------------------- /results/default_RK4_DoubleSym_cifar10_2019_06_11_00_36_36.txt: -------------------------------------------------------------------------------- 1 | device cuda:0 2 | time steps Y: [0, 1, 2, 3, 4] 3 | time steps theta: [0, 1, 2, 3, 4] 4 | no. of channels: [16, 32, 64, 64] 5 | no. of epochs: 120 6 | ODE reg param: 0.0 7 | Use checkpointing: False 8 | Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz 9 | Files already downloaded and verified 10 | Files already downloaded and verified 11 | training on 40000 images ; validating on 10000 images 12 | Training 251866 weights in 78 layers 13 | optimizer = SGD with momentum=9.0e-01, weight_decay=4.0e-04, nesterov=0, batchSize=125 14 | epoch time |runMean| |runVar| lr |params| avgLoss acc valLoss valAcc 15 | 1 81.3 4.01e+00 1.13e+01 1.00e-01 1.20e+01 1.68e+00 37.43 1.77e+00 39.75 16 | 2 81.4 2.26e+00 1.16e+00 1.00e-01 1.02e+01 1.29e+00 53.05 1.88e+00 39.89 17 | 3 81.2 1.29e+00 5.12e-01 1.00e-01 8.94e+00 1.08e+00 61.24 1.79e+00 48.46 18 | 4 81.2 1.21e+00 4.47e-01 1.00e-01 8.23e+00 9.79e-01 64.99 1.93e+00 47.54 19 | 5 80.2 1.13e+00 5.44e-01 1.00e-01 8.10e+00 9.10e-01 67.41 1.54e+00 54.43 20 | 6 80.9 1.57e+00 2.85e-01 1.00e-01 8.20e+00 8.54e-01 69.64 1.25e+00 58.08 21 | 7 81.2 1.47e+00 3.37e-01 1.00e-01 8.31e+00 7.98e-01 72.00 1.40e+00 56.32 22 | 8 81.0 1.39e+00 3.94e-01 1.00e-01 8.25e+00 7.53e-01 73.52 2.53e+00 40.00 23 | 9 81.5 1.12e+00 2.53e-01 1.00e-01 8.37e+00 7.37e-01 74.40 2.12e+00 49.23 24 | 10 81.0 1.24e+00 2.40e-01 1.00e-01 8.30e+00 6.95e-01 75.92 1.25e+00 58.90 25 | 11 80.8 1.57e+00 3.01e-01 1.00e-01 8.31e+00 6.75e-01 76.51 1.20e+00 61.67 26 | 12 81.1 1.44e+00 3.41e-01 1.00e-01 8.38e+00 6.55e-01 77.25 1.39e+00 61.13 27 | 13 81.3 1.28e+00 3.48e-01 1.00e-01 8.37e+00 6.41e-01 77.85 1.63e+00 52.88 28 | 14 81.4 1.35e+00 3.05e-01 1.00e-01 8.41e+00 6.24e-01 78.55 9.59e-01 67.18 29 | 15 81.4 1.26e+00 2.61e-01 1.00e-01 8.54e+00 6.16e-01 78.56 1.43e+00 59.09 30 | 16 81.3 1.37e+00 2.51e-01 1.00e-01 8.62e+00 6.10e-01 78.88 1.79e+00 56.17 31 | 17 81.4 1.45e+00 2.45e-01 1.00e-01 8.61e+00 6.00e-01 79.49 1.41e+00 59.91 32 | 18 80.9 1.52e+00 3.17e-01 1.00e-01 8.69e+00 5.93e-01 79.56 8.96e-01 70.41 33 | 19 81.1 1.40e+00 1.97e-01 1.00e-01 8.57e+00 5.86e-01 79.67 1.43e+00 56.84 34 | 20 81.1 1.56e+00 2.91e-01 1.00e-01 8.69e+00 5.78e-01 80.04 1.70e+00 52.18 35 | 21 81.1 1.62e+00 2.84e-01 1.00e-01 8.80e+00 5.72e-01 80.28 1.27e+00 62.04 36 | 22 81.1 1.46e+00 2.91e-01 1.00e-01 8.83e+00 5.68e-01 80.53 1.51e+00 58.77 37 | 23 81.3 1.50e+00 3.28e-01 1.00e-01 8.87e+00 5.62e-01 80.60 1.48e+00 60.29 38 | 24 81.1 1.39e+00 2.79e-01 1.00e-01 8.80e+00 5.60e-01 80.67 1.40e+00 57.79 39 | 25 80.9 1.49e+00 2.68e-01 1.00e-01 8.78e+00 5.48e-01 81.19 1.51e+00 59.04 40 | 26 81.4 1.59e+00 1.89e-01 1.00e-01 8.78e+00 5.48e-01 80.97 1.75e+00 56.00 41 | 27 81.0 1.58e+00 1.88e-01 1.00e-01 8.87e+00 5.40e-01 81.32 1.42e+00 62.08 42 | 28 81.2 1.33e+00 2.21e-01 1.00e-01 8.89e+00 5.41e-01 81.41 1.17e+00 64.33 43 | 29 81.1 1.29e+00 1.79e-01 1.00e-01 8.98e+00 5.44e-01 81.25 7.93e-01 72.70 44 | 30 81.5 1.33e+00 2.20e-01 1.00e-01 8.97e+00 5.35e-01 81.61 8.57e-01 71.78 45 | 31 81.1 1.54e+00 2.10e-01 1.00e-01 8.97e+00 5.32e-01 81.77 1.25e+00 61.04 46 | 32 80.8 1.51e+00 1.91e-01 1.00e-01 8.91e+00 5.28e-01 81.79 1.14e+00 64.05 47 | 33 80.9 1.24e+00 2.00e-01 1.00e-01 9.02e+00 5.30e-01 81.77 1.18e+00 62.48 48 | 34 80.7 1.30e+00 2.31e-01 1.00e-01 9.03e+00 5.29e-01 81.87 1.82e+00 52.72 49 | 35 80.8 1.65e+00 1.99e-01 1.00e-01 8.91e+00 5.25e-01 82.04 1.11e+00 64.51 50 | 36 80.8 1.54e+00 1.93e-01 1.00e-01 8.97e+00 5.11e-01 82.50 1.17e+00 63.89 51 | 37 80.2 1.40e+00 2.08e-01 1.00e-01 9.05e+00 5.24e-01 82.03 1.32e+00 63.24 52 | 38 80.7 1.60e+00 2.42e-01 1.00e-01 9.09e+00 5.19e-01 82.08 1.28e+00 66.75 53 | 39 80.8 1.68e+00 2.85e-01 1.00e-01 9.17e+00 5.17e-01 82.05 1.22e+00 64.89 54 | 40 80.7 1.27e+00 2.22e-01 1.00e-01 9.10e+00 5.15e-01 82.46 1.00e+00 70.43 55 | 41 80.7 1.38e+00 1.69e-01 1.00e-01 9.14e+00 5.08e-01 82.51 1.05e+00 66.63 56 | 42 81.3 1.53e+00 1.65e-01 1.00e-01 9.18e+00 5.15e-01 82.14 8.60e-01 71.01 57 | 43 81.0 1.36e+00 2.85e-01 1.00e-01 9.16e+00 5.08e-01 82.41 1.91e+00 54.35 58 | 44 80.8 1.56e+00 2.37e-01 1.00e-01 9.16e+00 5.06e-01 82.69 1.31e+00 63.01 59 | 45 81.1 1.64e+00 2.07e-01 1.00e-01 9.24e+00 5.11e-01 82.52 1.23e+00 64.42 60 | 46 81.3 1.46e+00 1.77e-01 1.00e-01 9.18e+00 5.13e-01 82.30 2.48e+00 49.32 61 | 47 81.1 1.43e+00 2.16e-01 1.00e-01 9.19e+00 4.98e-01 82.91 1.70e+00 49.98 62 | 48 81.0 1.40e+00 2.91e-01 1.00e-01 9.26e+00 5.02e-01 82.63 2.02e+00 54.94 63 | 49 81.2 1.57e+00 2.80e-01 1.00e-01 9.28e+00 5.01e-01 82.74 1.09e+00 66.99 64 | 50 81.1 1.43e+00 1.40e-01 1.00e-01 9.26e+00 5.00e-01 82.78 1.52e+00 56.58 65 | 51 80.9 1.40e+00 1.55e-01 1.00e-01 9.21e+00 4.97e-01 82.93 1.96e+00 53.74 66 | 52 81.1 1.42e+00 2.35e-01 1.00e-01 9.20e+00 5.04e-01 82.77 1.36e+00 64.37 67 | 53 81.3 1.47e+00 2.15e-01 1.00e-01 9.23e+00 4.95e-01 83.03 1.20e+00 64.22 68 | 54 81.2 1.57e+00 1.89e-01 1.00e-01 9.32e+00 4.98e-01 82.75 1.11e+00 67.47 69 | 55 81.2 1.67e+00 2.12e-01 1.00e-01 9.36e+00 4.93e-01 83.11 9.78e-01 69.81 70 | 56 80.9 1.40e+00 2.71e-01 1.00e-01 9.31e+00 4.95e-01 82.92 1.08e+00 67.40 71 | 57 81.2 1.52e+00 2.28e-01 1.00e-01 9.30e+00 4.95e-01 82.97 8.79e-01 70.88 72 | 58 81.2 1.38e+00 2.37e-01 1.00e-01 9.29e+00 4.93e-01 82.96 1.42e+00 58.73 73 | 59 81.3 1.37e+00 1.94e-01 1.00e-01 9.42e+00 4.95e-01 82.82 1.17e+00 65.31 74 | 60 81.5 1.40e+00 1.83e-01 1.00e-01 9.43e+00 4.91e-01 83.14 2.25e+00 46.26 75 | 61 81.3 9.62e-01 1.67e-01 1.00e-02 2.04e+00 3.43e-01 88.39 3.88e-01 86.69 76 | 62 81.5 5.20e-01 7.98e-02 1.00e-02 1.43e+00 2.96e-01 89.91 3.93e-01 86.70 77 | 63 81.0 3.88e-01 7.69e-02 1.00e-02 1.33e+00 2.78e-01 90.68 3.64e-01 87.43 78 | 64 80.8 3.53e-01 6.08e-02 1.00e-02 1.33e+00 2.66e-01 90.94 3.73e-01 87.69 79 | 65 80.4 3.78e-01 7.48e-02 1.00e-02 1.33e+00 2.58e-01 91.16 3.57e-01 87.81 80 | 66 80.3 4.01e-01 6.49e-02 1.00e-02 1.35e+00 2.50e-01 91.56 3.76e-01 87.39 81 | 67 80.6 3.68e-01 6.52e-02 1.00e-02 1.35e+00 2.41e-01 91.89 3.70e-01 87.50 82 | 68 80.6 3.21e-01 5.62e-02 1.00e-02 1.36e+00 2.34e-01 92.14 3.78e-01 87.69 83 | 69 80.5 3.81e-01 5.07e-02 1.00e-02 1.40e+00 2.32e-01 92.04 3.88e-01 86.78 84 | 70 80.7 3.83e-01 5.79e-02 1.00e-02 1.41e+00 2.27e-01 92.39 3.65e-01 87.58 85 | 71 80.8 3.82e-01 5.20e-02 1.00e-02 1.43e+00 2.25e-01 92.42 4.11e-01 86.18 86 | 72 80.8 3.05e-01 4.84e-02 1.00e-02 1.46e+00 2.18e-01 92.64 3.88e-01 86.95 87 | 73 80.1 4.11e-01 4.47e-02 1.00e-02 1.49e+00 2.14e-01 92.78 4.34e-01 85.67 88 | 74 80.9 4.39e-01 4.17e-02 1.00e-02 1.50e+00 2.13e-01 92.81 3.74e-01 87.51 89 | 75 80.9 3.55e-01 5.35e-02 1.00e-02 1.53e+00 2.12e-01 92.72 4.18e-01 86.55 90 | 76 81.1 2.86e-01 3.81e-02 1.00e-02 1.52e+00 2.06e-01 93.02 3.94e-01 87.01 91 | 77 81.0 3.15e-01 4.60e-02 1.00e-02 1.56e+00 2.03e-01 93.15 3.64e-01 87.82 92 | 78 81.4 3.22e-01 3.93e-02 1.00e-02 1.58e+00 2.01e-01 93.09 5.10e-01 84.12 93 | 79 81.2 2.99e-01 4.52e-02 1.00e-02 1.62e+00 1.97e-01 93.35 3.94e-01 87.12 94 | 80 81.0 3.20e-01 3.22e-02 1.00e-02 1.63e+00 1.96e-01 93.40 4.80e-01 85.09 95 | 81 81.6 1.70e-01 1.64e-02 1.00e-03 3.06e-01 1.67e-01 94.61 3.24e-01 89.23 96 | 82 81.2 7.54e-02 1.52e-02 1.00e-03 2.14e-01 1.56e-01 94.99 3.23e-01 89.25 97 | 83 81.1 8.13e-02 9.73e-03 1.00e-03 2.00e-01 1.53e-01 95.14 3.18e-01 89.37 98 | 84 81.4 8.62e-02 1.10e-02 1.00e-03 1.99e-01 1.49e-01 95.25 3.24e-01 89.32 99 | 85 81.2 7.55e-02 1.92e-02 1.00e-03 1.95e-01 1.48e-01 95.19 3.18e-01 89.37 100 | 86 81.6 7.07e-02 1.61e-02 1.00e-03 1.93e-01 1.45e-01 95.50 3.22e-01 89.22 101 | 87 81.3 6.76e-02 1.19e-02 1.00e-03 1.93e-01 1.47e-01 95.24 3.17e-01 89.25 102 | 88 81.4 7.02e-02 1.72e-02 1.00e-03 1.92e-01 1.45e-01 95.41 3.20e-01 89.30 103 | 89 81.2 5.81e-02 1.90e-02 1.00e-03 1.90e-01 1.42e-01 95.49 3.19e-01 89.38 104 | 90 81.2 6.40e-02 9.34e-03 1.00e-03 1.94e-01 1.42e-01 95.45 3.20e-01 89.43 105 | 91 81.0 6.50e-02 1.32e-02 1.00e-03 1.94e-01 1.43e-01 95.43 3.18e-01 89.41 106 | 92 81.1 6.72e-02 1.44e-02 1.00e-03 1.95e-01 1.40e-01 95.52 3.17e-01 89.42 107 | 93 81.1 6.22e-02 1.20e-02 1.00e-03 1.89e-01 1.39e-01 95.56 3.20e-01 89.54 108 | 94 81.0 7.47e-02 1.01e-02 1.00e-03 1.91e-01 1.38e-01 95.62 3.22e-01 89.49 109 | 95 81.0 6.78e-02 1.28e-02 1.00e-03 1.93e-01 1.36e-01 95.69 3.21e-01 89.26 110 | 96 81.2 6.32e-02 1.12e-02 1.00e-03 1.88e-01 1.33e-01 95.96 3.21e-01 89.54 111 | 97 81.0 6.65e-02 1.54e-02 1.00e-03 1.93e-01 1.36e-01 95.70 3.24e-01 89.24 112 | 98 81.1 5.62e-02 1.05e-02 1.00e-03 1.93e-01 1.34e-01 95.89 3.20e-01 89.36 113 | 99 80.2 5.63e-02 1.57e-02 1.00e-03 1.93e-01 1.36e-01 95.62 3.24e-01 89.37 114 | 100 81.0 6.01e-02 1.66e-02 1.00e-03 1.98e-01 1.36e-01 95.64 3.24e-01 89.43 115 | 101 80.8 2.03e-02 6.12e-03 1.00e-04 2.85e-02 1.31e-01 95.94 3.18e-01 89.41 116 | 102 81.4 2.13e-02 1.31e-02 1.00e-04 2.46e-02 1.32e-01 95.86 3.19e-01 89.36 117 | 103 81.1 2.21e-02 1.82e-02 1.00e-04 2.27e-02 1.32e-01 95.82 3.21e-01 89.47 118 | 104 81.2 1.35e-02 1.20e-02 1.00e-04 2.31e-02 1.30e-01 95.99 3.20e-01 89.39 119 | 105 80.9 1.50e-02 1.17e-02 1.00e-04 2.23e-02 1.27e-01 96.14 3.18e-01 89.59 120 | 106 81.1 1.52e-02 8.73e-03 1.00e-04 2.15e-02 1.29e-01 96.00 3.21e-01 89.38 121 | 107 81.0 1.29e-02 1.21e-02 1.00e-04 2.20e-02 1.28e-01 95.98 3.20e-01 89.29 122 | 108 81.0 1.74e-02 1.86e-02 1.00e-04 2.24e-02 1.31e-01 95.94 3.19e-01 89.46 123 | 109 81.1 1.39e-02 8.17e-03 1.00e-04 2.13e-02 1.31e-01 95.73 3.18e-01 89.58 124 | 110 81.0 1.95e-02 6.11e-03 1.00e-04 2.17e-02 1.28e-01 96.03 3.20e-01 89.36 125 | 111 80.9 1.88e-02 9.17e-03 1.00e-04 2.13e-02 1.29e-01 96.02 3.20e-01 89.50 126 | 112 81.1 1.30e-02 7.67e-03 1.00e-04 2.18e-02 1.28e-01 95.93 3.20e-01 89.39 127 | 113 80.3 2.12e-02 5.70e-03 1.00e-04 2.15e-02 1.28e-01 95.92 3.18e-01 89.51 128 | 114 81.3 1.39e-02 9.26e-03 1.00e-04 2.16e-02 1.28e-01 96.09 3.18e-01 89.50 129 | 115 81.3 1.78e-02 7.67e-03 1.00e-04 2.11e-02 1.27e-01 96.05 3.18e-01 89.48 130 | 116 81.0 1.44e-02 9.39e-03 1.00e-04 2.11e-02 1.26e-01 96.17 3.24e-01 89.36 131 | 117 81.1 2.74e-02 1.12e-02 1.00e-04 2.07e-02 1.27e-01 96.18 3.17e-01 89.47 132 | 118 81.0 2.10e-02 9.89e-03 1.00e-04 2.16e-02 1.29e-01 95.90 3.19e-01 89.47 133 | 119 81.0 1.36e-02 8.15e-03 1.00e-04 2.14e-02 1.28e-01 96.02 3.19e-01 89.58 134 | 120 81.6 2.02e-02 1.23e-02 1.00e-04 2.11e-02 1.26e-01 96.12 3.20e-01 89.48 135 | Time elapsed: 9804.232657432556 136 | Training complete. Now testing... 137 | 138 | testing loss: 3.32e-01 testing accuracy: 89.08 139 | RKNet( 140 | (dynamicBlocks): ModuleList( 141 | (0): rk4( 142 | (controlLayers): ModuleList( 143 | (0): DoubleSymLayer( 144 | (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 145 | (act): ReLU() 146 | (normLayer): TvNorm() 147 | (convt): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 148 | ) 149 | (1): DoubleSymLayer( 150 | (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 151 | (act): ReLU() 152 | (normLayer): TvNorm() 153 | (convt): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 154 | ) 155 | (2): DoubleSymLayer( 156 | (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 157 | (act): ReLU() 158 | (normLayer): TvNorm() 159 | (convt): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 160 | ) 161 | (3): DoubleSymLayer( 162 | (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 163 | (act): ReLU() 164 | (normLayer): TvNorm() 165 | (convt): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 166 | ) 167 | (4): DoubleSymLayer( 168 | (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 169 | (act): ReLU() 170 | (normLayer): TvNorm() 171 | (convt): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 172 | ) 173 | ) 174 | ) 175 | (1): rk4( 176 | (controlLayers): ModuleList( 177 | (0): DoubleSymLayer( 178 | (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 179 | (act): ReLU() 180 | (normLayer): TvNorm() 181 | (convt): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 182 | ) 183 | (1): DoubleSymLayer( 184 | (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 185 | (act): ReLU() 186 | (normLayer): TvNorm() 187 | (convt): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 188 | ) 189 | (2): DoubleSymLayer( 190 | (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 191 | (act): ReLU() 192 | (normLayer): TvNorm() 193 | (convt): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 194 | ) 195 | (3): DoubleSymLayer( 196 | (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 197 | (act): ReLU() 198 | (normLayer): TvNorm() 199 | (convt): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 200 | ) 201 | (4): DoubleSymLayer( 202 | (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 203 | (act): ReLU() 204 | (normLayer): TvNorm() 205 | (convt): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 206 | ) 207 | ) 208 | ) 209 | (2): rk4( 210 | (controlLayers): ModuleList( 211 | (0): DoubleSymLayer( 212 | (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 213 | (act): ReLU() 214 | (normLayer): TvNorm() 215 | (convt): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 216 | ) 217 | (1): DoubleSymLayer( 218 | (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 219 | (act): ReLU() 220 | (normLayer): TvNorm() 221 | (convt): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 222 | ) 223 | (2): DoubleSymLayer( 224 | (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 225 | (act): ReLU() 226 | (normLayer): TvNorm() 227 | (convt): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 228 | ) 229 | (3): DoubleSymLayer( 230 | (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 231 | (act): ReLU() 232 | (normLayer): TvNorm() 233 | (convt): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 234 | ) 235 | (4): DoubleSymLayer( 236 | (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 237 | (act): ReLU() 238 | (normLayer): TvNorm() 239 | (convt): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 240 | ) 241 | ) 242 | ) 243 | ) 244 | (connectors): ModuleList( 245 | (0): ConnectingLayer( 246 | (conv): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1)) 247 | (act): ReLU() 248 | (normLayer): BatchNorm2d(32, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True) 249 | ) 250 | (1): ConnectingLayer( 251 | (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1)) 252 | (act): ReLU() 253 | (normLayer): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True) 254 | ) 255 | (2): ConnectingLayer( 256 | (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 257 | (act): ReLU() 258 | (normLayer): BatchNorm2d(64, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True) 259 | ) 260 | ) 261 | (open): ConnectingLayer( 262 | (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 263 | (act): ReLU() 264 | (normLayer): BatchNorm2d(16, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True) 265 | ) 266 | (linear): Linear(in_features=64, out_features=10, bias=True) 267 | ) 268 | -------------------------------------------------------------------------------- /setup.md: -------------------------------------------------------------------------------- 1 | #### Setup help: 2 | Make sure you have git installed. 3 | Open terminal, go to wherever you want your local repository, and type: 4 | ``` 5 | git init 6 | git pull git@github.com:EmoryMLIP/DynamicBlocks.git 7 | ``` 8 | For running on a linux server, you may want to change the Git default editor: 9 | ``` 10 | git config --global core.editor "vim" 11 | ``` 12 | 13 | All the code should follow python3, so make sure that you use that throughout. (I get weird errors about the import statements if I try to run anything using python 2.) 14 | You may wish to set up a virtual environment (if not installed, install virtualenv). I called mine torchEnv. 15 | ``` 16 | python3 -m virtualenv torchEnv 17 | ``` 18 | 19 | (If this line gives you trouble, see \*Troubleshooting.) 20 | 21 | 22 | To activate the virtual environment: 23 | ``` 24 | source torchEnv/bin/activate 25 | ``` 26 | 27 | I did a pip freeze to get all the requirements, removed the hazardous dependencies, and stored them in requirements.txt 28 | To load them into your virtual environment, run: 29 | ``` 30 | pip3 install -r requirements.txt 31 | ``` 32 | 33 | Now run the default network (RK4 scheme using a Double Symmetric Layer on the CIFAR-10 dataset): 34 | ``` 35 | python3 RKNet.py 36 | ``` 37 | 38 | #### NOTE: default behavior assumes that when running on cpu, the user is debugging. As such, functions will overwrite many parameters to train much smaller models and prevent the cpu from crashing. 39 | 40 | #### When running outside of an IDE: 41 | 42 | To run python functions in any subfolders, add the root folder to the path via: 43 | ``` 44 | source startup.sh 45 | ``` 46 | 47 | Then, run the functions from the root folder; for example, run the rk4 test as: 48 | 49 | ``` 50 | python3 modules/testRK4Block.py 51 | ``` 52 | 53 | 54 | 55 | ----------------------------------------------------------------------------------------------------------------------- 56 | ### \*Troubleshooting setup 57 | 58 | Check the default python for creating virtual environments 59 | ``` 60 | echo $VIRTUALENV_PYTHON 61 | ``` 62 | 63 | You can set this to python3: 64 | ``` 65 | export VIRTUALENV_PYTHON=/usr/bin/python3 66 | ``` 67 | 68 | Also, ```virtualenv -p python3 torchEnv``` can be an alternative to ```python3 -m virtualenv torchEnv``` . 69 | -------------------------------------------------------------------------------- /startup.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="${PYTHONPATH}:." 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | 3 | import matplotlib # For MACOSX 4 | matplotlib.use('TkAgg') 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import scipy.io as sio 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | import torchvision 16 | import torch.optim as optim 17 | 18 | import time 19 | import datetime 20 | import math 21 | import copy 22 | from bisect import bisect 23 | import os 24 | 25 | ####### metric for regularization in time 26 | # regMetric = torch.dist # 2-norm 27 | def regMetric(x,y): 28 | return torch.norm(x-y,p=1) # 1-norm 29 | ####### 30 | 31 | dis = nn.CrossEntropyLoss() 32 | 33 | def misfit(X,C): 34 | """ 35 | softmax cross-entropy loss 36 | 37 | X - dims: nExamples -by- nFeats (+1) 38 | W - dims: nFeats +1 (if bias) -by- nClasses 39 | C - LongTensor containing class labels; dims = (nClasses by 1) where each entry is a value 0 to nClasses 40 | 41 | """ 42 | 43 | # # remove the maximum for all examples to prevent overflow 44 | X = X.view(C.numel(), -1) 45 | S,tt = torch.max(X,1) 46 | S = X-S.view(-1,1) 47 | return dis(S,C), torch.exp(S) 48 | 49 | def misfitW(X,W,C, device): 50 | """ 51 | deprecated....but used in test functions currently 52 | softmax cross-entropy loss 53 | 54 | X - dims: nExamples -by- nFeats (+1) 55 | W - dims: nFeats +1 (if bias) -by- nClasses 56 | C - LongTensor containing class labels; dims = (nClasses by 1) where each entry is a value 0 to nClasses 57 | 58 | """ 59 | 60 | featW = W.size(0) 61 | 62 | if X.size(1) + 1 == featW: 63 | # add bias 64 | e = torch.ones(X.size(0), 1, X.size(2), X.size(3)).to(device) 65 | X = torch.cat((X, e), dim=1) 66 | 67 | X = X.view(-1, featW) 68 | S = torch.matmul(X, W) 69 | # remove the maximum for all examples to prevent overflow 70 | Sm, tt = torch.max(S, 1) 71 | S = S - Sm.view(-1, 1) 72 | return dis(S, C), torch.exp(S) 73 | 74 | 75 | def getAccuracy(S,labels): 76 | """compute accuracy""" 77 | _, predicted = torch.max(S.data, 1) 78 | total = labels.size(0) 79 | correct = (predicted == labels).sum().item() 80 | return correct/total, correct, total 81 | 82 | def conv1x1(x,K): 83 | """1x1 convolution with no padding""" 84 | return F.conv2d(x, K, stride=1, padding=0) 85 | 86 | def conv3x3(x,K): 87 | """3x3 convolution with padding""" 88 | return F.conv2d(x, K, stride=1, padding=1) 89 | 90 | def convt3x3(x,K): 91 | """3x3 convolution transpose with padding""" 92 | return F.conv_transpose2d(x, K, stride=1, padding=1) 93 | 94 | 95 | def specifyGPU( gpuNumber ): 96 | """ 97 | DEPRECATED 98 | directs operating system to only use the input gpu number 99 | :param gpuNumber: integer number associated with GPU you wish to use 100 | Entering a gpuNumber exceeding the number of GPUs you have, will result 101 | in the os choosing the cpu as the device 102 | """ 103 | # import os 104 | 105 | nSystemGPUs = 2 # we only have 2 GPUs on our server 106 | 107 | if gpuNumber > nSystemGPUs-1: 108 | print("Warning: specifyGPU: Do you have more than %d GPUs? If not, specifyGPU(%d) " 109 | "will have issues" % (nSystemGPUs, gpuNumber), file=sys.stderr) 110 | 111 | # restricts code to only seeing GPU labeled gpuNumber 112 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpuNumber) 113 | 114 | 115 | def imshow(img): # need to import matplotlib things 116 | img = img / 2 + 0.5 # unnormalize 117 | npimg = img.numpy() 118 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 119 | plt.show() 120 | 121 | 122 | 123 | def interpNN(theta,tTheta,tInter,inter=None, blankLayer=None, bKeepParam=False): 124 | """ 125 | nearest neighbor interpolation 126 | TODO: refactor name .... this is really left neighbor interpolation 127 | """ 128 | if len(theta) != len(tTheta): 129 | print('interLayer1D: lengths do not match') 130 | return -1 131 | 132 | nPoints = len(tInter) # number of interpolation time points 133 | # make list of length same as tInter, and each element is of size K 134 | # assume all layers in the ModuleList theta have the same dimensions 135 | 136 | if inter is None: 137 | inter=nn.ModuleList() 138 | for i in range(nPoints): 139 | if blankLayer is None: 140 | module=copy.deepcopy(theta[0]) 141 | else: 142 | module = copy.deepcopy(blankLayer) 143 | module.load_state_dict(theta[0].state_dict()) 144 | inter.append(module) 145 | 146 | for k in range(nPoints): 147 | # get param for current time point 148 | # assume tTheta is sorted 149 | 150 | if tInter[k] <= tTheta[0]: # if interp point is outside the tTheta range 151 | recursively_interpolate(inter[k], theta[0], theta[0], 1, 0, bKeepParam=bKeepParam) 152 | elif tInter[k] >= tTheta[-1]: 153 | recursively_interpolate(inter[k], theta[-1], theta[-1], 1, 0, bKeepParam=bKeepParam) 154 | else: 155 | idxTh = bisect(tTheta, tInter[k]) 156 | # idxTh contains index of right point in tTheta to use for interpolation 157 | recursively_interpolate(inter[k], theta[idxTh-1], theta[idxTh-1], 1, 0, bKeepParam=bKeepParam) 158 | 159 | # ensure that running_mean and running_var send to gpu 160 | params = theta[0].parameters() 161 | first_param = next(params) 162 | device = first_param.device 163 | inter.to(device) 164 | return inter 165 | 166 | 167 | 168 | def interpLayer1Dpreallocated(theta,tTheta,tInter,inter=None, blankLayer=None, bKeepParam=False): 169 | """ 170 | 1D (linear) interpolation. For the observations theta at tTheta, find the observations at points ti. 171 | - ASSUMPTIONS: all tensors in the list theta have the same dimension and tTheta is sorted ascending 172 | - theta are my K parameters 173 | :param theta: list of LAYERS, think of them as measurements : f(x0) , f(x1), f(x2) , ... 174 | :param tTheta: points where we have the measurements: x0,x1,x2,... 175 | :param tInter: points where we want to have an approximate function value: a,b,c,d,... 176 | :param blankLayer: module, an empty form of the layer to be interpolated (a workaround for when adding control layers) 177 | :param bKeepParam: boolean, maintain Parameter status (False will mean interpolated parameters will be reduced 178 | to buffer status) 179 | :return: inter: approximate values f(a), f(b), f(c), f(d) using the assumption 180 | that connecting every successive theta to its previous one with a line 181 | """ 182 | if len(theta) != len(tTheta): 183 | print('interLayer1D: lengths do not match') 184 | return -1 185 | 186 | nPoints = len(tInter) # number of interpolation time points 187 | # make list of length same as tInter, and each element is of size K 188 | # assume all layers in the ModuleList theta have the same dimensions 189 | 190 | if inter is None: 191 | inter=nn.ModuleList() 192 | for i in range(nPoints): 193 | if blankLayer is None: 194 | module=copy.deepcopy(theta[0]) 195 | else: 196 | module = copy.deepcopy(blankLayer) 197 | module.load_state_dict(theta[0].state_dict()) 198 | inter.append(module) 199 | 200 | 201 | for k in range(nPoints): 202 | # get K for current time point 203 | # assume tTheta is sorted 204 | 205 | if tInter[k] <= tTheta[0]: # if interp point is outside the tTheta range 206 | recursively_interpolate(inter[k], theta[0], theta[0], 1, 0, bKeepParam=bKeepParam) 207 | elif tInter[k] >= tTheta[-1]: 208 | recursively_interpolate(inter[k], theta[-1], theta[-1], 1, 0, bKeepParam=bKeepParam) 209 | else: 210 | idxTh = bisect(tTheta, tInter[k]) 211 | # idxTh contains index of right point in tTheta to use for interpolation 212 | leftTh = tTheta[idxTh-1] 213 | rightTh = tTheta[idxTh] 214 | h = rightTh - leftTh 215 | wtLeft = (rightTh - tInter[k]) / h 216 | wtRight = (tInter[k] - leftTh) / h 217 | 218 | recursively_interpolate(inter[k], theta[idxTh - 1], theta[idxTh], wtLeft, wtRight, bKeepParam=bKeepParam) 219 | 220 | # ensure that running_mean and running_var send to gpu 221 | params = theta[0].parameters() 222 | first_param = next(params) 223 | device = first_param.device 224 | inter.to(device) 225 | return inter 226 | 227 | 228 | def recursively_interpolate(output,left,right,wleft,wright, bKeepParam=False): 229 | """ 230 | helper functions to interpolate the submodules and parameters 231 | :param output: module which holds interpolated values (a state layer) 232 | :param left: module to the left, used for interpolation 233 | :param right: module to the right, used for interpolation 234 | :param wleft: float, weight for left item 235 | :param wright: float, weight for right item 236 | :param bKeepParam: boolean, maintain Parameter status (False will mean interpolated parameters will be reduced 237 | to buffer status) 238 | :return: 239 | """ 240 | # center, left right = c,l,r 241 | for (cChild,lChild,rChild) in zip(output.named_children(),left.named_children(),right.named_children()): 242 | # print(cChild[0],lChild[0],rChild[0]) 243 | recursively_interpolate(cChild[1], lChild[1], rChild[1], wleft, wright, bKeepParam=bKeepParam) 244 | for (l, r) in zip(left.named_parameters(recurse=False), right.named_parameters(recurse=False)): 245 | # l[0] and r[0] hold names, l[1] and r[1] hold the values/Tensors 246 | if bKeepParam: 247 | exec( 'output.' + l[0] + '.data = wleft *l[1] + wright * r[1]') # find a better way to do this 248 | else: # replace parameter with a buffer 249 | output.__delattr__(l[0]) 250 | output.register_buffer(l[0], wleft*l[1] + wright*r[1]) 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | def listNorm(list1, list2): 259 | """ 260 | RETIRED, used in tests 261 | input two lists of tensors of same length 262 | convert/vectorize the tensors in each list 263 | compute their 2-norm 264 | 265 | :param list1: [K1, K2, K3,....] where Ki is a tensor 266 | :param list2: [A1, A2, A3,....] where Ai is a tensor 267 | :return: float 2-norm 268 | """ 269 | retValue = 0.0 270 | 271 | # use sqrt to iteratively update the norm for each element in the list 272 | for a,b in zip(list1,list2): 273 | a = a.contiguous().view(a.numel()) # vectorize 274 | b = b.contiguous().view(b.numel()) # vectorize 275 | retValue = torch.sqrt( retValue**2 + torch.norm(a-b,p=2)**2) 276 | 277 | return retValue 278 | 279 | 280 | 281 | 282 | def getDataLoaders( sDataset , batchSizeTrain, device, batchSizeTest=None, percentVal = 0.20, 283 | transformTrain=None, transformTest=None, nTrain=None, nVal=None, datapath=None): 284 | """ 285 | set up data loaders 286 | following https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb 287 | for the training and validation split 288 | 289 | :param sDataset: string, representation of the dataset, e.g. 'cifar10', 'cifar100', 'stl10' 290 | :param batchSizeTrain: int, batch size for the training data 291 | :param device: pytorch device 292 | :param batchSizeTest: int, batch size for the testing data 293 | :param percentVal: float, percent of the training data to set aside as validation 294 | :param transformTrain: set of transforms for the training data 295 | :param transformTest: set of transforms for the testing data 296 | :param nTrain: int, number of training images (None means use all, subject to percentVal) 297 | :param nVal: int, number of validation images (None means use all, subject to percentVal) 298 | percentVal is applied first, then if nTrain or nVal is not None, then a smaller subset of the split data will be used 299 | :param datapath: string, path to where the data is saved 300 | :return: 301 | """ 302 | 303 | 304 | shift = transforms.Normalize( 305 | mean=[0.5, 0.5, 0.5], 306 | std=[1., 1., 1.], 307 | ) 308 | 309 | # augmentations and convert to tensor 310 | if transformTrain is None: 311 | transformTrain = transforms.Compose( 312 | [transforms.ToTensor(), 313 | shift]) 314 | if transformTest is None: 315 | transformTest = transforms.Compose( 316 | [transforms.ToTensor(), 317 | shift]) 318 | 319 | if batchSizeTest is None: 320 | batchSizeTest = batchSizeTrain 321 | 322 | sDataPath= './data' 323 | if datapath is not None: 324 | sDataPath = datapath 325 | 326 | #---------------------------------------------------------------------------------- 327 | #----------------------------------CIFAR-10---------------------------------------- 328 | if sDataset == 'cifar10': 329 | 330 | trainSet = torchvision.datasets.CIFAR10(root=sDataPath, train=True, 331 | download=True, transform=transformTrain) 332 | valSet = torchvision.datasets.CIFAR10(root=sDataPath, train=True, 333 | download=True, transform=transformTest) 334 | testSet = torchvision.datasets.CIFAR10(root=sDataPath, train=False, 335 | download=True, transform=transformTest) 336 | nChanIn = 3 337 | nClasses = 10 338 | # ----------------------------------CIFAR-100---------------------------------------- 339 | if sDataset == 'cifar100': 340 | trainSet = torchvision.datasets.CIFAR100(root=sDataPath, train=True, 341 | download=True, transform=transformTrain) 342 | valSet = torchvision.datasets.CIFAR100(root=sDataPath, train=True, 343 | download=True, transform=transformTest) 344 | testSet = torchvision.datasets.CIFAR100(root=sDataPath, train=False, 345 | download=True, transform=transformTest) 346 | nChanIn = 3 347 | nClasses = 100 348 | # ----------------------------------STL-10----------------------------------------- 349 | elif sDataset == 'stl10': 350 | 351 | trainSet = torchvision.datasets.STL10(root=sDataPath, split='train', 352 | download=True, transform=transformTrain) 353 | valSet = torchvision.datasets.STL10(root=sDataPath, split='train', 354 | download=True, transform=transformTest) 355 | testSet = torchvision.datasets.STL10(root=sDataPath, split='test', 356 | download=True, transform=transformTest) 357 | nChanIn = 3 358 | nClasses = 10 359 | # ---------------------------------------------------------------------------------- 360 | 361 | # dimImg = trainSet.train_data.shape[1:3] # image dimensions 362 | 363 | nImages = len(trainSet) 364 | indices = list(range(nImages)) 365 | nValImages = int(np.floor(percentVal * nImages)) 366 | 367 | # np.random.seed(3) 368 | np.random.shuffle(indices) 369 | 370 | idxTrain, idxVal = indices[nValImages:], indices[:nValImages] 371 | 372 | if nTrain is not None: 373 | idxTrain = idxTrain[0:nTrain] 374 | if nVal is not None: 375 | idxVal = idxVal[0:nVal] 376 | 377 | samplerTrain = SubsetRandomSampler(idxTrain) 378 | samplerVal = SubsetRandomSampler(idxVal) 379 | 380 | loaderTrain = torch.utils.data.DataLoader(trainSet, batch_size=batchSizeTrain, 381 | sampler=samplerTrain) # is pin_memory=True a good idea???????? 382 | loaderVal = torch.utils.data.DataLoader(valSet, batch_size=batchSizeTest, 383 | sampler=samplerVal) 384 | loaderTest = torch.utils.data.DataLoader(testSet, batch_size=batchSizeTest, 385 | shuffle=False) 386 | 387 | print('training on %d images ; validating on %d images' % (len(idxTrain), len(idxVal))) 388 | 389 | return loaderTrain, loaderVal, loaderTest, nChanIn, nClasses 390 | 391 | 392 | 393 | def printFeatureMatrix(nt, nBlocks, vFeat): 394 | """ 395 | Construct and print a matrix that contains all of the channel sizes 396 | Example: 397 | Model architecture where each column is a layer 398 | no. Input Channels [ 3 32 32 32 32 64 64 64 64] 399 | no. Output Channels [ 32 32 32 32 64 64 64 64 64] 400 | conv kernel height [ 3 3 3 3 1 3 3 3 1] 401 | conv kernel width [ 3 3 3 3 1 3 3 3 1] 402 | :param nt: number of time steps in each RK4 block 403 | :param nBlocks: number of blocks 404 | :param vFeat: feature vector 405 | """ 406 | mFeat = np.zeros((4, (nt + 1) * nBlocks + 1), dtype=int) 407 | 408 | k = 0 409 | # opening layer 410 | mFeat[0, k] = 3 # RGB input 411 | mFeat[1, k] = vFeat[0] # channels after opening layer 412 | mFeat[2, k] = 3 # 3x3 convolution 413 | mFeat[3, k] = 3 414 | k = k + 1 415 | 416 | # block and connecting layer 417 | for blk in range(nBlocks): 418 | for i in range(nt): 419 | mFeat[0, k] = vFeat[blk] 420 | mFeat[1, k] = vFeat[blk] 421 | mFeat[2, k] = 3 422 | mFeat[3, k] = 3 423 | k = k + 1 424 | 425 | mFeat[0, k] = vFeat[blk] 426 | mFeat[1, k] = vFeat[blk + 1] 427 | mFeat[2, k] = 1 428 | mFeat[3, k] = 1 429 | k = k + 1 430 | 431 | rowLabels = ["no. Input Channels ", "no. Output Channels", "conv kernel height ", "conv kernel width "] 432 | print("\nModel architecture where each column is a layer") 433 | for rowLabel, row in zip(rowLabels, mFeat): 434 | print('%s [%s]' % (rowLabel, ' '.join('%03s' % i for i in row))) 435 | print("\n") 436 | return 0 437 | 438 | 439 | def normalInit(dims): 440 | """ 441 | Essentially, PyTorch's init.xavier_normal_ but clamped 442 | :param K: tensor to be initialized/overwritten 443 | :return: initialized tensor on the device in the nn.Parameter wrapper 444 | """ 445 | K = torch.zeros(dims) 446 | fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(K) 447 | sd = math.sqrt(2.0 / (fan_in + fan_out)) 448 | # sd = math.sqrt(2 / (nChanLayer1 * 3 * 3 * 3)) # Meganet approach 449 | with torch.no_grad(): 450 | K = K.normal_(0, sd) 451 | 452 | K = torch.clamp(K, min = -2*sd, max=2*sd) 453 | 454 | return K 455 | 456 | 457 | def appendTensorList(params, Klist, nWeights, weight_decay=None): 458 | """ 459 | RETIRED. used in test functions 460 | helper function for adding a list of tensors to the params 461 | 462 | 463 | :param params: the params object that will be passed to the optimizer 464 | :param Klist: list of Tensors that you want to be parameters for the optimizer 465 | :param nWeights: int, current number of weights in params (will be incremented) 466 | :param weight_decay: optional float, specific tikhinov regularization for Klist weights 467 | (specifically, if Klist are the scaling weights for a tvNorm, you may want no weight_decay) 468 | :return: updated params, updated number of weights in params 469 | """ 470 | 471 | if weight_decay is not None: 472 | for K in Klist: 473 | nWeights += K.numel() 474 | params.append({'params': K , 'weight_decay': weight_decay}) 475 | else: 476 | for K in Klist: 477 | nWeights += K.numel() 478 | params.append({'params': K}) 479 | 480 | return params, nWeights 481 | 482 | 483 | 484 | 485 | 486 | 487 | def getVectorizedParams(net,device): 488 | """ 489 | obtain vectors of the running means, running variances, and parameters 490 | use built-in parameters_to_vector approach 491 | 492 | :param net: the network 493 | :param device: retired/unused 494 | :return: vector of the running means, vector of the running variances, vector of the parameters 495 | """ 496 | param_device = None 497 | 498 | vRunMean = [] 499 | vRunVar = [] 500 | 501 | for name, param in net.named_buffers(): 502 | param_device = nn.utils.convert_parameters._check_param_device(param,param_device) 503 | if 'running_mean' in name: 504 | vRunMean.append(param.view(-1)) 505 | elif 'running_var' in name: 506 | vRunVar.append(param.view(-1)) 507 | 508 | runMean = torch.cat(vRunMean) 509 | runVar = torch.cat(vRunVar) 510 | params = nn.utils.convert_parameters.parameters_to_vector(net.parameters()) 511 | 512 | return runMean, runVar, params 513 | 514 | 515 | def trainNet(net,optimizer,lr,nEpoch, device, sBasePath, loaderTrain,loaderVal, nMini=1, verbose=True, 516 | regularization_param=0,checkpointing=False): 517 | """ 518 | function for training the network 519 | 520 | :param net: model to be trained 521 | :param optimizer: pytorch optimizer 522 | :param lr: numpy array of length nEpoch, learning for each epoch 523 | :param nEpoch: int, number of epochs 524 | :param device: pytorch device 525 | :param sBasePath: string, path for results files (will append .txt , _model.pt , etc. for each file) 526 | :param loaderTrain: dataloader for training data 527 | :param loaderVal: dataloader for validation sata 528 | :param nMini: int, when on cpu/debugging, print out results for every nMini batches 529 | :param verbose: boolean, when True and on cpu/debugging print out more often than every epoch 530 | :param regularization_param: float, weight applied to regularization metric (for regularization in time) 531 | :param checkpointing: boolean, if True use checkpointing 532 | :return: 533 | """ 534 | 535 | # save the train-val split indices 536 | torch.save([loaderTrain.sampler.indices,loaderVal.sampler.indices], sBasePath + '_trainValSplit.pt') 537 | 538 | sPathOpt = '' 539 | valOpt = 100000 # keep track of optimal validation loss 540 | checkpt = time.time() 541 | 542 | results = np.zeros((nEpoch,10)) 543 | 544 | oldRunMean, oldRunVar, oldParams = getVectorizedParams(net,device) 545 | 546 | 547 | regValue = 0 548 | if regularization_param>0: 549 | print('%-9s %-9s %-11s %-11s %-9s %-11s %-9s %-9s %-9s %-9s %-13s' % 550 | ('epoch', 'time', '|runMean|', '|runVar|', 551 | 'lr', '|params|', 'avgLoss', 'acc', 'valLoss', 'valAcc', 'reg'), flush=True) 552 | 553 | else: 554 | print('%-9s %-9s %-11s %-11s %-9s %-11s %-9s %-9s %-9s %-9s' % 555 | ('epoch', 'time', '|runMean|', '|runVar|', 556 | 'lr', '|params|', 'avgLoss', 'acc', 'valLoss', 'valAcc'), flush=True) 557 | 558 | 559 | for epoch in range(nEpoch): # loop over the dataset multiple times 560 | 561 | """ 562 | ########################## REMOVE THESE ################ 563 | if epoch==100: 564 | if net.dynamicBlocks[0].controlLayers[0]._get_name() == 'DoubleLayer': 565 | removeBatchNormRKNet(net) # remove batch norm 566 | print('removing batch norm') 567 | oldRunMean, oldRunVar, oldParams = getVectorizedParams(net,device) 568 | else: 569 | regularization_param = 0 # for doubleSym remove regularization 570 | """ 571 | 572 | # adjust learning rate by epoch 573 | for param_group in optimizer.param_groups: 574 | param_group['lr'] = lr.item(epoch) 575 | 576 | running_loss = 0.0 577 | running_num_correct = 0.0 578 | running_num_total = 0.0 579 | running_reg = 0.0 580 | for i, data in enumerate(loaderTrain, 0): 581 | 582 | net.train() # set model to train mode 583 | 584 | if checkpointing is True: 585 | loss, numCorrect, numTotal = net.checkpoint_train(data, optimizer, device,regularization_param) 586 | else: 587 | # regular optimization without checkpointing as before 588 | # get the inputs 589 | inputs, labels = data 590 | 591 | # zero the parameter gradients 592 | optimizer.zero_grad() 593 | 594 | inputs, labels = inputs.to(device), labels.to(device) 595 | # forward + backward + optimize 596 | 597 | x = net(inputs) 598 | loss, Si = misfit(x,labels) 599 | 600 | #add regularization to loss 601 | if regularization_param != 0: 602 | regValue = regularization_param*net.regularization() 603 | loss = loss + regValue 604 | loss.backward() 605 | optimizer.step() 606 | 607 | 608 | # impose bound constraints / clip the weights 609 | if hasattr(net, 'clip'): # some networks we created (ResNet and old RKNet) are not ClippedModules) 610 | net.clip() 611 | 612 | 613 | _ , numCorrect, numTotal = getAccuracy(Si,labels) 614 | 615 | running_loss += numTotal * loss.item() 616 | running_num_correct += numCorrect 617 | running_num_total += numTotal 618 | running_reg += numTotal * regValue 619 | 620 | # print statistics 621 | # print every few mini-batches when on the CPU 622 | if device.type == 'cpu': 623 | if verbose == True: 624 | if i % nMini == nMini - 1: # print every nMini mini-batches 625 | newRunMean, newRunVar, newParams = getVectorizedParams(net,device) 626 | print('%-4d %-4d %-9.1f %-11.2e %-11.2e %-9.2e %-11.2e %-9.2e %-9.2f' % 627 | (epoch + 1, 628 | running_num_total, 629 | time.time() - checkpt, 630 | torch.norm(oldRunMean-newRunMean), 631 | torch.norm(oldRunVar - newRunVar), 632 | optimizer.param_groups[0]['lr'], 633 | torch.norm(oldParams - newParams), 634 | running_loss / running_num_total, 635 | running_num_correct*100 / running_num_total),flush=True) 636 | 637 | ### EVALUATION 638 | # after 1 training epoch, validate 639 | 640 | valLoss, valAcc = evalNet(net, device, loaderVal) 641 | 642 | newRunMean, newRunVar, newParams = getVectorizedParams(net,device) 643 | 644 | results[epoch, 0] = epoch + 1 645 | results[epoch, 1] = time.time() - checkpt 646 | results[epoch, 2] = torch.norm(oldRunMean-newRunMean) 647 | results[epoch, 3] = torch.norm(oldRunVar - newRunVar) 648 | results[epoch, 4] = optimizer.param_groups[0]['lr'] 649 | results[epoch, 5] = torch.norm(oldParams - newParams) 650 | results[epoch, 6] = running_loss / running_num_total 651 | results[epoch, 7] = running_num_correct * 100 / running_num_total 652 | results[epoch, 8] = valLoss 653 | results[epoch, 9] = valAcc * 100 654 | 655 | if regularization_param>0: 656 | print('%-9d %-9.1f %-11.2e %-11.2e %-9.2e %-11.2e %-9.2e %-9.2f %-9.2e %-9.2f %-13.2e' % 657 | (results[epoch, 0], 658 | results[epoch, 1], 659 | results[epoch, 2], 660 | results[epoch, 3], 661 | results[epoch, 4], 662 | results[epoch, 5], 663 | results[epoch, 6], 664 | results[epoch, 7], 665 | results[epoch, 8], 666 | results[epoch, 9], 667 | running_reg / running_num_total),flush=True) 668 | else: 669 | print('%-9d %-9.1f %-11.2e %-11.2e %-9.2e %-11.2e %-9.2e %-9.2f %-9.2e %-9.2f' % 670 | (results[epoch, 0], 671 | results[epoch, 1], 672 | results[epoch, 2], 673 | results[epoch, 3], 674 | results[epoch, 4], 675 | results[epoch, 5], 676 | results[epoch, 6], 677 | results[epoch, 7], 678 | results[epoch, 8], 679 | results[epoch, 9]),flush=True) 680 | 681 | if valLoss < valOpt: 682 | # sPath = sBasePath + '_acc_' + "{:.3f}".format(valAcc) 683 | sPath = sBasePath + '_model' 684 | torch.save(net, sPath + '.pt') # save entire model 685 | valOpt = valLoss # update 686 | sPathOpt = sPath 687 | 688 | 689 | checkpt = time.time() 690 | 691 | 692 | 693 | oldRunMean, oldRunVar, oldParams = newRunMean, newRunVar, newParams 694 | 695 | # end for epoch 696 | 697 | saveResults(sBasePath + '_results.mat', results) 698 | 699 | return sPathOpt 700 | 701 | # --------------------------------------------------------------------------- 702 | # --------------------------------------------------------------------------- 703 | # --------------------------------------------------------------------------- 704 | 705 | 706 | 707 | 708 | def evalNet(net, device, loaderTest): 709 | """ 710 | evaluate the networks weights.... 711 | use for both validation and test sets, just change loaderTest 712 | 713 | return loss, acc (val_loss, val_acc OR test_loss, test_acc) 714 | """ 715 | 716 | 717 | net.eval() # set model to evaluation mode 718 | 719 | # compute validation accuracy 720 | vali = 0 721 | with torch.no_grad(): 722 | # valRunLoss = 0.0 723 | weightedLoss = 0.0 724 | valRunCorrect = 0 725 | valRunTotal = 0 726 | 727 | for vali, valData in enumerate(loaderTest, 0): 728 | imagesVal, labelsVal = valData 729 | imagesVal, labelsVal = imagesVal.to(device), labelsVal.to(device) 730 | xVal = net(imagesVal) 731 | lossVal, SiVal = misfit(xVal, labelsVal) 732 | accVal, batchCorrect, batchCount = getAccuracy(SiVal, labelsVal) 733 | 734 | # valRunLoss += lossVal.item() 735 | weightedLoss += batchCount * lossVal.item() 736 | valRunCorrect += batchCorrect 737 | valRunTotal += batchCount 738 | 739 | if device.type=='cpu': # for debugging 740 | break 741 | 742 | valAcc = valRunCorrect / valRunTotal 743 | valLoss = weightedLoss / valRunTotal # valRunLoss / (vali + 1) 744 | 745 | return valLoss , valAcc 746 | 747 | 748 | 749 | 750 | 751 | 752 | # 753 | # def helperDoubleSymRemoveBN(convLayer, normLayer): 754 | # """ 755 | # 756 | # FAILS TESTS BECAUSE of the activation function within 757 | # 758 | # -K^T ( act( N ( K(Y) ) ) ) 759 | # Easy to remove the N, but then no guarantee to maintain same weights for conv and conv transpose 760 | # 761 | # incorporate the batch norm in the convolution 762 | # 763 | # from -K^T ( s * (KY + beta + rm)/( sqrt( rv + eps ) ) + b ) + t 764 | # 765 | # to -K'^T ( K'*Y + beta' ) + t 766 | # 767 | # where K' = s^(1/2) / (rv+eps)^(1/4) * K and beta' = (s^(1/2) / (rv+eps)^(1/4) )*(b + s*(beta - rm)) 768 | # 769 | # t - convt bias 770 | # s - norm scaling weight 771 | # b - norm scaling bias 772 | # rm - running_mean 773 | # rv - running_variance 774 | # eps - epsilon of norm 775 | # K - convolution weights 776 | # beta - convolution bias 777 | # 778 | # """ 779 | # 780 | # # convLayer = layer.conv 781 | # # normLayer = layer.normLayer 782 | # # transLayer = layer.convt 783 | # 784 | # rm = normLayer.running_mean.data 785 | # rv = normLayer.running_var.data 786 | # s = normLayer.weight.data 787 | # b = normLayer.bias.data 788 | # epsil = normLayer.eps 789 | # beta = convLayer.bias.data # conv bias 790 | # K = convLayer.weight.data 791 | # 792 | # 793 | # # new K and b 794 | # 795 | # # denom = (rv+epsil )**(1/4) # remove the epsilon will give better results, but could fall victim to division by zero 796 | # denom = (rv) ** (1 / 4) + epsil**(9 / 8) # hack: epsilon tuned to give desired behavior 797 | # 798 | # num = b + (s * (beta - rm)) 799 | # 800 | # sqs = torch.sqrt(s) 801 | # beta = (sqs / denom) * num 802 | # 803 | # s = sqs/denom 804 | # s = s.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 805 | # K = s * K 806 | # 807 | # convLayer.bias.data = beta 808 | # convLayer.weight.data = K 809 | 810 | 811 | 812 | 813 | def helperRemoveBN(convLayer, normLayer): 814 | """ convLayer will absorb the information in normLayer assuming N( K(Y) )... norm of conv of features""" 815 | rm = normLayer.running_mean.data 816 | rv = normLayer.running_var.data 817 | s = normLayer.weight.data 818 | b = normLayer.bias.data 819 | epsil = normLayer.eps 820 | beta = convLayer.bias.data # conv bias 821 | K = convLayer.weight.data 822 | 823 | # new K and b 824 | denom = torch.sqrt(rv + epsil) 825 | s = s / denom 826 | beta = b + (s * (beta - rm)) 827 | s = s.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 828 | K = s * K 829 | 830 | convLayer.bias.data = beta 831 | convLayer.weight.data = K 832 | 833 | return 0 834 | 835 | 836 | def removeBatchNorm(layer): 837 | """ 838 | incorporate the batch norm in the convolution 839 | 840 | from s * (KY + beta + rm)/( sqrt( rv + eps ) ) + b 841 | 842 | to K'*Y + beta' 843 | 844 | where K' = s / sqrt(rv+eps) * K and beta' = b + s*(beta - rm)/sqrt(rv+eps) 845 | 846 | s - norm scaling weight 847 | b - norm scaling bias 848 | rm - running_mean 849 | rv - running_variance 850 | eps - epsilon of norm 851 | K - convolution weights 852 | beta - convolution bias 853 | 854 | """ 855 | bUnchanged = 1 856 | 857 | if layer._get_name() == 'DoubleLayer' or layer._get_name() == 'PreactDoubleLayer': 858 | if hasattr(layer,'normLayer1'): 859 | if layer.normLayer1._get_name()=='BatchNorm2d': 860 | bUnchanged = helperRemoveBN(layer.conv1, layer.normLayer1) 861 | layer.__delattr__('normLayer1') 862 | if hasattr(layer, 'normLayer2'): 863 | if layer.normLayer2._get_name() == 'BatchNorm2d': 864 | bUnchanged = helperRemoveBN(layer.conv2, layer.normLayer2) 865 | layer.__delattr__('normLayer2') 866 | elif layer._get_name() == 'DoubleSymLayer': 867 | if layer.normLayer._get_name() == 'BatchNorm2d': 868 | print('no known way to remove batchnorm from DoubleSymLayer with a non-identity activation function') 869 | else: 870 | if hasattr(layer, 'normLayer'): 871 | if layer.normLayer._get_name() == 'BatchNorm2d': 872 | bUnchanged =helperRemoveBN(layer.conv, layer.normLayer) 873 | layer.__delattr__('normLayer') 874 | 875 | return bUnchanged 876 | 877 | def removeBatchNormRKNet(net): 878 | """remove batch norm in the state layers and control layers 879 | 880 | Only works for rk1 , tY = tTheta 881 | 882 | 883 | for rk4 and rk1 with more states than layers....what do I do with the interpolation? 884 | This code will remove all the batch norms, but the resulting network will not be tuned 885 | to similar accuracy or performance 886 | 887 | """ 888 | bUnchanged = 1 889 | #--------------------------------------------- 890 | # Remove Batch Norm if exists and incorporate into the convolutions 891 | #--------------------------------------------- 892 | 893 | # remove the batch norm 894 | for blk in net.dynamicBlocks: # for each dynamic block 895 | 896 | if blk._get_name() == 'rk1': 897 | states = blk.stateLayers 898 | nControlLayers = len(blk.controlLayers)-1 899 | reducedStates = states 900 | elif blk._get_name() == 'rk4': # account for the odd setup of the stateLayers in the rk4 901 | states = [] 902 | reducedStates = [] 903 | for i in range(len(blk.stateLayers)): 904 | reducedStates.append(blk.stateLayers[i][0]) 905 | for j in [0,1,2,3]: 906 | states.append(blk.stateLayers[i][j]) 907 | 908 | reducedStates.append(blk.stateLayers[i][3]) # don't forget the last layer 909 | nControlLayers = len(blk.controlLayers) 910 | else: 911 | print('don\'t know what scheme is used in block') 912 | 913 | 914 | for l in states: # for each state layer in the dynamic block 915 | 916 | bUnchanged = removeBatchNorm(l) 917 | 918 | # need to remove it in the control layer as well 919 | if l._get_name() == 'DoubleLayer' or l._get_name() == 'PreactDoubleLayer': 920 | for i in range(nControlLayers): 921 | thisLayer = blk.controlLayers[i] 922 | if hasattr(thisLayer,'normLayer1'): 923 | if thisLayer.normLayer1._get_name() == 'BatchNorm2d': 924 | thisLayer.__delattr__('normLayer1') 925 | thisLayer.conv1.weight.data = reducedStates[i].conv1.weight.data 926 | thisLayer.conv1.bias.data = reducedStates[i].conv1.bias.data 927 | if hasattr(thisLayer, 'normLayer2'): 928 | if thisLayer.normLayer2._get_name() == 'BatchNorm2d': 929 | thisLayer.__delattr__('normLayer2') 930 | thisLayer.conv2.weight.data = reducedStates[i].conv2.weight.data 931 | thisLayer.conv2.bias.data = reducedStates[i].conv2.bias.data 932 | 933 | if blk._get_name() == 'rk1': 934 | removeBatchNorm(blk.controlLayers[-1]) 935 | 936 | 937 | if bUnchanged==1: 938 | print('No batch norms were removed') 939 | 940 | return net 941 | 942 | 943 | 944 | 945 | # For all run functions 946 | def getRunDefaults(sDataset, gpu=0): 947 | """ 948 | the default running hyperparameters for experimentss 949 | :param sDataset: string, name of dataset, 'cifar10' , 'cifar100', 'stl10' 950 | :param gpu: int, None means use cpu; otherwise, signify which gpu to use 951 | :return: vFeat, transformTrain, transformTest, batchSize, lr, tTheta, nTrain, nVal, device 952 | vFeat - vector of number of features/channels per dynamic block 953 | transformTrain - default transformations for training dataloader 954 | transformTest - default transformations for testing dataloader 955 | batchSize - int, batch size 956 | lr - numpy array, learning rate per epoch 957 | tTheta - list, time discretization for controls 958 | nTrain - int, number of training examples 959 | nVal - int, number of validation examples 960 | device - pytorch device 961 | """ 962 | 963 | 964 | lr = np.concatenate((1e-1 * np.ones((80, 1), dtype=np.float32), 965 | 1e-2 * np.ones((40, 1), dtype=np.float32), 966 | 1e-3 * np.ones((40, 1), dtype=np.float32), 967 | 1e-4 * np.ones((40, 1), dtype=np.float32)), axis=0) 968 | 969 | 970 | tTheta=[0,1,2,3,4] 971 | 972 | if sDataset == 'cifar10' or sDataset=='cifar100': 973 | vFeat = [16,32,64,64] 974 | 975 | shift = transforms.Normalize( 976 | mean=[0.5, 0.5, 0.5], 977 | std=[1., 1., 1.], 978 | ) 979 | 980 | # augmentations and convert to tensor 981 | 982 | transformTrain = transforms.Compose( 983 | [transforms.RandomHorizontalFlip(p=0.5), 984 | transforms.RandomCrop(32, padding=4), 985 | transforms.ToTensor(), 986 | shift]) 987 | 988 | transformTest = transforms.Compose( 989 | [transforms.ToTensor(), 990 | shift]) 991 | 992 | batchSize = 125 993 | 994 | if sDataset=='cifar10': 995 | # cifar10 is so much easier, 120 epochs is good enough 996 | lr = np.concatenate((1e-1 * np.ones((60, 1), dtype=np.float32), 997 | 1e-2 * np.ones((20, 1), dtype=np.float32), 998 | 1e-3 * np.ones((20, 1), dtype=np.float32), 999 | 1e-4 * np.ones((20, 1), dtype=np.float32)), axis=0) 1000 | 1001 | if sDataset=='cifar100': 1002 | lr = np.concatenate((1e-1 * np.ones((60), dtype=np.float32), 1003 | 1/np.geomspace(1/1e-1, 1/1e-2, num=20), 1004 | 1e-2 * np.ones((20), dtype=np.float32), 1005 | 1 / np.geomspace(1 / 1e-2, 1 / 1e-3, num=20), 1006 | 1e-3 * np.ones((20), dtype=np.float32), 1007 | 1 / np.geomspace(1 / 1e-3, 1 / 1e-4, num=20), 1008 | 1e-4 * np.ones((20), dtype=np.float32)), axis=0) 1009 | elif sDataset == 'stl10': 1010 | vFeat = [16, 32, 64, 128, 128] 1011 | 1012 | normalize = transforms.Normalize( 1013 | mean=[0.485, 0.456, 0.406], 1014 | std=[0.229, 0.224, 0.225], 1015 | ) 1016 | 1017 | # augmentations and convert to tensor 1018 | 1019 | transformTrain = transforms.Compose( 1020 | [transforms.RandomHorizontalFlip(p=0.5), 1021 | transforms.RandomCrop(96, padding=12), 1022 | transforms.ToTensor(), 1023 | normalize]) 1024 | 1025 | 1026 | transformTest = transforms.Compose( 1027 | [transforms.ToTensor(), 1028 | normalize]) 1029 | 1030 | batchSize = 30 1031 | 1032 | # TODO sDataset = tiny Image Net 1033 | 1034 | 1035 | # if on CPU, run small version of the model 1036 | if not torch.cuda.is_available() or gpu is None: 1037 | print('running CPU default vFeat and lr') 1038 | vFeat = [4, 7, 7] 1039 | lr = np.array([1e-1, 1e-4]) 1040 | nTrain = batchSize*4 + 5 1041 | nVal = batchSize + 2 1042 | device = torch.device('cpu') 1043 | else: 1044 | # specifyGPU(gpu) 1045 | device = torch.device('cuda:' + str(gpu)) 1046 | nTrain = None 1047 | nVal = None 1048 | 1049 | return vFeat, transformTrain, transformTest, batchSize, lr, tTheta, nTrain, nVal, device 1050 | 1051 | 1052 | 1053 | 1054 | 1055 | 1056 | 1057 | 1058 | 1059 | 1060 | 1061 | 1062 | ################################################### 1063 | # SAVING AND LOADING 1064 | ################################################### 1065 | 1066 | def saveResults(path, npArray): 1067 | sio.savemat(path, {'results': npArray}) 1068 | 1069 | def loadResults(path): 1070 | return sio.loadmat(path)['results'] 1071 | 1072 | def saveTensor(path, inputs): 1073 | sio.savemat(path, {'inputs' : inputs.cpu().numpy()}) 1074 | 1075 | def loadTensor(path, inputs=None): 1076 | loaded_inputs = torch.Tensor(sio.loadmat(path)) 1077 | if inputs is not None: 1078 | inputs = loaded_inputs 1079 | 1080 | return loaded_inputs 1081 | 1082 | 1083 | def saveParams(path,Kopen,Kresnet,Kconnect,Knorm,KtvNorm,W, runMean, runVar): 1084 | 1085 | par = {} 1086 | par["Kopen"] = Kopen.detach().cpu().numpy() 1087 | for i in range(len(Kresnet)): 1088 | key = "Kresnet_" + str(i) 1089 | par[key] = Kresnet[i].detach().cpu().numpy() 1090 | for i in range(len(Kconnect)): 1091 | key = "Kconnect_" + str(i) 1092 | par[key] = Kconnect[i].detach().cpu().numpy() 1093 | for i in range(len(Knorm)): 1094 | key = "Knorm_" + str(i) 1095 | par[key] = Knorm[i].detach().cpu().numpy() 1096 | for i in range(len(KtvNorm)): 1097 | key = "KtvNorm_" + str(i) 1098 | par[key] = KtvNorm[i].detach().cpu().numpy() 1099 | par["W"] = W.detach().cpu().numpy() 1100 | 1101 | for i in range(len(runMean)): 1102 | key = "runMean_" + str(i) 1103 | par[key] = runMean[i].detach().cpu().numpy() 1104 | for i in range(len(runVar)): 1105 | key = "runVar_" + str(i) 1106 | par[key] = runVar[i].detach().cpu().numpy() 1107 | 1108 | # print(par) 1109 | # data = {'parameters': par} 1110 | sio.savemat(path, par) 1111 | return 0 1112 | 1113 | def loadParams(path,Kopen,Kresnet,Kconnect,Knorm,KtvNorm,W,runMean, runVar, device): 1114 | loaded_params = sio.loadmat(path) 1115 | 1116 | # loaded_params['W'] - W.detach().numpy() # should be all 0s 1117 | # torch.Tensor(loaded_params['W']) - W # should be all 0s 1118 | 1119 | # load them back into the parameters 1120 | for name in loaded_params.keys(): 1121 | if name[0] != '_': # ignore the titles with first char '_' , __headers__ and such 1122 | if '_' in name: 1123 | word, idx = str.split(name, '_') 1124 | if word == 'Kconnect': 1125 | vars()[word][int(idx)].data = torch.Tensor(loaded_params[name]).to(device) 1126 | else: 1127 | vars()[word][int(idx)].data = torch.Tensor(loaded_params[name].squeeze()).to(device) 1128 | else: 1129 | vars()[name].data = torch.Tensor(loaded_params[name].squeeze()).to(device) 1130 | 1131 | return loaded_params 1132 | 1133 | 1134 | def testSaveLoad(Kopen,Kresnet,Kconnect,Knorm,KtvNorm,W): 1135 | import copy 1136 | 1137 | # import h5py # import tables # pytables 1138 | # save model 1139 | # PATH = 'modelWeights.idk' 1140 | # torch.save(KtvNorm, PATH) 1141 | 1142 | 1143 | 1144 | oldParams = copy.deepcopy([Kopen] + Kresnet + Kconnect + Knorm + KtvNorm + [W]) 1145 | 1146 | saveParams('par1.mat', Kopen, Kresnet, Kconnect, Knorm, KtvNorm, W) 1147 | 1148 | # clear # use zero_ ??? 1149 | # Kopen.data = torch.zeros(Kopen.shape) 1150 | Kopen.data.zero_() 1151 | for i in range(len(Kresnet)): 1152 | Kresnet[i].data = torch.zeros(Kresnet[i].shape) 1153 | for i in range(len(Kconnect)): 1154 | Kconnect[i].data = torch.zeros(Kconnect[i].shape) 1155 | for i in range(len(Knorm)): 1156 | Knorm[i].data = torch.zeros(Knorm[i].shape) 1157 | for i in range(len(KtvNorm)): 1158 | KtvNorm[i].data = torch.zeros(KtvNorm[i].shape) 1159 | W.data = torch.zeros(W.shape) 1160 | 1161 | loadParams('par1.mat',Kopen,Kresnet,Kconnect,Knorm,KtvNorm,W) 1162 | 1163 | print("\nnorm difference") 1164 | print(listNorm(oldParams, [Kopen] + Kresnet + Kconnect + Knorm + KtvNorm + [W])) 1165 | print('\n') 1166 | 1167 | 1168 | 1169 | 1170 | 1171 | # Testing functions 1172 | 1173 | 1174 | 1175 | def func(x,K,K2): 1176 | # z = 2*K 1177 | z = conv3x3(x,K) 1178 | # z = DoubleSymLayer(x,K) 1179 | # z = tvNorm(K) 1180 | # z = rk4DoubleSymBlock(x, [K,K2], tvnorm=False, weightBias = [None]*10) 1181 | 1182 | # vFeat = [4,8,12,16] 1183 | # nChanIn = 3 1184 | # nClasses = 10 1185 | # tY = [0,1,2,3,4] 1186 | # tTheta = tY 1187 | # dynamicScheme = rk4 1188 | # layer = DoubleSymLayer 1189 | # net = RKNet(tY, tTheta, nChanIn, nClasses, vFeat, dynamicScheme=dynamicScheme, layer=layer) 1190 | # 1191 | 1192 | return z 1193 | 1194 | def dCheckDoubleSym(): 1195 | # derivative check for the DoubleSym Layer 1196 | # define function f(b) 1197 | # b = torch.Tensor([1.,1.,2.,2.]) 1198 | b = torch.randn(4,4,3,3) 1199 | b2 = torch.randn(4,4,3,3) 1200 | b.requires_grad = True 1201 | 1202 | x = torch.randn(2, 4, 3, 3) # 2 images, 4 channels, 3x3 pixels 1203 | x.requires_grad = False 1204 | # x = torch.randn(4) 1205 | 1206 | # f = DoubleSymLayer(x, b, tvnorm=False) 1207 | # f = conv3x3(x,b) 1208 | # f = 2 * b 1209 | 1210 | f = func(x, b, b2) 1211 | 1212 | v = torch.randn(4*4*3*3) 1213 | # v = torch.randn(4) 1214 | 1215 | # K = torch.Tensor([5,7,11,13]) 1216 | # K = torch.Tensor(torch.randn(4, 4, 3, 3)) 1217 | K = torch.randn(4,4,3,3) 1218 | K2 = torch.randn(4,4,3,3) 1219 | dx = torch.ones(2, 4, 3, 3) 1220 | 1221 | torch.autograd.backward([f], [dx, K, K2] ) # stores gradient of f with respect to K in b.grad 1222 | J = b.grad.view(1,-1) 1223 | 1224 | Jv = torch.matmul(J,v) 1225 | 1226 | 1227 | err = np.zeros((3,30)) 1228 | for k in range(1,30): 1229 | h = 2**(-k) 1230 | 1231 | # fK = 2*K 1232 | fK = func(x,K, K2) 1233 | 1234 | first = func(x, K+(h*v).view(K.shape), K2) - fK 1235 | 1236 | # first = DoubleSymLayer( x, K+(h*v).view(K.shape) , tvnorm=False ) - f 1237 | # first = conv3x3(x, K + (h * v).view(K.shape)) - f 1238 | # first = 2* ( K + (h * v).view(K.shape) ) - fK 1239 | 1240 | E0 = torch.norm(first) 1241 | E1 = torch.norm(first - h*Jv) 1242 | 1243 | print('h=%1.2e E0=%1.2e E1=%1.2e' % (h,E0,E1)) 1244 | err[0, k - 1] = h 1245 | err[1, k - 1] = E0 1246 | err[2, k - 1] = E1 1247 | 1248 | 1249 | return err 1250 | 1251 | 1252 | 1253 | 1254 | 1255 | def inter1D(theta,tTheta,tInter): 1256 | """ 1257 | 1D (linear) interpolation. For the observations theta at tTheta, find the observations at points ti. 1258 | - ASSUMPTIONS: all tensors in the list theta have the same dimension and tTheta is sorted ascending 1259 | - theta are my K parameters 1260 | :param theta: list of Tensors, think of them as measurements : f(x0) , f(x1), f(x2) , ... 1261 | :param tTheta: points where we have the measurements: x0,x1,x2,... 1262 | :param tInter: points where we want to have an approximate function value: a,b,c,d,... 1263 | :return: inter: approximate values f(a), f(b), f(c), f(d) using the assumption 1264 | that connecting every successive theta to its previous one with a line 1265 | """ 1266 | if len(theta) != len(tTheta): 1267 | print('inter1D: lengths do not match') 1268 | return -1 1269 | 1270 | nPoints = len(tInter) # number of interpolation time points 1271 | # make list of length same as tInter, and each element is of size K 1272 | # interpK = thi 1273 | # assume all tensors in the list theta have the same dimensions 1274 | inter = [torch.zeros(theta[0].shape)] * nPoints 1275 | 1276 | 1277 | for k in range(nPoints): 1278 | # get K for current time point 1279 | # assume tTheta is sorted 1280 | if tInter[k] <= tTheta[0]: # if interp point is outside the tTheta range 1281 | inter[k] = 1*theta[0] 1282 | elif tInter[k] >= tTheta[-1]: 1283 | inter[k] = 1*theta[-1] 1284 | # elif tInter[k] in tTheta: # if interp point is already in the list, give that function value 1285 | # idxTh = tTheta.index(tInter[k]) 1286 | # inter[k] = theta[idxTh] 1287 | else: 1288 | idxTh = bisect(tTheta, tInter[k]) 1289 | # idxTh contains index of right point in tTheta to use for interpolation 1290 | leftTh = tTheta[idxTh-1] 1291 | rightTh = tTheta[idxTh] 1292 | h = rightTh - leftTh 1293 | wtLeft = (rightTh - tInter[k]) / h 1294 | wtRight = (tInter[k] - leftTh) / h 1295 | inter[k] = wtLeft*theta[idxTh-1] + wtRight*theta[idxTh] 1296 | 1297 | return inter 1298 | 1299 | 1300 | 1301 | def testInter1D(): 1302 | """test example for inter1D""" 1303 | t = torch.ones(2,1) 1304 | K = [2*t, 4*t,6*t,5*t] 1305 | tTheta = [1,2,4,5] 1306 | ti = [0.7,1.0,1.9,2.0,2.1,2.4,2.7,3,3.3,5.0, 5.4] 1307 | interpolatedK = inter1D(K, tTheta, ti) 1308 | print(interpolatedK) 1309 | return 0 1310 | 1311 | 1312 | --------------------------------------------------------------------------------