├── img ├── fabric.png └── fabric.svg ├── neural_fabrics.py ├── readme.md └── utils ├── __init__.py ├── logger.py └── visualize.py /img/fabric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vabh/convolutional-neural-fabrics/6a7e84c44200a5df419693a2a1b4c7c0a64e9464/img/fabric.png -------------------------------------------------------------------------------- /neural_fabrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as weight_init 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | from torch.autograd import Variable 14 | 15 | from utils import logger 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default='cifar10', help='cifar10 | cifar100') 19 | parser.add_argument('--dataroot', default='./data', help='path to dataset') 20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 21 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 22 | parser.add_argument('--niter', type=int, default=200, help='number of epochs to train for') 23 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate, default=0.0002') 24 | parser.add_argument('--cuda' , action='store_false', help='enables cuda') 25 | parser.add_argument('--save', help='folder to store log files, model checkpoints') 26 | 27 | opt = parser.parse_args() 28 | print(opt) 29 | 30 | #logger 31 | try: 32 | os.makedirs(opt.save) 33 | print('Logging at: ' + opt.save) 34 | except OSError: 35 | pass 36 | log = logger.Logger(opt.save+'/train.log', ['loss', 'train error', 'test error']) 37 | 38 | # set random seed 39 | opt.manualSeed = random.randint(1, 10000) # fix seed 40 | print("Random Seed: ", opt.manualSeed) 41 | random.seed(opt.manualSeed) 42 | torch.manual_seed(opt.manualSeed) 43 | if torch.cuda.is_available() and opt.cuda: 44 | torch.cuda.manual_seed(opt.manualSeed) 45 | 46 | # set cudnn 47 | cudnn.benchmark = True 48 | if torch.cuda.is_available() and not opt.cuda: 49 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 50 | 51 | # get data loaders 52 | # default set to cifar10 53 | train_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=True, 54 | transform=transforms.Compose([ 55 | transforms.RandomHorizontalFlip(), 56 | transforms.RandomCrop(32, padding=4), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)), 59 | ])) 60 | test_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=False, 61 | transform=transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)), 64 | ])) 65 | assert train_dataset 66 | assert test_dataset 67 | 68 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batchSize, 69 | shuffle=True, num_workers=int(opt.workers)) 70 | 71 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batchSize, 72 | shuffle=False, num_workers=int(opt.workers)) 73 | 74 | # count number of incorrect classifications 75 | def compute_score(output, target): 76 | pred = output.max(1)[1] 77 | incorrect = pred.ne(target).cpu().sum() 78 | batch_size = output.size(0) 79 | return incorrect 80 | 81 | 82 | #define model 83 | 84 | class UpSample(nn.Module): 85 | def __init__(self, inChannels, outChannels): 86 | super(UpSample, self).__init__() 87 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) 88 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1) 89 | self.batch_norm = nn.BatchNorm2d(outChannels) 90 | 91 | def forward(self, x): 92 | x = self.upsample(x) 93 | x = self.conv(x) 94 | x = self.batch_norm(x) 95 | x = nn.ReLU(True)(x) 96 | return x 97 | 98 | class DownSample(nn.Module): 99 | def __init__(self, inChannels, outChannels): 100 | super(DownSample, self).__init__() 101 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=2, padding=1) 102 | self.batch_norm = nn.BatchNorm2d(outChannels) 103 | 104 | def forward(self, x): 105 | x = self.conv(x) 106 | x = self.batch_norm(x) 107 | x = nn.ReLU(True)(x) 108 | return x 109 | 110 | class SameRes(nn.Module): 111 | def __init__(self, inChannels, outChannels): 112 | super(SameRes, self).__init__() 113 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1) 114 | self.batch_norm = nn.BatchNorm2d(outChannels) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | x = self.batch_norm(x) 119 | x = nn.ReLU(True)(x) 120 | return x 121 | 122 | class Net(nn.Module): 123 | def __init__(self): 124 | super(Net, self).__init__() 125 | 126 | self.channels = 128 127 | self.kernel_size = 3 128 | 129 | self.layers = 8 130 | self.scales = 5 131 | 132 | self.node_ops = nn.ModuleList() 133 | 134 | self.start_node = SameRes(3, self.channels) 135 | 136 | self.fc = nn.Linear(self.channels,10) 137 | 138 | for layer in range(self.layers): 139 | self.node_ops.append(nn.ModuleList()) # add list for each layer 140 | self.node_ops[layer] = nn.ModuleList() # list for each scale 141 | 142 | if layer == 0: 143 | for i in range(self.scales): 144 | self.node_ops[layer][i] = nn.ModuleList() 145 | 146 | node = DownSample(self.channels,self.channels) 147 | self.node_ops[layer][i].append(node) 148 | else: 149 | for i in range(self.scales): 150 | self.node_ops[layer][i] = nn.ModuleList() 151 | 152 | node = SameRes(self.channels,self.channels) 153 | self.node_ops[layer][i].append(node) 154 | if i == 0: 155 | self.node_ops[layer][i].append( 156 | UpSample(self.channels,self.channels)) 157 | elif i == self.scales -1: 158 | self.node_ops[layer][i].append( 159 | DownSample(self.channels,self.channels)) 160 | if layer == self.layers-1: 161 | self.node_ops[layer][i].append( 162 | DownSample(self.channels,self.channels)) 163 | else: 164 | self.node_ops[layer][i].append( 165 | DownSample(self.channels,self.channels)) 166 | self.node_ops[layer][i].append( 167 | UpSample(self.channels,self.channels)) 168 | if layer == self.layers-1: 169 | self.node_ops[layer][i].append( 170 | DownSample(self.channels,self.channels)) 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | weight_init.kaiming_normal(m.weight) 174 | weight_init.constant(m.bias, 0.1) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.normal_(1.0,0.02) 177 | m.bias.data.fill_(0) 178 | 179 | def forward(self, x): 180 | node_activ = [[[] for i in range(self.scales)] for j in range(self.layers)] 181 | out = self.start_node(x) 182 | for layer in range(self.layers): 183 | if layer == 0: 184 | for i in range(self.scales): 185 | if i == 0: 186 | node_activ[layer][i] = self.node_ops[layer][i][0](out) 187 | else: 188 | node_activ[layer][i] = self.node_ops[layer][i][0](node_activ[layer][i-1]) 189 | else: 190 | for i in range(self.scales): 191 | if i == 0: 192 | t1 = (node_activ[layer-1][i]) 193 | t2 = self.node_ops[layer][i][1](node_activ[layer-1][i+1]) 194 | t = self.node_ops[layer][i][0](t1 + t2) 195 | node_activ[layer][i] = t 196 | elif i == self.scales-1: 197 | t1 = (node_activ[layer-1][i]) 198 | t2 = self.node_ops[layer][i][1](node_activ[layer-1][i-1]) 199 | 200 | if layer == self.layers-1: 201 | t3 = self.node_ops[layer][i][2](node_activ[layer][i-1]) 202 | t = self.node_ops[layer][i][0](t1 + t2 + t3) 203 | else: 204 | t = self.node_ops[layer][i][0](t1 + t2) 205 | node_activ[layer][i] = t 206 | else: 207 | t1 = (node_activ[layer-1][i]) 208 | t2 = self.node_ops[layer][i][2](node_activ[layer-1][i+1]) 209 | t3 = self.node_ops[layer][i][1](node_activ[layer-1][i-1]) 210 | if layer == self.layers-1: 211 | t4 = self.node_ops[layer][i][3](node_activ[layer][i-1]) 212 | t = self.node_ops[layer][i][0](t1 + t2 + t3 + t4) 213 | else: 214 | t = self.node_ops[layer][i][0](t1 + t2 + t3) 215 | node_activ[layer][i] = t 216 | 217 | out = node_activ[-1][-1] 218 | out = out.view(out.size(0),-1) 219 | out = self.fc(out) 220 | return out 221 | 222 | net = Net() 223 | # net.apply(weights_init) 224 | print(net) 225 | 226 | # criterion 227 | criterion = nn.CrossEntropyLoss() 228 | 229 | if opt.cuda: 230 | net.cuda() 231 | criterion.cuda() 232 | 233 | # setup optimizer 234 | 235 | #train 236 | def train(epoch): 237 | net.train() 238 | score_epoch = 0 239 | loss_epoch = 0 240 | print('Epoch: ' + str(epoch)) 241 | if epoch > 120: 242 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) 243 | elif epoch > 80: 244 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005) 245 | else: 246 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005) 247 | 248 | for i, (images, labels) in enumerate(train_loader): 249 | images = Variable(images).cuda() 250 | labels = Variable(labels, requires_grad=False).cuda() 251 | 252 | optimizer.zero_grad() 253 | output = net(images) 254 | loss = criterion(output, labels) 255 | loss.backward() 256 | optimizer.step() 257 | 258 | loss_epoch = loss_epoch + loss.data[0] 259 | score_epoch = score_epoch + compute_score(output.data, labels.data) 260 | 261 | loss_epoch = loss_epoch / len(train_loader) 262 | print('[%d/%d][%d] train_loss: %.4f err: %d' 263 | % (epoch, opt.niter, len(train_loader), loss_epoch, score_epoch)) 264 | return loss_epoch, score_epoch 265 | 266 | 267 | #test network 268 | def test(): 269 | net.eval() 270 | score_epoch = 0 271 | loss_epoch = 0 272 | for i, (images, labels) in enumerate(test_loader): 273 | images = Variable(images).cuda() 274 | labels = Variable(labels, requires_grad=False).cuda() 275 | 276 | output = net(images) 277 | loss = criterion(output, labels) 278 | 279 | loss_epoch = loss_epoch + loss.data[0] 280 | score_epoch = score_epoch + compute_score(output.data, labels.data) 281 | 282 | loss_epoch = loss_epoch / len(test_loader) 283 | print('Test error: %d' % (score_epoch)) 284 | return loss_epoch, score_epoch 285 | 286 | 287 | #train for opt.niter epochs 288 | start_error = test() 289 | for epoch in range(1,opt.niter+1): 290 | train_loss, train_error = train(epoch) 291 | test_loss, test_error = test() 292 | log.add([train_loss, train_error/50000.0, test_error/10000.0]) 293 | log.plot() 294 | if epoch % 10 == 0: 295 | # do checkpointing 296 | torch.save(net.state_dict(), '%s/net_epoch_%d.pth' % (opt.save, epoch)) 297 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of [Convolutional Neural Fabrics arxiv:1606.02492](http://arxiv.org/abs/1606.02492) 2 | There are some minor differences: 3 | - The raw image is first convolved, to obtain #`channels` feature maps. 4 | - The upsampling is followed by a convolution, and the result is then summed with the other inputs. In the paper, they first sum and then convolve on the result. 5 | - These can be easily changed in the `UpSample`, `DownSample`, `SameRes` class definitions inside `neural_fabrics.py`. Feel free to implement your own procedure and experiment. 6 | 7 | To run on CIFAR-10: 8 |
 9 | python neural_fabric.py --dataset cifar10 --save fabric_cifar10
