├── utils ├── __init__.py ├── visualize.py └── logger.py ├── img └── fabric.png ├── readme.md └── neural_fabrics.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/fabric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vabh/convolutional-neural-fabrics/HEAD/img/fabric.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of [Convolutional Neural Fabrics arxiv:1606.02492](http://arxiv.org/abs/1606.02492) 2 | There are some minor differences: 3 | - The raw image is first convolved, to obtain #`channels` feature maps. 4 | - The upsampling is followed by a convolution, and the result is then summed with the other inputs. In the paper, they first sum and then convolve on the result. 5 | - These can be easily changed in the `UpSample`, `DownSample`, `SameRes` class definitions inside `neural_fabrics.py`. Feel free to implement your own procedure and experiment. 6 | 7 | To run on CIFAR-10: 8 |
9 | python neural_fabric.py --dataset cifar10 --save fabric_cifar10 10 |11 | 12 | Test set error: 7.2%, with rotation and translation augmented training data. 13 | 14 |
15 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | # https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py
2 | from graphviz import Digraph
3 | from torch.autograd import Variable
4 |
5 | def make_dot(var):
6 | node_attr = dict(style='filled',
7 | shape='box',
8 | align='left',
9 | fontsize='12',
10 | ranksep='0.1',
11 | height='0.2')
12 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
13 | seen = set()
14 |
15 | def add_nodes(var):
16 | if var not in seen:
17 | if isinstance(var, Variable):
18 | value = '('+(', ').join(['%d'% v for v in var.size()])+')'
19 | dot.node(str(id(var)), str(value), fillcolor='lightblue')
20 | else:
21 | dot.node(str(id(var)), str(type(var).__name__))
22 | seen.add(var)
23 | if hasattr(var, 'previous_functions'):
24 | for u in var.previous_functions:
25 | dot.edge(str(id(u[0])), str(id(var)))
26 | add_nodes(u[0])
27 | add_nodes(var.creator)
28 | return dot
29 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | class Logger:
6 | def __init__(self, log_file, names=None, delimiter='\t'):
7 | assert log_file is not None
8 | if names is None:
9 | names = []
10 |
11 | self.log_file = log_file
12 | self.names = names
13 | self.delim = delimiter
14 | self.fields = len(names)
15 |
16 | header = self._gather_values(self.names, prefix='#')
17 | with open(log_file, 'w') as f:
18 | f.write(header + '\n')
19 |
20 | def _gather_values(self, vals, prefix=''):
21 | output = ''
22 | for value in vals:
23 | output = output + self.delim + str(value)
24 | output = prefix + output
25 | return output
26 |
27 | def add(self, vals):
28 | assert len(vals) == self.fields
29 | output = self._gather_values(vals)
30 | with open(self.log_file, 'a') as f:
31 | f.write(output + '\n')
32 |
33 | def plot(self):
34 | data = np.loadtxt(self.log_file, skiprows=1)
35 | plt.clf()
36 | p = plt.plot(data)
37 | plt.legend(p, self.names)
38 | plt.grid()
39 | plt.savefig(self.log_file+'.png', format='png')
40 |
41 | if __name__ == '__main__':
42 | l = Logger('test.log', names=['a', 'b', 'c'])
43 | for i in range(4):
44 | l.add(['a', 'b', 'c'])
45 |
46 |
47 |
--------------------------------------------------------------------------------
/neural_fabrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as weight_init
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim as optim
10 | import torch.utils.data
11 | import torchvision.datasets as dset
12 | import torchvision.transforms as transforms
13 | from torch.autograd import Variable
14 |
15 | from utils import logger
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--dataset', default='cifar10', help='cifar10 | cifar100')
19 | parser.add_argument('--dataroot', default='./data', help='path to dataset')
20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
21 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
22 | parser.add_argument('--niter', type=int, default=200, help='number of epochs to train for')
23 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate, default=0.0002')
24 | parser.add_argument('--cuda' , action='store_false', help='enables cuda')
25 | parser.add_argument('--save', help='folder to store log files, model checkpoints')
26 |
27 | opt = parser.parse_args()
28 | print(opt)
29 |
30 | #logger
31 | try:
32 | os.makedirs(opt.save)
33 | print('Logging at: ' + opt.save)
34 | except OSError:
35 | pass
36 | log = logger.Logger(opt.save+'/train.log', ['loss', 'train error', 'test error'])
37 |
38 | # set random seed
39 | opt.manualSeed = random.randint(1, 10000) # fix seed
40 | print("Random Seed: ", opt.manualSeed)
41 | random.seed(opt.manualSeed)
42 | torch.manual_seed(opt.manualSeed)
43 | if torch.cuda.is_available() and opt.cuda:
44 | torch.cuda.manual_seed(opt.manualSeed)
45 |
46 | # set cudnn
47 | cudnn.benchmark = True
48 | if torch.cuda.is_available() and not opt.cuda:
49 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
50 |
51 | # get data loaders
52 | # default set to cifar10
53 | train_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=True,
54 | transform=transforms.Compose([
55 | transforms.RandomHorizontalFlip(),
56 | transforms.RandomCrop(32, padding=4),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
59 | ]))
60 | test_dataset = dset.CIFAR10(root=opt.dataroot, download=True, train=False,
61 | transform=transforms.Compose([
62 | transforms.ToTensor(),
63 | transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
64 | ]))
65 | assert train_dataset
66 | assert test_dataset
67 |
68 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batchSize,
69 | shuffle=True, num_workers=int(opt.workers))
70 |
71 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batchSize,
72 | shuffle=False, num_workers=int(opt.workers))
73 |
74 | # count number of incorrect classifications
75 | def compute_score(output, target):
76 | pred = output.max(1)[1]
77 | incorrect = pred.ne(target).cpu().sum()
78 | batch_size = output.size(0)
79 | return incorrect
80 |
81 |
82 | #define model
83 |
84 | class UpSample(nn.Module):
85 | def __init__(self, inChannels, outChannels):
86 | super(UpSample, self).__init__()
87 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
88 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1)
89 | self.batch_norm = nn.BatchNorm2d(outChannels)
90 |
91 | def forward(self, x):
92 | x = self.upsample(x)
93 | x = self.conv(x)
94 | x = self.batch_norm(x)
95 | x = nn.ReLU(True)(x)
96 | return x
97 |
98 | class DownSample(nn.Module):
99 | def __init__(self, inChannels, outChannels):
100 | super(DownSample, self).__init__()
101 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=2, padding=1)
102 | self.batch_norm = nn.BatchNorm2d(outChannels)
103 |
104 | def forward(self, x):
105 | x = self.conv(x)
106 | x = self.batch_norm(x)
107 | x = nn.ReLU(True)(x)
108 | return x
109 |
110 | class SameRes(nn.Module):
111 | def __init__(self, inChannels, outChannels):
112 | super(SameRes, self).__init__()
113 | self.conv = nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1)
114 | self.batch_norm = nn.BatchNorm2d(outChannels)
115 |
116 | def forward(self, x):
117 | x = self.conv(x)
118 | x = self.batch_norm(x)
119 | x = nn.ReLU(True)(x)
120 | return x
121 |
122 | class Net(nn.Module):
123 | def __init__(self):
124 | super(Net, self).__init__()
125 |
126 | self.channels = 128
127 | self.kernel_size = 3
128 |
129 | self.layers = 8
130 | self.scales = 5
131 |
132 | self.node_ops = nn.ModuleList()
133 |
134 | self.start_node = SameRes(3, self.channels)
135 |
136 | self.fc = nn.Linear(self.channels,10)
137 |
138 | for layer in range(self.layers):
139 | self.node_ops.append(nn.ModuleList()) # add list for each layer
140 | self.node_ops[layer] = nn.ModuleList() # list for each scale
141 |
142 | if layer == 0:
143 | for i in range(self.scales):
144 | self.node_ops[layer][i] = nn.ModuleList()
145 |
146 | node = DownSample(self.channels,self.channels)
147 | self.node_ops[layer][i].append(node)
148 | else:
149 | for i in range(self.scales):
150 | self.node_ops[layer][i] = nn.ModuleList()
151 |
152 | node = SameRes(self.channels,self.channels)
153 | self.node_ops[layer][i].append(node)
154 | if i == 0:
155 | self.node_ops[layer][i].append(
156 | UpSample(self.channels,self.channels))
157 | elif i == self.scales -1:
158 | self.node_ops[layer][i].append(
159 | DownSample(self.channels,self.channels))
160 | if layer == self.layers-1:
161 | self.node_ops[layer][i].append(
162 | DownSample(self.channels,self.channels))
163 | else:
164 | self.node_ops[layer][i].append(
165 | DownSample(self.channels,self.channels))
166 | self.node_ops[layer][i].append(
167 | UpSample(self.channels,self.channels))
168 | if layer == self.layers-1:
169 | self.node_ops[layer][i].append(
170 | DownSample(self.channels,self.channels))
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | weight_init.kaiming_normal(m.weight)
174 | weight_init.constant(m.bias, 0.1)
175 | elif isinstance(m, nn.BatchNorm2d):
176 | m.weight.data.normal_(1.0,0.02)
177 | m.bias.data.fill_(0)
178 |
179 | def forward(self, x):
180 | node_activ = [[[] for i in range(self.scales)] for j in range(self.layers)]
181 | out = self.start_node(x)
182 | for layer in range(self.layers):
183 | if layer == 0:
184 | for i in range(self.scales):
185 | if i == 0:
186 | node_activ[layer][i] = self.node_ops[layer][i][0](out)
187 | else:
188 | node_activ[layer][i] = self.node_ops[layer][i][0](node_activ[layer][i-1])
189 | else:
190 | for i in range(self.scales):
191 | if i == 0:
192 | t1 = (node_activ[layer-1][i])
193 | t2 = self.node_ops[layer][i][1](node_activ[layer-1][i+1])
194 | t = self.node_ops[layer][i][0](t1 + t2)
195 | node_activ[layer][i] = t
196 | elif i == self.scales-1:
197 | t1 = (node_activ[layer-1][i])
198 | t2 = self.node_ops[layer][i][1](node_activ[layer-1][i-1])
199 |
200 | if layer == self.layers-1:
201 | t3 = self.node_ops[layer][i][2](node_activ[layer][i-1])
202 | t = self.node_ops[layer][i][0](t1 + t2 + t3)
203 | else:
204 | t = self.node_ops[layer][i][0](t1 + t2)
205 | node_activ[layer][i] = t
206 | else:
207 | t1 = (node_activ[layer-1][i])
208 | t2 = self.node_ops[layer][i][2](node_activ[layer-1][i+1])
209 | t3 = self.node_ops[layer][i][1](node_activ[layer-1][i-1])
210 | if layer == self.layers-1:
211 | t4 = self.node_ops[layer][i][3](node_activ[layer][i-1])
212 | t = self.node_ops[layer][i][0](t1 + t2 + t3 + t4)
213 | else:
214 | t = self.node_ops[layer][i][0](t1 + t2 + t3)
215 | node_activ[layer][i] = t
216 |
217 | out = node_activ[-1][-1]
218 | out = out.view(out.size(0),-1)
219 | out = self.fc(out)
220 | return out
221 |
222 | net = Net()
223 | # net.apply(weights_init)
224 | print(net)
225 |
226 | # criterion
227 | criterion = nn.CrossEntropyLoss()
228 |
229 | if opt.cuda:
230 | net.cuda()
231 | criterion.cuda()
232 |
233 | # setup optimizer
234 |
235 | #train
236 | def train(epoch):
237 | net.train()
238 | score_epoch = 0
239 | loss_epoch = 0
240 | print('Epoch: ' + str(epoch))
241 | if epoch > 120:
242 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
243 | elif epoch > 80:
244 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
245 | else:
246 | optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
247 |
248 | for i, (images, labels) in enumerate(train_loader):
249 | images = Variable(images).cuda()
250 | labels = Variable(labels, requires_grad=False).cuda()
251 |
252 | optimizer.zero_grad()
253 | output = net(images)
254 | loss = criterion(output, labels)
255 | loss.backward()
256 | optimizer.step()
257 |
258 | loss_epoch = loss_epoch + loss.data[0]
259 | score_epoch = score_epoch + compute_score(output.data, labels.data)
260 |
261 | loss_epoch = loss_epoch / len(train_loader)
262 | print('[%d/%d][%d] train_loss: %.4f err: %d'
263 | % (epoch, opt.niter, len(train_loader), loss_epoch, score_epoch))
264 | return loss_epoch, score_epoch
265 |
266 |
267 | #test network
268 | def test():
269 | net.eval()
270 | score_epoch = 0
271 | loss_epoch = 0
272 | for i, (images, labels) in enumerate(test_loader):
273 | images = Variable(images).cuda()
274 | labels = Variable(labels, requires_grad=False).cuda()
275 |
276 | output = net(images)
277 | loss = criterion(output, labels)
278 |
279 | loss_epoch = loss_epoch + loss.data[0]
280 | score_epoch = score_epoch + compute_score(output.data, labels.data)
281 |
282 | loss_epoch = loss_epoch / len(test_loader)
283 | print('Test error: %d' % (score_epoch))
284 | return loss_epoch, score_epoch
285 |
286 |
287 | #train for opt.niter epochs
288 | start_error = test()
289 | for epoch in range(1,opt.niter+1):
290 | train_loss, train_error = train(epoch)
291 | test_loss, test_error = test()
292 | log.add([train_loss, train_error/50000.0, test_error/10000.0])
293 | log.plot()
294 | if epoch % 10 == 0:
295 | # do checkpointing
296 | torch.save(net.state_dict(), '%s/net_epoch_%d.pth' % (opt.save, epoch))
297 |
--------------------------------------------------------------------------------