├── README.md ├── utils.py ├── main.py └── net.py /README.md: -------------------------------------------------------------------------------- 1 | This repository provide the pytorch re-implementation of CliqueNet, the site of original paper is [here](https://arxiv.org/abs/1802.10419). In this implementation, the test is done on CIFAR-10 dataset. 2 | 3 | There are several different from the implementation of the statement in paper. 4 | - we use a post activation of conv-bn-relu instead of pre-activation bn-relu-conv 5 | - we adopt the strategy of attention transition and compression, but we didn't adopt the bottleneck inside clique blocks 6 | - we offer a simple data augmentation option for random flip 7 | 8 | ## Requirement 9 | Our code is based on the latest version of pytorch, please visit the [official site](https://pytorch.org) to install the latest version. 10 | 11 | ## Usage 12 | 13 | To train a cliquenet on CIFAR-10, please refer the following command: 14 | ```Shell 15 | python main.py [-h] [-batch_size BATCH_SIZE] [-num_epochs NUM_EPOCHS] [-lr LR] 16 | [-clip CLIP] [-disable_cuda] [-augmentation] 17 | [-print_freq PRINT_FREQ] [-pretrained PRETRAINED] [-gpu GPU] 18 | 19 | optional arguments: 20 | -h, --help show this help message and exit 21 | -batch_size BATCH_SIZE 22 | -num_epochs NUM_EPOCHS 23 | -lr LR Initial learning rate 24 | -disable_cuda Disable CUDA 25 | -augmentation Apply data augmentation 26 | -print_freq PRINT_FREQ 27 | Log print frequency 28 | -pretrained PRETRAINED 29 | -gpu GPU Which gpu to use 30 | 31 | 32 | ``` 33 | 34 | you can also modify the hyperparameters in `main.py` to change the net configuration 35 | 36 | ## Results 37 | 38 | We conduct a simple version of experiment, the dropout ratio of our network is 0.1, we train the network for 200 epochs without data augmentation. The current result on CIFAR-10 test set is only at most 92.23, there is still a large margin between ours and the paper results. We will try to fix it later. -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 14 20:25:51 2018 4 | 5 | @author: Yuxi Li 6 | """ 7 | 8 | import argparse 9 | import torch 10 | from torchvision import datasets, transforms 11 | 12 | def get_dataloader(args): 13 | 14 | transform = transforms.Compose([transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 16 | 17 | if args.augmentation: 18 | train_transform = transforms.Compose( 19 | [transforms.RandomHorizontalFlip(), 20 | transforms.RandomVerticalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 23 | ]) 24 | else: 25 | train_transform = transform 26 | 27 | 28 | trainset = datasets.CIFAR10(root='./data', train=True, 29 | download=True, transform=train_transform) 30 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 31 | shuffle=True) 32 | 33 | testset = datasets.CIFAR10(root='./data', train=False, 34 | download=True, transform=transform) 35 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 36 | shuffle=False) 37 | 38 | 39 | return train_loader, test_loader 40 | 41 | 42 | def get_args(): 43 | parser = argparse.ArgumentParser(description='CliqueNet') 44 | 45 | parser.add_argument('-batch_size', type=int, default=128) 46 | parser.add_argument('-num_epochs', type=int, default=1) 47 | parser.add_argument('-lr', type=float, default=2e-2, help="Initial learning rate") 48 | parser.add_argument('-disable_cuda', action='store_true', 49 | help='Disable CUDA') 50 | parser.add_argument('-augmentation', action='store_true', 51 | help='Apply data augmentation') 52 | parser.add_argument('-print_freq', type=int, default=10, help="Log print frequency") 53 | parser.add_argument('-pretrained', type=str, default="Start from a pretrained model") 54 | parser.add_argument('-gpu', type=int, default=0, help = "Which gpu to use") 55 | args = parser.parse_args() 56 | args.use_cuda = not args.disable_cuda and torch.cuda.is_available() 57 | 58 | return args 59 | 60 | 61 | if __name__ == "__main__": 62 | args = get_args() 63 | loader,_ = get_dataloader(args) 64 | print(len(loader.dataset)) 65 | for data in loader: 66 | x,y = data 67 | print(x[0,0,:,:]) 68 | break 69 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.nn import CrossEntropyLoss 8 | import os 9 | 10 | from net import CliqueNet 11 | from utils import get_args, get_dataloader 12 | 13 | if __name__ == "__main__": 14 | 15 | args = get_args() 16 | train_loader, test_loader = get_dataloader(args) 17 | use_cuda = args.use_cuda 18 | num_classes = 10 19 | dropout_prob = 0.1 20 | #hyper-parameters 21 | 22 | # A,B,C,D,E,r = 32,32,32,32,10,args.r # a classic CapsNet 23 | model = CliqueNet(3, num_classes, 4, 36, attention=True, compression=True, dropout_prob=dropout_prob) 24 | criterion = CrossEntropyLoss() 25 | #closs = CrossEntropyLoss() 26 | 27 | with torch.cuda.device(args.gpu): 28 | # print(args.gpu, type(args.gpu)) 29 | if args.pretrained: 30 | model.load_state_dict(torch.load(args.pretrained)) 31 | if use_cuda: 32 | print("activating cuda") 33 | model = model.cuda() 34 | 35 | total_epochs = args.num_epochs 36 | milestones = [int(total_epochs*0.5), int(total_epochs*0.75)] 37 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001, nesterov=True) 38 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) 39 | best_acc = 0.0 40 | for epoch in range(args.num_epochs): 41 | #Train 42 | print("Epoch {}".format(epoch)) 43 | b = 0 44 | correct = 0 45 | model.train() 46 | for data in train_loader: 47 | b += 1 48 | 49 | optimizer.zero_grad() 50 | imgs,labels = data #n,1,28,28; 51 | if use_cuda: 52 | imgs = imgs.cuda() 53 | labels = labels.cuda() 54 | 55 | score = model(imgs) 56 | loss = criterion(score, labels) 57 | loss.backward() 58 | optimizer.step() 59 | #stats 60 | pred = score.max(1)[1] #b 61 | acc = pred.eq(labels).cpu().sum().data.item() 62 | correct += acc 63 | if b % args.print_freq == 0: 64 | print("batch:{}, lr:{:.4f}".format(b, optimizer.param_groups[0]['lr'])) 65 | print("total loss: {:.4f}, acc: {:}/{}".format(loss.data.item(), acc, args.batch_size)) 66 | acc = float(correct)/len(train_loader.dataset) 67 | print("Epoch{} Train acc:{:4}".format(epoch, acc)) 68 | scheduler.step() 69 | #Test 70 | print('Testing...') 71 | model.eval() 72 | correct = 0 73 | for data in test_loader: 74 | imgs,labels = data #b,1,28,28 75 | if use_cuda: 76 | imgs = imgs.cuda() 77 | labels = labels.cuda() 78 | 79 | score = model(imgs) 80 | #stats 81 | pred = score.max(1)[1] 82 | acc = pred.eq(labels).cpu().sum().data.item() 83 | correct += acc 84 | acc = float(correct)/len(test_loader.dataset) 85 | print("Epoch{} Test acc:{:4}".format(epoch, acc)) 86 | 87 | if acc >= best_acc: 88 | best_acc = acc 89 | if not os.path.exists('./model'): 90 | os.makedirs('./model') 91 | print("Writing checkpoint to: model/model_{}.pth".format(epoch)) 92 | torch.save(model.state_dict(), "model/model_{}.pth".format(epoch)) 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 14 20:25:51 2018 4 | 5 | @author: Yuxi Li 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | def global_pooling(x): 13 | # input x [n, c, h, w] 14 | # output l [n, c] 15 | s = torch.mean(x, dim=-1) 16 | s = torch.mean(s, dim=-1) 17 | 18 | return s 19 | 20 | class CliqueNet(nn.Module): 21 | """object of CliqueNet""" 22 | def __init__(self, nin, num_classes, layers, filters, attention=False, compression=False, dropout_prob=0.0): 23 | super(CliqueNet, self).__init__() 24 | self.conv = nn.Conv2d(nin, 64, kernel_size=3, padding=1, stride=1) 25 | self.bn = nn.BatchNorm2d(64) 26 | 27 | self.clique = nn.ModuleList([CliqueBlock(64, layers, filters, kernel=3, dropout_prob=dropout_prob)]) 28 | for i in xrange(2): 29 | self.clique.append(CliqueBlock(layers*filters, layers, filters, kernel=3, dropout_prob=dropout_prob)) 30 | self.transition = nn.ModuleList([Transition(layers*filters, layers*filters, attention, dropout_prob) for i in xrange(3)]) 31 | 32 | feature_size = 0 33 | 34 | if compression: 35 | self.compression = nn.ModuleList() 36 | nout = 64+layers*filters 37 | self.compression.append(self.conv_bn_relu(nout, nout/2, dropout_prob)) 38 | feature_size += nout/2 39 | nout = layers*filters*2 40 | self.compression.append(self.conv_bn_relu(nout, nout/2, dropout_prob)) 41 | feature_size += nout/2 42 | self.compression.append(self.conv_bn_relu(nout, nout/2, dropout_prob)) 43 | feature_size += nout/2 44 | else: 45 | self.compression = None 46 | feature_size += (64+layers*filters*4) 47 | 48 | self.predict = nn.Linear(feature_size, num_classes) 49 | 50 | def conv_bn_relu(self, nin, nout, dropout_prob): 51 | conv = nn.Sequential( 52 | nn.Conv2d(nin, nout, kernel_size=1, padding=0, stride=1), 53 | nn.BatchNorm2d(nout), 54 | nn.ReLU(), 55 | nn.Dropout2d(dropout_prob)) 56 | 57 | return conv 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | x = F.relu(self.bn(x)) 62 | count = 0 63 | features = [] 64 | for c, t in zip(self.clique, self.transition): 65 | feature, s2 = c(x) 66 | x = t(s2) 67 | 68 | if self.compression is not None: 69 | feature = self.compression[count](feature) 70 | count += 1 71 | 72 | output = global_pooling(feature) 73 | features.append(output) 74 | 75 | output = torch.cat(features, dim=1) 76 | return self.predict(output) 77 | 78 | class CliqueBlock(nn.Module): 79 | """ clique block for alternative cliques """ 80 | def __init__(self, nin, layers, filters, kernel, dropout_prob=0.0): 81 | super(CliqueBlock, self).__init__() 82 | self.layers = layers 83 | self.channel = filters 84 | self.kernel = kernel 85 | self.nin = nin 86 | self.filters = filters 87 | 88 | num_kernels = layers*(layers-1) # A^2_layers 89 | num_norms = 2*layers 90 | 91 | # the organization of inside parameters 92 | # {W01, W02, .... , W0l} 93 | # {W12, W13, ... ,W1l, W21, W23,...., W2l, .... , Wl1, Wl2, ... W(l-1)l} 94 | 95 | self.W0 = nn.Parameter(torch.rand(self.layers, filters, nin, kernel, kernel)) 96 | self.W = nn.Parameter(torch.rand(num_kernels, filters, filters, kernel, kernel)) 97 | self.b = nn.Parameter(torch.rand(2*self.layers, filters)) 98 | self.activates = nn.ModuleList([nn.Sequential(nn.BatchNorm2d(filters), nn.ReLU(), nn.Dropout2d(dropout_prob)) 99 | for i in xrange(num_norms)]) 100 | 101 | self.reset_parameters(0.01) 102 | 103 | def reset_parameters(self, std): 104 | for weight in self.parameters(): 105 | weight.data.normal_(mean=0, std=std) 106 | 107 | def stage1(self, x0): 108 | 109 | # input {X0} 110 | # return {X2, X3, X4, .... Xl} 111 | 112 | output = None 113 | for i in xrange(self.layers): 114 | if i == 0: 115 | data = x0 116 | weight = self.W0[i] 117 | else: 118 | data = torch.cat([data, output], dim=1) 119 | 120 | weight = torch.cat([self.W0[i]]+[self.W[self.coordinate2idx(j, i)] for j in xrange(i)], dim=1) 121 | 122 | bias = self.b[i] 123 | 124 | conv = F.conv2d(data, weight, bias, stride=1, padding=self.kernel/2) 125 | output = self.activates[i](conv) 126 | 127 | return torch.cat([data[:, (self.nin+self.filters):, :, :], output], dim=1) 128 | 129 | def stage2(self, x): 130 | 131 | # input {X2, X3, ... , Xl} 132 | # output {X1', X2',..., Xl'} 133 | output = None 134 | from_layers = range(1, self.layers) # from layer index 135 | 136 | for i in xrange(self.layers): 137 | if i == 0: 138 | data = x 139 | else: 140 | data = torch.cat([data[:, self.filters:, :, :], output], dim=1) 141 | 142 | weight = torch.cat([self.W[self.coordinate2idx(j, i)] for j in from_layers], dim=1) 143 | bias = self.b[self.layers+i] 144 | from_layers = from_layers[1:] + [self.recurrent_index(from_layers[-1]+1)] 145 | 146 | conv = F.conv2d(data, weight, bias, stride=1, padding=self.kernel/2) 147 | output = self.activates[self.layers+i](conv) 148 | 149 | s2 = torch.cat([data, output], dim=1) 150 | 151 | return s2 152 | 153 | def coordinate2idx(self, from_idx, to_idx): 154 | 155 | # input: idx (from, to) excluding the x0 pairs 156 | # output: the linear index in self.W matrix 157 | 158 | assert from_idx != to_idx 159 | return from_idx*(self.layers-1)+to_idx-1 160 | 161 | def recurrent_index(self, a): 162 | return a % self.layers 163 | 164 | def forward(self, x0): 165 | 166 | s1 = self.stage1(x0) 167 | s2 = self.stage2(s1) 168 | 169 | feature = torch.cat([x0, s2], dim=1) 170 | 171 | return feature, s2 172 | 173 | class Transition(nn.Module): 174 | """docstring for Transition""" 175 | def __init__(self, nin, nout, attention=False, dropout_prob=0.0): 176 | super(Transition, self).__init__() 177 | self.trans = nn.Sequential( 178 | nn.Conv2d(nin, nout, kernel_size=1, padding=0, stride=1), 179 | nn.BatchNorm2d(nout), 180 | nn.ReLU(), 181 | nn.Dropout2d(dropout_prob)) 182 | 183 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) 184 | 185 | if attention: 186 | self.attention = nn.Sequential( 187 | nn.Linear(nout, nout/2), 188 | nn.ReLU(), 189 | nn.Linear(nout/2, nout), 190 | nn.Sigmoid()) 191 | else: 192 | self.attention = None 193 | 194 | def forward(self, x): 195 | 196 | s = self.trans(x) 197 | 198 | if self.attention is not None: 199 | # global pooling 200 | w = global_pooling(s) # [n, c] 201 | w = self.attention(w) # [n, nout] 202 | s = w[:, :, None, None]*s 203 | 204 | s = self.pool(s) 205 | return s 206 | 207 | 208 | if __name__ == '__main__': 209 | pass --------------------------------------------------------------------------------