10 | 
11 | 12 | Test set error: 7.2%, with rotation and translation augmented training data. 13 | 14 | 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vabh/convolutional-neural-fabrics/6a7e84c44200a5df419693a2a1b4c7c0a64e9464/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | class Logger: 6 | def __init__(self, log_file, names=None, delimiter='\t'): 7 | assert log_file is not None 8 | if names is None: 9 | names = [] 10 | 11 | self.log_file = log_file 12 | self.names = names 13 | self.delim = delimiter 14 | self.fields = len(names) 15 | 16 | header = self._gather_values(self.names, prefix='#') 17 | with open(log_file, 'w') as f: 18 | f.write(header + '\n') 19 | 20 | def _gather_values(self, vals, prefix=''): 21 | output = '' 22 | for value in vals: 23 | output = output + self.delim + str(value) 24 | output = prefix + output 25 | return output 26 | 27 | def add(self, vals): 28 | assert len(vals) == self.fields 29 | output = self._gather_values(vals) 30 | with open(self.log_file, 'a') as f: 31 | f.write(output + '\n') 32 | 33 | def plot(self): 34 | data = np.loadtxt(self.log_file, skiprows=1) 35 | plt.clf() 36 | p = plt.plot(data) 37 | plt.legend(p, self.names) 38 | plt.grid() 39 | plt.savefig(self.log_file+'.png', format='png') 40 | 41 | if __name__ == '__main__': 42 | l = Logger('test.log', names=['a', 'b', 'c']) 43 | for i in range(4): 44 | l.add(['a', 'b', 'c']) 45 | 46 | 47 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py 2 | from graphviz import Digraph 3 | from torch.autograd import Variable 4 | 5 | def make_dot(var): 6 | node_attr = dict(style='filled', 7 | shape='box', 8 | align='left', 9 | fontsize='12', 10 | ranksep='0.1', 11 | height='0.2') 12 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 13 | seen = set() 14 | 15 | def add_nodes(var): 16 | if var not in seen: 17 | if isinstance(var, Variable): 18 | value = '('+(', ').join(['%d'% v for v in var.size()])+')' 19 | dot.node(str(id(var)), str(value), fillcolor='lightblue') 20 | else: 21 | dot.node(str(id(var)), str(type(var).__name__)) 22 | seen.add(var) 23 | if hasattr(var, 'previous_functions'): 24 | for u in var.previous_functions: 25 | dot.edge(str(id(u[0])), str(id(var))) 26 | add_nodes(u[0]) 27 | add_nodes(var.creator) 28 | return dot 29 | --------------------------------------------------------------------------------