├── .gitignore ├── README.md ├── figures ├── Figure_1.jpg └── Figure_2.jpg ├── main_cifar.py ├── main_imagenet.py ├── prednet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Python egg metadata, regenerated from source files by setuptools. 9 | /*.egg-info 10 | .eggs/ 11 | 12 | # PyPI distribution artifacts. 13 | build/ 14 | dist/ 15 | 16 | # Sublime project files 17 | *.sublime-project 18 | *.sublime-workspace 19 | 20 | # Tests 21 | .pytest_cache/ 22 | 23 | # Other 24 | *.DS_Store 25 | .idea/ 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCN with Local Recurrent Processing 2 | This repository contains the code for PCN with local recurrent processing introduced in the following paper: 3 | 4 | [Deep Predictive Coding Network with Local Recurrent Processing for Object Recognition](https://arxiv.org/abs/1805.07526) (NIPS2018) 5 | 6 | Kuan Han, Haiguang Wen, Yizhen Zhang, Di Fu, Eugenio Culurciello, Zhongming Liu 7 | 8 | The code is built on Pytorch. 9 | 10 | ## Introduction 11 | 12 | Deep predictive coding network (PCN) with local recurrent processing is a bi-directional and dynamical neural network with local recurrent processing, inspired by predictive coding in neuroscience. Unlike any feedforward-only convolutional neural network, PCN includes both feedback connections, which carry top-down predictions, and feedforward connections, which carry bottom-up errors of prediction. Feedback and feedforward connections enable adjacent layers to interact locally and recurrently to refine representations towards minimization of layer-wise prediction errors. When unfolded over time, the recurrent processing gives rise to an increasingly deeper hierarchy of non-linear transformation, allowing a shallow network to dynamically extend itself into an arbitrarily deep network. We train and test PCN for image classification with SVHN, CIFAR and ImageNet datasets. Despite notably fewer layers and parameters, PCN achieves competitive performance compared to classical and state-of-the-art models. The internal representations in PCN converge over time and yield increasingly better accuracy in object recognition. 13 | 14 | ![Image of pcav1](https://github.com/libilab/PCN_v2/blob/master/figures/Figure_1.jpg) 15 | (a) The plain model (left) is a feedforward CNN with 3×3 convolutional connections (solid arrows) and 1×1 bypass connections (dashed arrows). 16 | 17 | (b) On the basis of the plain model, the local PCN (right) uses additional feedback (solid arrows) and recurrent (circular arrows) connections. The PCN consists of a stack of basic building blocks. Each block runs multiple cycles of local recurrent processing between adjacent layers, and merges its input to its output through the bypass connections. The output from one block is then sent to its next block to initiate local recurrent processing in a higher block. It further continues until reaching the top of the network. 18 | 19 | ## Usages 20 | 21 | ### To train PCN with local recurrent processing on ImageNet 22 | For dependencies and the ImageNet dataset, see the instructions [here](https://github.com/pytorch/examples/tree/master/imagenet). 23 | 24 | As an example, the following command trains a PCN with default setting on ImageNet: 25 | ```bash 26 | python main_imagenet.py --data /Path/to/ImageNet/Dataset/Folder 27 | ``` 28 | 29 | ### To train PCN with local recurrent processing on CIFAR 30 | 31 | As an example, the following command trains a PCN with default setting on CIFAR100: 32 | ```bash 33 | python main_cifar.py 34 | ``` 35 | 36 | ## Results 37 | ![Image of pcav1](https://github.com/libilab/PCN_v2/blob/master/figures/Figure_2.jpg) 38 | PCN shows better categorization performance given more cycles of recurrent processing, for CIFAR-10, CIFAR-100 and ImageNet. The red dash line represents the accuracy of the plain model. 39 | 40 | ## Updates 41 | 10/17/2018: 42 | 43 | (1) readme file. 44 | 45 | ## Contact 46 | For any questions and comments, please [contact us](https://engineering.purdue.edu/libi/lab/Home.html). 47 | -------------------------------------------------------------------------------- /figures/Figure_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libilab/PCN-with-Local-Recurrent-Processing/36095b45c7ed534c543f14ef28eb2695ec636c18/figures/Figure_1.jpg -------------------------------------------------------------------------------- /figures/Figure_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libilab/PCN-with-Local-Recurrent-Processing/36095b45c7ed534c543f14ef28eb2695ec636c18/figures/Figure_2.jpg -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import argparse 11 | from prednet import * 12 | from utils import progress_bar 13 | from torch.autograd import Variable 14 | 15 | def main_cifar(model='PredNetBpD', circles=5, gpunum=1, Tied=False, weightDecay=1e-3, nesterov=False): 16 | use_cuda = True # torch.cuda.is_available() 17 | best_acc = 0 # best test accuracy 18 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 19 | batchsize = 128 20 | root = './' 21 | rep = 1 22 | lr = 0.01 23 | 24 | models = {'PredNetBpD':PredNetBpD} 25 | modelname = model+'_'+str(circles)+'CLS_'+str(nesterov)+'Nes_'+str(weightDecay)+'WD_'+str(Tied)+'TIED_'+str(rep)+'REP' 26 | 27 | # clearn folder 28 | checkpointpath = root+'checkpoint/' 29 | logpath = root+'log/' 30 | if not os.path.isdir(checkpointpath): 31 | os.mkdir(checkpointpath) 32 | if not os.path.isdir(logpath): 33 | os.mkdir(logpath) 34 | while(os.path.isfile(logpath+'training_stats_'+modelname+'.txt')): 35 | rep += 1 36 | modelname = model+'_'+str(circles)+'CLS_'+str(nesterov)+'Nes_'+str(weightDecay)+'WD_'+str(Tied)+'TIED_'+str(rep)+'REP' 37 | 38 | # Data 39 | print('==> Preparing data..') 40 | transform_train = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 45 | transform_test = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 48 | trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train) 49 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=2) 50 | testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test) 51 | testloader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=False, num_workers=2) 52 | 53 | # Model 54 | print('==> Building model..') 55 | net = models[model](num_classes=100,cls=circles,Tied=Tied) 56 | 57 | 58 | # Define objective function 59 | criterion = nn.CrossEntropyLoss() 60 | optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=weightDecay, nesterov=nesterov) 61 | 62 | # Parallel computing 63 | if use_cuda: 64 | net.cuda() 65 | net = torch.nn.DataParallel(net, device_ids=range(gpunum)) 66 | cudnn.benchmark = True 67 | 68 | # item() is a recent addition, so this helps with backward compatibility. 69 | def to_python_float(t): 70 | if hasattr(t, 'item'): 71 | return t.item() 72 | else: 73 | return t[0] 74 | 75 | # Training 76 | def train(epoch): 77 | print('\nEpoch: %d' % epoch) 78 | net.train() 79 | train_loss = 0 80 | correct = 0 81 | total = 0 82 | 83 | training_setting = 'batchsize=%d | epoch=%d | lr=%.1e ' % (batchsize, epoch, optimizer.param_groups[0]['lr']) 84 | statfile.write('\nTraining Setting: '+training_setting+'\n') 85 | 86 | for batch_idx, (inputs, targets) in enumerate(trainloader): 87 | if use_cuda: 88 | inputs, targets = inputs.cuda(), targets.cuda() 89 | optimizer.zero_grad() 90 | inputs, targets = Variable(inputs), Variable(targets) 91 | outputs = net(inputs) 92 | loss = criterion(outputs, targets) 93 | loss.backward() 94 | optimizer.step() 95 | 96 | train_loss += to_python_float(loss.data) 97 | _, predicted = torch.max(outputs.data, 1) 98 | total += targets.size(0) 99 | correct += predicted.eq(targets.data).float().cpu().sum() 100 | 101 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 102 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 103 | statstr = 'Training: Epoch=%d | Loss: %.3f | Acc: %.3f%% (%d/%d) | best acc: %.3f' \ 104 | % (epoch, train_loss/(batch_idx+1), 100.*correct/total, correct, total, best_acc) 105 | statfile.write(statstr+'\n') 106 | 107 | 108 | # Testing 109 | def test(epoch): 110 | nonlocal best_acc 111 | net.eval() 112 | test_loss = 0 113 | correct = 0 114 | total = 0 115 | with torch.no_grad(): 116 | for batch_idx, (inputs, targets) in enumerate(testloader): 117 | if use_cuda: 118 | inputs, targets = inputs.cuda(), targets.cuda() 119 | inputs, targets = Variable(inputs), Variable(targets) 120 | outputs = net(inputs) 121 | loss = criterion(outputs, targets) 122 | 123 | test_loss += to_python_float(loss.data) 124 | _, predicted = torch.max(outputs.data, 1) 125 | total += targets.size(0) 126 | correct += predicted.eq(targets.data).float().cpu().sum() 127 | 128 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 129 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 130 | statstr = 'Testing: Epoch=%d | Loss: %.3f | Acc: %.3f%% (%d/%d) | best_acc: %.3f' \ 131 | % (epoch, test_loss/(batch_idx+1), 100.*correct/total, correct, total, best_acc) 132 | statfile.write(statstr+'\n') 133 | 134 | # Save checkpoint. 135 | acc = 100.*correct/total 136 | state = { 137 | 'net': net.state_dict(), 138 | 'acc': acc, 139 | 'epoch': epoch, 140 | } 141 | torch.save(state, checkpointpath + modelname + '_last_ckpt.t7') 142 | if acc >= best_acc: 143 | print('Saving..') 144 | torch.save(state, checkpointpath + modelname + '_best_ckpt.t7') 145 | best_acc = acc 146 | 147 | # Set adaptive learning rates 148 | def decrease_learning_rate(): 149 | """Decay the previous learning rate by 10""" 150 | for param_group in optimizer.param_groups: 151 | param_group['lr'] /= 10 152 | 153 | 154 | for epoch in range(start_epoch, start_epoch+300): 155 | statfile = open(logpath+'training_stats_'+modelname+'.txt', 'a+') 156 | if epoch==150 or epoch==225 or epoch == 262: 157 | decrease_learning_rate() 158 | train(epoch) 159 | test(epoch) 160 | 161 | if __name__ == '__main__': 162 | main_cifar() 163 | -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | from prednet import * 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 19 | 20 | parser.add_argument('--data', default='/Path/to/ImageNet/Dataset', type=str, 21 | help='path to dataset') 22 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('--epochs', default=130, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 27 | help='manual epoch number (useful on restarts)') 28 | parser.add_argument('-b', '--batch-size', default=128, type=int, 29 | metavar='N', help='mini-batch size (default: 256)') 30 | parser.add_argument('-c', '--circles', default=3, type=int, 31 | metavar='N', help='PCN cicles (default: 5)') 32 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 33 | metavar='LR', help='initial learning rate') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--print-freq', '-p', default=10, type=int, 39 | metavar='N', help='print frequency (default: 10)') 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 41 | help='path to latest checkpoint (default: none)') 42 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 43 | help='evaluate model on validation set') 44 | parser.add_argument('--world-size', default=1, type=int, 45 | help='number of distributed processes') 46 | parser.add_argument('--dist-rank', default=1, type=int, 47 | help='number of distributed processes') 48 | parser.add_argument('--dist-url', default='tcp://10.0.0.10:23456', type=str, 49 | help='url used to set up distributed training') 50 | parser.add_argument('--dist-backend', default='gloo', type=str, 51 | help='distributed backend') 52 | 53 | best_prec1 = 0 54 | 55 | 56 | def main_imagenet(): 57 | global args, best_prec1 58 | args = parser.parse_args() 59 | 60 | args.distributed = args.world_size > 1 61 | 62 | # Distrubted Training if possible 63 | if args.distributed: 64 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 65 | world_size=args.world_size, rank=args.dist_rank) 66 | 67 | # Create model 68 | model_name = 'PredNetBpE' 69 | models = {'PredNetBpE':PredNetBpE} 70 | modelname = model_name+'_'+str(args.circles)+'CLS' 71 | print("=> creating model '{}'".format(modelname)) 72 | model = models[model_name](num_classes=1000,cls=args.circles) 73 | 74 | # Create path 75 | root = './' 76 | checkpointpath = root+'checkpoint/' 77 | logpath = root+'log/' 78 | if not os.path.isdir(checkpointpath): 79 | os.mkdir(checkpointpath) 80 | if not os.path.isdir(logpath): 81 | os.mkdir(logpath) 82 | 83 | # Put model into GPU or Distribute it 84 | if not args.distributed: 85 | model = torch.nn.DataParallel(model).cuda() 86 | else: 87 | model.cuda() 88 | model = torch.nn.parallel.DistributedDataParallel(model) 89 | cudnn.benchmark = True 90 | 91 | # Define loss function and optimizer 92 | criterion = nn.CrossEntropyLoss().cuda() 93 | 94 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 95 | momentum=args.momentum, 96 | weight_decay=args.weight_decay) 97 | 98 | # Resume from checkpoint if needed 99 | if args.resume: 100 | if os.path.isfile(args.resume): 101 | print("=> loading checkpoint '{}'".format(args.resume)) 102 | checkpoint = torch.load(args.resume) 103 | args.start_epoch = checkpoint['epoch'] 104 | best_prec1 = checkpoint['best_prec1'] 105 | model.load_state_dict(checkpoint['state_dict']) 106 | optimizer.load_state_dict(checkpoint['optimizer']) 107 | print("=> loaded checkpoint '{}' (epoch {})" 108 | .format(args.resume, checkpoint['epoch'])) 109 | else: 110 | print("=> no checkpoint found at '{}'".format(args.resume)) 111 | 112 | 113 | 114 | # Load Training Data 115 | traindir = os.path.join(args.data, 'train') 116 | valdir = os.path.join(args.data, 'val') 117 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 118 | std=[0.229, 0.224, 0.225]) 119 | 120 | train_dataset = datasets.ImageFolder( 121 | traindir, 122 | transforms.Compose([ 123 | transforms.RandomResizedCrop(224), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ToTensor(), 126 | normalize, 127 | ])) 128 | 129 | if args.distributed: 130 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 131 | else: 132 | train_sampler = None 133 | 134 | #training dataloader 135 | train_loader = torch.utils.data.DataLoader( 136 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 137 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 138 | 139 | #validation dataloader 140 | val_loader = torch.utils.data.DataLoader( 141 | datasets.ImageFolder(valdir, transforms.Compose([ 142 | transforms.Resize(256), 143 | transforms.CenterCrop(224), 144 | transforms.ToTensor(), 145 | normalize, 146 | ])), 147 | batch_size=args.batch_size, shuffle=False, 148 | num_workers=args.workers, pin_memory=True) 149 | 150 | # 10-crop validation dataloader 151 | # Reference:https://discuss.pytorch.org/t/how-to-properly-do-10-crop-testing-on-imagenet/11341 152 | val_loader_10 = torch.utils.data.DataLoader( 153 | datasets.ImageFolder(valdir, transforms.Compose([ 154 | transforms.Resize(256), 155 | transforms.TenCrop(224), 156 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), 157 | transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])), 158 | ])), 159 | batch_size=args.batch_size, shuffle=False, 160 | num_workers=args.workers, pin_memory=True) 161 | 162 | #evaluate model if needed 163 | if args.evaluate: 164 | validate(val_loader, model, criterion) 165 | return 166 | 167 | for epoch in range(args.start_epoch, args.epochs): 168 | statfile = open(logpath+'training_stats_'+modelname+'.txt', 'a+') 169 | if args.distributed: 170 | train_sampler.set_epoch(epoch) 171 | adjust_learning_rate(optimizer, epoch) 172 | 173 | # train for one epoch 174 | statstr = train(train_loader, model, criterion, optimizer, epoch) 175 | statfile.write(statstr+'\n') 176 | 177 | # evaluate on validation set with single crop testing 178 | prec1, prec5, statstr = validate(val_loader, model, criterion, epoch, 1) 179 | statfile.write(statstr+'\n') 180 | 181 | # evaluate on validation set with 10 crop testing 182 | prec1_10, prec5_10, statstr_10 = validate_10(val_loader_10, model, criterion, epoch, 10) 183 | statfile.write(statstr_10+'\n') 184 | 185 | # remember best prec@1 and save checkpoint 186 | is_best = prec1 > best_prec1 187 | best_prec1 = max(prec1, best_prec1) 188 | save_checkpoint({ 189 | 'epoch': epoch + 1, 190 | 'name': modelname, 191 | 'state_dict': model.state_dict(), 192 | 'best_prec1': best_prec1, 193 | 'prec1': prec1, 194 | 'prec5': prec5, 195 | 'optimizer' : optimizer.state_dict(), 196 | }, is_best, checkpointpath) 197 | 198 | 199 | def train(train_loader, model, criterion, optimizer, epoch): 200 | batch_time = AverageMeter() 201 | data_time = AverageMeter() 202 | losses = AverageMeter() 203 | top1 = AverageMeter() 204 | top5 = AverageMeter() 205 | 206 | # switch to train mode 207 | model.train() 208 | 209 | end = time.time() 210 | for i, (input, target) in enumerate(train_loader): 211 | # measure data loading time 212 | data_time.update(time.time() - end) 213 | 214 | target = target.cuda(async=True) 215 | input_var = torch.autograd.Variable(input) 216 | target_var = torch.autograd.Variable(target) 217 | 218 | # compute output 219 | output = model(input_var) 220 | loss = criterion(output, target_var) 221 | 222 | # measure accuracy and record loss 223 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 224 | losses.update(to_python_float(loss.data), input.size(0)) 225 | top1.update(to_python_float(prec1), input.size(0)) 226 | top5.update(to_python_float(prec5), input.size(0)) 227 | 228 | # compute gradient and do SGD step 229 | optimizer.zero_grad() 230 | loss.backward() 231 | optimizer.step() 232 | 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | 238 | #print training status with certain frequency 'print_freq' 239 | if i % args.print_freq == 0: 240 | print('Epoch: [{0}][{1}/{2}]\t' 241 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 242 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 243 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 244 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 245 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 246 | epoch, i, len(train_loader), batch_time=batch_time, 247 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 248 | statstr='Train-Epoch: [{0}] | Time ({batch_time.avg:.3f}) | Data ({data_time.avg:.3f}) | Loss ({loss.avg:.4f}) | Prec@1 ({top1.avg:.3f}) | Prec@5 ({top5.avg:.3f})'.format( 249 | epoch, batch_time=batch_time, 250 | data_time=data_time, loss=losses, top1=top1, top5=top5) 251 | # statfile.write(statstr+'\n') 252 | return statstr 253 | 254 | 255 | def validate(val_loader, model, criterion, epoch, crop_num = 1): 256 | batch_time = AverageMeter() 257 | losses = AverageMeter() 258 | top1 = AverageMeter() 259 | top5 = AverageMeter() 260 | 261 | # switch to evaluate mode 262 | model.eval() 263 | 264 | end = time.time() 265 | with torch.no_grad(): 266 | for i, (input, target) in enumerate(val_loader): 267 | target = target.cuda(async=True) 268 | input_var = torch.autograd.Variable(input) 269 | target_var = torch.autograd.Variable(target) 270 | 271 | # compute output 272 | output = model(input_var) 273 | loss = criterion(output, target_var) 274 | 275 | # measure accuracy and record loss 276 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 277 | losses.update(to_python_float(loss.data), input.size(0)) 278 | top1.update(to_python_float(prec1), input.size(0)) 279 | top5.update(to_python_float(prec5), input.size(0)) 280 | 281 | # measure elapsed time 282 | batch_time.update(time.time() - end) 283 | end = time.time() 284 | 285 | #print validation status with certain frequency 'print_freq' 286 | if i % args.print_freq == 0: 287 | print('{0}-crop-validation\t' 288 | 'Test: [{1}/{2}]\t' 289 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 290 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 291 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 292 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 293 | crop_num, i, len(val_loader), batch_time=batch_time, loss=losses, 294 | top1=top1, top5=top5)) 295 | 296 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 297 | .format(top1=top1, top5=top5)) 298 | statstr = str(crop_num)+'-crop-validation-Epoch: [{0}] | Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | Loss {loss.val:.4f} ({loss.avg:.4f}) | Prec@1 {top1.val:.3f} ({top1.avg:.3f}) | Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 299 | epoch, batch_time=batch_time, loss=losses, 300 | top1=top1, top5=top5) 301 | # statfile.write(statstr+'\n') 302 | 303 | return top1.avg, top5.avg, statstr 304 | 305 | def validate_10(val_loader, model, criterion, epoch, crop_num = 10): 306 | batch_time = AverageMeter() 307 | losses = AverageMeter() 308 | top1 = AverageMeter() 309 | top5 = AverageMeter() 310 | 311 | # switch to evaluate mode 312 | model.eval() 313 | 314 | end = time.time() 315 | with torch.no_grad(): 316 | for i, (input, target) in enumerate(val_loader): 317 | target = target.cuda(async=True) 318 | input_var = torch.autograd.Variable(input) 319 | target_var = torch.autograd.Variable(target) 320 | 321 | # compute output 322 | bs, ncrops, c, h, w = input_var.size() 323 | temp_output = model(input_var.view(-1, c, h, w)) 324 | output = temp_output.view(bs, ncrops, -1).mean(1) 325 | loss = criterion(output, target_var) 326 | 327 | # measure accuracy and record loss 328 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 329 | losses.update(to_python_float(loss.data), input.size(0)) 330 | top1.update(to_python_float(prec1), input.size(0)) 331 | top5.update(to_python_float(prec5), input.size(0)) 332 | 333 | # measure elapsed time 334 | batch_time.update(time.time() - end) 335 | end = time.time() 336 | 337 | #print 10-crop validation status with certain frequency 'print_freq' 338 | if i % args.print_freq == 0: 339 | print('{0}-crop-validation\t' 340 | 'Test: [{1}/{2}]\t' 341 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 342 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 343 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 344 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 345 | crop_num, i, len(val_loader), batch_time=batch_time, loss=losses, 346 | top1=top1, top5=top5)) 347 | 348 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 349 | .format(top1=top1, top5=top5)) 350 | statstr = str(crop_num)+'-crop-validation-Epoch: [{0}] | Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | Loss {loss.val:.4f} ({loss.avg:.4f}) | Prec@1 {top1.val:.3f} ({top1.avg:.3f}) | Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 351 | epoch, batch_time=batch_time, loss=losses, 352 | top1=top1, top5=top5) 353 | # statfile.write(statstr+'\n') 354 | 355 | return top1.avg, top5.avg, statstr 356 | 357 | 358 | def save_checkpoint(state, is_best, checkpointpath, filename='checkpoint.pth.tar'): 359 | '''Save model''' 360 | torch.save(state, checkpointpath+filename) 361 | if is_best: 362 | shutil.copyfile(checkpointpath+filename, checkpointpath+'model_best.pth.tar') 363 | 364 | # Reference: https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py 365 | # item() is a recent addition in pytorch 0.4.0, this function help code run both at pytorch 0.3 and pytorch 0.4 version 366 | def to_python_float(t): 367 | if hasattr(t, 'item'): 368 | return t.item() 369 | else: 370 | return t[0] 371 | 372 | class AverageMeter(object): 373 | """Computes and stores the average and current value""" 374 | def __init__(self): 375 | self.reset() 376 | 377 | def reset(self): 378 | self.val = 0 379 | self.avg = 0 380 | self.sum = 0 381 | self.count = 0 382 | 383 | def update(self, val, n=1): 384 | self.val = val 385 | self.sum += val * n 386 | self.count += n 387 | self.avg = self.sum / self.count 388 | 389 | 390 | def adjust_learning_rate(optimizer, epoch): 391 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 392 | lr = args.lr * (0.1 ** (epoch // 30)) 393 | for param_group in optimizer.param_groups: 394 | param_group['lr'] = lr 395 | 396 | 397 | def accuracy(output, target, topk=(1,)): 398 | """Computes the precision@k for the specified values of k""" 399 | maxk = max(topk) 400 | batch_size = target.size(0) 401 | 402 | _, pred = output.topk(maxk, 1, True, True) 403 | pred = pred.t() 404 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 405 | 406 | res = [] 407 | for k in topk: 408 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 409 | res.append(correct_k.mul_(100.0 / batch_size)) 410 | return res 411 | 412 | 413 | if __name__ == '__main__': 414 | main_imagenet() 415 | -------------------------------------------------------------------------------- /prednet.py: -------------------------------------------------------------------------------- 1 | '''PredNet in PyTorch.''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import math 8 | 9 | # As in the paper, the first layer is a regular conv layer (to reduce memory consumption) 10 | class features2(nn.Module): 11 | def __init__(self, inchan, outchan, kernel_size=7, stride=2, padding=3, bias=False): 12 | super().__init__() 13 | self.conv = nn.Conv2d(inchan, outchan, kernel_size, stride, padding, bias=bias) 14 | self.featBN = nn.BatchNorm2d(outchan) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | def forward(self, x): 18 | y = self.relu(self.featBN(self.conv(x))) 19 | return y 20 | 21 | class PcConvBp(nn.Module): 22 | def __init__(self, inchan, outchan, kernel_size=3, stride=1, padding=1, cls=0, bias=False): 23 | super().__init__() 24 | self.FFconv = nn.Conv2d(inchan, outchan, kernel_size, stride, padding, bias=bias) 25 | self.FBconv = nn.ConvTranspose2d(outchan, inchan, kernel_size, stride, padding, bias=bias) 26 | self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1,outchan,1,1))]) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.cls = cls 29 | self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False) 30 | 31 | def forward(self, x): 32 | y = self.relu(self.FFconv(x)) 33 | b0 = F.relu(self.b0[0]+1.0).expand_as(y) 34 | for _ in range(self.cls): 35 | y = self.FFconv(self.relu(x - self.FBconv(y)))*b0 + y 36 | y = y + self.bypass(x) 37 | return y 38 | 39 | ''' Architecture PredNetBpE ''' 40 | class PredNetBpE(nn.Module): 41 | def __init__(self, num_classes=1000, cls=0, Tied = False): 42 | super().__init__() 43 | self.ics = [ 3, 64, 64, 128, 128, 128, 128, 256, 256, 256, 512, 512] # input chanels 44 | self.ocs = [ 64, 64, 128, 128, 128, 128, 256, 256, 256, 512, 512, 512] # output chanels 45 | self.maxpool = [False,False, True,False, True,False, True,False,False, True,False,False] # downsample flag 46 | self.cls = cls # num of time steps 47 | self.nlays = len(self.ics) 48 | 49 | self.baseconv = features2(self.ics[0], self.ocs[0]) 50 | # construct PC layers 51 | # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. 52 | if Tied == False: 53 | self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(1, self.nlays)]) 54 | else: 55 | self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(1, self.nlays)]) 56 | self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(1, self.nlays)]) 57 | # Linear layer 58 | self.linear = nn.Linear(self.ocs[-1], num_classes) 59 | self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.BNend = nn.BatchNorm2d(self.ocs[-1]) 62 | 63 | def forward(self, x): 64 | x = self.baseconv(x) 65 | for i in range(self.nlays-1): 66 | x = self.BNs[i](x) 67 | x = self.PcConvs[i](x) # ReLU + Conv 68 | if self.maxpool[i]: 69 | x = self.maxpool2d(x) 70 | 71 | # classifier 72 | out = self.relu(self.BNend(x)) 73 | out = F.avg_pool2d(out, kernel_size=7, stride=1) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | ''' Architecture PredNetBpD ''' 79 | class PredNetBpD(nn.Module): 80 | def __init__(self, num_classes=10, cls=0, Tied = False): 81 | super().__init__() 82 | self.ics = [3, 64, 64, 128, 128, 256, 256, 512] # input chanels 83 | self.ocs = [64, 64, 128, 128, 256, 256, 512, 512] # output chanels 84 | self.maxpool = [False, False, True, False, True, False, False, False] # downsample flag 85 | self.cls = cls # num of time steps 86 | self.nlays = len(self.ics) 87 | 88 | # construct PC layers 89 | # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. 90 | if Tied == False: 91 | self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) 92 | else: 93 | self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) 94 | self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) 95 | # Linear layer 96 | self.linear = nn.Linear(self.ocs[-1], num_classes) 97 | self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.BNend = nn.BatchNorm2d(self.ocs[-1]) 100 | 101 | def forward(self, x): 102 | for i in range(self.nlays): 103 | x = self.BNs[i](x) 104 | x = self.PcConvs[i](x) # ReLU + Conv 105 | if self.maxpool[i]: 106 | x = self.maxpool2d(x) 107 | 108 | # classifier 109 | out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1)) 110 | out = out.view(out.size(0), -1) 111 | out = self.linear(out) 112 | return out 113 | 114 | ''' Architecture PredNetBpD ''' 115 | class PredNetBpC(nn.Module): 116 | def __init__(self, num_classes=10, cls=0, Tied = False): 117 | super().__init__() 118 | self.ics = [3, 64, 64, 128, 128, 256, 256, 256] # input chanels 119 | self.ocs = [64, 64, 128, 128, 256, 256, 256, 256] # output chanels 120 | self.maxpool = [False, False, True, False, True, False, False, False] # downsample flag 121 | self.cls = cls # num of time steps 122 | self.nlays = len(self.ics) 123 | 124 | # construct PC layers 125 | # Unlike PCN v1, we do not have a tied version here. We may or may not incorporate a tied version in the future. 126 | if Tied == False: 127 | self.PcConvs = nn.ModuleList([PcConvBp(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) 128 | else: 129 | self.PcConvs = nn.ModuleList([PcConvBpTied(self.ics[i], self.ocs[i], cls=self.cls) for i in range(self.nlays)]) 130 | self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)]) 131 | # Linear layer 132 | self.linear = nn.Linear(self.ocs[-1], num_classes) 133 | self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.BNend = nn.BatchNorm2d(self.ocs[-1]) 136 | 137 | def forward(self, x): 138 | for i in range(self.nlays): 139 | x = self.BNs[i](x) 140 | x = self.PcConvs[i](x) # ReLU + Conv 141 | if self.maxpool[i]: 142 | x = self.maxpool2d(x) 143 | 144 | # classifier 145 | out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1)) 146 | out = out.view(out.size(0), -1) 147 | out = self.linear(out) 148 | return out 149 | 150 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | --------------------------------------------------------------------------------