├── README.md ├── .gitignore ├── data_loader.py ├── main.py ├── resnet.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch Implementation of iCaRL 2 | A PyTorch Implementation of [iCaRL: Incremental Classifier and Representation Learning](https://arxiv.org/abs/1611.07725). 3 | 4 | The code implements experiments on CIFAR-10 and CIFAR-100 5 | 6 | ### Notes 7 | * This code does **not** reproduce the results from the paper. This may be due to following reasons: 8 | 1. Different hyperparameters than the original ones in the paper (e.g. learning rate schedules). 9 | 2. Different loss function; I replaced BCELoss to CrosEntropyLoss since it seemd to produce better results. 10 | 3. I tried to replicate the algorithm exactly as described in the paper, but I might have missed some of the details. 11 | * Versions 12 | - Python 2.7 13 | - PyTorch v0.1.12 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | class iCIFAR10(CIFAR10): 7 | def __init__(self, root, 8 | classes=range(10), 9 | train=True, 10 | transform=None, 11 | target_transform=None, 12 | download=False): 13 | super(iCIFAR10, self).__init__(root, 14 | train=train, 15 | transform=transform, 16 | target_transform=target_transform, 17 | download=download) 18 | 19 | # Select subset of classes 20 | if self.train: 21 | train_data = [] 22 | train_labels = [] 23 | 24 | for i in xrange(len(self.train_data)): 25 | if self.train_labels[i] in classes: 26 | train_data.append(self.train_data[i]) 27 | train_labels.append(self.train_labels[i]) 28 | 29 | self.train_data = np.array(train_data) 30 | self.train_labels = train_labels 31 | 32 | else: 33 | test_data = [] 34 | test_labels = [] 35 | 36 | for i in xrange(len(self.test_data)): 37 | if self.test_labels[i] in classes: 38 | test_data.append(self.test_data[i]) 39 | test_labels.append(self.test_labels[i]) 40 | 41 | self.test_data = np.array(test_data) 42 | self.test_labels = test_labels 43 | 44 | def __getitem__(self, index): 45 | if self.train: 46 | img, target = self.train_data[index], self.train_labels[index] 47 | else: 48 | img, target = self.test_data[index], self.test_labels[index] 49 | 50 | img = Image.fromarray(img) 51 | 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | return index, img, target 59 | 60 | def __len__(self): 61 | if self.train: 62 | return len(self.train_data) 63 | else: 64 | return len(self.test_data) 65 | 66 | def get_image_class(self, label): 67 | return self.train_data[np.array(self.train_labels) == label] 68 | 69 | def append(self, images, labels): 70 | """Append dataset with images and labels 71 | 72 | Args: 73 | images: Tensor of shape (N, C, H, W) 74 | labels: list of labels 75 | """ 76 | 77 | self.train_data = np.concatenate((self.train_data, images), axis=0) 78 | self.train_labels = self.train_labels + labels 79 | 80 | class iCIFAR100(iCIFAR10): 81 | base_folder = 'cifar-100-python' 82 | url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 83 | filename = "cifar-100-python.tar.gz" 84 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 85 | train_list = [ 86 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 87 | ] 88 | test_list = [ 89 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 90 | ] 91 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.datasets as dsets 4 | import torchvision.models as models 5 | import torchvision.transforms as transforms 6 | from torch.autograd import Variable 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | import matplotlib.pyplot as plt 11 | import matplotlib.gridspec as gridspec 12 | 13 | from data_loader import iCIFAR10, iCIFAR100 14 | from model import iCaRLNet 15 | 16 | def show_images(images): 17 | N = images.shape[0] 18 | fig = plt.figure(figsize=(1, N)) 19 | gs = gridspec.GridSpec(1, N) 20 | gs.update(wspace=0.05, hspace=0.05) 21 | 22 | for i, img in enumerate(images): 23 | ax = plt.subplot(gs[i]) 24 | plt.axis('off') 25 | ax.set_xticklabels([]) 26 | ax.set_yticklabels([]) 27 | ax.set_aspect('equal') 28 | plt.imshow(img) 29 | plt.show() 30 | 31 | 32 | # Hyper Parameters 33 | total_classes = 10 34 | num_classes = 10 35 | 36 | 37 | transform = transforms.Compose([ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 42 | ]) 43 | 44 | transform_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | 49 | # Initialize CNN 50 | K = 2000 # total number of exemplars 51 | icarl = iCaRLNet(2048, 1) 52 | icarl.cuda() 53 | 54 | 55 | for s in range(0, total_classes, num_classes): 56 | # Load Datasets 57 | print "Loading training examples for classes", range(s, s+num_classes) 58 | train_set = iCIFAR10(root='./data', 59 | train=True, 60 | classes=range(s,s+num_classes), 61 | download=True, 62 | transform=transform_test) 63 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, 64 | shuffle=True, num_workers=2) 65 | 66 | test_set = iCIFAR10(root='./data', 67 | train=False, 68 | classes=range(num_classes), 69 | download=True, 70 | transform=transform_test) 71 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, 72 | shuffle=True, num_workers=2) 73 | 74 | 75 | 76 | # Update representation via BackProp 77 | icarl.update_representation(train_set) 78 | m = K / icarl.n_classes 79 | 80 | # Reduce exemplar sets for known classes 81 | icarl.reduce_exemplar_sets(m) 82 | 83 | # Construct exemplar sets for new classes 84 | for y in xrange(icarl.n_known, icarl.n_classes): 85 | print "Constructing exemplar set for class-%d..." %(y), 86 | images = train_set.get_image_class(y) 87 | icarl.construct_exemplar_set(images, m, transform_test) 88 | print "Done" 89 | 90 | for y, P_y in enumerate(icarl.exemplar_sets): 91 | print "Exemplar set for class-%d:" % (y), P_y.shape 92 | #show_images(P_y[:10]) 93 | 94 | icarl.n_known = icarl.n_classes 95 | print "iCaRL classes: %d" % icarl.n_known 96 | 97 | total = 0.0 98 | correct = 0.0 99 | for indices, images, labels in train_loader: 100 | images = Variable(images).cuda() 101 | preds = icarl.classify(images, transform_test) 102 | total += labels.size(0) 103 | correct += (preds.data.cpu() == labels).sum() 104 | 105 | print('Train Accuracy: %d %%' % (100 * correct / total)) 106 | 107 | total = 0.0 108 | correct = 0.0 109 | for indices, images, labels in test_loader: 110 | images = Variable(images).cuda() 111 | preds = icarl.classify(images, transform_test) 112 | total += labels.size(0) 113 | correct += (preds.data.cpu() == labels).sum() 114 | 115 | print('Test Accuracy: %d %%' % (100 * correct / total)) 116 | 117 | 118 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet18/34/50/101/152 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1, shortcut=None): 17 | super(BasicBlock, self).__init__() 18 | self.layers = nn.Sequential( 19 | conv3x3(in_planes, planes, stride), 20 | nn.BatchNorm2d(planes), 21 | nn.ReLU(True), 22 | conv3x3(planes, planes), 23 | nn.BatchNorm2d(planes), 24 | ) 25 | self.shortcut = shortcut 26 | 27 | def forward(self, x): 28 | residual = x 29 | y = self.layers(x) 30 | if self.shortcut: 31 | residual = self.shortcut(x) 32 | y += residual 33 | y = F.relu(y) 34 | return y 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1, shortcut=None): 41 | super(Bottleneck, self).__init__() 42 | self.layers = nn.Sequential( 43 | nn.Conv2d(in_planes, planes, kernel_size=1, bias=False), 44 | nn.BatchNorm2d(planes), 45 | nn.ReLU(True), 46 | nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), 47 | nn.BatchNorm2d(planes), 48 | nn.ReLU(True), 49 | nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False), 50 | nn.BatchNorm2d(planes * 4), 51 | ) 52 | self.shortcut = shortcut 53 | 54 | def forward(self, x): 55 | residual = x 56 | y = self.layers(x) 57 | if self.shortcut: 58 | residual = self.shortcut(x) 59 | y += residual 60 | y = F.relu(y) 61 | return y 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, nblocks, num_classes=10): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | self.pre_layers = nn.Sequential( 69 | conv3x3(3,64), 70 | nn.BatchNorm2d(64), 71 | nn.ReLU(True), 72 | ) 73 | self.layer1 = self._make_layer(block, 64, nblocks[0]) 74 | self.layer2 = self._make_layer(block, 128, nblocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, nblocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, nblocks[3], stride=2) 77 | self.avgpool = nn.AvgPool2d(4) 78 | self.fc = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, nblocks, stride=1): 81 | shortcut = None 82 | if stride != 1 or self.in_planes != planes * block.expansion: 83 | shortcut = nn.Sequential( 84 | nn.Conv2d(self.in_planes, planes * block.expansion, 85 | kernel_size=1, stride=stride, bias=False), 86 | nn.BatchNorm2d(planes * block.expansion), 87 | ) 88 | layers = [] 89 | layers.append(block(self.in_planes, planes, stride, shortcut)) 90 | self.in_planes = planes * block.expansion 91 | for i in range(1, nblocks): 92 | layers.append(block(self.in_planes, planes)) 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.pre_layers(x) 97 | x = self.layer1(x) 98 | x = self.layer2(x) 99 | x = self.layer3(x) 100 | x = self.layer4(x) 101 | x = self.avgpool(x) 102 | x = x.view(x.size(0), -1) 103 | x = self.fc(x) 104 | return x 105 | 106 | 107 | def resnet18(): 108 | return ResNet(BasicBlock, [2,2,2,2]) 109 | 110 | def resnet34(): 111 | return ResNet(BasicBlock, [3,4,6,3]) 112 | 113 | def resnet50(): 114 | return ResNet(Bottleneck, [3,4,6,3]) 115 | 116 | def resnet101(): 117 | return ResNet(Bottleneck, [3,4,23,3]) 118 | 119 | def resnet152(): 120 | return ResNet(Bottleneck, [3,8,36,3]) 121 | 122 | # net = ResNet(BasicBlock, [2,2,2,2]) 123 | # x = torch.randn(1,3,32,32) 124 | # y = net(Variable(x)) 125 | # print(y.size()) 126 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from resnet import resnet18 10 | 11 | # Hyper Parameters 12 | num_epochs = 50 13 | batch_size = 100 14 | learning_rate = 0.002 15 | 16 | class iCaRLNet(nn.Module): 17 | def __init__(self, feature_size, n_classes): 18 | # Network architecture 19 | super(iCaRLNet, self).__init__() 20 | self.feature_extractor = resnet18() 21 | self.feature_extractor.fc =\ 22 | nn.Linear(self.feature_extractor.fc.in_features, feature_size) 23 | self.bn = nn.BatchNorm1d(feature_size, momentum=0.01) 24 | self.ReLU = nn.ReLU() 25 | self.fc = nn.Linear(feature_size, n_classes, bias=False) 26 | 27 | self.n_classes = n_classes 28 | self.n_known = 0 29 | 30 | # List containing exemplar_sets 31 | # Each exemplar_set is a np.array of N images 32 | # with shape (N, C, H, W) 33 | self.exemplar_sets = [] 34 | 35 | # Learning method 36 | self.cls_loss = nn.CrossEntropyLoss() 37 | self.dist_loss = nn.BCELoss() 38 | self.optimizer = optim.Adam(self.parameters(), lr=learning_rate, 39 | weight_decay=0.00001) 40 | #self.optimizer = optim.SGD(self.parameters(), lr=2.0, 41 | # weight_decay=0.00001) 42 | 43 | # Means of exemplars 44 | self.compute_means = True 45 | self.exemplar_means = [] 46 | 47 | def forward(self, x): 48 | x = self.feature_extractor(x) 49 | x = self.bn(x) 50 | x = self.ReLU(x) 51 | x = self.fc(x) 52 | return x 53 | 54 | def increment_classes(self, n): 55 | """Add n classes in the final fc layer""" 56 | in_features = self.fc.in_features 57 | out_features = self.fc.out_features 58 | weight = self.fc.weight.data 59 | 60 | self.fc = nn.Linear(in_features, out_features+n, bias=False) 61 | self.fc.weight.data[:out_features] = weight 62 | self.n_classes += n 63 | 64 | def classify(self, x, transform): 65 | """Classify images by neares-means-of-exemplars 66 | 67 | Args: 68 | x: input image batch 69 | Returns: 70 | preds: Tensor of size (batch_size,) 71 | """ 72 | batch_size = x.size(0) 73 | 74 | if self.compute_means: 75 | print "Computing mean of exemplars...", 76 | exemplar_means = [] 77 | for P_y in self.exemplar_sets: 78 | features = [] 79 | # Extract feature for each exemplar in P_y 80 | for ex in P_y: 81 | ex = Variable(transform(Image.fromarray(ex)), volatile=True).cuda() 82 | feature = self.feature_extractor(ex.unsqueeze(0)) 83 | feature = feature.squeeze() 84 | feature.data = feature.data / feature.data.norm() # Normalize 85 | features.append(feature) 86 | features = torch.stack(features) 87 | mu_y = features.mean(0).squeeze() 88 | mu_y.data = mu_y.data / mu_y.data.norm() # Normalize 89 | exemplar_means.append(mu_y) 90 | self.exemplar_means = exemplar_means 91 | self.compute_means = False 92 | print "Done" 93 | 94 | exemplar_means = self.exemplar_means 95 | means = torch.stack(exemplar_means) # (n_classes, feature_size) 96 | means = torch.stack([means] * batch_size) # (batch_size, n_classes, feature_size) 97 | means = means.transpose(1, 2) # (batch_size, feature_size, n_classes) 98 | 99 | feature = self.feature_extractor(x) # (batch_size, feature_size) 100 | for i in xrange(feature.size(0)): # Normalize 101 | feature.data[i] = feature.data[i] / feature.data[i].norm() 102 | feature = feature.unsqueeze(2) # (batch_size, feature_size, 1) 103 | feature = feature.expand_as(means) # (batch_size, feature_size, n_classes) 104 | 105 | dists = (feature - means).pow(2).sum(1).squeeze() #(batch_size, n_classes) 106 | _, preds = dists.min(1) 107 | 108 | return preds 109 | 110 | 111 | def construct_exemplar_set(self, images, m, transform): 112 | """Construct an exemplar set for image set 113 | 114 | Args: 115 | images: np.array containing images of a class 116 | """ 117 | # Compute and cache features for each example 118 | features = [] 119 | for img in images: 120 | x = Variable(transform(Image.fromarray(img)), volatile=True).cuda() 121 | feature = self.feature_extractor(x.unsqueeze(0)).data.cpu().numpy() 122 | feature = feature / np.linalg.norm(feature) # Normalize 123 | features.append(feature[0]) 124 | 125 | features = np.array(features) 126 | class_mean = np.mean(features, axis=0) 127 | class_mean = class_mean / np.linalg.norm(class_mean) # Normalize 128 | 129 | exemplar_set = [] 130 | exemplar_features = [] # list of Variables of shape (feature_size,) 131 | for k in xrange(m): 132 | S = np.sum(exemplar_features, axis=0) 133 | phi = features 134 | mu = class_mean 135 | mu_p = 1.0/(k+1) * (phi + S) 136 | mu_p = mu_p / np.linalg.norm(mu_p) 137 | i = np.argmin(np.sqrt(np.sum((mu - mu_p) ** 2, axis=1))) 138 | 139 | exemplar_set.append(images[i]) 140 | exemplar_features.append(features[i]) 141 | """ 142 | print "Selected example", i 143 | print "|exemplar_mean - class_mean|:", 144 | print np.linalg.norm((np.mean(exemplar_features, axis=0) - class_mean)) 145 | #features = np.delete(features, i, axis=0) 146 | """ 147 | 148 | self.exemplar_sets.append(np.array(exemplar_set)) 149 | 150 | 151 | def reduce_exemplar_sets(self, m): 152 | for y, P_y in enumerate(self.exemplar_sets): 153 | self.exemplar_sets[y] = P_y[:m] 154 | 155 | 156 | def combine_dataset_with_exemplars(self, dataset): 157 | for y, P_y in enumerate(self.exemplar_sets): 158 | exemplar_images = P_y 159 | exemplar_labels = [y] * len(P_y) 160 | dataset.append(exemplar_images, exemplar_labels) 161 | 162 | 163 | def update_representation(self, dataset): 164 | 165 | self.compute_means = True 166 | 167 | # Increment number of weights in final fc layer 168 | classes = list(set(dataset.train_labels)) 169 | new_classes = [cls for cls in classes if cls > self.n_classes - 1] 170 | self.increment_classes(len(new_classes)) 171 | self.cuda() 172 | print "%d new classes" % (len(new_classes)) 173 | 174 | # Form combined training set 175 | self.combine_dataset_with_exemplars(dataset) 176 | 177 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 178 | shuffle=True, num_workers=2) 179 | 180 | # Store network outputs with pre-update parameters 181 | q = torch.zeros(len(dataset), self.n_classes).cuda() 182 | for indices, images, labels in loader: 183 | images = Variable(images).cuda() 184 | indices = indices.cuda() 185 | g = F.sigmoid(self.forward(images)) 186 | q[indices] = g.data 187 | q = Variable(q).cuda() 188 | 189 | # Run network training 190 | optimizer = self.optimizer 191 | 192 | for epoch in xrange(num_epochs): 193 | for i, (indices, images, labels) in enumerate(loader): 194 | images = Variable(images).cuda() 195 | labels = Variable(labels).cuda() 196 | indices = indices.cuda() 197 | 198 | optimizer.zero_grad() 199 | g = self.forward(images) 200 | 201 | # Classification loss for new classes 202 | loss = self.cls_loss(g, labels) 203 | #loss = loss / len(range(self.n_known, self.n_classes)) 204 | 205 | # Distilation loss for old classes 206 | if self.n_known > 0: 207 | g = F.sigmoid(g) 208 | q_i = q[indices] 209 | dist_loss = sum(self.dist_loss(g[:,y], q_i[:,y])\ 210 | for y in xrange(self.n_known)) 211 | #dist_loss = dist_loss / self.n_known 212 | loss += dist_loss 213 | 214 | loss.backward() 215 | optimizer.step() 216 | 217 | if (i+1) % 10 == 0: 218 | print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' 219 | %(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.data[0])) 220 | --------------------------------------------------------------------------------