├── README.md ├── figures ├── Figure1.png └── fig_3.png ├── prednet.py ├── run.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # PCN with Global Recurrent Processing 2 | This repository contains the code for PCN with global recurrent processing introduced in the following paper: 3 | 4 | [Deep Predictive Coding Network for Object Recognition](https://arxiv.org/abs/1802.04762) (ICML2018) 5 | 6 | Haiguang Wen, Kuan Han, Junxing Shi, Yizhen Zhang, Eugenio Culurciello, Zhongming Liu 7 | 8 | The code is built on Pytorch 9 | 10 | ## Introduction 11 | 12 | Deep predictive coding network (PCN) with global recurrent processing is a bi-directional and recurrent neural net, based on the predictive coding theory in neuroscience. It has feedforward, feedback, and recurrent connections. Feedback connections from a higher layer carry the prediction of its lower-layer representation; feedforward connections carry the prediction errors to its higher-layer. Given image input, PCN runs recursive cycles of bottom-up and top-down computation to update its internal representations and reduce the difference between bottom-up input and top-down prediction at every layer. PCN was found to always outperform its feedforward-only counterpart: a model without any mechanism for recurrent dynamics. Its performance tended to improve given more cycles of computation over time. 13 | 14 | ![Image of pcav1](https://github.com/libilab/PCN_v1/blob/master/figures/Figure1.png) 15 | (a) An example PCN with 9 layers and its CNN counterpart (or the plain model). 16 | 17 | (b) Two-layer substructure of PCN. Feedback (blue), feedforward (green), and recurrent (black) connections convey the top-down prediction, the bottom-up prediction error, and the past information, respectively. 18 | 19 | (c) The dynamic process in the PCN iteratively updates and refines the representation of visual input over time. PCN outputs the probability over candidate categories for object recognition. 20 | 21 | ## Usage 22 | Install Torch and required dependencies like cuDNN. See the instructions [here](https://github.com/pytorch/pytorch) for a step-by-step guide. 23 | 24 | Clone this repo: https://github.com/libilab/PCN_v1.git 25 | 26 | As an example, the following command trains a PCN with circles = 6 on CIFAR-100 using 4 GPU: 27 | 28 | ```bash 29 | python run.py --circles 6 --model 'PredNet' --gpunum 4 30 | ``` 31 | 32 | ## Results on CIFAR 33 | 34 | ![Image of pcav1](https://github.com/libilab/PCN_v1/blob/master/figures/fig_3.png) 35 | 36 | Testing accuracies of PCNs with different time steps. 37 | 38 | ## Updates 39 | 10/17/2018: 40 | 41 | (1) readme file. 42 | 43 | 02/12/2020: 44 | 45 | (1) removed group normalization to match the implementation in the paper. 46 | -------------------------------------------------------------------------------- /figures/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libilab/PCN-with-Global-Recurrent-Processing/721ecbde1777b7a0c3e96ed5eefb0e7567cdc434/figures/Figure1.png -------------------------------------------------------------------------------- /figures/fig_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/libilab/PCN-with-Global-Recurrent-Processing/721ecbde1777b7a0c3e96ed5eefb0e7567cdc434/figures/fig_3.png -------------------------------------------------------------------------------- /prednet.py: -------------------------------------------------------------------------------- 1 | '''PredNet in PyTorch.''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | # Feedforeward module 10 | class FFconv2d(nn.Module): 11 | def __init__(self, inchan, outchan, downsample=False): 12 | super().__init__() 13 | self.conv2d = nn.Conv2d(inchan, outchan, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.downsample = downsample 15 | if self.downsample: 16 | self.Downsample = nn.MaxPool2d(kernel_size=2, stride=2) 17 | 18 | def forward(self, x): 19 | x = self.conv2d(x) 20 | if self.downsample: 21 | x = self.Downsample(x) 22 | return x 23 | 24 | 25 | # Feedback module 26 | class FBconv2d(nn.Module): 27 | def __init__(self, inchan, outchan, upsample=False): 28 | super().__init__() 29 | self.convtranspose2d = nn.ConvTranspose2d(inchan, outchan, kernel_size=3, stride=1, padding=1, bias=False) 30 | self.upsample = upsample 31 | if self.upsample: 32 | self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear') 33 | 34 | def forward(self, x): 35 | if self.upsample: 36 | x = self.Upsample(x) 37 | x = self.convtranspose2d(x) 38 | return x 39 | 40 | 41 | # FFconv2d and FBconv2d that share weights 42 | class Conv2d(nn.Module): 43 | def __init__(self, inchan, outchan, sample=False): 44 | super().__init__() 45 | self.kernel_size = 3 46 | self.weights = nn.init.xavier_normal(torch.Tensor(outchan,inchan,self.kernel_size,self.kernel_size)) 47 | self.weights = nn.Parameter(self.weights, requires_grad=True) 48 | self.sample = sample 49 | if self.sample: 50 | self.Downsample = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear') 52 | 53 | def forward(self, x, feedforward=True): 54 | if feedforward: 55 | x = F.conv2d(x, self.weights, stride=1, padding=1) 56 | if self.sample: 57 | x = self.Downsample(x) 58 | else: 59 | if self.sample: 60 | x = self.Upsample(x) 61 | x = F.conv_transpose2d(x, self.weights, stride=1, padding=1) 62 | return x 63 | 64 | # PredNet 65 | class PredNet(nn.Module): 66 | 67 | def __init__(self, num_classes=10, cls=3): 68 | super().__init__() 69 | ics = [3, 64, 64, 128, 128, 256, 256, 256] # input chanels 70 | ocs = [64, 64, 128, 128, 256, 256, 256, 256] # output chanels 71 | sps = [False, False, True, False, True, False, False, False] # downsample flag 72 | self.cls = cls # num of circles 73 | self.nlays = len(ics) #number of layers 74 | 75 | # Feedforward layers 76 | self.FFconv = nn.ModuleList([FFconv2d(ics[i],ocs[i],downsample=sps[i]) for i in range(self.nlays)]) 77 | # Feedback layers 78 | if cls > 0: 79 | self.FBconv = nn.ModuleList([FBconv2d(ocs[i],ics[i],upsample=sps[i]) for i in range(self.nlays)]) 80 | 81 | # Update rate 82 | self.a0 = nn.ParameterList([nn.Parameter(torch.zeros(1,ics[i],1,1)+0.5) for i in range(1,self.nlays)]) 83 | self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1,ocs[i],1,1)+1.0) for i in range(self.nlays)]) 84 | 85 | # Linear layer 86 | self.linear = nn.Linear(ocs[-1], num_classes) 87 | 88 | def forward(self, x): 89 | 90 | # Feedforward 91 | xr = [F.relu(self.FFconv[0](x))] 92 | for i in range(1,self.nlays): 93 | xr.append(F.relu(self.FFconv[i](xr[i-1]))) 94 | 95 | # Dynamic process 96 | for t in range(self.cls): 97 | 98 | # Feedback prediction 99 | xp = [] 100 | for i in range(self.nlays-1,0,-1): 101 | xp = [self.FBconv[i](xr[i])] + xp 102 | a0 = F.relu(self.a0[i-1]).expand_as(xr[i-1]) 103 | xr[i-1] = F.relu(xp[0]*a0 + xr[i-1]*(1-a0)) 104 | 105 | # Feedforward prediction error 106 | b0 = F.relu(self.b0[0]).expand_as(xr[0]) 107 | xr[0] = F.relu(self.FFconv[0](x-self.FBconv[0](xr[0]))*b0 + xr[0]) 108 | for i in range(1, self.nlays): 109 | b0 = F.relu(self.b0[i]).expand_as(xr[i]) 110 | xr[i] = F.relu(self.FFconv[i](xr[i-1]-xp[i-1])*b0 + xr[i]) 111 | 112 | # classifier 113 | out = F.avg_pool2d(xr[-1], xr[-1].size(-1)) 114 | out = out.view(out.size(0), -1) 115 | out = self.linear(out) 116 | 117 | return out 118 | 119 | 120 | # PredNet 121 | class PredNetTied(nn.Module): 122 | def __init__(self, num_classes=10, cls=3): 123 | super().__init__() 124 | ics = [3, 64, 64, 128, 128, 256, 256, 256] # input chanels 125 | ocs = [64, 64, 128, 128, 256, 256, 256, 256] # output chanels 126 | sps = [False, False, True, False, True, False, False, False] # downsample flag 127 | self.cls = cls # num of circles 128 | self.nlays = len(ics) # num of circles 129 | 130 | # Convolutional layers 131 | self.conv = nn.ModuleList([Conv2d(ics[i],ocs[i],sample=sps[i]) for i in range(self.nlays)]) 132 | 133 | # Update rate 134 | self.a0 = nn.ParameterList([nn.Parameter(torch.zeros(1,ics[i],1,1)+0.5) for i in range(1,self.nlays)]) 135 | self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1,ocs[i],1,1)+1.0) for i in range(self.nlays)]) 136 | 137 | # Linear layer 138 | self.linear = nn.Linear(ocs[-1], num_classes) 139 | 140 | def forward(self, x): 141 | 142 | # Feedforward 143 | xr = [F.relu(self.conv[0](x))] 144 | for i in range(1,self.nlays): 145 | xr.append(F.relu(self.conv[i](xr[i-1]))) 146 | 147 | # Dynamic process 148 | for t in range(self.cls): 149 | 150 | # Feedback prediction 151 | xp = [] 152 | for i in range(self.nlays-1,0,-1): 153 | xp = [self.conv[i](xr[i],feedforward=False)] + xp 154 | a = F.relu(self.a0[i-1]).expand_as(xr[i-1]) 155 | xr[i-1] = F.relu(xp[0]*a + xr[i-1]*(1-a)) 156 | 157 | # Feedforward prediction error 158 | b = F.relu(self.b0[0]).expand_as(xr[0]) 159 | xr[0] = F.relu(self.conv[0](x - self.conv[0](xr[0],feedforward=False))*b + xr[0]) 160 | for i in range(1, self.nlays): 161 | b = F.relu(self.b0[i]).expand_as(xr[i]) 162 | xr[i] = F.relu(self.conv[i](xr[i-1]-xp[i-1])*b + xr[i]) 163 | 164 | # classifier 165 | out = F.avg_pool2d(xr[-1], xr[-1].size(-1)) 166 | out = out.view(out.size(0), -1) 167 | out = self.linear(out) 168 | return out 169 | 170 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from train import train_prednet 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training') 8 | parser.add_argument('--cls', default=6, type=int, help='number of cycles') 9 | parser.add_argument('--model', default='PredNet', help= 'models to train') 10 | parser.add_argument('--gpunum', default=2, type=int, help='number of gpu used to train the model') 11 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 12 | args = parser.parse_args() 13 | 14 | train_prednet(model=args.model, cls=cls, gpunum=args.gpunum, lr=args.lr) 15 | 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR with PyTorch.''' 2 | from __future__ import print_function 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import argparse 11 | from prednet import * 12 | from utils import progress_bar 13 | from torch.autograd import Variable 14 | 15 | def train_prednet(model='PredNetTied', cls=6, gpunum=4, lr=0.01): 16 | use_cuda = torch.cuda.is_available() # choose to use gpu if possible 17 | best_acc = 0 # best test accuracy 18 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 19 | batchsize = 128 #batch size 20 | root = './' 21 | rep = 1 #intial repitetion is 1 22 | 23 | models = {'PredNet': PredNet, 'PredNetTied':PredNetTied} 24 | modelname = model+'_'+str(lr)+'LR_'+str(cls)+'CLS_'+str(rep)+'REP' 25 | 26 | # clearn folder 27 | checkpointpath = root+'checkpoint/' 28 | logpath = root+'log/' 29 | 30 | if not os.path.isdir(checkpointpath): 31 | os.mkdir(checkpointpath) 32 | if not os.path.isdir(logpath): 33 | os.mkdir(logpath) 34 | while(os.path.isfile(checkpointpath + modelname + '_last_ckpt.t7')): 35 | rep += 1 36 | modelname = model+'_'+str(lr)+'LR_'+str(cls)+'CLS_'+str(rep)+'REP' 37 | 38 | # Data 39 | print('==> Preparing data..') 40 | transform_train = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 45 | transform_test = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 48 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 49 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=2) 50 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 51 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 52 | 53 | # Define objective function 54 | criterion = nn.CrossEntropyLoss() 55 | 56 | # Model 57 | print('==> Building model..') 58 | net = models[model](num_classes=100,cls=cls) 59 | 60 | #set up optimizer 61 | if model=='PredNetTied': 62 | convparas = [p for p in net.conv.parameters()]+\ 63 | [p for p in net.linear.parameters()] 64 | else: 65 | convparas = [p for p in net.FFconv.parameters()]+\ 66 | [p for p in net.FBconv.parameters()]+\ 67 | [p for p in net.linear.parameters()] 68 | 69 | rateparas = [p for p in net.a0.parameters()]+\ 70 | [p for p in net.b0.parameters()] 71 | optimizer = optim.SGD([ 72 | {'params': convparas}, 73 | {'params': rateparas, 'weight_decay': 0}, 74 | ], lr=lr, momentum=0.9, weight_decay=5e-4) 75 | 76 | 77 | # Parallel computing using mutiple gpu 78 | if use_cuda: 79 | net.cuda() 80 | net = torch.nn.DataParallel(net, device_ids=range(gpunum)) 81 | cudnn.benchmark = True 82 | 83 | # Training 84 | def train(epoch): 85 | print('\nEpoch: %d' % epoch) 86 | net.train() 87 | train_loss = 0 88 | correct = 0 89 | total = 0 90 | 91 | training_setting = 'batchsize=%d | epoch=%d | lr=%.1e ' % (batchsize, epoch, optimizer.param_groups[0]['lr']) 92 | statfile.write('\nTraining Setting: '+training_setting+'\n') 93 | 94 | for batch_idx, (inputs, targets) in enumerate(trainloader): 95 | if use_cuda: 96 | inputs, targets = inputs.cuda(), targets.cuda() 97 | optimizer.zero_grad() 98 | inputs, targets = Variable(inputs), Variable(targets) 99 | outputs = net(inputs) 100 | loss = criterion(outputs, targets) 101 | loss.backward() 102 | optimizer.step() 103 | 104 | train_loss += loss.data[0] 105 | _, predicted = torch.max(outputs.data, 1) 106 | total += targets.size(0) 107 | correct += predicted.eq(targets.data).cpu().sum() 108 | 109 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 110 | % (train_loss/(batch_idx+1), 100.*(float)(correct)/(float)(total), correct, total)) 111 | #writing training record 112 | statstr = 'Training: Epoch=%d | Loss: %.3f | Acc: %.3f%% (%d/%d) | best acc: %.3f' \ 113 | % (epoch, train_loss/(batch_idx+1), 100.*(float)(correct)/(float)(total), correct, total, best_acc) 114 | statfile.write(statstr+'\n') 115 | 116 | 117 | # Testing 118 | def test(epoch): 119 | nonlocal best_acc 120 | net.eval() 121 | test_loss = 0 122 | correct = 0 123 | total = 0 124 | for batch_idx, (inputs, targets) in enumerate(testloader): 125 | if use_cuda: 126 | inputs, targets = inputs.cuda(), targets.cuda() 127 | inputs, targets = Variable(inputs, volatile=True), Variable(targets) 128 | outputs = net(inputs) 129 | loss = criterion(outputs, targets) 130 | 131 | test_loss += loss.data[0] 132 | _, predicted = torch.max(outputs.data, 1) 133 | total += targets.size(0) 134 | correct += predicted.eq(targets.data).cpu().sum() 135 | 136 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 137 | % (test_loss/(batch_idx+1), 100.*(float)(correct)/(float)(total), correct, total)) 138 | statstr = 'Testing: Epoch=%d | Loss: %.3f | Acc: %.3f%% (%d/%d) | best_acc: %.3f' \ 139 | % (epoch, test_loss/(batch_idx+1), 100.*(float)(correct)/(float)(total), correct, total, best_acc) 140 | statfile.write(statstr+'\n') 141 | 142 | # Save checkpoint. 143 | acc = 100.*correct/total 144 | state = { 145 | 'state_dict': net.state_dict(), 146 | 'acc': acc, 147 | 'epoch': epoch, 148 | } 149 | 150 | torch.save(state, checkpointpath + modelname + '_last_ckpt.t7') 151 | 152 | #check if current accuarcy is the best 153 | if acc >= best_acc: 154 | print('Saving..') 155 | torch.save(state, checkpointpath + modelname + '_best_ckpt.t7') 156 | best_acc = acc 157 | 158 | # Set adaptive learning rates 159 | def decrease_learning_rate(): 160 | """Decay the previous learning rate by 10""" 161 | for param_group in optimizer.param_groups: 162 | param_group['lr'] /= 10 163 | 164 | #train network 165 | for epoch in range(start_epoch, start_epoch+250): 166 | statfile = open(logpath+'training_stats_'+modelname+'.txt', 'a+') #open file for writing 167 | if epoch==80 or epoch==140 or epoch==200: 168 | decrease_learning_rate() 169 | train(epoch) 170 | test(epoch) 171 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''These helper functions are forked from https://github.com/kuangliu/pytorch-cifar''' 2 | 3 | '''Some helper functions for PyTorch, including: 4 | - get_mean_and_std: calculate the mean and std value of dataset. 5 | - msr_init: net parameter initialization. 6 | - progress_bar: progress bar mimic xlua.progress. 7 | ''' 8 | import os 9 | import sys 10 | import time 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | 16 | 17 | def get_mean_and_std(dataset): 18 | '''Compute the mean and std value of dataset.''' 19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 20 | mean = torch.zeros(3) 21 | std = torch.zeros(3) 22 | print('==> Computing mean and std..') 23 | for inputs, targets in dataloader: 24 | for i in range(3): 25 | mean[i] += inputs[:,i,:,:].mean() 26 | std[i] += inputs[:,i,:,:].std() 27 | mean.div_(len(dataset)) 28 | std.div_(len(dataset)) 29 | return mean, std 30 | 31 | def init_params(net): 32 | '''Init layer parameters.''' 33 | for m in net.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | init.kaiming_normal(m.weight, mode='fan_out') 36 | if m.bias: 37 | init.constant(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant(m.weight, 1) 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | init.normal(m.weight, std=1e-3) 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | 46 | 47 | _, term_width = os.popen('stty size', 'r').read().split() 48 | term_width = int(term_width) 49 | 50 | TOTAL_BAR_LENGTH = 65. 51 | last_time = time.time() 52 | begin_time = last_time 53 | def progress_bar(current, total, msg=None): 54 | global last_time, begin_time 55 | if current == 0: 56 | begin_time = time.time() # Reset for new bar. 57 | 58 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 59 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 60 | 61 | sys.stdout.write(' [') 62 | for i in range(cur_len): 63 | sys.stdout.write('=') 64 | sys.stdout.write('>') 65 | for i in range(rest_len): 66 | sys.stdout.write('.') 67 | sys.stdout.write(']') 68 | 69 | cur_time = time.time() 70 | step_time = cur_time - last_time 71 | last_time = cur_time 72 | tot_time = cur_time - begin_time 73 | 74 | L = [] 75 | L.append(' Step: %s' % format_time(step_time)) 76 | L.append(' | Tot: %s' % format_time(tot_time)) 77 | if msg: 78 | L.append(' | ' + msg) 79 | 80 | msg = ''.join(L) 81 | sys.stdout.write(msg) 82 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 83 | sys.stdout.write(' ') 84 | 85 | # Go back to the center of the bar. 86 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 87 | sys.stdout.write('\b') 88 | sys.stdout.write(' %d/%d ' % (current+1, total)) 89 | 90 | if current < total-1: 91 | sys.stdout.write('\r') 92 | else: 93 | sys.stdout.write('\n') 94 | sys.stdout.flush() 95 | 96 | def format_time(seconds): 97 | days = int(seconds / 3600/24) 98 | seconds = seconds - days*3600*24 99 | hours = int(seconds / 3600) 100 | seconds = seconds - hours*3600 101 | minutes = int(seconds / 60) 102 | seconds = seconds - minutes*60 103 | secondsf = int(seconds) 104 | seconds = seconds - secondsf 105 | millis = int(seconds*1000) 106 | 107 | f = '' 108 | i = 1 109 | if days > 0: 110 | f += str(days) + 'D' 111 | i += 1 112 | if hours > 0 and i <= 2: 113 | f += str(hours) + 'h' 114 | i += 1 115 | if minutes > 0 and i <= 2: 116 | f += str(minutes) + 'm' 117 | i += 1 118 | if secondsf > 0 and i <= 2: 119 | f += str(secondsf) + 's' 120 | i += 1 121 | if millis > 0 and i <= 2: 122 | f += str(millis) + 'ms' 123 | i += 1 124 | if f == '': 125 | f = '0ms' 126 | return f 127 | --------------------------------------------------------------------------------