├── __init__.py ├── README.md ├── utils.py ├── networks.py ├── ncm_layer.py ├── training_ncm.py └── training_incremental_ncm.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepNCM 2 | PyTorch implementation of "DeepNCM: Deep Nearest Class Mean Classifiers" by Samantha Guerriero, Barbara Caputo and Thomas Mensink, ICLR Workshop 2018 (https://openreview.net/pdf?id=rkPLZ4JPM) 3 | 4 | To run the code with standard DeepNCM on Cifar100, just run: 5 | ``` 6 | python training_ncm.py --dataset=100 7 | ``` 8 | change to ```dataset=10``` for test on Cifar10. For visualization purposes, visdom is employed. To disable it just add the option ```--no_vis``` when to the above command. 9 | 10 | The implementation of the layer is on *ncm_layer.py* . 11 | To apply the layer to any network, just replace the standard linear classifier with the NCM one and be aware of this 2 differences: 12 | * The forward step has 2 phases, a first phase which computes the features relative to the given images and a second phase which computes the class scores (-distances). 13 | * After the loss computation and backward, the means of the classifier must be updated. 14 | 15 | For the incremental case, other 2 differences are present: 16 | * The forward step must be anticipated by a preparation step which adds to the network the space for novel class means. 17 | * The targets must be converted to the order in which the classifier have seen the classes, to align predictions and labels. 18 | 19 | For both cases examples can be found in the relative training files and networks. 20 | 21 | 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | _, term_width = os.popen('stty size', 'r').read().split() 46 | term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | 126 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import ncm_layer 13 | import math 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | 42 | 43 | class ResNet_NCM(nn.Module): 44 | def __init__(self, block, num_blocks, num_classes=10): 45 | super(ResNet_NCM, self).__init__() 46 | self.in_planes = 64 47 | 48 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 51 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 52 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 53 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 54 | self.linear = ncm_layer.NCM_classifier(512*block.expansion, num_classes) 55 | 56 | def _make_layer(self, block, planes, num_blocks, stride): 57 | strides = [stride] + [1]*(num_blocks-1) 58 | layers = [] 59 | for stride in strides: 60 | layers.append(block(self.in_planes, planes, stride)) 61 | self.in_planes = planes * block.expansion 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = self.layer1(out) 67 | out = self.layer2(out) 68 | out = self.layer3(out) 69 | out = self.layer4(out) 70 | out = F.avg_pool2d(out, 4) 71 | out = out.view(out.size(0), -1) 72 | return out 73 | 74 | def update_means(self, x,y): 75 | self.linear.update_means(x,y) 76 | 77 | def predict(self, x): 78 | out = self.linear(x) 79 | return out 80 | 81 | 82 | 83 | class ResNet_iNCM(nn.Module): 84 | def __init__(self, block, num_blocks, num_classes=0): 85 | super(ResNet_iNCM, self).__init__() 86 | self.in_planes = 64 87 | 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 91 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 92 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 93 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 94 | self.linear = ncm_layer.incremental_NCM_classifier(512*block.expansion, num_classes) 95 | 96 | def _make_layer(self, block, planes, num_blocks, stride): 97 | strides = [stride] + [1]*(num_blocks-1) 98 | layers = [] 99 | for stride in strides: 100 | layers.append(block(self.in_planes, planes, stride)) 101 | self.in_planes = planes * block.expansion 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | out = F.relu(self.bn1(self.conv1(x))) 106 | out = self.layer1(out) 107 | out = self.layer2(out) 108 | out = self.layer3(out) 109 | out = self.layer4(out) 110 | out = F.avg_pool2d(out, 4) 111 | out = out.view(out.size(0), -1) 112 | return out 113 | 114 | def update_means(self, x,y): 115 | self.linear.update_means(x,y) 116 | 117 | def predict(self, x): 118 | out = self.linear(x) 119 | return out 120 | 121 | def prepare(self,y): 122 | self.linear.init_from_labels(y) 123 | 124 | 125 | def ResNet18(): 126 | return ResNet(BasicBlock, [2,2,2,2]) 127 | 128 | def ResNet18_NCM(): 129 | return ResNet_NCM(BasicBlock, [2,2,2,2]) 130 | 131 | def ResNet34_iNCM(classes=10): 132 | return ResNet_iNCM(BasicBlock, [3,4,6,3]) 133 | 134 | def ResNet34_NCM(classes=10): 135 | return ResNet_NCM(BasicBlock, [3,4,6,3],num_classes=classes) 136 | 137 | def ResNet34(classes=10): 138 | return ResNet(BasicBlock, [3,4,6,3],num_classes=classes) 139 | 140 | 141 | -------------------------------------------------------------------------------- /ncm_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | 7 | 8 | class NCM_classifier(nn.Module): 9 | 10 | # Initialize the classifier 11 | def __init__(self, features, classes, alpha=0.9): 12 | super(NCM_classifier, self).__init__() 13 | self.means=nn.Parameter(torch.zeros(classes,features),requires_grad=False) # Class Means 14 | self.running_means=nn.Parameter(torch.zeros(classes,features),requires_grad=False) 15 | self.alpha=alpha # Mean decay value 16 | self.features=features # Input features 17 | self.classes=classes 18 | 19 | 20 | # Forward pass (x=features) 21 | def forward(self,x): 22 | means_reshaped=self.means.view(1,self.classes,self.features).expand(x.shape[0],self.classes,self.features) 23 | features_reshaped=x.view(-1,1,self.features).expand(x.shape[0],self.classes,self.features) 24 | diff=(features_reshaped-means_reshaped)**2 25 | cumulative_diff=diff.sum(dim=-1) 26 | 27 | return -cumulative_diff 28 | 29 | 30 | # Update centers (x=features, y=labels) 31 | def update_means(self,x,y): 32 | for i in torch.unique(y): # For each label 33 | N,mean=self.compute_mean(x,y,i) # Compute mean 34 | 35 | # If labels already in the set, just update holder, otherwise add it to the model 36 | if N==0: 37 | self.running_means.data[i,:]=self.means.data[i,:] 38 | else: 39 | self.running_means.data[i,:]=mean 40 | 41 | # Update means 42 | self.update() 43 | 44 | 45 | # Perform the update following the mean decay procedure 46 | def update(self): 47 | self.means.data=self.alpha*self.means.data+(1-self.alpha)*self.running_means 48 | 49 | 50 | 51 | # Compute mean by filtering the data of the same label 52 | def compute_mean(self,x,y,i): 53 | mask=(i==y).view(-1,1).float() 54 | mask=mask.cuda() 55 | N=mask.sum() 56 | if N==0: 57 | return N,0 58 | else: 59 | return N,(x.data*mask).sum(dim=0)/N 60 | 61 | 62 | 63 | class incremental_NCM_classifier(nn.Module): 64 | 65 | # Initialize the classifier 66 | def __init__(self, features, classes=0, alpha=0.9): 67 | super(incremental_NCM_classifier, self).__init__() 68 | if classes==0: 69 | self.means=nn.Parameter(torch.Tensor(0),requires_grad=False) # Class Means 70 | self.running_means=nn.Parameter(torch.Tensor(0),requires_grad=False) 71 | self.counter=nn.Parameter(torch.Tensor(0),requires_grad=False) 72 | else: 73 | self.means=nn.Parameter(torch.zeros(classes,features),requires_grad=False) # Class Means 74 | self.running_means=nn.Parameter(torch.zeros(classes,features),requires_grad=False) 75 | self.counter=nn.Parameter(torch.zeros(classes),requires_grad=False) 76 | 77 | self.classes=classes 78 | 79 | 80 | self.alpha=alpha # Mean decay value 81 | self.features=features # Input features 82 | self.labels={} 83 | 84 | 85 | # Forward pass (x=features) 86 | def forward(self,x): 87 | means_reshaped=self.means.view(1,self.classes,self.features).expand(x.shape[0],self.classes,self.features) 88 | features_reshaped=x.view(-1,1,self.features).expand(x.shape[0],self.classes,self.features) 89 | diff=(features_reshaped-means_reshaped)**2 90 | cumulative_diff=diff.sum(dim=-1) 91 | 92 | return -cumulative_diff 93 | 94 | 95 | # Update centers (x=features, y=labels) 96 | def update_means(self,x,y): 97 | for i in torch.unique(y): 98 | index=int(i) # For each label 99 | # Compute mean 100 | N,mean=self.compute_mean(x,y,i) 101 | 102 | if index not in self.labels.keys(): 103 | self.add_class(index) 104 | 105 | converted=self.labels[index] 106 | 107 | # If labels already in the set, just update holder, otherwise add it to the model 108 | if N>0: 109 | self.means.data[converted,:]= 1/(self.counter[converted]+N)*(self.means.data[converted,:]*self.counter[converted]+mean*N) 110 | self.counter.data[converted]+=N 111 | 112 | 113 | # Update centers (x=features, y=labels) 114 | def update_means_decay(self,x,y): 115 | for i in torch.unique(y): 116 | index=int(i) # For each label 117 | # Compute mean 118 | N,mean=self.compute_mean(x,y,i) 119 | 120 | if index not in self.labels.keys(): 121 | self.add_class(index) 122 | 123 | converted=self.labels[index] 124 | 125 | # If labels already in the set, just update holder, otherwise add it to the model 126 | if N==0: 127 | self.running_means.data[converted,:]=self.means.data[converted,:] 128 | else: 129 | self.running_means.data[converted,:]=mean 130 | 131 | # Update means 132 | self.update_decay() 133 | 134 | 135 | # Perform the update following the mean decay procedure 136 | def update(self): 137 | self.means.data=self.alpha*self.means.data+(1-self.alpha)*self.running_means 138 | 139 | 140 | # Perform the update following the mean decay procedure 141 | def update_decay(self): 142 | self.means.data=self.alpha*self.means.data+(1-self.alpha)*self.running_means 143 | 144 | 145 | 146 | # Compute mean by filtering the data of the same label 147 | def compute_mean(self,x,y,i): 148 | mask=(i==y).view(-1,1).float() 149 | mask=mask.cuda() 150 | N=mask.sum() 151 | if N==0: 152 | return N,0 153 | else: 154 | return N,(x.data*mask).sum(dim=0)/N 155 | 156 | 157 | 158 | def convert_labels(self,y): 159 | out=[] 160 | for i in y: 161 | out.append(self.labels[int(i)]) 162 | return torch.LongTensor(out).to(y.device) 163 | 164 | def convert_single_label(self,y): 165 | return self.labels[y] 166 | 167 | 168 | 169 | # Add a class to the dataset, updating the labels indeces 170 | def add_class(self, index): 171 | print('Adding '+str(index)+' as '+str(self.classes)) 172 | 173 | self.labels[index]=self.classes 174 | self.classes=self.classes+1 175 | 176 | if self.classes==1: 177 | device=self.means.data.device 178 | self.means=nn.Parameter(torch.zeros(self.classes,self.features).to(device),requires_grad=False) # Class Means 179 | self.running_means=nn.Parameter(torch.zeros(self.classes,self.features).to(device),requires_grad=False) 180 | self.counter.data=nn.Parameter(torch.zeros(self.classes).to(device),requires_grad=False) 181 | else: 182 | device=self.means.data.device 183 | self.means.data=torch.cat([self.means.data,torch.zeros(1,self.features).to(device)],dim=0) 184 | self.running_means.data=torch.cat([self.running_means.data,torch.zeros(1,self.features).to(device)],dim=0) 185 | self.counter.data=torch.cat([self.counter.data,torch.zeros(1).to(device)],dim=0) 186 | 187 | 188 | 189 | 190 | def reset_counter(self): 191 | self.counter.data=self.counter.data*0 192 | 193 | # Add a classi to the dataset, updating the labels indeces 194 | def init_from_labels(self, y): 195 | for i in torch.unique(y): 196 | index=int(i) 197 | # For each label 198 | if index not in self.labels.keys(): 199 | self.add_class(index) 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | -------------------------------------------------------------------------------- /training_ncm.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import networks 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | 20 | from utils import progress_bar 21 | 22 | 23 | EPOCHS=250 24 | LR=0.1 25 | 26 | 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 29 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 30 | parser.add_argument('--no_vis', '-v', action='store_true', help='avoid visualization') 31 | parser.add_argument('--dataset', type=int, default=10, help='choose dataset') 32 | 33 | 34 | 35 | args = parser.parse_args() 36 | 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 39 | 40 | # Data 41 | print('==> Preparing data..') 42 | transform_train = transforms.Compose([ 43 | transforms.RandomCrop(32, padding=4), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | 49 | transform_test = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 52 | ]) 53 | 54 | if args.dataset==10: 55 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 56 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 57 | 58 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 59 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 60 | elif args.dataset==100: 61 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 62 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 63 | 64 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 65 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 66 | 67 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 68 | 69 | # Model 70 | print('==> Building model..') 71 | net = networks.ResNet34_NCM(classes=args.dataset) 72 | net = net.to(device) 73 | if device == 'cuda': 74 | cudnn.benchmark = True 75 | 76 | if args.resume: 77 | # Load checkpoint. 78 | print('==> Resuming from checkpoint..') 79 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 80 | checkpoint = torch.load('./checkpoint/ckpt.t7') 81 | net.load_state_dict(checkpoint['net']) 82 | start_epoch = checkpoint['epoch'] 83 | 84 | criterion = nn.CrossEntropyLoss() 85 | 86 | 87 | 88 | # Training, single epoch 89 | def train(epoch,optimizer): 90 | print('\nEpoch: %d' % epoch) 91 | net.train() 92 | train_loss = 0 93 | correct = 0 94 | total = 0 95 | for batch_idx, (inputs, targets) in enumerate(trainloader): 96 | inputs= inputs.to(device) 97 | targets_dev=targets.to(device) 98 | optimizer.zero_grad() 99 | 100 | # Produce features 101 | outputs = net.forward(inputs) 102 | 103 | # Predict using current class means 104 | prediction=net.predict(outputs) 105 | 106 | # Apply loss 107 | loss = criterion(prediction, targets_dev) 108 | 109 | # Backward + update 110 | loss.backward() 111 | optimizer.step() 112 | 113 | # Update class means 114 | net.update_means(outputs,targets) 115 | 116 | # Printing stuff 117 | train_loss += loss.item() 118 | _, predicted = prediction.max(1) 119 | total += targets.size(0) 120 | correct += predicted.eq(targets_dev).sum().item() 121 | if batch_idx%200==0: 122 | print() 123 | print('TRAINING') 124 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 125 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 126 | return (train_loss/(batch_idx+1)), 100.*correct/total 127 | 128 | 129 | # Test 130 | def test(epoch): 131 | net.eval() 132 | test_loss = 0 133 | correct = 0 134 | total = 0 135 | with torch.no_grad(): 136 | for batch_idx, (inputs, targets) in enumerate(testloader): 137 | inputs, targets = inputs.to(device), targets.to(device) 138 | outputs = net.forward(inputs) 139 | outputs=net.predict(outputs) 140 | loss = criterion(outputs, targets) 141 | 142 | test_loss += loss.item() 143 | _, predicted = outputs.max(1) 144 | total += targets.size(0) 145 | correct += predicted.eq(targets).sum().item() 146 | if batch_idx%100==0: 147 | print() 148 | print('TEST') 149 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 150 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 151 | 152 | # Save checkpoint. 153 | acc = 100.*correct/total 154 | return acc 155 | 156 | 157 | 158 | # Full training procedure 159 | def loop(epochs=200,dataset_name='cifar'+str(args.dataset)): 160 | visualize=not args.no_vis 161 | if visualize: 162 | import visdom 163 | vis = visdom.Visdom() 164 | vis.env ='deep ncm ' + dataset_name 165 | model_name='DEEP NCM' 166 | iters=[] 167 | losses_training=[] 168 | accuracy_training=[] 169 | accuracies_test=[] 170 | lr=LR 171 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=LR, momentum=0.9, weight_decay=5e-4) 172 | 173 | for epoch in range(start_epoch, start_epoch+epochs): 174 | 175 | if epoch%50==0 and epoch>50: 176 | lr=lr*0.1 177 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, weight_decay=5e-4) 178 | print('LR now is ' + str(lr)) 179 | 180 | # Perform 1 training epoch 181 | loss_epoch, acc_epoch = train(epoch,optimizer) 182 | 183 | # Validate the model 184 | result = test(epoch) 185 | 186 | if visualize: 187 | # Update lists (for visualization purposes) 188 | accuracies_test.append(result) 189 | accuracy_training.append(acc_epoch) 190 | losses_training.append(loss_epoch) 191 | iters.append(epoch) 192 | 193 | 194 | # Print results 195 | vis.line( 196 | X=np.array(iters), 197 | Y=np.array(losses_training), 198 | opts={ 199 | 'title': ' Training Loss ' , 200 | 'xlabel': 'epochs', 201 | 'ylabel': 'loss'}, 202 | name='Training Loss ', 203 | win=10) 204 | vis.line( 205 | X=np.array(iters), 206 | Y=np.array(accuracy_training), 207 | opts={ 208 | 'title': ' Training Accuracy ', 209 | 'xlabel': 'epochs', 210 | 'ylabel': 'accuracy'}, 211 | name='Training Accuracy ', 212 | win=11) 213 | vis.line( 214 | X=np.array(iters), 215 | Y=np.array(accuracies_test), 216 | opts={ 217 | 'title': ' Accuracy ', 218 | 'xlabel': 'epochs', 219 | 'ylabel': 'accuracy'}, 220 | name='Validation Accuracy ', 221 | win=12) 222 | 223 | 224 | loop(epochs=EPOCHS) 225 | -------------------------------------------------------------------------------- /training_incremental_ncm.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import networks 15 | import numpy as np 16 | import argparse 17 | 18 | 19 | from utils import progress_bar 20 | 21 | 22 | EPOCHS=250 23 | LR=0.1 24 | 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 27 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 28 | parser.add_argument('--no_vis', '-v', action='store_true', help='avoid visualization') 29 | parser.add_argument('--dataset', type=int, default=10, help='choose dataset') 30 | 31 | args = parser.parse_args() 32 | 33 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 34 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 35 | 36 | # Data 37 | print('==> Preparing data..') 38 | transform_train = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | transform_test = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | if args.dataset==10: 51 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 52 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 53 | 54 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 55 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 56 | elif args.dataset==100: 57 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 58 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8) 59 | 60 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 61 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 62 | 63 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 64 | 65 | # Model 66 | print('==> Building model..') 67 | net = networks.ResNet34_iNCM() 68 | net = net.to(device) 69 | if device == 'cuda': 70 | cudnn.benchmark = True 71 | 72 | if args.resume: 73 | # Load checkpoint. 74 | print('==> Resuming from checkpoint..') 75 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 76 | checkpoint = torch.load('./checkpoint/ckpt.t7') 77 | net.load_state_dict(checkpoint['net']) 78 | start_epoch = checkpoint['epoch'] 79 | 80 | criterion = nn.CrossEntropyLoss() 81 | 82 | 83 | 84 | # Training, single epoch 85 | def train(epoch,optimizer): 86 | print('\nEpoch: %d' % epoch) 87 | net.train() 88 | train_loss = 0 89 | correct = 0 90 | total = 0 91 | for batch_idx, (inputs, targets) in enumerate(trainloader): 92 | inputs= inputs.to(device) 93 | optimizer.zero_grad() 94 | 95 | # Initialize classifier (if novel classes are present) 96 | net.prepare(targets) 97 | 98 | # Forward predictionbs 99 | outputs = net.forward(inputs) 100 | prediction=net.predict(outputs) 101 | 102 | # Convert labels to match the order seen by the classifier 103 | targets_converted=net.linear.convert_labels(targets).to(outputs.device) 104 | 105 | # Compute loss 106 | loss = criterion(prediction, targets_converted) 107 | 108 | # Backward + update 109 | loss.backward() 110 | optimizer.step() 111 | 112 | # Update means 113 | net.update_means(outputs,targets) 114 | 115 | # Printing stuff 116 | train_loss += loss.item() 117 | _, predicted = prediction.max(1) 118 | total += targets.size(0) 119 | correct += predicted.eq(targets_converted).sum().item() 120 | if batch_idx%200==0: 121 | print() 122 | print('TRAINING') 123 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 124 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 125 | return (train_loss/(batch_idx+1)), 100.*correct/total 126 | 127 | def test(epoch): 128 | net.eval() 129 | test_loss = 0 130 | correct = 0 131 | total = 0 132 | with torch.no_grad(): 133 | for batch_idx, (inputs, targets) in enumerate(testloader): 134 | inputs= inputs.to(device) 135 | outputs = net.forward(inputs) 136 | outputs=net.predict(outputs) 137 | targets_converted=net.linear.convert_labels(targets).to(outputs.device) 138 | loss = criterion(outputs, targets_converted) 139 | 140 | test_loss += loss.item() 141 | _, predicted = outputs.max(1) 142 | total += targets.size(0) 143 | correct += predicted.eq(targets_converted).sum().item() 144 | if batch_idx%100==0: 145 | print() 146 | print('TEST') 147 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 148 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 149 | 150 | # Save checkpoint. 151 | acc = 100.*correct/total 152 | return acc 153 | 154 | 155 | def loop(epochs=200,dataset_name='cifar'+str(args.dataset)): 156 | visualize=not args.no_vis 157 | if visualize: 158 | import visdom 159 | vis = visdom.Visdom() 160 | vis.env ='incremental deep ncm ' + dataset_name 161 | model_name='DEEP NCM' 162 | iters=[] 163 | losses_training=[] 164 | accuracy_training=[] 165 | accuracies_test=[] 166 | lr=LR 167 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=LR, momentum=0.9, weight_decay=5e-4) 168 | 169 | for epoch in range(start_epoch, start_epoch+epochs): 170 | 171 | if epoch%50==0 and epoch>50: 172 | lr=lr*0.1 173 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=0.9, weight_decay=5e-4) 174 | print('LR now is ' + str(lr)) 175 | 176 | # Perform 1 training epoch 177 | loss_epoch, acc_epoch = train(epoch,optimizer) 178 | 179 | # Validate the model 180 | result = test(epoch) 181 | 182 | if visualize: 183 | # Update lists (for visualization purposes) 184 | accuracies_test.append(result) 185 | accuracy_training.append(acc_epoch) 186 | losses_training.append(loss_epoch) 187 | iters.append(epoch) 188 | 189 | 190 | # Print results 191 | vis.line( 192 | X=np.array(iters), 193 | Y=np.array(losses_training), 194 | opts={ 195 | 'title': ' Training Loss ' , 196 | 'xlabel': 'epochs', 197 | 'ylabel': 'loss'}, 198 | name='Training Loss ', 199 | win=10) 200 | vis.line( 201 | X=np.array(iters), 202 | Y=np.array(accuracy_training), 203 | opts={ 204 | 'title': ' Training Accuracy ', 205 | 'xlabel': 'epochs', 206 | 'ylabel': 'accuracy'}, 207 | name='Training Accuracy ', 208 | win=11) 209 | vis.line( 210 | X=np.array(iters), 211 | Y=np.array(accuracies_test), 212 | opts={ 213 | 'title': ' Accuracy ', 214 | 'xlabel': 'epochs', 215 | 'ylabel': 'accuracy'}, 216 | name='Validation Accuracy ', 217 | win=12) 218 | 219 | 220 | 221 | loop(epochs=EPOCHS) 222 | --------------------------------------------------------------------------------