35 |
36 | Fig 2. Alternate updating rule in CliqueNet. "{}" denotes the concatenating operator.
37 |
38 |
39 |
40 | ## Usage
41 |
42 | - Our experiments are conducted with [TensorFlow](https://github.com/tensorflow/tensorflow) in Python 2.
43 | - Clone this repo: `git clone https://github.com/iboing/CliqueNet`
44 | - An example to train a model on CIFAR or SVHN:
45 | ```bash
46 | python train.py --gpu [gpu id] --dataset [cifar-10 or cifar-100 or SVHN] --k [filters per layer] --T [all layers of three blocks] --dir [path to save models]
47 | ```
48 | - Additional techniques (optional): if you want to use attentional transition, bottleneck architecture, or compression strategy in our paper, add `--if_a True`, `--if_b True`, and `--if_c True`, respectively.
49 |
50 |
51 | ## Ablation experiments
52 |
53 | With the feedback connections, CliqueNet alternately re-update previous layers with updated layers, to enable refined features. The weights among layers are re-used for multiple times, so that a deeper representation space can be attained with a fixed number of parameters. In order to test the effectiveness of CliqueNet's feature refinement, we analyze the features generated in different stages by conducting experiments using different versions of CliqueNet. As illustrated by Fig 3, the CliqueNet(I+I) only uses Stage-I feature. The CliqueNet(I+II) uses Stage-I feature concatenated with input layer as the block feature, but transits Stage-II feature into the next block. The CliqueNet(II+II) only uses refined features.
54 |
55 |

56 |
57 | Fig 3. A schema for CliqueNet(i+j), i,j belong to {I,II}.
58 |
59 | |Model |block feature |transit |error(%)|
60 | |-----------------|-----------------|--------|--------|
61 | |CliqueNet(I+I) |{ X_0, Stage-I } |Stage-I |6.64 |
62 | |CliqueNet(I+II) |{ X_0, Stage-I } |Stage-II|6.10 |
63 | |CliqueNet(II+II) |{ X_0, Stage-II }|Stage-II|5.76 |
64 |
65 | Tab 1. Resutls of different versions of CliqueNets.
66 |
67 | To run the experiments above, please modify `train.py` as:
68 | ```python
69 | from models.cliquenet_I_I import build_model
70 | ```
71 | for CliqueNet(I+I), and
72 | ```python
73 | from models.cliquenet_I_II import build_model
74 | ```
75 | for CliqueNet(I+II).
76 |
77 | We further consider a situation where the feedback is not processed entirely. Concretely, when k=64 and T=15, we use the Stage-II feature, but only the first `X` steps, see Fig 2. Then `X=0` is just the case of CliqueNet(I+I), and `X=5` corresponds to CliqueNet(II+II).
78 |
79 |
80 | |Model|CIFAR-10| CIFAR-100|
81 | |--------------|----|-----|
82 | |CliqueNet(X=0)|5.83|24.79|
83 | |CliqueNet(X=1)|5.63|24.65|
84 | |CliqueNet(X=2)|5.54|24.37|
85 | |CliqueNet(X=3)|5.41|23.75|
86 | |CliqueNet(X=4)|5.20|24.04|
87 | |CliqueNet(X=5)|5.12|23.73|
88 |
89 | Tab 2. Performance of CliqueNets with different `X`.
90 |
91 | To run the experiments with different `X`, modify `train.py` as:
92 | ```python
93 | from models.cliquenet_X import build_model
94 | ```
95 | and set the value of `X` in `./models/cliquenet_X.py`
96 |
97 | ## Comparison with state of the arts
98 |
99 | The results listed below demonstrate the superiority of CliqueNet over DenseNet when there are no additional techniques (bottleneck, compression, etc.).
100 |
101 | |Model | FLOPs | Params | CIFAR-10 | CIFAR-100 | SVHN |
102 | |------------------------------------| ------|--------| -------- |-----------|------|
103 | |DenseNet (k = 12, T = 36) | 0.53G | 1.0M | 7.00 | 27.55 | 1.79 |
104 | |DenseNet (k = 12, T = 96) | 3.54G | 7.0M | 5.77 | 23.79 | 1.67 |
105 | |DenseNet (k = 24, T = 96) | 13.78G| 27.2M | 5.83 | 23.42 | 1.59 |
106 | | | | | | | |
107 | |CliqueNet (k = 36, T = 12) | 0.91G | 0.94M | 5.93 | 27.32 | 1.77 |
108 | |CliqueNet (k = 64, T = 15) | 4.21G | 4.49M | 5.12 | 23.98 | 1.62 |
109 | |CliqueNet (k = 80, T = 15) | 6.45G | 6.94M | 5.10 | 23.32 | 1.56 |
110 | |CliqueNet (k = 80, T = 18) | 9.45G | 10.14M | 5.06 | 23.14 | 1.51 |
111 |
112 | Tab 3. Main results on CIFAR and SVHN without data augmentation.
113 |
114 | Because larger T would lead to higher computation cost and slightly more parameters, we prefer using a larger k in our experiments. To make comparisons more fair, we also consider the situation where k and T of DenseNets and CliqueNets are exactly the same, see Tab 4.
115 |
116 | |Model|Params|CIFAR-10 | CIFAR-100|
117 | |--------------------|-----|----|-----|
118 | |DenseNet(k=12,T=36) |1.02M|7.00|27.55|
119 | |CliqueNet(k=12,T=36)|1.05M|5.79|26.85|
120 | | | | | |
121 | |DenseNet(k=24,T=18) |0.99M|7.13|27.70|
122 | |CliqueNet(k=24,T=18)|0.99M|6.04|26.57|
123 | | | | | |
124 | |DenseNet(k=36,T=12) |0.96M|6.89|27.54|
125 | |CliqueNet(k=36,T=12)|0.94M|5.93|27.32|
126 |
127 | Tab 4. Comparisons with the same k and T.
128 |
129 | Note that the result of DenseNet(k=12, T=36) is reported by original paper. The others are implementated by ourselves under the same experimental settings.
130 |
131 |
132 |
133 | ## Results on ImageNet
134 | Our code for experiments on ImageNet with TensorFlow will be released soon.
135 |
136 | Here we provide a [PyTorch](http://pytorch.org) version to train a CliqueNet on ImageNet. An example to run:
137 |
138 | ```Python
139 | python train_imagenet.py [path to the imagenet dataset]
140 | ```
141 | (As the default, CliqueNet-S3 is trained, batchsize is 160 and attentional transition is used.)
142 |
143 | The PyTorch pre-trained model can be downloaded here (Google Drive): [S3_model](https://drive.google.com/open?id=1IcsmIrTYmxd62Whh5nf8Z7grcCxNAfDk).
144 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import tensorflow as tf
5 | import time
6 | from dataloader.data_generator import load_data
7 | from models.cliquenet import build_model
8 |
9 | def into_batch(data, label, batch_size, shuffle):
10 | if shuffle:
11 | rand_indexes = np.random.permutation(data.shape[0])
12 | data = data[rand_indexes]
13 | label = label[rand_indexes]
14 |
15 | batch_count=len(data)/batch_size
16 | batches_data = np.split(data[:batch_count*batch_size], batch_count)
17 | batches_data.append(data[batch_count*batch_size:])
18 | batches_labels = np.split(label[:batch_count * batch_size], batch_count)
19 | batches_labels.append(label[batch_count*batch_size:])
20 | batch_count+=1
21 |
22 | return batches_data, batches_labels, batch_count
23 |
24 |
25 | def count_params():
26 | total_params=0
27 | for variable in tf.trainable_variables():
28 | shape=variable.get_shape()
29 | params=1
30 | for dim in shape:
31 | params=params*dim.value
32 | total_params+=params
33 | print("Total training params: %.2fM" % (total_params / 1e6))
34 |
35 |
36 | if __name__=='__main__':
37 | ##
38 | train_params={'normalize_type': 'by-channels', ## by-channels, divide-255, divide-256
39 | 'initial_lr': 0.1,
40 | 'weight_decay': 1e-4,
41 | 'batch_size': 64,
42 | 'total_epoch': 300,
43 | 'keep_prob':0.8
44 | }
45 |
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--gpu', default="0")
48 | parser.add_argument('--dataset',
49 | choices=['cifar-10', 'cifar-100', 'svhn'])
50 | parser.add_argument('--k', type=int,
51 | help='filters per layer')
52 | parser.add_argument('--T', type=int,
53 | help='total layers in all blocks')
54 | parser.add_argument('--dir',
55 | help='folder to store models')
56 | parser.add_argument('--if_a', default=False, type=bool,
57 | help='if use attentional transition')
58 | parser.add_argument('--if_b', default=False, type=bool,
59 | help='if use bottleneck architecture')
60 | parser.add_argument('--if_c', default=False, type=bool,
61 | help='if use compression')
62 |
63 | args = parser.parse_args()
64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
65 |
66 | dataset=args.dataset
67 |
68 | if dataset=='svhn':
69 | total_epoches=40
70 | else:
71 | total_epoches = train_params['total_epoch']
72 |
73 | result_dir=args.dir
74 |
75 | batch_size = train_params['batch_size']
76 | lr = train_params['initial_lr']
77 | kp = train_params['keep_prob']
78 | weight_decay = train_params['weight_decay']
79 |
80 | if os.path.exists(result_dir)==False:
81 | os.mkdir(result_dir)
82 |
83 | train_data, train_label, test_data, test_label=load_data(dataset, train_params['normalize_type'])
84 |
85 | image_size=train_data.shape[1:]
86 | label_num=train_label.shape[-1]
87 |
88 | graph=tf.Graph()
89 | with graph.as_default():
90 | input_images=tf.placeholder(tf.float32, [None, image_size[0], image_size[1], image_size[2]], name='input_images')
91 | true_labels=tf.placeholder(tf.float32, [None, label_num], name='labels')
92 | is_train=tf.placeholder(tf.bool, shape=[])
93 | learning_rate=tf.placeholder(tf.float32, shape=[], name='learning_rate')
94 | keep_prob=tf.placeholder(tf.float32, shape=[], name='keep_prob')
95 | ### build model
96 | logits, prob=build_model(input_images, args.k, args.T, label_num, is_train, keep_prob, args.if_a, args.if_b, args.if_c)
97 | ### loss and accuracy
98 | loss_cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=true_labels))
99 | if_correct = tf.equal(tf.argmax(prob, 1), tf.argmax(true_labels, 1))
100 | accuracy = tf.reduce_mean(tf.cast(if_correct, tf.float32))
101 | l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
102 | ### optimizer
103 | optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True)
104 | train_op = optimizer.minimize(loss_cross_entropy + l2_loss*weight_decay)
105 | saver=tf.train.Saver()
106 |
107 | ### begin training ###
108 |
109 | config=tf.ConfigProto()
110 | config.gpu_options.allow_growth = True
111 |
112 | with tf.Session(config=config, graph=graph) as sess:
113 | sess.run(tf.global_variables_initializer())
114 | count_params()
115 | ### train batch data ###
116 |
117 | # to shuffle before each epoch
118 |
119 | ### test batch data ###
120 |
121 | batches_data_test, batches_labels_test, batch_count_test = into_batch(test_data, test_label, batch_size, shuffle=False)
122 |
123 | loss_train=[]
124 | acc_train=[]
125 | loss_test=[]
126 | acc_test=[]
127 | best_acc=0
128 | for epoch in range(1, total_epoches+1):
129 | if epoch == total_epoches/2 : lr=lr*0.1
130 | if epoch == total_epoches*3/4 : lr=lr*0.1
131 |
132 | batches_data, batches_labels, batch_count=into_batch(train_data, train_label, batch_size, shuffle=True)
133 |
134 | ### train ###
135 | loss_per_bat=[]
136 | acc_per_bat=[]
137 | for batch_id in range(batch_count):
138 | data_per_bat = batches_data[batch_id]
139 | label_per_bat = batches_labels[batch_id]
140 | result_per_bat = sess.run([train_op, loss_cross_entropy, accuracy],
141 | feed_dict={input_images : data_per_bat,
142 | true_labels : label_per_bat,
143 | learning_rate : lr,
144 | is_train : True,
145 | keep_prob: kp})
146 | loss_per_bat.append(result_per_bat[1])
147 | acc_per_bat.append(result_per_bat[2])
148 | if (batch_id+1) % 100==0:
149 | print 'epoch:', epoch, 'batch:', batch_id+1, 'in', batch_count
150 | print 'loss:', result_per_bat[1], 'accuracy:', result_per_bat[2]
151 |
152 | saver.save(sess, os.path.join(result_dir, dataset+'_epoch_%d.ckpt' % epoch))
153 | loss_train.append(np.mean(loss_per_bat))
154 | acc_train.append(np.mean(acc_per_bat))
155 |
156 | ### test ###
157 | loss_per_bat=[]
158 | acc_per_bat=[]
159 | for batch_id in range(batch_count_test):
160 | data_per_bat = batches_data_test[batch_id]
161 | label_per_bat = batches_labels_test[batch_id]
162 | result_per_bat = sess.run([loss_cross_entropy, accuracy],
163 | feed_dict={input_images : data_per_bat,
164 | true_labels : label_per_bat,
165 | is_train: False,
166 | keep_prob: 1})
167 | loss_per_bat.append(result_per_bat[0]) ## result[0]->loss
168 | acc_per_bat.append(result_per_bat[1]) ## result[1]->acc
169 | loss_test.append(np.mean(loss_per_bat))
170 | acc_test.append(np.mean(acc_per_bat))
171 |
172 | if acc_test[-1]>best_acc:
173 | best_acc=acc_test[-1]
174 |
175 | print time.ctime()
176 | print 'epoch:',epoch
177 | print 'train loss:', loss_train[-1],'acc:',acc_train[-1]
178 | print 'test loss:', loss_test[-1], 'acc:', acc_test[-1]
179 | print 'best test acc:', best_acc
180 | print '\n'
181 |
182 | np.save(os.path.join(result_dir, result_dir+'_loss_train.npy'), np.array(loss_train))
183 | np.save(os.path.join(result_dir, result_dir+'_acc_train.npy'), np.array(acc_train))
184 | np.save(os.path.join(result_dir, result_dir+'_loss_test.npy'), np.array(loss_test))
185 | np.save(os.path.join(result_dir, result_dir+'_acc_test.npy'), np.array(acc_test))
186 |
--------------------------------------------------------------------------------
/imagenet_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim
6 | import torch.utils.data
7 | import torch.nn.functional as F
8 | from torch.nn import init
9 |
10 | class attention(nn.Module):
11 | def __init__(self, input_channels, map_size):
12 | super(attention, self).__init__()
13 |
14 | self.pool = nn.AvgPool2d(kernel_size = map_size)
15 | self.fc1 = nn.Linear(in_features = input_channels,out_features = input_channels // 2)
16 | self.fc2 = nn.Linear(in_features = input_channels // 2, out_features = input_channels)
17 |
18 |
19 | def forward(self, x):
20 | output = self.pool(x)
21 | output = output.view(output.size()[0], output.size()[1])
22 | output = self.fc1(output)
23 | output = F.relu(output)
24 | output = self.fc2(output)
25 | output = F.sigmoid(output)
26 | output = output.view(output.size()[0],output.size()[1],1,1)
27 | output = torch.mul(x, output)
28 | return output
29 |
30 |
31 | class transition(nn.Module):
32 | def __init__(self, if_att, current_size, input_channels, keep_prob):
33 | super(transition, self).__init__()
34 | self.input_channels = input_channels
35 | self.keep_prob = keep_prob
36 | self.bn = nn.BatchNorm2d(self.input_channels)
37 | self.conv = nn.Conv2d(self.input_channels, self.input_channels, kernel_size = 1, bias = False)
38 | # self.dropout = nn.Dropout2d(1 - self.keep_prob)
39 | self.pool = nn.AvgPool2d(kernel_size = 2)
40 | self.if_att = if_att
41 | if self.if_att == True:
42 | self.attention = attention(input_channels = self.input_channels, map_size = current_size)
43 |
44 | def forward(self, x):
45 | output = self.bn(x)
46 | output = F.relu(output)
47 | output = self.conv(output)
48 | if self.if_att==True:
49 | output = self.attention(output)
50 | # output = self.dropout(output)
51 | output = self.pool(output)
52 | return output
53 |
54 | class global_pool(nn.Module):
55 | def __init__(self, input_size, input_channels):
56 | super(global_pool, self).__init__()
57 | self.input_size = input_size
58 | self.input_channels = input_channels
59 | self.bn = nn.BatchNorm2d(self.input_channels)
60 | self.pool = nn.AvgPool2d(kernel_size = self.input_size)
61 |
62 | def forward(self, x):
63 | output = self.bn(x)
64 | output = F.relu(output)
65 | output = self.pool(output)
66 | return output
67 |
68 | class compress(nn.Module):
69 | def __init__(self, input_channels, keep_prob):
70 | super(compress, self).__init__()
71 | self.keep_prob = keep_prob
72 | self.bn = nn.BatchNorm2d(input_channels)
73 | self.conv = nn.Conv2d(input_channels, input_channels//2, kernel_size = 1, padding = 0, bias = False)
74 |
75 |
76 | def forward(self, x):
77 | output = self.bn(x)
78 | output = F.relu(output)
79 | output = self.conv(output)
80 | # output = F.dropout2d(output, 1 - self.keep_prob)
81 | return output
82 |
83 | class clique_block(nn.Module):
84 | def __init__(self, input_channels, channels_per_layer, layer_num, loop_num, keep_prob):
85 | super(clique_block, self).__init__()
86 | self.input_channels = input_channels
87 | self.channels_per_layer = channels_per_layer
88 | self.layer_num = layer_num
89 | self.loop_num = loop_num
90 | self.keep_prob = keep_prob
91 |
92 | # conv 1 x 1
93 | self.conv_param = nn.ModuleList([nn.Conv2d(self.channels_per_layer, self.channels_per_layer, kernel_size = 1, padding = 0, bias = False)
94 | for i in range((self.layer_num + 1) ** 2)])
95 |
96 | for i in range(1, self.layer_num + 1):
97 | self.conv_param[i] = nn.Conv2d(self.input_channels, self.channels_per_layer, kernel_size = 1, padding = 0, bias = False)
98 | for i in range(1, self.layer_num + 1):
99 | self.conv_param[i * (self.layer_num + 2)] = None
100 | for i in range(0, self.layer_num + 1):
101 | self.conv_param[i * (self.layer_num + 1)] = None
102 |
103 | self.forward_bn = nn.ModuleList([nn.BatchNorm2d(self.input_channels + i * self.channels_per_layer) for i in range(self.layer_num)])
104 | self.forward_bn_b = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer) for i in range(self.layer_num)])
105 | self.loop_bn = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer * (self.layer_num - 1)) for i in range(self.layer_num)])
106 | self.loop_bn_b = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer) for i in range(self.layer_num)])
107 |
108 | # conv 3 x 3
109 | self.conv_param_bottle = nn.ModuleList([nn.Conv2d(self.channels_per_layer, self.channels_per_layer, kernel_size = 3, padding = 1, bias = False)
110 | for i in range(self.layer_num)])
111 |
112 |
113 | def forward(self, x):
114 | # key: 1, 2, 3, 4, 5, update every loop
115 | self.blob_dict={}
116 | # save every loops results
117 | self.blob_dict_list=[]
118 |
119 | # first forward
120 | for layer_id in range(1, self.layer_num + 1):
121 | bottom_blob = x
122 | # bottom_param = self.param_dict['0_' + str(layer_id)]
123 |
124 | bottom_param = self.conv_param[layer_id].weight
125 | for layer_id_id in range(1, layer_id):
126 | # pdb.set_trace()
127 | bottom_blob = torch.cat((bottom_blob, self.blob_dict[str(layer_id_id)]), 1)
128 | # bottom_param = torch.cat((bottom_param, self.param_dict[str(layer_id_id) + '_' + str(layer_id)]), 1)
129 | bottom_param = torch.cat((bottom_param, self.conv_param[layer_id_id * (self.layer_num + 1) + layer_id].weight), 1)
130 | next_layer = self.forward_bn[layer_id - 1](bottom_blob)
131 | next_layer = F.relu(next_layer)
132 | # conv 1 x 1
133 | next_layer = F.conv2d(next_layer, bottom_param, stride = 1, padding = 0)
134 | # conv 3 x 3
135 | next_layer = self.forward_bn_b[layer_id - 1](next_layer)
136 | next_layer = F.relu(next_layer)
137 | next_layer = F.conv2d(next_layer, self.conv_param_bottle[layer_id - 1].weight, stride = 1, padding = 1)
138 | # next_layer = F.dropout2d(next_layer, 1 - self.keep_prob)
139 | self.blob_dict[str(layer_id)] = next_layer
140 | self.blob_dict_list.append(self.blob_dict)
141 |
142 | # loop
143 | for loop_id in range(self.loop_num):
144 | for layer_id in range(1, self.layer_num + 1):
145 |
146 | layer_list = [l_id for l_id in range(1, self.layer_num + 1)]
147 | layer_list.remove(layer_id)
148 |
149 | bottom_blobs = self.blob_dict[str(layer_list[0])]
150 | # bottom_param = self.param_dict[layer_list[0] + '_' + str(layer_id)]
151 | bottom_param = self.conv_param[layer_list[0] * (self.layer_num + 1) + layer_id].weight
152 | for bottom_id in range(len(layer_list) - 1):
153 | bottom_blobs = torch.cat((bottom_blobs, self.blob_dict[str(layer_list[bottom_id + 1])]), 1)
154 | # bottom_param = torch.cat((bottom_param, self.param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]), 1)
155 | bottom_param = torch.cat((bottom_param, self.conv_param[layer_list[bottom_id + 1] * (self.layer_num + 1) + layer_id].weight), 1)
156 | bottom_blobs = self.loop_bn[layer_id - 1](bottom_blobs)
157 | bottom_blobs = F.relu(bottom_blobs)
158 | # conv 1 x 1
159 | mid_blobs = F.conv2d(bottom_blobs, bottom_param, stride = 1, padding = 0)
160 | # conv 3 x 3
161 | top_blob = self.loop_bn_b[layer_id - 1](mid_blobs)
162 | top_blob = F.relu(top_blob)
163 | top_blob = F.conv2d(top_blob, self.conv_param_bottle[layer_id - 1].weight, stride = 1, padding = 1)
164 | self.blob_dict[str(layer_id)] = top_blob
165 | self.blob_dict_list.append(self.blob_dict)
166 |
167 | assert len(self.blob_dict_list) == 1 + self.loop_num
168 |
169 | # output
170 | block_feature_I = self.blob_dict_list[0]['1']
171 | for layer_id in range(2, self.layer_num + 1):
172 | block_feature_I = torch.cat((block_feature_I, self.blob_dict_list[0][str(layer_id)]), 1)
173 | block_feature_I = torch.cat((x, block_feature_I), 1)
174 |
175 | block_feature_II = self.blob_dict_list[self.loop_num]['1']
176 | for layer_id in range(2, self.layer_num + 1):
177 | block_feature_II = torch.cat((block_feature_II, self.blob_dict_list[self.loop_num][str(layer_id)]), 1)
178 | return block_feature_I, block_feature_II
179 |
--------------------------------------------------------------------------------
/train_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | import datetime
4 | import time
5 | import random
6 | import os
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.optim
14 | import torch.utils.data
15 | import torchvision.transforms as transforms
16 | import torchvision.datasets as datasets
17 |
18 |
19 | import imagenet_pytorch.cliquenet as cliquenet
20 |
21 |
22 | parser = argparse.ArgumentParser(description='CliqueNet ImageNet Training')
23 |
24 | parser.add_argument('data', metavar='DIR',
25 | help='path to the imagenet dataset')
26 |
27 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
28 | help='number of data loading workers (default: 8)')
29 |
30 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
31 | help='number of total epochs to run (default: 100)')
32 |
33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
34 | help='manual epoch number (useful on restarts)')
35 |
36 | parser.add_argument('-b', '--batch-size', default=160, type=int,
37 | metavar='N', help='mini-batch size (default: 160)')
38 |
39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
40 | metavar='LR', help='initial learning rate (default: 0.1)')
41 |
42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
43 | help='momentum (default: 0.9)')
44 |
45 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
46 | metavar='W', help='weight decay (default: 1e-4)')
47 |
48 | parser.add_argument('--print-freq', '-p', default=50, type=int,
49 | metavar='N', help='print frequency (default: 50)')
50 |
51 | parser.add_argument('--resume', default=None, type=str, metavar='PATH',
52 | help='path to latest checkpoint')
53 |
54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
55 | help='evaluate model on validation set')
56 |
57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
58 | help='use pre-trained model')
59 | parser.add_argument('--no-attention', dest='attention', action='store_false', help='not use attentional transition')
60 | parser.set_defaults(attention=True)
61 |
62 | best_prec1 = 0
63 |
64 |
65 | def main():
66 |
67 | global args, best_prec1
68 | args = parser.parse_args()
69 | if args.attention:
70 | print 'attentional transition is used'
71 | # create model
72 | model = cliquenet.build_cliquenet(input_channels=64, list_channels=[40, 80, 160, 160], list_layer_num=[6, 6, 6, 6], if_att=args.attention)
73 | model = torch.nn.DataParallel(model).cuda()
74 |
75 | # define loss function (criterion) and optimizer
76 | criterion = nn.CrossEntropyLoss().cuda()
77 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
78 |
79 | # # optionally resume from a checkpoint
80 | if args.resume:
81 | if os.path.isfile(args.resume):
82 | print("=> loading checkpoint '{}'".format(args.resume))
83 | checkpoint = torch.load(args.resume)
84 | args.start_epoch = checkpoint['epoch']
85 | best_prec1 = checkpoint['best_prec1']
86 | model.load_state_dict(checkpoint['state_dict'])
87 | optimizer.load_state_dict(checkpoint['optimizer'])
88 | print("=> loaded checkpoint '{}' (epoch {})"
89 | .format(args.resume, checkpoint['epoch']))
90 | else:
91 | print("=> no checkpoint found at '{}'".format(args.resume))
92 |
93 | cudnn.benchmark = True
94 |
95 |
96 | # Data loading code
97 | traindir = os.path.join(args.data, 'train')
98 | valdir = os.path.join(args.data, 'val')
99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
100 | std=[0.229, 0.224, 0.225])
101 |
102 | train_loader = torch.utils.data.DataLoader(
103 | datasets.ImageFolder(traindir, transforms.Compose([
104 | transforms.RandomResizedCrop(224),
105 | transforms.RandomHorizontalFlip(),
106 | transforms.ToTensor(),
107 | normalize,
108 | ])),
109 | batch_size=args.batch_size, shuffle=True,
110 | num_workers=args.workers, pin_memory=True)
111 |
112 | val_loader = torch.utils.data.DataLoader(
113 | datasets.ImageFolder(valdir, transforms.Compose([
114 | transforms.Resize(256),
115 | transforms.CenterCrop(224),
116 | transforms.ToTensor(),
117 | normalize,
118 | ])),
119 | batch_size=args.batch_size, shuffle=False,
120 | num_workers=args.workers, pin_memory=True)
121 |
122 | if args.evaluate:
123 | validate(val_loader, model, criterion)
124 | return
125 |
126 | # get_number_of_param(model)
127 |
128 | for epoch in range(args.start_epoch, args.epochs):
129 | adjust_learning_rate(optimizer, epoch)
130 |
131 | # train for one epoch
132 | train(train_loader, model, criterion, optimizer, epoch)
133 |
134 | # evaluate on validation set
135 | prec1 = validate(val_loader, model, criterion)
136 |
137 | # remember best prec@1 and save checkpoint
138 | is_best = prec1 > best_prec1
139 | best_prec1 = max(prec1, best_prec1)
140 | save_checkpoint({
141 | 'epoch': epoch + 1,
142 | 'state_dict': model.state_dict(),
143 | 'best_prec1': best_prec1,
144 | 'optimizer' : optimizer.state_dict(),
145 | }, is_best)
146 |
147 | def get_number_of_param(model):
148 | """get the number of param for every element"""
149 | count = 0
150 | for param in model.parameters():
151 | param_size = param.size()
152 | count_of_one_param = 1
153 | for dis in param_size:
154 | count_of_one_param *= dis
155 | print(param.size(), count_of_one_param)
156 | print(count)
157 | count += count_of_one_param
158 | print('total number of the model is %d'%count)
159 |
160 |
161 | def train(train_loader, model, criterion, optimizer, epoch):
162 | """train model"""
163 | batch_time = AverageMeter()
164 | data_time = AverageMeter()
165 | losses = AverageMeter()
166 | top1 = AverageMeter()
167 | top5 = AverageMeter()
168 |
169 | # switch to train mode
170 | model.train()
171 |
172 | end = time.time()
173 | # last_datetime = datetime.datetime.now()
174 | for i, (input, target) in enumerate(train_loader):
175 | # measure data loading time
176 | data_time.update(time.time() - end)
177 | input = input.cuda()
178 | target = target.cuda(async=True)
179 | input_var = torch.autograd.Variable(input)
180 | # pdb.set_trace()
181 | target_var = torch.autograd.Variable(target)
182 |
183 | # compute output
184 | output = model(input_var)
185 | loss = criterion(output, target_var)
186 |
187 | # measure accuracy and record loss
188 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
189 | losses.update(loss.data[0], input.size(0))
190 | top1.update(prec1[0], input.size(0))
191 | top5.update(prec5[0], input.size(0))
192 |
193 | # compute gradient and do SGD step
194 | optimizer.zero_grad()
195 | loss.backward()
196 | optimizer.step()
197 |
198 | # measure elapsed time
199 | batch_time.update(time.time() - end)
200 | end = time.time()
201 |
202 | if i % args.print_freq == 0:
203 | print('Epoch: [{0}][{1}/{2}]\t'
204 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
205 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
206 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
207 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
208 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
209 | epoch, i, len(train_loader), batch_time=batch_time,
210 | data_time=data_time, loss=losses, top1=top1, top5=top5))
211 |
212 | print time.ctime()
213 |
214 |
215 | def validate(val_loader, model, criterion):
216 | """validate model"""
217 | batch_time = AverageMeter()
218 | losses = AverageMeter()
219 | top1 = AverageMeter()
220 | top5 = AverageMeter()
221 |
222 | # switch to evaluate mode
223 | model.eval()
224 |
225 | end = time.time()
226 | for i, (input, target) in enumerate(val_loader):
227 | target = target.cuda(async=True)
228 | input_var = torch.autograd.Variable(input, volatile=True)
229 | target_var = torch.autograd.Variable(target, volatile=True)
230 |
231 | # compute output
232 | output = model(input_var)
233 | loss = criterion(output, target_var)
234 |
235 | # measure accuracy and record loss
236 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
237 | losses.update(loss.data[0], input.size(0))
238 | top1.update(prec1[0], input.size(0))
239 | top5.update(prec5[0], input.size(0))
240 |
241 | # measure elapsed time
242 | batch_time.update(time.time() - end)
243 | end = time.time()
244 |
245 | if i % args.print_freq == 0:
246 | print('Test: [{0}/{1}]\t'
247 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
248 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
249 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
250 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
251 | i, len(val_loader), batch_time=batch_time, loss=losses,
252 | top1=top1, top5=top5))
253 |
254 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
255 | .format(top1=top1, top5=top5))
256 |
257 | return top1.avg
258 |
259 |
260 | def save_checkpoint(state, is_best, filename='./cliquenet_s0.pth.tar'):
261 | """Save the trained model"""
262 | torch.save(state, filename)
263 | if is_best:
264 | shutil.copyfile(filename, './best_cliquenet_s0.pth.tar')
265 |
266 |
267 | class AverageMeter(object):
268 | """Computes and stores the average and current value"""
269 | def __init__(self):
270 | self.reset()
271 |
272 | def reset(self):
273 | self.val = 0
274 | self.avg = 0
275 | self.sum = 0
276 | self.count = 0
277 |
278 | def update(self, val, n=1):
279 | self.val = val
280 | self.sum += val * n
281 | self.count += n
282 | self.avg = self.sum / self.count
283 |
284 |
285 | def adjust_learning_rate(optimizer, epoch):
286 | """Sets the learning rate to the initial Learning rate decayed by 10 every 30 epochs"""
287 | lr = args.lr * (0.1 ** (epoch // 30))
288 | print('current learning rate is: %f'%lr)
289 | for param_group in optimizer.param_groups:
290 | param_group['lr'] = lr
291 |
292 |
293 | def accuracy(output, target, topk=(1,)):
294 | """Computes the precision@k for the specified values of k"""
295 | maxk = max(topk)
296 | batch_size = target.size(0)
297 |
298 | _, pred = output.topk(maxk, 1, True, True)
299 | pred = pred.t()
300 | correct = pred.eq(target.view(1, -1).expand_as(pred))
301 |
302 | res = []
303 | for k in topk:
304 | correct_k = correct[:k].view(-1).float().sum(0)
305 | res.append(correct_k.mul_(100.0 / batch_size))
306 | return res
307 |
308 |
309 | if __name__ == '__main__':
310 | main()
311 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def bias_var(out_channels, init_method):
5 | initial_value=tf.constant(0.0, shape=[out_channels])
6 | biases=tf.Variable(initial_value)
7 |
8 | return biases
9 |
10 | def conv_var(kernel_size, in_channels, out_channels, init_method, name):
11 | shape=[kernel_size[0], kernel_size[1], in_channels, out_channels]
12 | if init_method=='msra':
13 | return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.variance_scaling_initializer())
14 | elif init_method=='xavier':
15 | return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())
16 |
17 | def attentional_transition(input_layer, name):
18 |
19 | channels=input_layer.get_shape().as_list()[-1]
20 | map_size=input_layer.get_shape().as_list()[1]
21 |
22 | bottom_fc=tf.nn.avg_pool(input_layer, [1, map_size, map_size, 1], [1, map_size, map_size, 1], 'VALID')
23 |
24 | assert bottom_fc.get_shape().as_list()[-1]==channels ## none,1,1,C
25 |
26 | bottom_fc=tf.reshape(bottom_fc, [-1, channels]) ## none, C
27 |
28 | Wfc=tf.get_variable(name=name+'_W1', shape=[channels, channels/2], initializer=tf.contrib.layers.xavier_initializer())
29 | bfc=tf.get_variable(name=name+'_b1', initializer=tf.constant(0.0, shape=[channels/2]))
30 |
31 | mid_fc=tf.nn.relu(tf.matmul(bottom_fc, Wfc)+bfc)
32 |
33 | Wfc=tf.get_variable(name=name+'_W2', shape=[channels/2, channels], initializer=tf.contrib.layers.xavier_initializer())
34 | bfc=tf.get_variable(name=name+'_b2', initializer=tf.constant(0.0, shape=[channels]))
35 |
36 | top_fc=tf.nn.sigmoid(tf.matmul(mid_fc, Wfc)+bfc) ## none, C
37 |
38 | top_fc = tf.reshape(top_fc, [-1, 1, 1, channels])
39 |
40 | output_layer = tf.multiply(input_layer, top_fc)
41 |
42 | return output_layer
43 |
44 | def transition(input_layer, if_a, is_train, keep_prob, name):
45 | channels=input_layer.get_shape().as_list()[-1]
46 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None)
47 | output_layer=tf.nn.relu(output_layer)
48 | filters=conv_var(kernel_size=(1,1), in_channels=channels, out_channels=channels, init_method='msra', name=name)
49 | output_layer=tf.nn.conv2d(output_layer, filters, [1, 1, 1, 1], padding='SAME')
50 | output_layer=tf.nn.dropout(output_layer, keep_prob)
51 | ## attentional transition
52 | if if_a:
53 | output_layer=attentional_transition(output_layer, name=name+'-ATT')
54 |
55 | output_layer=tf.nn.avg_pool(output_layer, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')
56 |
57 | return output_layer
58 |
59 | def compress(input_layer, is_train, keep_prob, name):
60 | channels=input_layer.get_shape().as_list()[-1]
61 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None)
62 | output_layer=tf.nn.relu(output_layer)
63 | filters=conv_var(kernel_size=(1,1), in_channels=channels, out_channels=channels/2, init_method='msra', name=name)
64 | output_layer=tf.nn.conv2d(output_layer, filters, [1, 1, 1, 1], padding='SAME')
65 | output_layer=tf.nn.dropout(output_layer, keep_prob)
66 |
67 | return output_layer
68 |
69 |
70 | def global_pool(input_layer, is_train):
71 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None)
72 | output_layer=tf.nn.relu(output_layer)
73 |
74 | map_size=input_layer.get_shape().as_list()[1]
75 | return tf.nn.avg_pool(output_layer, [1, map_size, map_size, 1], [1, map_size, map_size, 1], 'VALID')
76 |
77 | def first_transit(input_layer, channels, strides, with_biase=False):
78 | filters=conv_var(kernel_size=(3,3), in_channels=3, out_channels=channels, init_method='msra', name='first_tran')
79 | conved=tf.nn.conv2d(input_layer, filters, [1, strides, strides, 1], padding='SAME')
80 | if with_biase==True:
81 | biases=bias_var(out_channels=channels)
82 | biased=tf.nn.bias_add(conved, biases)
83 | return biased
84 | return conved
85 |
86 |
87 | def loop_block(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1):
88 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2.
89 | channels=channels_per_layer
90 | node_0_channels=input_layer.get_shape().as_list()[-1]
91 | ## init param
92 | param_dict={}
93 | kernel_size=(1, 1) if if_b==True else (3, 3)
94 | for layer_id in range(1, layer_num):
95 | add_id=1
96 | while layer_id+add_id <= layer_num:
97 |
98 | ## ->
99 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id))
100 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters
101 | ## <-
102 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id))
103 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv
104 | add_id+=1
105 |
106 | for layer_id in range(layer_num):
107 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1))
108 | param_dict[str(0)+'_'+str(layer_id+1)]=filters
109 |
110 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num
111 |
112 | ### bottleneck param ###
113 | if if_b==True:
114 | param_dict_B={}
115 | for layer_id in range(1, layer_num+1):
116 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id))
117 | param_dict_B[str(layer_id)]=filters
118 |
119 | ## init blob
120 | blob_dict={}
121 |
122 | for layer_id in range(1, layer_num+1):
123 | bottom_blob=input_layer
124 | bottom_param=param_dict['0_'+str(layer_id)]
125 | for layer_id_id in range(1, layer_id):
126 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3)
127 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2)
128 |
129 |
130 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None)
131 | mid_layer=tf.nn.relu(mid_layer)
132 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME')
133 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
134 | ## Bottle neck
135 | if if_b==True:
136 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
137 | next_layer=tf.nn.relu(next_layer)
138 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
139 | next_layer=tf.nn.dropout(next_layer, keep_prob)
140 | else:
141 | next_layer=mid_layer
142 | blob_dict[str(layer_id)]=next_layer
143 |
144 | ## begin loop
145 | for loop_id in range(loop_num):
146 | for layer_id in range(1, layer_num+1): ## [1,2,3,4,5]
147 |
148 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)]
149 | layer_list.remove(str(layer_id))
150 |
151 | bottom_blobs=blob_dict[layer_list[0]]
152 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)]
153 | for bottom_id in range(len(layer_list)-1):
154 | bottom_blobs=tf.concat((bottom_blobs, blob_dict[layer_list[bottom_id+1]]),
155 | axis=3) ### concatenate the data blobs
156 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]),
157 | axis=2) ### concatenate the parameters
158 |
159 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None)
160 | mid_layer=tf.nn.relu(mid_layer)
161 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob
162 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
163 | ## Bottle neck
164 | if if_b==True:
165 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
166 | next_layer=tf.nn.relu(next_layer)
167 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
168 | next_layer=tf.nn.dropout(next_layer, keep_prob)
169 | else:
170 | next_layer=mid_layer
171 | blob_dict[str(layer_id)]=next_layer
172 |
173 | transit_feature=blob_dict['1']
174 | for layer_id in range(2, layer_num+1):
175 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3)
176 |
177 | block_feature=tf.concat((input_layer, transit_feature), axis=3)
178 |
179 | return block_feature, transit_feature
180 |
181 | def loop_block_I_I(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name):
182 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2.
183 | channels=channels_per_layer
184 | node_0_channels=input_layer.get_shape().as_list()[-1]
185 | ## init param
186 | param_dict={}
187 | kernel_size=(1, 1) if if_b==True else (3, 3)
188 | for layer_id in range(1, layer_num):
189 | add_id=1
190 | while layer_id+add_id <= layer_num:
191 |
192 | ## ->
193 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id))
194 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters
195 | ## <-
196 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id))
197 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv
198 | add_id+=1
199 |
200 | for layer_id in range(layer_num):
201 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1))
202 | param_dict[str(0)+'_'+str(layer_id+1)]=filters
203 |
204 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num
205 |
206 | ### bottleneck param ###
207 | if if_b==True:
208 | param_dict_B={}
209 | for layer_id in range(1, layer_num+1):
210 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id))
211 | param_dict_B[str(layer_id)]=filters
212 |
213 | ## init blob
214 | blob_dict={}
215 |
216 | for layer_id in range(1, layer_num+1):
217 | bottom_blob=input_layer
218 | bottom_param=param_dict['0_'+str(layer_id)]
219 | for layer_id_id in range(1, layer_id):
220 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3)
221 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2)
222 |
223 |
224 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None)
225 | mid_layer=tf.nn.relu(mid_layer)
226 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME')
227 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
228 | ## Bottle neck
229 | if if_b==True:
230 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
231 | next_layer=tf.nn.relu(next_layer)
232 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
233 | next_layer=tf.nn.dropout(next_layer, keep_prob)
234 | else:
235 | next_layer=mid_layer
236 | blob_dict[str(layer_id)]=next_layer
237 |
238 | ## no loop
239 |
240 | transit_feature=blob_dict['1']
241 | for layer_id in range(2, layer_num+1):
242 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3)
243 |
244 | block_feature=tf.concat((input_layer, transit_feature), axis=3)
245 |
246 | return block_feature, transit_feature
247 |
248 |
249 | def loop_block_I_II(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1):
250 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2.
251 | import copy
252 |
253 | channels=channels_per_layer
254 | node_0_channels=input_layer.get_shape().as_list()[-1]
255 | ## init param
256 | param_dict={}
257 | kernel_size=(1, 1) if if_b==True else (3, 3)
258 | for layer_id in range(1, layer_num):
259 | add_id=1
260 | while layer_id+add_id <= layer_num:
261 |
262 | ## ->
263 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id))
264 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters
265 | ## <-
266 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id))
267 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv
268 | add_id+=1
269 |
270 | for layer_id in range(layer_num):
271 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1))
272 | param_dict[str(0)+'_'+str(layer_id+1)]=filters
273 |
274 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num
275 |
276 | ### bottleneck param ###
277 | if if_b==True:
278 | param_dict_B={}
279 | for layer_id in range(1, layer_num+1):
280 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id))
281 | param_dict_B[str(layer_id)]=filters
282 |
283 | ## init blob
284 | blob_dict={}
285 | blob_dict_list=[]
286 |
287 | for layer_id in range(1, layer_num+1):
288 | bottom_blob=input_layer
289 | bottom_param=param_dict['0_'+str(layer_id)]
290 | for layer_id_id in range(1, layer_id):
291 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3)
292 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2)
293 |
294 |
295 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None)
296 | mid_layer=tf.nn.relu(mid_layer)
297 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME')
298 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
299 | ## Bottle neck
300 | if if_b==True:
301 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
302 | next_layer=tf.nn.relu(next_layer)
303 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
304 | next_layer=tf.nn.dropout(next_layer, keep_prob)
305 | else:
306 | next_layer=mid_layer
307 | blob_dict[str(layer_id)]=next_layer
308 |
309 | blob_dict_list.append(blob_dict)
310 |
311 | ## begin loop
312 | for loop_id in range(loop_num):
313 | blob_dict_new = copy.copy(blob_dict_list[-1])
314 | for layer_id in range(1, layer_num+1): ## [1,2,3,4,5]
315 |
316 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)]
317 | layer_list.remove(str(layer_id))
318 |
319 | bottom_blobs=blob_dict_new[layer_list[0]]
320 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)]
321 | for bottom_id in range(len(layer_list)-1):
322 | bottom_blobs=tf.concat((bottom_blobs, blob_dict_new[layer_list[bottom_id+1]]),
323 | axis=3) ### concatenate the data blobs
324 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]),
325 | axis=2) ### concatenate the parameters
326 |
327 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None)
328 | mid_layer=tf.nn.relu(mid_layer)
329 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob
330 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
331 | ## Bottle neck
332 | if if_b==True:
333 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
334 | next_layer=tf.nn.relu(next_layer)
335 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
336 | next_layer=tf.nn.dropout(next_layer, keep_prob)
337 | else:
338 | next_layer=mid_layer
339 | blob_dict_new[str(layer_id)]=next_layer
340 | blob_dict_list.append(blob_dict_new)
341 |
342 | assert len(blob_dict_list)==1+loop_num
343 |
344 | stage_I = blob_dict_list[0]['1']
345 | for layer_id in range(2, layer_num+1):
346 | stage_I=tf.concat((stage_I, blob_dict_list[0][str(layer_id)]), axis=3)
347 |
348 | stage_II = blob_dict_list[1]['1']
349 | for layer_id in range(2, layer_num+1):
350 | stage_II=tf.concat((stage_II, blob_dict_list[1][str(layer_id)]), axis=3)
351 |
352 | block_feature = tf.concat((input_layer, stage_I), axis=3)
353 | transit_feature = stage_II
354 |
355 | return block_feature, transit_feature
356 |
357 |
358 | def loop_block_X(input_layer, x_value, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1):
359 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2.
360 | channels=channels_per_layer
361 | node_0_channels=input_layer.get_shape().as_list()[-1]
362 | ## init param
363 | param_dict={}
364 | kernel_size=(1, 1) if if_b==True else (3, 3)
365 | for layer_id in range(1, layer_num):
366 | add_id=1
367 | while layer_id+add_id <= layer_num:
368 |
369 | ## ->
370 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id))
371 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters
372 | ## <-
373 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id))
374 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv
375 | add_id+=1
376 |
377 | for layer_id in range(layer_num):
378 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1))
379 | param_dict[str(0)+'_'+str(layer_id+1)]=filters
380 |
381 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num
382 |
383 | ### bottleneck param ###
384 | if if_b==True:
385 | param_dict_B={}
386 | for layer_id in range(1, layer_num+1):
387 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id))
388 | param_dict_B[str(layer_id)]=filters
389 |
390 | ## init blob
391 | blob_dict={}
392 |
393 | for layer_id in range(1, layer_num+1):
394 | bottom_blob=input_layer
395 | bottom_param=param_dict['0_'+str(layer_id)]
396 | for layer_id_id in range(1, layer_id):
397 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3)
398 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2)
399 |
400 |
401 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None)
402 | mid_layer=tf.nn.relu(mid_layer)
403 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME')
404 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
405 | ## Bottle neck
406 | if if_b==True:
407 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
408 | next_layer=tf.nn.relu(next_layer)
409 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
410 | next_layer=tf.nn.dropout(next_layer, keep_prob)
411 | else:
412 | next_layer=mid_layer
413 | blob_dict[str(layer_id)]=next_layer
414 |
415 | ## begin loop
416 | for loop_id in range(loop_num):
417 | for layer_id in range(1, x_value+1): ## [1,2,3,4,5]
418 |
419 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)]
420 | layer_list.remove(str(layer_id))
421 |
422 | bottom_blobs=blob_dict[layer_list[0]]
423 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)]
424 | for bottom_id in range(len(layer_list)-1):
425 | bottom_blobs=tf.concat((bottom_blobs, blob_dict[layer_list[bottom_id+1]]),
426 | axis=3) ### concatenate the data blobs
427 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]),
428 | axis=2) ### concatenate the parameters
429 |
430 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None)
431 | mid_layer=tf.nn.relu(mid_layer)
432 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob
433 | mid_layer=tf.nn.dropout(mid_layer, keep_prob)
434 | ## Bottle neck
435 | if if_b==True:
436 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None)
437 | next_layer=tf.nn.relu(next_layer)
438 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME')
439 | next_layer=tf.nn.dropout(next_layer, keep_prob)
440 | else:
441 | next_layer=mid_layer
442 | blob_dict[str(layer_id)]=next_layer
443 |
444 | transit_feature=blob_dict['1']
445 | for layer_id in range(2, layer_num+1):
446 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3)
447 |
448 | block_feature=tf.concat((input_layer, transit_feature), axis=3)
449 |
450 | return block_feature, transit_feature
451 |
--------------------------------------------------------------------------------