├── models ├── __init__.py ├── cliquenet.py ├── cliquenet_I_I.py ├── cliquenet_I_II.py ├── cliquenet_X.py └── utils.py ├── dataloader ├── __init__.py ├── preprocess.py └── data_generator.py ├── imagenet_pytorch ├── __init__.py ├── cliquenet.py └── utils.py ├── img ├── fig1.JPG ├── fig2.JPG ├── fig3.JPG └── tab1.JPG ├── LICENSE ├── README.md ├── train.py └── train_imagenet.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imagenet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/fig1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iboing/CliqueNet/HEAD/img/fig1.JPG -------------------------------------------------------------------------------- /img/fig2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iboing/CliqueNet/HEAD/img/fig2.JPG -------------------------------------------------------------------------------- /img/fig3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iboing/CliqueNet/HEAD/img/fig3.JPG -------------------------------------------------------------------------------- /img/tab1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iboing/CliqueNet/HEAD/img/tab1.JPG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yibo Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/cliquenet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import * 3 | 4 | block_num=3 5 | 6 | def build_model(input_images, k, T, label_num, is_train, keep_prob, if_a, if_b, if_c): 7 | 8 | current=first_transit(input_images, channels=64, strides=1, with_biase=False) 9 | current_list=[] 10 | 11 | ## build blocks 12 | for i in range(block_num): 13 | block_feature, transit_feature = loop_block(current, if_b, channels_per_layer=k, layer_num=T/3, is_train=is_train, keep_prob=keep_prob, block_name='b'+str(i)) 14 | if if_c==True: 15 | block_feature=compress(block_feature, is_train=is_train, keep_prob=keep_prob, name='com'+str(i)) 16 | current_list.append(global_pool(block_feature, is_train)) 17 | if i==block_num-1: 18 | break 19 | current=transition(transit_feature, if_a, is_train=is_train, keep_prob=keep_prob, name='tran'+str(i)) 20 | 21 | 22 | ## final feature 23 | final_feature=current_list[0] 24 | for block_id in range(len(current_list)-1): 25 | final_feature=tf.concat((final_feature, current_list[block_id+1]), 26 | axis=3) 27 | feature_length=final_feature.get_shape().as_list()[-1] 28 | print 'final feature length:',feature_length 29 | 30 | feature_flatten=tf.reshape(final_feature, [-1, feature_length]) 31 | ## final_fc 32 | Wfc=tf.get_variable(name='FC_W', shape=[feature_length, label_num], initializer=tf.contrib.layers.xavier_initializer()) 33 | bfc=tf.get_variable(name='FC_b', initializer=tf.constant(0.0, shape=[label_num])) 34 | 35 | logits=tf.matmul(feature_flatten, Wfc)+bfc 36 | prob=tf.nn.softmax(logits) 37 | 38 | return logits, prob -------------------------------------------------------------------------------- /models/cliquenet_I_I.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import * 3 | 4 | block_num=3 5 | 6 | def build_model(input_images, k, T, label_num, is_train, keep_prob, if_a, if_b, if_c): 7 | 8 | current=first_transit(input_images, channels=64, strides=1, with_biase=False) 9 | current_list=[] 10 | 11 | ## build blocks 12 | for i in range(block_num): 13 | block_feature, transit_feature = loop_block_I_I(current, if_b, channels_per_layer=k, layer_num=T/3, is_train=is_train, keep_prob=keep_prob, block_name='b'+str(i)) 14 | if if_c==True: 15 | block_feature=compress(block_feature, is_train=is_train, keep_prob=keep_prob, name='com'+str(i)) 16 | current_list.append(global_pool(block_feature, is_train)) 17 | if i==block_num-1: 18 | break 19 | current=transition(transit_feature, if_a, is_train=is_train, keep_prob=keep_prob, name='tran'+str(i)) 20 | 21 | 22 | ## final feature 23 | final_feature=current_list[0] 24 | for block_id in range(len(current_list)-1): 25 | final_feature=tf.concat((final_feature, current_list[block_id+1]), 26 | axis=3) 27 | feature_length=final_feature.get_shape().as_list()[-1] 28 | print 'final feature length:',feature_length 29 | 30 | feature_flatten=tf.reshape(final_feature, [-1, feature_length]) 31 | ## final_fc 32 | Wfc=tf.get_variable(name='FC_W', shape=[feature_length, label_num], initializer=tf.contrib.layers.xavier_initializer()) 33 | bfc=tf.get_variable(name='FC_b', initializer=tf.constant(0.0, shape=[label_num])) 34 | 35 | logits=tf.matmul(feature_flatten, Wfc)+bfc 36 | prob=tf.nn.softmax(logits) 37 | 38 | return logits, prob -------------------------------------------------------------------------------- /models/cliquenet_I_II.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import * 3 | 4 | block_num=3 5 | 6 | def build_model(input_images, k, T, label_num, is_train, keep_prob, if_a, if_b, if_c): 7 | 8 | current=first_transit(input_images, channels=64, strides=1, with_biase=False) 9 | current_list=[] 10 | 11 | ## build blocks 12 | for i in range(block_num): 13 | block_feature, transit_feature = loop_block_I_II(current, if_b, channels_per_layer=k, layer_num=T/3, is_train=is_train, keep_prob=keep_prob, block_name='b'+str(i)) 14 | if if_c==True: 15 | block_feature=compress(block_feature, is_train=is_train, keep_prob=keep_prob, name='com'+str(i)) 16 | current_list.append(global_pool(block_feature, is_train)) 17 | if i==block_num-1: 18 | break 19 | current=transition(transit_feature, if_a, is_train=is_train, keep_prob=keep_prob, name='tran'+str(i)) 20 | 21 | 22 | ## final feature 23 | final_feature=current_list[0] 24 | for block_id in range(len(current_list)-1): 25 | final_feature=tf.concat((final_feature, current_list[block_id+1]), 26 | axis=3) 27 | feature_length=final_feature.get_shape().as_list()[-1] 28 | print 'final feature length:',feature_length 29 | 30 | feature_flatten=tf.reshape(final_feature, [-1, feature_length]) 31 | ## final_fc 32 | Wfc=tf.get_variable(name='FC_W', shape=[feature_length, label_num], initializer=tf.contrib.layers.xavier_initializer()) 33 | bfc=tf.get_variable(name='FC_b', initializer=tf.constant(0.0, shape=[label_num])) 34 | 35 | logits=tf.matmul(feature_flatten, Wfc)+bfc 36 | prob=tf.nn.softmax(logits) 37 | 38 | return logits, prob -------------------------------------------------------------------------------- /models/cliquenet_X.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import * 3 | 4 | x=2 5 | 6 | block_num=3 7 | 8 | def build_model(input_images, k, T, label_num, is_train, keep_prob, if_a, if_b, if_c): 9 | 10 | current=first_transit(input_images, channels=64, strides=1, with_biase=False) 11 | current_list=[] 12 | 13 | ## build blocks 14 | for i in range(block_num): 15 | block_feature, transit_feature = loop_block_X(current, x, if_b, channels_per_layer=k, layer_num=T/3, is_train=is_train, keep_prob=keep_prob, block_name='b'+str(i)) 16 | if if_c==True: 17 | block_feature=compress(block_feature, is_train=is_train, keep_prob=keep_prob, name='com'+str(i)) 18 | current_list.append(global_pool(block_feature, is_train)) 19 | if i==block_num-1: 20 | break 21 | current=transition(transit_feature, if_a, is_train=is_train, keep_prob=keep_prob, name='tran'+str(i)) 22 | 23 | 24 | ## final feature 25 | final_feature=current_list[0] 26 | for block_id in range(len(current_list)-1): 27 | final_feature=tf.concat((final_feature, current_list[block_id+1]), 28 | axis=3) 29 | feature_length=final_feature.get_shape().as_list()[-1] 30 | print 'final feature length:',feature_length 31 | 32 | feature_flatten=tf.reshape(final_feature, [-1, feature_length]) 33 | ## final_fc 34 | Wfc=tf.get_variable(name='FC_W', shape=[feature_length, label_num], initializer=tf.contrib.layers.xavier_initializer()) 35 | bfc=tf.get_variable(name='FC_b', initializer=tf.constant(0.0, shape=[label_num])) 36 | 37 | logits=tf.matmul(feature_flatten, Wfc)+bfc 38 | prob=tf.nn.softmax(logits) 39 | 40 | return logits, prob -------------------------------------------------------------------------------- /imagenet_pytorch/cliquenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | import math 11 | from utils import transition, global_pool, compress, clique_block 12 | 13 | class build_cliquenet(nn.Module): 14 | def __init__(self, input_channels, list_channels, list_layer_num, if_att): 15 | super(build_cliquenet, self).__init__() 16 | self.fir_trans = nn.Conv2d(3, input_channels, kernel_size=7, stride=2, padding=3, bias=False) 17 | self.fir_bn = nn.BatchNorm2d(input_channels) 18 | self.fir_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 19 | self.block_num = len(list_channels) 20 | 21 | self.if_att = if_att 22 | self.list_block = nn.ModuleList() 23 | self.list_trans = nn.ModuleList() 24 | self.list_gb = nn.ModuleList() 25 | self.list_gb_channel = [] 26 | self.list_compress = nn.ModuleList() 27 | input_size_init = 56 28 | 29 | for i in xrange(self.block_num): 30 | 31 | if i == 0: 32 | self.list_block.append(clique_block(input_channels=input_channels, channels_per_layer=list_channels[0], layer_num=list_layer_num[0], loop_num=1, keep_prob=0.8)) 33 | self.list_gb_channel.append(input_channels + list_channels[0] * list_layer_num[0]) 34 | else : 35 | self.list_block.append(clique_block(input_channels=list_channels[i-1] * list_layer_num[i-1], channels_per_layer=list_channels[i], layer_num=list_layer_num[i], loop_num=1, keep_prob=0.8)) 36 | self.list_gb_channel.append(list_channels[i-1] * list_layer_num[i-1] + list_channels[i] * list_layer_num[i]) 37 | 38 | if i < self.block_num - 1: 39 | self.list_trans.append(transition(self.if_att, current_size=input_size_init, input_channels=list_channels[i] * list_layer_num[i], keep_prob=0.8)) 40 | 41 | self.list_gb.append(global_pool(input_size=input_size_init, input_channels=self.list_gb_channel[i] // 2)) 42 | self.list_compress.append(compress(input_channels=self.list_gb_channel[i], keep_prob=0.8)) 43 | input_size_init = input_size_init // 2 44 | 45 | self.fc = nn.Linear(in_features=sum(self.list_gb_channel) // 2, out_features=1000) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 50 | m.weight.data.normal_(0, math.sqrt(2. / n)) 51 | elif isinstance(m, nn.BatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | elif isinstance(m, nn.Linear): 55 | m.bias.data.zero_() 56 | 57 | def forward(self, x): 58 | 59 | output = self.fir_trans(x) 60 | output = self.fir_bn(output) 61 | output = F.relu(output) 62 | output = self.fir_pool(output) 63 | 64 | feature_I_list = [] 65 | 66 | # use stage II + stage II mode 67 | for i in xrange(self.block_num): 68 | block_feature_I, block_feature_II = self.list_block[i](output) 69 | block_feature_I = self.list_compress[i](block_feature_I) 70 | feature_I_list.append(self.list_gb[i](block_feature_I)) 71 | if i < self.block_num - 1: 72 | output = self.list_trans[i](block_feature_II) 73 | 74 | 75 | final_feature = feature_I_list[0] 76 | for block_id in range(1, len(feature_I_list)): 77 | final_feature=torch.cat((final_feature, feature_I_list[block_id]), 1) 78 | 79 | final_feature = final_feature.view(final_feature.size()[0], final_feature.size()[1]) 80 | output = self.fc(final_feature) 81 | return output 82 | -------------------------------------------------------------------------------- /dataloader/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy.io import loadmat 4 | 5 | 6 | def unpickle(file): 7 | import cPickle 8 | fo = open(file, 'rb') 9 | dict = cPickle.load(fo) 10 | fo.close() 11 | if 'data' in dict: 12 | dict['data'] = dict['data'].reshape((-1, 3, 32, 32)).swapaxes(1, 3).swapaxes(1, 2) 13 | 14 | return dict 15 | 16 | def load_data_one_10(f): 17 | batch = unpickle(f) 18 | data = batch['data'] 19 | labels = batch['labels'] 20 | print "Loading %s: %d" % (f, len(data)) 21 | return data, labels 22 | 23 | def load_data_one_100(f): 24 | batch = unpickle(f) 25 | data = batch['data'] 26 | labels = batch['fine_labels'] 27 | print "Loading %s: %d" % (f, len(data)) 28 | return data, labels 29 | 30 | def labels_to_one_hot(labels, classes): 31 | #Convert 1D array of labels to one hot representation 32 | 33 | new_labels = np.zeros((labels.shape[0], classes)) 34 | new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape) 35 | return new_labels 36 | 37 | def load_data10(phase, data_path): 38 | if phase=='train': 39 | files = [ 'data_batch_%d' % d for d in xrange(1, 6) ] 40 | else: 41 | files = ['test_batch'] 42 | 43 | data, labels = load_data_one_10(data_path + '/' + files[0]) 44 | for f in files[1:]: 45 | data_n, labels_n = load_data_one_10(data_path + '/' + f) 46 | data = np.append(data, data_n, axis=0) 47 | labels = np.append(labels, labels_n, axis=0) 48 | labels = np.array([ [ float(i == label) for i in xrange(10) ] for label in labels ]) ## cifar-10 49 | return data, labels 50 | 51 | def load_data100(phase, data_path): 52 | if phase=='train': 53 | files = phase 54 | else: 55 | files = 'test' 56 | 57 | data, labels = load_data_one_100(data_path + '/' + files) 58 | labels=np.hstack(labels) 59 | labels = labels_to_one_hot(labels, 100) ## cifar-100 60 | return data, labels 61 | 62 | def cifar_preprocess(cifar_type, data_path): 63 | if cifar_type=='cifar-10': 64 | train_data, train_labels=load_data10('train', data_path) 65 | test_data, test_labels=load_data10('test', data_path) 66 | elif cifar_type=='cifar-100': 67 | train_data, train_labels=load_data100('train', data_path) 68 | test_data, test_labels=load_data100('test', data_path) 69 | 70 | return train_data, train_labels, test_data, test_labels 71 | 72 | def svhn_preprocess(train_path, extra_path, test_path): 73 | train_file=os.path.join(train_path, 'train_32x32.mat') 74 | extra_file=os.path.join(extra_path, 'extra_32x32.mat') 75 | test_file=os.path.join(test_path, 'test_32x32.mat') 76 | ## 77 | raw_data=loadmat(train_file) 78 | train_data_0=raw_data['X'].transpose(3,0,1,2) 79 | train_labels_0=raw_data['y'].reshape(-1) 80 | train_labels_0[train_labels_0==10]=0 81 | train_labels_0=np.array([[float(i==label) for i in xrange(10)] for label in train_labels_0]) 82 | ## 83 | raw_data=loadmat(extra_file) 84 | extra_data=raw_data['X'].transpose(3,0,1,2) 85 | extra_labels=raw_data['y'].reshape(-1) 86 | extra_labels[extra_labels==10]=0 87 | extra_labels=np.array([[float(i==label) for i in xrange(10)] for label in extra_labels]) 88 | ## 89 | train_data=np.concatenate((train_data_0, extra_data), axis=0) 90 | train_labels=np.concatenate((train_labels_0, extra_labels), axis=0) 91 | ## 92 | raw_data=loadmat(test_file) 93 | test_data=raw_data['X'].transpose(3,0,1,2) 94 | test_labels=raw_data['y'].reshape(-1) 95 | test_labels[test_labels==10]=0 96 | test_labels=np.array([[float(i==label) for i in xrange(10)] for label in test_labels]) 97 | 98 | return train_data, train_labels, test_data, test_labels 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /dataloader/data_generator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | #import urllib.request ## python 3 4 | from urllib import urlretrieve ## python 2 5 | import tarfile 6 | import zipfile 7 | from preprocess import cifar_preprocess, svhn_preprocess 8 | import numpy as np 9 | 10 | def download_progress(count, block_size, total_size): 11 | pct_complete = float(count * block_size) / total_size 12 | msg = "\r {0:.1%} already downloaded".format(pct_complete) 13 | sys.stdout.write(msg) 14 | sys.stdout.flush() 15 | 16 | 17 | def download(data_type): 18 | url_dict={'cifar-10':'http://www.cs.toronto.edu/' 19 | '~kriz/cifar-10-python.tar.gz', 20 | 'cifar-100': 'http://www.cs.toronto.edu/' 21 | '~kriz/cifar-100-python.tar.gz', 22 | 'svhn_train': "http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 23 | 'svhn_test': "http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 24 | 'svhn_extra': "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat"} 25 | url=url_dict[data_type] 26 | file_name=url.split('/')[-1] 27 | data_folder='./'+data_type 28 | if data_type=='cifar-10': 29 | data_path=os.path.join(data_folder, data_type+'-batches-py') 30 | elif data_type=='cifar-100': 31 | data_path=os.path.join(data_folder, data_type+'-python') 32 | elif data_type=='svhn_train' or data_type=='svhn_test' or data_type=='svhn_extra': 33 | data_path=data_folder 34 | 35 | if os.path.exists(data_path)==False: 36 | 37 | os.mkdir(data_folder) 38 | file_path=os.path.join(data_folder, file_name) 39 | 40 | print "Downloading data file from %s to %s" % (url, file_path) 41 | urlretrieve(url=url, 42 | filename=file_path, 43 | reporthook=download_progress) 44 | 45 | print "\nExtracting file..." 46 | if file_path.endswith(".zip"): 47 | zipfile.ZipFile(file=file_path, mode='r').extractall(data_folder) 48 | elif file_path.endswith((".tar.gz", ".tgz")): 49 | tarfile.open(name=file_path, mode="r:gz").extractall(data_folder) 50 | 51 | print "Successfully downloaded and extracted" 52 | 53 | else: 54 | print "Data file already exists" 55 | 56 | return data_path 57 | 58 | 59 | def data_normalization(train_data_raw, test_data_raw, normalize_type): 60 | if normalize_type=='divide-255': 61 | train_data=train_data_raw/255.0 62 | test_data=test_data_raw/255.0 63 | 64 | return train_data, test_data 65 | elif normalize_type=='divide-256': 66 | train_data=train_data_raw/256.0 67 | test_data=test_data_raw/256.0 68 | 69 | return train_data, test_data 70 | elif normalize_type=='by-channels': 71 | train_data=np.zeros(train_data_raw.shape) 72 | test_data=np.zeros(test_data_raw.shape) 73 | for channel in range(train_data_raw.shape[-1]): 74 | images=np.concatenate((train_data_raw, test_data_raw), axis=0) 75 | channel_mean=np.mean(images[:,:,:,channel]) 76 | channel_std=np.std(images[:,:,:,channel]) 77 | train_data[:,:,:,channel]=(train_data_raw[:,:,:,channel]-channel_mean)/channel_std 78 | test_data[:,:,:,channel]=(test_data_raw[:,:,:,channel]-channel_mean)/channel_std 79 | 80 | return train_data, test_data 81 | 82 | elif normalize_type=='None': 83 | 84 | return train_data_raw, test_data_raw 85 | 86 | def load_data(data_type, normalize_type): ## cifar-10 or cifar-100 87 | if data_type=='svhn': 88 | data_path_train=download(data_type+'_train') 89 | data_path_extra=download(data_type+'_extra') 90 | data_path_test=download(data_type+'_test') 91 | train_data_raw, train_labels, test_data_raw, test_labels=svhn_preprocess(data_path_train, data_path_extra, data_path_test) 92 | else: 93 | data_path=download(data_type) 94 | train_data_raw, train_labels, test_data_raw, test_labels=cifar_preprocess(data_type, data_path) 95 | 96 | train_data, test_data=data_normalization(train_data_raw, test_data_raw, normalize_type) 97 | 98 | return train_data, train_labels, test_data, test_labels 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CliqueNet 2 | 3 | 4 | This repository is for [Convolutional Neural Networks with Alternately Updated Clique](https://arxiv.org/abs/1802.10419) (to appear in CVPR 2018, Oral presentation), 5 | 6 | by Yibo Yang, Zhisheng Zhong, Tiancheng Shen, and [Zhouchen Lin](http://www.cis.pku.edu.cn/faculty/vision/zlin/zlin.htm). 7 | 8 | ### citation 9 | If you find CliqueNet useful in your research, please consider citing: 10 | 11 | @article{yang18, 12 | author={Yibo Yang and Zhisheng Zhong and Tiancheng Shen and Zhouchen Lin}, 13 | title={Convolutional Neural Networks with Alternately Updated Clique}, 14 | journal={arXiv preprint arXiv:1802.10419}, 15 | year={2018} 16 | } 17 | 18 | ### table of contents 19 | - [Introduction](#introduction) 20 | - [Usage](#usage) 21 | - [Ablation experiments](#ablation-experiments) 22 | - [Comparison with state of the arts](#comparision-with-state-of-the-arts) 23 | - [Results on ImageNet](#results-on-imagenet) 24 | 25 | ## Introduction 26 | CliqueNet is a newly proposed convolutional neural network architecture where any pair of layers in the same block are connected bilaterally (Fig 1). Any layer is both the input and output another one, and information flow can be maximized. During propagation, the layers are updated alternately (Fig 2), so that each layer will always receive the feedback information from the layers that are updated more lately. We show that the refined features are more discriminative and lead to a better performance. On benchmark classification datasets including CIFAR-10, CIFAR-100, SVHN, and ILSVRC 2012, we achieve better or comparable results over state of the arts with fewer parameters. This repo contains the code of our project, and also provides some experimental results that are out of the paper. 27 | 28 | 29 |
30 | 31 | Fig 1. An illustration of a block with 4 layers. Node 0 denotes the input layer of this block. 32 | 33 | 34 |
35 | 36 | Fig 2. Alternate updating rule in CliqueNet. "{}" denotes the concatenating operator. 37 | 38 | 39 | 40 | ## Usage 41 | 42 | - Our experiments are conducted with [TensorFlow](https://github.com/tensorflow/tensorflow) in Python 2. 43 | - Clone this repo: `git clone https://github.com/iboing/CliqueNet` 44 | - An example to train a model on CIFAR or SVHN: 45 | ```bash 46 | python train.py --gpu [gpu id] --dataset [cifar-10 or cifar-100 or SVHN] --k [filters per layer] --T [all layers of three blocks] --dir [path to save models] 47 | ``` 48 | - Additional techniques (optional): if you want to use attentional transition, bottleneck architecture, or compression strategy in our paper, add `--if_a True`, `--if_b True`, and `--if_c True`, respectively. 49 | 50 | 51 | ## Ablation experiments 52 | 53 | With the feedback connections, CliqueNet alternately re-update previous layers with updated layers, to enable refined features. The weights among layers are re-used for multiple times, so that a deeper representation space can be attained with a fixed number of parameters. In order to test the effectiveness of CliqueNet's feature refinement, we analyze the features generated in different stages by conducting experiments using different versions of CliqueNet. As illustrated by Fig 3, the CliqueNet(I+I) only uses Stage-I feature. The CliqueNet(I+II) uses Stage-I feature concatenated with input layer as the block feature, but transits Stage-II feature into the next block. The CliqueNet(II+II) only uses refined features. 54 | 55 |
56 | 57 | Fig 3. A schema for CliqueNet(i+j), i,j belong to {I,II}. 58 | 59 | |Model |block feature |transit |error(%)| 60 | |-----------------|-----------------|--------|--------| 61 | |CliqueNet(I+I) |{ X_0, Stage-I } |Stage-I |6.64 | 62 | |CliqueNet(I+II) |{ X_0, Stage-I } |Stage-II|6.10 | 63 | |CliqueNet(II+II) |{ X_0, Stage-II }|Stage-II|5.76 | 64 | 65 | Tab 1. Resutls of different versions of CliqueNets. 66 | 67 | To run the experiments above, please modify `train.py` as: 68 | ```python 69 | from models.cliquenet_I_I import build_model 70 | ``` 71 | for CliqueNet(I+I), and 72 | ```python 73 | from models.cliquenet_I_II import build_model 74 | ``` 75 | for CliqueNet(I+II). 76 | 77 | We further consider a situation where the feedback is not processed entirely. Concretely, when k=64 and T=15, we use the Stage-II feature, but only the first `X` steps, see Fig 2. Then `X=0` is just the case of CliqueNet(I+I), and `X=5` corresponds to CliqueNet(II+II). 78 | 79 | 80 | |Model|CIFAR-10| CIFAR-100| 81 | |--------------|----|-----| 82 | |CliqueNet(X=0)|5.83|24.79| 83 | |CliqueNet(X=1)|5.63|24.65| 84 | |CliqueNet(X=2)|5.54|24.37| 85 | |CliqueNet(X=3)|5.41|23.75| 86 | |CliqueNet(X=4)|5.20|24.04| 87 | |CliqueNet(X=5)|5.12|23.73| 88 | 89 | Tab 2. Performance of CliqueNets with different `X`. 90 | 91 | To run the experiments with different `X`, modify `train.py` as: 92 | ```python 93 | from models.cliquenet_X import build_model 94 | ``` 95 | and set the value of `X` in `./models/cliquenet_X.py` 96 | 97 | ## Comparison with state of the arts 98 | 99 | The results listed below demonstrate the superiority of CliqueNet over DenseNet when there are no additional techniques (bottleneck, compression, etc.). 100 | 101 | |Model | FLOPs | Params | CIFAR-10 | CIFAR-100 | SVHN | 102 | |------------------------------------| ------|--------| -------- |-----------|------| 103 | |DenseNet (k = 12, T = 36) | 0.53G | 1.0M | 7.00 | 27.55 | 1.79 | 104 | |DenseNet (k = 12, T = 96) | 3.54G | 7.0M | 5.77 | 23.79 | 1.67 | 105 | |DenseNet (k = 24, T = 96) | 13.78G| 27.2M | 5.83 | 23.42 | 1.59 | 106 | | | | | | | | 107 | |CliqueNet (k = 36, T = 12) | 0.91G | 0.94M | 5.93 | 27.32 | 1.77 | 108 | |CliqueNet (k = 64, T = 15) | 4.21G | 4.49M | 5.12 | 23.98 | 1.62 | 109 | |CliqueNet (k = 80, T = 15) | 6.45G | 6.94M | 5.10 | 23.32 | 1.56 | 110 | |CliqueNet (k = 80, T = 18) | 9.45G | 10.14M | 5.06 | 23.14 | 1.51 | 111 | 112 | Tab 3. Main results on CIFAR and SVHN without data augmentation. 113 | 114 | Because larger T would lead to higher computation cost and slightly more parameters, we prefer using a larger k in our experiments. To make comparisons more fair, we also consider the situation where k and T of DenseNets and CliqueNets are exactly the same, see Tab 4. 115 | 116 | |Model|Params|CIFAR-10 | CIFAR-100| 117 | |--------------------|-----|----|-----| 118 | |DenseNet(k=12,T=36) |1.02M|7.00|27.55| 119 | |CliqueNet(k=12,T=36)|1.05M|5.79|26.85| 120 | | | | | | 121 | |DenseNet(k=24,T=18) |0.99M|7.13|27.70| 122 | |CliqueNet(k=24,T=18)|0.99M|6.04|26.57| 123 | | | | | | 124 | |DenseNet(k=36,T=12) |0.96M|6.89|27.54| 125 | |CliqueNet(k=36,T=12)|0.94M|5.93|27.32| 126 | 127 | Tab 4. Comparisons with the same k and T. 128 | 129 | Note that the result of DenseNet(k=12, T=36) is reported by original paper. The others are implementated by ourselves under the same experimental settings. 130 | 131 | 132 | 133 | ## Results on ImageNet 134 | Our code for experiments on ImageNet with TensorFlow will be released soon. 135 | 136 | Here we provide a [PyTorch](http://pytorch.org) version to train a CliqueNet on ImageNet. An example to run: 137 | 138 | ```Python 139 | python train_imagenet.py [path to the imagenet dataset] 140 | ``` 141 | (As the default, CliqueNet-S3 is trained, batchsize is 160 and attentional transition is used.) 142 | 143 | The PyTorch pre-trained model can be downloaded here (Google Drive): [S3_model](https://drive.google.com/open?id=1IcsmIrTYmxd62Whh5nf8Z7grcCxNAfDk). 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import tensorflow as tf 5 | import time 6 | from dataloader.data_generator import load_data 7 | from models.cliquenet import build_model 8 | 9 | def into_batch(data, label, batch_size, shuffle): 10 | if shuffle: 11 | rand_indexes = np.random.permutation(data.shape[0]) 12 | data = data[rand_indexes] 13 | label = label[rand_indexes] 14 | 15 | batch_count=len(data)/batch_size 16 | batches_data = np.split(data[:batch_count*batch_size], batch_count) 17 | batches_data.append(data[batch_count*batch_size:]) 18 | batches_labels = np.split(label[:batch_count * batch_size], batch_count) 19 | batches_labels.append(label[batch_count*batch_size:]) 20 | batch_count+=1 21 | 22 | return batches_data, batches_labels, batch_count 23 | 24 | 25 | def count_params(): 26 | total_params=0 27 | for variable in tf.trainable_variables(): 28 | shape=variable.get_shape() 29 | params=1 30 | for dim in shape: 31 | params=params*dim.value 32 | total_params+=params 33 | print("Total training params: %.2fM" % (total_params / 1e6)) 34 | 35 | 36 | if __name__=='__main__': 37 | ## 38 | train_params={'normalize_type': 'by-channels', ## by-channels, divide-255, divide-256 39 | 'initial_lr': 0.1, 40 | 'weight_decay': 1e-4, 41 | 'batch_size': 64, 42 | 'total_epoch': 300, 43 | 'keep_prob':0.8 44 | } 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--gpu', default="0") 48 | parser.add_argument('--dataset', 49 | choices=['cifar-10', 'cifar-100', 'svhn']) 50 | parser.add_argument('--k', type=int, 51 | help='filters per layer') 52 | parser.add_argument('--T', type=int, 53 | help='total layers in all blocks') 54 | parser.add_argument('--dir', 55 | help='folder to store models') 56 | parser.add_argument('--if_a', default=False, type=bool, 57 | help='if use attentional transition') 58 | parser.add_argument('--if_b', default=False, type=bool, 59 | help='if use bottleneck architecture') 60 | parser.add_argument('--if_c', default=False, type=bool, 61 | help='if use compression') 62 | 63 | args = parser.parse_args() 64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 65 | 66 | dataset=args.dataset 67 | 68 | if dataset=='svhn': 69 | total_epoches=40 70 | else: 71 | total_epoches = train_params['total_epoch'] 72 | 73 | result_dir=args.dir 74 | 75 | batch_size = train_params['batch_size'] 76 | lr = train_params['initial_lr'] 77 | kp = train_params['keep_prob'] 78 | weight_decay = train_params['weight_decay'] 79 | 80 | if os.path.exists(result_dir)==False: 81 | os.mkdir(result_dir) 82 | 83 | train_data, train_label, test_data, test_label=load_data(dataset, train_params['normalize_type']) 84 | 85 | image_size=train_data.shape[1:] 86 | label_num=train_label.shape[-1] 87 | 88 | graph=tf.Graph() 89 | with graph.as_default(): 90 | input_images=tf.placeholder(tf.float32, [None, image_size[0], image_size[1], image_size[2]], name='input_images') 91 | true_labels=tf.placeholder(tf.float32, [None, label_num], name='labels') 92 | is_train=tf.placeholder(tf.bool, shape=[]) 93 | learning_rate=tf.placeholder(tf.float32, shape=[], name='learning_rate') 94 | keep_prob=tf.placeholder(tf.float32, shape=[], name='keep_prob') 95 | ### build model 96 | logits, prob=build_model(input_images, args.k, args.T, label_num, is_train, keep_prob, args.if_a, args.if_b, args.if_c) 97 | ### loss and accuracy 98 | loss_cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=true_labels)) 99 | if_correct = tf.equal(tf.argmax(prob, 1), tf.argmax(true_labels, 1)) 100 | accuracy = tf.reduce_mean(tf.cast(if_correct, tf.float32)) 101 | l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) 102 | ### optimizer 103 | optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True) 104 | train_op = optimizer.minimize(loss_cross_entropy + l2_loss*weight_decay) 105 | saver=tf.train.Saver() 106 | 107 | ### begin training ### 108 | 109 | config=tf.ConfigProto() 110 | config.gpu_options.allow_growth = True 111 | 112 | with tf.Session(config=config, graph=graph) as sess: 113 | sess.run(tf.global_variables_initializer()) 114 | count_params() 115 | ### train batch data ### 116 | 117 | # to shuffle before each epoch 118 | 119 | ### test batch data ### 120 | 121 | batches_data_test, batches_labels_test, batch_count_test = into_batch(test_data, test_label, batch_size, shuffle=False) 122 | 123 | loss_train=[] 124 | acc_train=[] 125 | loss_test=[] 126 | acc_test=[] 127 | best_acc=0 128 | for epoch in range(1, total_epoches+1): 129 | if epoch == total_epoches/2 : lr=lr*0.1 130 | if epoch == total_epoches*3/4 : lr=lr*0.1 131 | 132 | batches_data, batches_labels, batch_count=into_batch(train_data, train_label, batch_size, shuffle=True) 133 | 134 | ### train ### 135 | loss_per_bat=[] 136 | acc_per_bat=[] 137 | for batch_id in range(batch_count): 138 | data_per_bat = batches_data[batch_id] 139 | label_per_bat = batches_labels[batch_id] 140 | result_per_bat = sess.run([train_op, loss_cross_entropy, accuracy], 141 | feed_dict={input_images : data_per_bat, 142 | true_labels : label_per_bat, 143 | learning_rate : lr, 144 | is_train : True, 145 | keep_prob: kp}) 146 | loss_per_bat.append(result_per_bat[1]) 147 | acc_per_bat.append(result_per_bat[2]) 148 | if (batch_id+1) % 100==0: 149 | print 'epoch:', epoch, 'batch:', batch_id+1, 'in', batch_count 150 | print 'loss:', result_per_bat[1], 'accuracy:', result_per_bat[2] 151 | 152 | saver.save(sess, os.path.join(result_dir, dataset+'_epoch_%d.ckpt' % epoch)) 153 | loss_train.append(np.mean(loss_per_bat)) 154 | acc_train.append(np.mean(acc_per_bat)) 155 | 156 | ### test ### 157 | loss_per_bat=[] 158 | acc_per_bat=[] 159 | for batch_id in range(batch_count_test): 160 | data_per_bat = batches_data_test[batch_id] 161 | label_per_bat = batches_labels_test[batch_id] 162 | result_per_bat = sess.run([loss_cross_entropy, accuracy], 163 | feed_dict={input_images : data_per_bat, 164 | true_labels : label_per_bat, 165 | is_train: False, 166 | keep_prob: 1}) 167 | loss_per_bat.append(result_per_bat[0]) ## result[0]->loss 168 | acc_per_bat.append(result_per_bat[1]) ## result[1]->acc 169 | loss_test.append(np.mean(loss_per_bat)) 170 | acc_test.append(np.mean(acc_per_bat)) 171 | 172 | if acc_test[-1]>best_acc: 173 | best_acc=acc_test[-1] 174 | 175 | print time.ctime() 176 | print 'epoch:',epoch 177 | print 'train loss:', loss_train[-1],'acc:',acc_train[-1] 178 | print 'test loss:', loss_test[-1], 'acc:', acc_test[-1] 179 | print 'best test acc:', best_acc 180 | print '\n' 181 | 182 | np.save(os.path.join(result_dir, result_dir+'_loss_train.npy'), np.array(loss_train)) 183 | np.save(os.path.join(result_dir, result_dir+'_acc_train.npy'), np.array(acc_train)) 184 | np.save(os.path.join(result_dir, result_dir+'_loss_test.npy'), np.array(loss_test)) 185 | np.save(os.path.join(result_dir, result_dir+'_acc_test.npy'), np.array(acc_test)) 186 | -------------------------------------------------------------------------------- /imagenet_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | class attention(nn.Module): 11 | def __init__(self, input_channels, map_size): 12 | super(attention, self).__init__() 13 | 14 | self.pool = nn.AvgPool2d(kernel_size = map_size) 15 | self.fc1 = nn.Linear(in_features = input_channels,out_features = input_channels // 2) 16 | self.fc2 = nn.Linear(in_features = input_channels // 2, out_features = input_channels) 17 | 18 | 19 | def forward(self, x): 20 | output = self.pool(x) 21 | output = output.view(output.size()[0], output.size()[1]) 22 | output = self.fc1(output) 23 | output = F.relu(output) 24 | output = self.fc2(output) 25 | output = F.sigmoid(output) 26 | output = output.view(output.size()[0],output.size()[1],1,1) 27 | output = torch.mul(x, output) 28 | return output 29 | 30 | 31 | class transition(nn.Module): 32 | def __init__(self, if_att, current_size, input_channels, keep_prob): 33 | super(transition, self).__init__() 34 | self.input_channels = input_channels 35 | self.keep_prob = keep_prob 36 | self.bn = nn.BatchNorm2d(self.input_channels) 37 | self.conv = nn.Conv2d(self.input_channels, self.input_channels, kernel_size = 1, bias = False) 38 | # self.dropout = nn.Dropout2d(1 - self.keep_prob) 39 | self.pool = nn.AvgPool2d(kernel_size = 2) 40 | self.if_att = if_att 41 | if self.if_att == True: 42 | self.attention = attention(input_channels = self.input_channels, map_size = current_size) 43 | 44 | def forward(self, x): 45 | output = self.bn(x) 46 | output = F.relu(output) 47 | output = self.conv(output) 48 | if self.if_att==True: 49 | output = self.attention(output) 50 | # output = self.dropout(output) 51 | output = self.pool(output) 52 | return output 53 | 54 | class global_pool(nn.Module): 55 | def __init__(self, input_size, input_channels): 56 | super(global_pool, self).__init__() 57 | self.input_size = input_size 58 | self.input_channels = input_channels 59 | self.bn = nn.BatchNorm2d(self.input_channels) 60 | self.pool = nn.AvgPool2d(kernel_size = self.input_size) 61 | 62 | def forward(self, x): 63 | output = self.bn(x) 64 | output = F.relu(output) 65 | output = self.pool(output) 66 | return output 67 | 68 | class compress(nn.Module): 69 | def __init__(self, input_channels, keep_prob): 70 | super(compress, self).__init__() 71 | self.keep_prob = keep_prob 72 | self.bn = nn.BatchNorm2d(input_channels) 73 | self.conv = nn.Conv2d(input_channels, input_channels//2, kernel_size = 1, padding = 0, bias = False) 74 | 75 | 76 | def forward(self, x): 77 | output = self.bn(x) 78 | output = F.relu(output) 79 | output = self.conv(output) 80 | # output = F.dropout2d(output, 1 - self.keep_prob) 81 | return output 82 | 83 | class clique_block(nn.Module): 84 | def __init__(self, input_channels, channels_per_layer, layer_num, loop_num, keep_prob): 85 | super(clique_block, self).__init__() 86 | self.input_channels = input_channels 87 | self.channels_per_layer = channels_per_layer 88 | self.layer_num = layer_num 89 | self.loop_num = loop_num 90 | self.keep_prob = keep_prob 91 | 92 | # conv 1 x 1 93 | self.conv_param = nn.ModuleList([nn.Conv2d(self.channels_per_layer, self.channels_per_layer, kernel_size = 1, padding = 0, bias = False) 94 | for i in range((self.layer_num + 1) ** 2)]) 95 | 96 | for i in range(1, self.layer_num + 1): 97 | self.conv_param[i] = nn.Conv2d(self.input_channels, self.channels_per_layer, kernel_size = 1, padding = 0, bias = False) 98 | for i in range(1, self.layer_num + 1): 99 | self.conv_param[i * (self.layer_num + 2)] = None 100 | for i in range(0, self.layer_num + 1): 101 | self.conv_param[i * (self.layer_num + 1)] = None 102 | 103 | self.forward_bn = nn.ModuleList([nn.BatchNorm2d(self.input_channels + i * self.channels_per_layer) for i in range(self.layer_num)]) 104 | self.forward_bn_b = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer) for i in range(self.layer_num)]) 105 | self.loop_bn = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer * (self.layer_num - 1)) for i in range(self.layer_num)]) 106 | self.loop_bn_b = nn.ModuleList([nn.BatchNorm2d(self.channels_per_layer) for i in range(self.layer_num)]) 107 | 108 | # conv 3 x 3 109 | self.conv_param_bottle = nn.ModuleList([nn.Conv2d(self.channels_per_layer, self.channels_per_layer, kernel_size = 3, padding = 1, bias = False) 110 | for i in range(self.layer_num)]) 111 | 112 | 113 | def forward(self, x): 114 | # key: 1, 2, 3, 4, 5, update every loop 115 | self.blob_dict={} 116 | # save every loops results 117 | self.blob_dict_list=[] 118 | 119 | # first forward 120 | for layer_id in range(1, self.layer_num + 1): 121 | bottom_blob = x 122 | # bottom_param = self.param_dict['0_' + str(layer_id)] 123 | 124 | bottom_param = self.conv_param[layer_id].weight 125 | for layer_id_id in range(1, layer_id): 126 | # pdb.set_trace() 127 | bottom_blob = torch.cat((bottom_blob, self.blob_dict[str(layer_id_id)]), 1) 128 | # bottom_param = torch.cat((bottom_param, self.param_dict[str(layer_id_id) + '_' + str(layer_id)]), 1) 129 | bottom_param = torch.cat((bottom_param, self.conv_param[layer_id_id * (self.layer_num + 1) + layer_id].weight), 1) 130 | next_layer = self.forward_bn[layer_id - 1](bottom_blob) 131 | next_layer = F.relu(next_layer) 132 | # conv 1 x 1 133 | next_layer = F.conv2d(next_layer, bottom_param, stride = 1, padding = 0) 134 | # conv 3 x 3 135 | next_layer = self.forward_bn_b[layer_id - 1](next_layer) 136 | next_layer = F.relu(next_layer) 137 | next_layer = F.conv2d(next_layer, self.conv_param_bottle[layer_id - 1].weight, stride = 1, padding = 1) 138 | # next_layer = F.dropout2d(next_layer, 1 - self.keep_prob) 139 | self.blob_dict[str(layer_id)] = next_layer 140 | self.blob_dict_list.append(self.blob_dict) 141 | 142 | # loop 143 | for loop_id in range(self.loop_num): 144 | for layer_id in range(1, self.layer_num + 1): 145 | 146 | layer_list = [l_id for l_id in range(1, self.layer_num + 1)] 147 | layer_list.remove(layer_id) 148 | 149 | bottom_blobs = self.blob_dict[str(layer_list[0])] 150 | # bottom_param = self.param_dict[layer_list[0] + '_' + str(layer_id)] 151 | bottom_param = self.conv_param[layer_list[0] * (self.layer_num + 1) + layer_id].weight 152 | for bottom_id in range(len(layer_list) - 1): 153 | bottom_blobs = torch.cat((bottom_blobs, self.blob_dict[str(layer_list[bottom_id + 1])]), 1) 154 | # bottom_param = torch.cat((bottom_param, self.param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]), 1) 155 | bottom_param = torch.cat((bottom_param, self.conv_param[layer_list[bottom_id + 1] * (self.layer_num + 1) + layer_id].weight), 1) 156 | bottom_blobs = self.loop_bn[layer_id - 1](bottom_blobs) 157 | bottom_blobs = F.relu(bottom_blobs) 158 | # conv 1 x 1 159 | mid_blobs = F.conv2d(bottom_blobs, bottom_param, stride = 1, padding = 0) 160 | # conv 3 x 3 161 | top_blob = self.loop_bn_b[layer_id - 1](mid_blobs) 162 | top_blob = F.relu(top_blob) 163 | top_blob = F.conv2d(top_blob, self.conv_param_bottle[layer_id - 1].weight, stride = 1, padding = 1) 164 | self.blob_dict[str(layer_id)] = top_blob 165 | self.blob_dict_list.append(self.blob_dict) 166 | 167 | assert len(self.blob_dict_list) == 1 + self.loop_num 168 | 169 | # output 170 | block_feature_I = self.blob_dict_list[0]['1'] 171 | for layer_id in range(2, self.layer_num + 1): 172 | block_feature_I = torch.cat((block_feature_I, self.blob_dict_list[0][str(layer_id)]), 1) 173 | block_feature_I = torch.cat((x, block_feature_I), 1) 174 | 175 | block_feature_II = self.blob_dict_list[self.loop_num]['1'] 176 | for layer_id in range(2, self.layer_num + 1): 177 | block_feature_II = torch.cat((block_feature_II, self.blob_dict_list[self.loop_num][str(layer_id)]), 1) 178 | return block_feature_I, block_feature_II 179 | -------------------------------------------------------------------------------- /train_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import datetime 4 | import time 5 | import random 6 | import os 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | 18 | 19 | import imagenet_pytorch.cliquenet as cliquenet 20 | 21 | 22 | parser = argparse.ArgumentParser(description='CliqueNet ImageNet Training') 23 | 24 | parser.add_argument('data', metavar='DIR', 25 | help='path to the imagenet dataset') 26 | 27 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 28 | help='number of data loading workers (default: 8)') 29 | 30 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 31 | help='number of total epochs to run (default: 100)') 32 | 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | 36 | parser.add_argument('-b', '--batch-size', default=160, type=int, 37 | metavar='N', help='mini-batch size (default: 160)') 38 | 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 40 | metavar='LR', help='initial learning rate (default: 0.1)') 41 | 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum (default: 0.9)') 44 | 45 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 46 | metavar='W', help='weight decay (default: 1e-4)') 47 | 48 | parser.add_argument('--print-freq', '-p', default=50, type=int, 49 | metavar='N', help='print frequency (default: 50)') 50 | 51 | parser.add_argument('--resume', default=None, type=str, metavar='PATH', 52 | help='path to latest checkpoint') 53 | 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pre-trained model') 59 | parser.add_argument('--no-attention', dest='attention', action='store_false', help='not use attentional transition') 60 | parser.set_defaults(attention=True) 61 | 62 | best_prec1 = 0 63 | 64 | 65 | def main(): 66 | 67 | global args, best_prec1 68 | args = parser.parse_args() 69 | if args.attention: 70 | print 'attentional transition is used' 71 | # create model 72 | model = cliquenet.build_cliquenet(input_channels=64, list_channels=[40, 80, 160, 160], list_layer_num=[6, 6, 6, 6], if_att=args.attention) 73 | model = torch.nn.DataParallel(model).cuda() 74 | 75 | # define loss function (criterion) and optimizer 76 | criterion = nn.CrossEntropyLoss().cuda() 77 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 78 | 79 | # # optionally resume from a checkpoint 80 | if args.resume: 81 | if os.path.isfile(args.resume): 82 | print("=> loading checkpoint '{}'".format(args.resume)) 83 | checkpoint = torch.load(args.resume) 84 | args.start_epoch = checkpoint['epoch'] 85 | best_prec1 = checkpoint['best_prec1'] 86 | model.load_state_dict(checkpoint['state_dict']) 87 | optimizer.load_state_dict(checkpoint['optimizer']) 88 | print("=> loaded checkpoint '{}' (epoch {})" 89 | .format(args.resume, checkpoint['epoch'])) 90 | else: 91 | print("=> no checkpoint found at '{}'".format(args.resume)) 92 | 93 | cudnn.benchmark = True 94 | 95 | 96 | # Data loading code 97 | traindir = os.path.join(args.data, 'train') 98 | valdir = os.path.join(args.data, 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | 102 | train_loader = torch.utils.data.DataLoader( 103 | datasets.ImageFolder(traindir, transforms.Compose([ 104 | transforms.RandomResizedCrop(224), 105 | transforms.RandomHorizontalFlip(), 106 | transforms.ToTensor(), 107 | normalize, 108 | ])), 109 | batch_size=args.batch_size, shuffle=True, 110 | num_workers=args.workers, pin_memory=True) 111 | 112 | val_loader = torch.utils.data.DataLoader( 113 | datasets.ImageFolder(valdir, transforms.Compose([ 114 | transforms.Resize(256), 115 | transforms.CenterCrop(224), 116 | transforms.ToTensor(), 117 | normalize, 118 | ])), 119 | batch_size=args.batch_size, shuffle=False, 120 | num_workers=args.workers, pin_memory=True) 121 | 122 | if args.evaluate: 123 | validate(val_loader, model, criterion) 124 | return 125 | 126 | # get_number_of_param(model) 127 | 128 | for epoch in range(args.start_epoch, args.epochs): 129 | adjust_learning_rate(optimizer, epoch) 130 | 131 | # train for one epoch 132 | train(train_loader, model, criterion, optimizer, epoch) 133 | 134 | # evaluate on validation set 135 | prec1 = validate(val_loader, model, criterion) 136 | 137 | # remember best prec@1 and save checkpoint 138 | is_best = prec1 > best_prec1 139 | best_prec1 = max(prec1, best_prec1) 140 | save_checkpoint({ 141 | 'epoch': epoch + 1, 142 | 'state_dict': model.state_dict(), 143 | 'best_prec1': best_prec1, 144 | 'optimizer' : optimizer.state_dict(), 145 | }, is_best) 146 | 147 | def get_number_of_param(model): 148 | """get the number of param for every element""" 149 | count = 0 150 | for param in model.parameters(): 151 | param_size = param.size() 152 | count_of_one_param = 1 153 | for dis in param_size: 154 | count_of_one_param *= dis 155 | print(param.size(), count_of_one_param) 156 | print(count) 157 | count += count_of_one_param 158 | print('total number of the model is %d'%count) 159 | 160 | 161 | def train(train_loader, model, criterion, optimizer, epoch): 162 | """train model""" 163 | batch_time = AverageMeter() 164 | data_time = AverageMeter() 165 | losses = AverageMeter() 166 | top1 = AverageMeter() 167 | top5 = AverageMeter() 168 | 169 | # switch to train mode 170 | model.train() 171 | 172 | end = time.time() 173 | # last_datetime = datetime.datetime.now() 174 | for i, (input, target) in enumerate(train_loader): 175 | # measure data loading time 176 | data_time.update(time.time() - end) 177 | input = input.cuda() 178 | target = target.cuda(async=True) 179 | input_var = torch.autograd.Variable(input) 180 | # pdb.set_trace() 181 | target_var = torch.autograd.Variable(target) 182 | 183 | # compute output 184 | output = model(input_var) 185 | loss = criterion(output, target_var) 186 | 187 | # measure accuracy and record loss 188 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 189 | losses.update(loss.data[0], input.size(0)) 190 | top1.update(prec1[0], input.size(0)) 191 | top5.update(prec5[0], input.size(0)) 192 | 193 | # compute gradient and do SGD step 194 | optimizer.zero_grad() 195 | loss.backward() 196 | optimizer.step() 197 | 198 | # measure elapsed time 199 | batch_time.update(time.time() - end) 200 | end = time.time() 201 | 202 | if i % args.print_freq == 0: 203 | print('Epoch: [{0}][{1}/{2}]\t' 204 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 205 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 206 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 207 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 208 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 209 | epoch, i, len(train_loader), batch_time=batch_time, 210 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 211 | 212 | print time.ctime() 213 | 214 | 215 | def validate(val_loader, model, criterion): 216 | """validate model""" 217 | batch_time = AverageMeter() 218 | losses = AverageMeter() 219 | top1 = AverageMeter() 220 | top5 = AverageMeter() 221 | 222 | # switch to evaluate mode 223 | model.eval() 224 | 225 | end = time.time() 226 | for i, (input, target) in enumerate(val_loader): 227 | target = target.cuda(async=True) 228 | input_var = torch.autograd.Variable(input, volatile=True) 229 | target_var = torch.autograd.Variable(target, volatile=True) 230 | 231 | # compute output 232 | output = model(input_var) 233 | loss = criterion(output, target_var) 234 | 235 | # measure accuracy and record loss 236 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 237 | losses.update(loss.data[0], input.size(0)) 238 | top1.update(prec1[0], input.size(0)) 239 | top5.update(prec5[0], input.size(0)) 240 | 241 | # measure elapsed time 242 | batch_time.update(time.time() - end) 243 | end = time.time() 244 | 245 | if i % args.print_freq == 0: 246 | print('Test: [{0}/{1}]\t' 247 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 248 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 249 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 250 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 251 | i, len(val_loader), batch_time=batch_time, loss=losses, 252 | top1=top1, top5=top5)) 253 | 254 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 255 | .format(top1=top1, top5=top5)) 256 | 257 | return top1.avg 258 | 259 | 260 | def save_checkpoint(state, is_best, filename='./cliquenet_s0.pth.tar'): 261 | """Save the trained model""" 262 | torch.save(state, filename) 263 | if is_best: 264 | shutil.copyfile(filename, './best_cliquenet_s0.pth.tar') 265 | 266 | 267 | class AverageMeter(object): 268 | """Computes and stores the average and current value""" 269 | def __init__(self): 270 | self.reset() 271 | 272 | def reset(self): 273 | self.val = 0 274 | self.avg = 0 275 | self.sum = 0 276 | self.count = 0 277 | 278 | def update(self, val, n=1): 279 | self.val = val 280 | self.sum += val * n 281 | self.count += n 282 | self.avg = self.sum / self.count 283 | 284 | 285 | def adjust_learning_rate(optimizer, epoch): 286 | """Sets the learning rate to the initial Learning rate decayed by 10 every 30 epochs""" 287 | lr = args.lr * (0.1 ** (epoch // 30)) 288 | print('current learning rate is: %f'%lr) 289 | for param_group in optimizer.param_groups: 290 | param_group['lr'] = lr 291 | 292 | 293 | def accuracy(output, target, topk=(1,)): 294 | """Computes the precision@k for the specified values of k""" 295 | maxk = max(topk) 296 | batch_size = target.size(0) 297 | 298 | _, pred = output.topk(maxk, 1, True, True) 299 | pred = pred.t() 300 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 301 | 302 | res = [] 303 | for k in topk: 304 | correct_k = correct[:k].view(-1).float().sum(0) 305 | res.append(correct_k.mul_(100.0 / batch_size)) 306 | return res 307 | 308 | 309 | if __name__ == '__main__': 310 | main() 311 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def bias_var(out_channels, init_method): 5 | initial_value=tf.constant(0.0, shape=[out_channels]) 6 | biases=tf.Variable(initial_value) 7 | 8 | return biases 9 | 10 | def conv_var(kernel_size, in_channels, out_channels, init_method, name): 11 | shape=[kernel_size[0], kernel_size[1], in_channels, out_channels] 12 | if init_method=='msra': 13 | return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.variance_scaling_initializer()) 14 | elif init_method=='xavier': 15 | return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer()) 16 | 17 | def attentional_transition(input_layer, name): 18 | 19 | channels=input_layer.get_shape().as_list()[-1] 20 | map_size=input_layer.get_shape().as_list()[1] 21 | 22 | bottom_fc=tf.nn.avg_pool(input_layer, [1, map_size, map_size, 1], [1, map_size, map_size, 1], 'VALID') 23 | 24 | assert bottom_fc.get_shape().as_list()[-1]==channels ## none,1,1,C 25 | 26 | bottom_fc=tf.reshape(bottom_fc, [-1, channels]) ## none, C 27 | 28 | Wfc=tf.get_variable(name=name+'_W1', shape=[channels, channels/2], initializer=tf.contrib.layers.xavier_initializer()) 29 | bfc=tf.get_variable(name=name+'_b1', initializer=tf.constant(0.0, shape=[channels/2])) 30 | 31 | mid_fc=tf.nn.relu(tf.matmul(bottom_fc, Wfc)+bfc) 32 | 33 | Wfc=tf.get_variable(name=name+'_W2', shape=[channels/2, channels], initializer=tf.contrib.layers.xavier_initializer()) 34 | bfc=tf.get_variable(name=name+'_b2', initializer=tf.constant(0.0, shape=[channels])) 35 | 36 | top_fc=tf.nn.sigmoid(tf.matmul(mid_fc, Wfc)+bfc) ## none, C 37 | 38 | top_fc = tf.reshape(top_fc, [-1, 1, 1, channels]) 39 | 40 | output_layer = tf.multiply(input_layer, top_fc) 41 | 42 | return output_layer 43 | 44 | def transition(input_layer, if_a, is_train, keep_prob, name): 45 | channels=input_layer.get_shape().as_list()[-1] 46 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None) 47 | output_layer=tf.nn.relu(output_layer) 48 | filters=conv_var(kernel_size=(1,1), in_channels=channels, out_channels=channels, init_method='msra', name=name) 49 | output_layer=tf.nn.conv2d(output_layer, filters, [1, 1, 1, 1], padding='SAME') 50 | output_layer=tf.nn.dropout(output_layer, keep_prob) 51 | ## attentional transition 52 | if if_a: 53 | output_layer=attentional_transition(output_layer, name=name+'-ATT') 54 | 55 | output_layer=tf.nn.avg_pool(output_layer, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID') 56 | 57 | return output_layer 58 | 59 | def compress(input_layer, is_train, keep_prob, name): 60 | channels=input_layer.get_shape().as_list()[-1] 61 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None) 62 | output_layer=tf.nn.relu(output_layer) 63 | filters=conv_var(kernel_size=(1,1), in_channels=channels, out_channels=channels/2, init_method='msra', name=name) 64 | output_layer=tf.nn.conv2d(output_layer, filters, [1, 1, 1, 1], padding='SAME') 65 | output_layer=tf.nn.dropout(output_layer, keep_prob) 66 | 67 | return output_layer 68 | 69 | 70 | def global_pool(input_layer, is_train): 71 | output_layer=tf.contrib.layers.batch_norm(input_layer, scale=True, is_training=is_train, updates_collections=None) 72 | output_layer=tf.nn.relu(output_layer) 73 | 74 | map_size=input_layer.get_shape().as_list()[1] 75 | return tf.nn.avg_pool(output_layer, [1, map_size, map_size, 1], [1, map_size, map_size, 1], 'VALID') 76 | 77 | def first_transit(input_layer, channels, strides, with_biase=False): 78 | filters=conv_var(kernel_size=(3,3), in_channels=3, out_channels=channels, init_method='msra', name='first_tran') 79 | conved=tf.nn.conv2d(input_layer, filters, [1, strides, strides, 1], padding='SAME') 80 | if with_biase==True: 81 | biases=bias_var(out_channels=channels) 82 | biased=tf.nn.bias_add(conved, biases) 83 | return biased 84 | return conved 85 | 86 | 87 | def loop_block(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1): 88 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2. 89 | channels=channels_per_layer 90 | node_0_channels=input_layer.get_shape().as_list()[-1] 91 | ## init param 92 | param_dict={} 93 | kernel_size=(1, 1) if if_b==True else (3, 3) 94 | for layer_id in range(1, layer_num): 95 | add_id=1 96 | while layer_id+add_id <= layer_num: 97 | 98 | ## -> 99 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id)) 100 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters 101 | ## <- 102 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id)) 103 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv 104 | add_id+=1 105 | 106 | for layer_id in range(layer_num): 107 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1)) 108 | param_dict[str(0)+'_'+str(layer_id+1)]=filters 109 | 110 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num 111 | 112 | ### bottleneck param ### 113 | if if_b==True: 114 | param_dict_B={} 115 | for layer_id in range(1, layer_num+1): 116 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id)) 117 | param_dict_B[str(layer_id)]=filters 118 | 119 | ## init blob 120 | blob_dict={} 121 | 122 | for layer_id in range(1, layer_num+1): 123 | bottom_blob=input_layer 124 | bottom_param=param_dict['0_'+str(layer_id)] 125 | for layer_id_id in range(1, layer_id): 126 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3) 127 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2) 128 | 129 | 130 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None) 131 | mid_layer=tf.nn.relu(mid_layer) 132 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') 133 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 134 | ## Bottle neck 135 | if if_b==True: 136 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 137 | next_layer=tf.nn.relu(next_layer) 138 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 139 | next_layer=tf.nn.dropout(next_layer, keep_prob) 140 | else: 141 | next_layer=mid_layer 142 | blob_dict[str(layer_id)]=next_layer 143 | 144 | ## begin loop 145 | for loop_id in range(loop_num): 146 | for layer_id in range(1, layer_num+1): ## [1,2,3,4,5] 147 | 148 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)] 149 | layer_list.remove(str(layer_id)) 150 | 151 | bottom_blobs=blob_dict[layer_list[0]] 152 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)] 153 | for bottom_id in range(len(layer_list)-1): 154 | bottom_blobs=tf.concat((bottom_blobs, blob_dict[layer_list[bottom_id+1]]), 155 | axis=3) ### concatenate the data blobs 156 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]), 157 | axis=2) ### concatenate the parameters 158 | 159 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None) 160 | mid_layer=tf.nn.relu(mid_layer) 161 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob 162 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 163 | ## Bottle neck 164 | if if_b==True: 165 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 166 | next_layer=tf.nn.relu(next_layer) 167 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 168 | next_layer=tf.nn.dropout(next_layer, keep_prob) 169 | else: 170 | next_layer=mid_layer 171 | blob_dict[str(layer_id)]=next_layer 172 | 173 | transit_feature=blob_dict['1'] 174 | for layer_id in range(2, layer_num+1): 175 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3) 176 | 177 | block_feature=tf.concat((input_layer, transit_feature), axis=3) 178 | 179 | return block_feature, transit_feature 180 | 181 | def loop_block_I_I(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name): 182 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2. 183 | channels=channels_per_layer 184 | node_0_channels=input_layer.get_shape().as_list()[-1] 185 | ## init param 186 | param_dict={} 187 | kernel_size=(1, 1) if if_b==True else (3, 3) 188 | for layer_id in range(1, layer_num): 189 | add_id=1 190 | while layer_id+add_id <= layer_num: 191 | 192 | ## -> 193 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id)) 194 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters 195 | ## <- 196 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id)) 197 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv 198 | add_id+=1 199 | 200 | for layer_id in range(layer_num): 201 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1)) 202 | param_dict[str(0)+'_'+str(layer_id+1)]=filters 203 | 204 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num 205 | 206 | ### bottleneck param ### 207 | if if_b==True: 208 | param_dict_B={} 209 | for layer_id in range(1, layer_num+1): 210 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id)) 211 | param_dict_B[str(layer_id)]=filters 212 | 213 | ## init blob 214 | blob_dict={} 215 | 216 | for layer_id in range(1, layer_num+1): 217 | bottom_blob=input_layer 218 | bottom_param=param_dict['0_'+str(layer_id)] 219 | for layer_id_id in range(1, layer_id): 220 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3) 221 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2) 222 | 223 | 224 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None) 225 | mid_layer=tf.nn.relu(mid_layer) 226 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') 227 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 228 | ## Bottle neck 229 | if if_b==True: 230 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 231 | next_layer=tf.nn.relu(next_layer) 232 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 233 | next_layer=tf.nn.dropout(next_layer, keep_prob) 234 | else: 235 | next_layer=mid_layer 236 | blob_dict[str(layer_id)]=next_layer 237 | 238 | ## no loop 239 | 240 | transit_feature=blob_dict['1'] 241 | for layer_id in range(2, layer_num+1): 242 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3) 243 | 244 | block_feature=tf.concat((input_layer, transit_feature), axis=3) 245 | 246 | return block_feature, transit_feature 247 | 248 | 249 | def loop_block_I_II(input_layer, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1): 250 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2. 251 | import copy 252 | 253 | channels=channels_per_layer 254 | node_0_channels=input_layer.get_shape().as_list()[-1] 255 | ## init param 256 | param_dict={} 257 | kernel_size=(1, 1) if if_b==True else (3, 3) 258 | for layer_id in range(1, layer_num): 259 | add_id=1 260 | while layer_id+add_id <= layer_num: 261 | 262 | ## -> 263 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id)) 264 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters 265 | ## <- 266 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id)) 267 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv 268 | add_id+=1 269 | 270 | for layer_id in range(layer_num): 271 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1)) 272 | param_dict[str(0)+'_'+str(layer_id+1)]=filters 273 | 274 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num 275 | 276 | ### bottleneck param ### 277 | if if_b==True: 278 | param_dict_B={} 279 | for layer_id in range(1, layer_num+1): 280 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id)) 281 | param_dict_B[str(layer_id)]=filters 282 | 283 | ## init blob 284 | blob_dict={} 285 | blob_dict_list=[] 286 | 287 | for layer_id in range(1, layer_num+1): 288 | bottom_blob=input_layer 289 | bottom_param=param_dict['0_'+str(layer_id)] 290 | for layer_id_id in range(1, layer_id): 291 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3) 292 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2) 293 | 294 | 295 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None) 296 | mid_layer=tf.nn.relu(mid_layer) 297 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') 298 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 299 | ## Bottle neck 300 | if if_b==True: 301 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 302 | next_layer=tf.nn.relu(next_layer) 303 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 304 | next_layer=tf.nn.dropout(next_layer, keep_prob) 305 | else: 306 | next_layer=mid_layer 307 | blob_dict[str(layer_id)]=next_layer 308 | 309 | blob_dict_list.append(blob_dict) 310 | 311 | ## begin loop 312 | for loop_id in range(loop_num): 313 | blob_dict_new = copy.copy(blob_dict_list[-1]) 314 | for layer_id in range(1, layer_num+1): ## [1,2,3,4,5] 315 | 316 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)] 317 | layer_list.remove(str(layer_id)) 318 | 319 | bottom_blobs=blob_dict_new[layer_list[0]] 320 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)] 321 | for bottom_id in range(len(layer_list)-1): 322 | bottom_blobs=tf.concat((bottom_blobs, blob_dict_new[layer_list[bottom_id+1]]), 323 | axis=3) ### concatenate the data blobs 324 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]), 325 | axis=2) ### concatenate the parameters 326 | 327 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None) 328 | mid_layer=tf.nn.relu(mid_layer) 329 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob 330 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 331 | ## Bottle neck 332 | if if_b==True: 333 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 334 | next_layer=tf.nn.relu(next_layer) 335 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 336 | next_layer=tf.nn.dropout(next_layer, keep_prob) 337 | else: 338 | next_layer=mid_layer 339 | blob_dict_new[str(layer_id)]=next_layer 340 | blob_dict_list.append(blob_dict_new) 341 | 342 | assert len(blob_dict_list)==1+loop_num 343 | 344 | stage_I = blob_dict_list[0]['1'] 345 | for layer_id in range(2, layer_num+1): 346 | stage_I=tf.concat((stage_I, blob_dict_list[0][str(layer_id)]), axis=3) 347 | 348 | stage_II = blob_dict_list[1]['1'] 349 | for layer_id in range(2, layer_num+1): 350 | stage_II=tf.concat((stage_II, blob_dict_list[1][str(layer_id)]), axis=3) 351 | 352 | block_feature = tf.concat((input_layer, stage_I), axis=3) 353 | transit_feature = stage_II 354 | 355 | return block_feature, transit_feature 356 | 357 | 358 | def loop_block_X(input_layer, x_value, if_b, channels_per_layer, layer_num, is_train, keep_prob, block_name, loop_num=1): 359 | if if_b: layer_num = layer_num/2 ## if bottleneck is used, the T value should be multiplied by 2. 360 | channels=channels_per_layer 361 | node_0_channels=input_layer.get_shape().as_list()[-1] 362 | ## init param 363 | param_dict={} 364 | kernel_size=(1, 1) if if_b==True else (3, 3) 365 | for layer_id in range(1, layer_num): 366 | add_id=1 367 | while layer_id+add_id <= layer_num: 368 | 369 | ## -> 370 | filters=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id)+'_'+str(layer_id+add_id)) 371 | param_dict[str(layer_id)+'_'+str(layer_id+add_id)]=filters 372 | ## <- 373 | filters_inv=conv_var(kernel_size=kernel_size, in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(layer_id+add_id)+'_'+str(layer_id)) 374 | param_dict[str(layer_id+add_id)+'_'+str(layer_id)]=filters_inv 375 | add_id+=1 376 | 377 | for layer_id in range(layer_num): 378 | filters=conv_var(kernel_size=kernel_size, in_channels=node_0_channels, out_channels=channels, init_method='msra', name=block_name+'-'+str(0)+'_'+str(layer_id+1)) 379 | param_dict[str(0)+'_'+str(layer_id+1)]=filters 380 | 381 | assert len(param_dict)==layer_num*(layer_num-1)+layer_num 382 | 383 | ### bottleneck param ### 384 | if if_b==True: 385 | param_dict_B={} 386 | for layer_id in range(1, layer_num+1): 387 | filters=conv_var(kernel_size=(3,3), in_channels=channels, out_channels=channels, init_method='msra', name=block_name+'-'+'to-'+str(layer_id)) 388 | param_dict_B[str(layer_id)]=filters 389 | 390 | ## init blob 391 | blob_dict={} 392 | 393 | for layer_id in range(1, layer_num+1): 394 | bottom_blob=input_layer 395 | bottom_param=param_dict['0_'+str(layer_id)] 396 | for layer_id_id in range(1, layer_id): 397 | bottom_blob=tf.concat((bottom_blob, blob_dict[str(layer_id_id)]), axis=3) 398 | bottom_param=tf.concat((bottom_param, param_dict[str(layer_id_id)+'_'+str(layer_id)]), axis=2) 399 | 400 | 401 | mid_layer=tf.contrib.layers.batch_norm(bottom_blob, scale=True, is_training=is_train, updates_collections=None) 402 | mid_layer=tf.nn.relu(mid_layer) 403 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') 404 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 405 | ## Bottle neck 406 | if if_b==True: 407 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 408 | next_layer=tf.nn.relu(next_layer) 409 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 410 | next_layer=tf.nn.dropout(next_layer, keep_prob) 411 | else: 412 | next_layer=mid_layer 413 | blob_dict[str(layer_id)]=next_layer 414 | 415 | ## begin loop 416 | for loop_id in range(loop_num): 417 | for layer_id in range(1, x_value+1): ## [1,2,3,4,5] 418 | 419 | layer_list=[str(l_id) for l_id in range(1, layer_num+1)] 420 | layer_list.remove(str(layer_id)) 421 | 422 | bottom_blobs=blob_dict[layer_list[0]] 423 | bottom_param=param_dict[layer_list[0]+'_'+str(layer_id)] 424 | for bottom_id in range(len(layer_list)-1): 425 | bottom_blobs=tf.concat((bottom_blobs, blob_dict[layer_list[bottom_id+1]]), 426 | axis=3) ### concatenate the data blobs 427 | bottom_param=tf.concat((bottom_param, param_dict[layer_list[bottom_id+1]+'_'+str(layer_id)]), 428 | axis=2) ### concatenate the parameters 429 | 430 | mid_layer=tf.contrib.layers.batch_norm(bottom_blobs, scale=True, is_training=is_train, updates_collections=None) 431 | mid_layer=tf.nn.relu(mid_layer) 432 | mid_layer=tf.nn.conv2d(mid_layer, bottom_param, [1,1,1,1], padding='SAME') ### update the data blob 433 | mid_layer=tf.nn.dropout(mid_layer, keep_prob) 434 | ## Bottle neck 435 | if if_b==True: 436 | next_layer=tf.contrib.layers.batch_norm(mid_layer, scale=True, is_training=is_train, updates_collections=None) 437 | next_layer=tf.nn.relu(next_layer) 438 | next_layer=tf.nn.conv2d(next_layer, param_dict_B[str(layer_id)], [1,1,1,1], padding='SAME') 439 | next_layer=tf.nn.dropout(next_layer, keep_prob) 440 | else: 441 | next_layer=mid_layer 442 | blob_dict[str(layer_id)]=next_layer 443 | 444 | transit_feature=blob_dict['1'] 445 | for layer_id in range(2, layer_num+1): 446 | transit_feature=tf.concat((transit_feature, blob_dict[str(layer_id)]), axis=3) 447 | 448 | block_feature=tf.concat((input_layer, transit_feature), axis=3) 449 | 450 | return block_feature, transit_feature 451 | --------------------------------------------------------------------------------