├── eval.sh ├── split.sh ├── .gitignore ├── group.sh ├── README.md ├── download_cifar100.py ├── cifar100.py ├── utils.py ├── resnet.py ├── eval.py ├── resnet_split.py ├── train_split.py └── train.py /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export CUDA_VISIBLE_DEVICES=0 3 | export LD_PRELOAD="/usr/lib/libtcmalloc.so" 4 | checkpoint="./split_2-2-2/model.ckpt-119999" 5 | basemodel="./group_2-2-2/model.ckpt-199999" 6 | output_file="./split_2-2-2/eval-119999.pkl" 7 | #data_dir="./cifar100/train_val_split" 8 | data_dir="/data1/dalgu/cifar100/train_val_split" 9 | 10 | python eval.py --checkpoint $checkpoint \ 11 | --basemodel $basemodel \ 12 | --output_file $output_file \ 13 | --data_dir $data_dir \ 14 | --batch_size 100 \ 15 | --test_iter 100 \ 16 | --num_residual_units 2 \ 17 | --k 8 \ 18 | --ngroups1 2 \ 19 | --ngroups2 2 \ 20 | --ngroups3 2 \ 21 | --gpu_fraction 0.96 \ 22 | --display 10 \ 23 | #--finetune True \ 24 | #--load_last_layers True \ 25 | -------------------------------------------------------------------------------- /split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export CUDA_VISIBLE_DEVICES=0 3 | export LD_PRELOAD="/usr/lib/libtcmalloc.so" 4 | train_dir="./split_2-2-2" 5 | #data_dir="./cifar100/train_val_split" 6 | data_dir="/data1/dalgu/cifar100/train_val_split" 7 | 8 | python train_split.py --train_dir $train_dir \ 9 | --data_dir $data_dir \ 10 | --batch_size 90 \ 11 | --test_interval 500 \ 12 | --test_iter 50 \ 13 | --num_residual_units 2 \ 14 | --k 8 \ 15 | --ngroups1 2 \ 16 | --ngroups2 2 \ 17 | --ngroups3 2 \ 18 | --l2_weight 0.0005 \ 19 | --initial_lr 0.1 \ 20 | --lr_step_epoch 100.0,140.0 \ 21 | --lr_decay 0.1 \ 22 | --max_steps 120000 \ 23 | --checkpoint_interval 5000 \ 24 | --gpu_fraction 0.96 \ 25 | --display 100 \ 26 | --basemodel "./group_2-2-2/model.ckpt-199999" \ 27 | #--checkpoint "./split_2-2-2/model.ckpt-40000" \ 28 | #--finetune True \ 29 | #--load_last_layers True \ 30 | 31 | 32 | # Finetune with Deep split 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## General 2 | 3 | # Compiled Object files 4 | *.slo 5 | *.lo 6 | *.o 7 | *.cuo 8 | 9 | # Compiled Dynamic libraries 10 | *.so 11 | *.dylib 12 | 13 | # Compiled Static libraries 14 | *.lai 15 | *.la 16 | *.a 17 | 18 | # Compiled protocol buffers 19 | *.pb.h 20 | *.pb.cc 21 | *_pb2.py 22 | 23 | # Compiled python 24 | *.pyc 25 | 26 | # Compiled MATLAB 27 | *.mex* 28 | 29 | # IPython notebook checkpoints 30 | .ipynb_checkpoints 31 | 32 | # Editor temporaries 33 | *.swp 34 | *~ 35 | 36 | # Sublime Text settings 37 | *.sublime-workspace 38 | *.sublime-project 39 | 40 | # Eclipse Project settings 41 | *.*project 42 | .settings 43 | 44 | # QtCreator files 45 | *.user 46 | 47 | # PyCharm files 48 | .idea 49 | 50 | # OSX dir files 51 | .DS_Store 52 | 53 | ## Project specific 54 | # CIFAR-100 dataset 55 | cifar100* 56 | cifar-100-binary* 57 | !cifar100.py 58 | 59 | # Training & evaluation shell files 60 | *.sh 61 | !group.sh 62 | !split.sh 63 | !eval.sh 64 | scripts/ 65 | 66 | # Trained checkpoints 67 | baseline*/ 68 | group*/ 69 | split*/ 70 | -------------------------------------------------------------------------------- /group.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export CUDA_VISIBLE_DEVICES=0 3 | # Your tcmalloc .so path 4 | export LD_PRELOAD="/usr/lib/libtcmalloc.so" 5 | train_dir="./group_2-2-2" 6 | # Our train/val split dataset 7 | # Run 'python download_cifar100.py' before training 8 | data_dir="./cifar100/train_val_split" 9 | 10 | python train.py --train_dir $train_dir \ 11 | --data_dir $data_dir \ 12 | --batch_size 90 \ 13 | --test_interval 500 \ 14 | --test_iter 50 \ 15 | --num_residual_units 2 \ 16 | --k 8 \ 17 | --ngroups1 2 \ 18 | --ngroups2 2 \ 19 | --ngroups3 2 \ 20 | --l2_weight 0.0001 \ 21 | --gamma1 1.0 \ 22 | --gamma2 1.0 \ 23 | --gamma3 10.0 \ 24 | --dropout_keep_prob 0.5 \ 25 | --initial_lr 0.1 \ 26 | --lr_step_epoch 240.0,300.0 \ 27 | --lr_decay 0.1 \ 28 | --bn_no_scale True \ 29 | --weighted_group_loss True \ 30 | --max_steps 200000 \ 31 | --checkpoint_interval 5000 \ 32 | --group_summary_interval 5000 \ 33 | --gpu_fraction 0.96 \ 34 | --display 100 \ 35 | #--checkpoint "./group_2-2-2/model.ckpt-149999" \ 36 | #--finetune True \ 37 | #--basemodel "./baseline/model.ckpt-449999" \ 38 | #--load_last_layers True \ 39 | 40 | 41 | # Deep split(2-2-2) 42 | # Dropout with prob 0.5 43 | # Softmax reparametrization & Training from scratch 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # splitnet-wrn 2 | 3 | TensorFlow implementation of SplitNet: Learning to Semantically Split Deep Networks for Parameter Reduction and Model Parallelization, ICML 2017 4 | 5 | ![SplitNet Concept](https://user-images.githubusercontent.com/13655756/27619160-8537abb6-5bfb-11e7-8854-2b5ee8be2312.png) 6 | 7 | - Juyong Kim\*(SNU), Yookoon Park\*(SNU), Gunhee Kim(SNU), and Sung Ju Hwang(UNIST) (\*: Equal contributions) 8 | 9 | We propose a novel deep neural network that is both lightweight and effectively structured for model parallelization. Our network, which we name as *SplitNet*, automatically learns to split the network weights into either a set or a hierarchy of multiple groups that use disjoint sets of features, by learning both the class-to-group and feature-to-group assignment matrices along with the network weights. This produces a tree-structured network that involves no connection between branched subtrees of semantically disparate class groups. SplitNet thus greatly reduces the number of parameters and required computations, and is also embarrassingly model-parallelizable at test time, since the evaluation for each subnetwork is completely independent except for the shared lower layer weights that can be duplicated over multiple processors, or assigned to a separate processor. We validate our method with two different deep network models (ResNet and AlexNet) on two datasets (CIFAR-100 and ILSVRC 2012) for image classification, on which our method obtains networks with significantly reduced number of parameters while achieving comparable or superior accuracies over original full deep networks, and accelerated test speed with multiple GPUs. 10 | 11 | ## Prerequisite 12 | 13 | 1. TensorFlow 14 | 2. Train/val/test split of CIFAR-100 dataset(please run `python download_cifar100.py`) 15 | 16 | ## How To Run 17 | 18 | ```shell 19 | # Clone the repo. 20 | git clone https://github.com/dalgu90/splitnet-wrn.git 21 | cd splitnet-wrn 22 | 23 | # Download CIFAR-100 dataset and split train set into train/val 24 | python download_cifar100.py 25 | 26 | # Find grouping of deep(2-2-2) split of WRN-16-8 27 | ./group.sh 28 | 29 | # Split and finetune 30 | ./split.sh 31 | 32 | # To evaluate 33 | ./eval.sh 34 | ``` 35 | 36 | ## Acknowledgement 37 | 38 | This work was supported by Samsung Research Funding Center of Samsung Electronics under Project Number SRFC-IT150203. 39 | 40 | 41 | ## Authors 42 | 43 | [Juyong Kim](http://juyongkim.com/)*1, Yookoon Park*1, [Gunhee Kim](http://www.cs.cmu.edu/~gunhee/)1, and [Sung Ju Hwang](http://www.sungjuhwang.com/)2 44 | 45 | 1: [Vision and Learning Lab](http://vision.snu.ac.kr/) @ Computer Science and Engineering, Seoul National University, Seoul, Korea 46 | 2: [MLVR Lab](http://ml.unist.ac.kr/) @ School of Electrical and Computer Engineering, UNIST, Ulsan, South Korea 47 | \*: Equal contribution 48 | 49 | 50 | ## License 51 | ``` 52 | MIT license 53 | ``` 54 | If you find any problem, please feel free to contact to the authors. :^) 55 | -------------------------------------------------------------------------------- /download_cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import tarfile 5 | import cPickle as pickle 6 | import numpy as np 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dataset_path", help="The directory to save splited CIFAR-100 dataset") 11 | args = parser.parse_args() 12 | 13 | # CIFAR-100 download parameters 14 | cifar_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' 15 | cifar_dpath = 'cifar100' 16 | cifar_py_name = 'cifar-100-python' 17 | cifar_fname = cifar_py_name + '.tar.gz' 18 | 19 | # CIFAR-100 dataset train/val split parameters 20 | dataset_path = 'cifar100/train_val_split' if not args.dataset_path else args.dataset_path 21 | dataset_path = os.path.expanduser(dataset_path) 22 | num_train_instance = 45000 23 | num_val_instance = 50000 - num_train_instance 24 | num_test_instance = 10000 25 | 26 | def download_file(url, path): 27 | import urllib2 28 | file_name = url.split('/')[-1] 29 | u = urllib2.urlopen(url) 30 | f = open(os.path.join(path, file_name), 'wb') 31 | meta = u.info() 32 | file_size = int(meta.getheaders("Content-Length")[0]) 33 | print "Downloading: %s Bytes: %s" % (file_name, file_size) 34 | 35 | download_size = 0 36 | block_size = 8192 37 | while True: 38 | buf = u.read(block_size) 39 | if not buf: 40 | break 41 | download_size += len(buf) 42 | f.write(buf) 43 | status = "\r%12d [%3.2f%%]" % (download_size, download_size * 100. / file_size) 44 | print status, 45 | sys.stdout.flush() 46 | f.close() 47 | 48 | # Check if the dataset split already exists 49 | if os.path.exists(dataset_path) and os.path.exists(os.path.join(dataset_path, 'train')): 50 | print('CIFAR-100 train/val split exists\nNothing to be done... Quit!') 51 | sys.exit(0) 52 | 53 | # Download and extract CIFAR-100 54 | if not os.path.exists(os.path.join(cifar_dpath, cifar_py_name)) \ 55 | or not os.path.exists(os.path.join(cifar_dpath, cifar_py_name, 'train')): 56 | print('Downloading CIFAR-100') 57 | if not os.path.exists(cifar_dpath): 58 | os.makedirs(cifar_dpath) 59 | tar_fpath = os.path.join(cifar_dpath, cifar_fname) 60 | if not os.path.exists(tar_fpath) or os.path.getsize(tar_fpath) != 169001437: 61 | download_file(cifar_url, cifar_dpath) 62 | print('Extracting CIFAR-100') 63 | with tarfile.open(tar_fpath) as tar: 64 | tar.extractall(path=cifar_dpath) 65 | 66 | # Load the dataset and split 67 | print('Load CIFAR-100 dataset') 68 | with open(os.path.join(cifar_dpath, cifar_py_name, 'train')) as fd: 69 | train_orig = pickle.load(fd) 70 | train_orig_data = train_orig['data'] 71 | train_orig_label = np.array(train_orig['fine_labels'], dtype=np.uint8) 72 | with open(os.path.join(cifar_dpath, cifar_py_name, 'test')) as fd: 73 | test_orig = pickle.load(fd) 74 | test_orig_data = test_orig['data'] 75 | test_orig_label = np.array(test_orig['fine_labels'], dtype=np.uint8) 76 | 77 | # Split the dataset 78 | print('Split the dataset') 79 | train_val_idxs = range(50000) 80 | random.shuffle(train_val_idxs) 81 | train = {'data': train_orig_data[train_val_idxs[:num_train_instance], :], 82 | 'labels': train_orig_label[train_val_idxs[:num_train_instance]]} 83 | val = {'data': train_orig_data[train_val_idxs[num_train_instance:] ,:], 84 | 'labels': train_orig_label[train_val_idxs[num_train_instance:]]} 85 | train_val = {'data':train_orig_data, 'labels':train_orig_label} 86 | test = {'data':test_orig_data, 'labels':test_orig_label} 87 | 88 | # Save the dataset split 89 | print('Save the dataset split') 90 | if not os.path.exists(dataset_path): 91 | os.makedirs(dataset_path) 92 | for name, data in zip(['train', 'val', 'train_val', 'test'], [train, val, train_val, test]): 93 | print('[%s] ' % name + ', '.join(['%s: %s' % (k, str(v.shape)) for k, v in data.iteritems()])) 94 | with open(os.path.join(dataset_path, name), 'wb') as fd: 95 | pickle.dump(data, fd) 96 | 97 | print 'done' 98 | -------------------------------------------------------------------------------- /cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import random 7 | import threading 8 | import cPickle as pickle 9 | import numpy as np 10 | import skimage.util 11 | 12 | import tensorflow as tf 13 | 14 | # Process images of this size. Note that this differs from the original CIFAR 15 | # image size of 32 x 32. If one alters this number, then the entire model 16 | # architecture will change and any model would need to be retrained. 17 | IMAGE_SIZE = 32 18 | 19 | # Global constants describing the CIFAR-100 data set. 20 | NUM_CLASSES = 100 21 | 22 | 23 | class ThreadsafeIter: 24 | """Takes an iterator/generator and makes it thread-safe by 25 | serializing call to the `next` method of given iterator/generator. 26 | """ 27 | def __init__(self, it): 28 | self.it = it 29 | self.lock = threading.Lock() 30 | 31 | def __iter__(self): 32 | return self 33 | 34 | def next(self): 35 | with self.lock: 36 | return self.it.next() 37 | 38 | 39 | class CIFAR100Runner(object): 40 | _image_summary_added = False 41 | """ 42 | This class manages the the background threads needed to fill 43 | a queue full of data. 44 | """ 45 | def __init__(self, pkl_path, shuffle=False, distort=False, 46 | capacity=2000, image_per_thread=16): 47 | self._shuffle = shuffle 48 | self._distort = distort 49 | with open(pkl_path, 'rb') as fd: 50 | data = pickle.load(fd) 51 | self._images = data['data'].reshape([-1, 3, 32, 32]).transpose((0, 2, 3, 1)).copy(order='C') 52 | self._labels = data['labels'] # numpy 1-D array 53 | self.size = len(self._labels) 54 | 55 | self.queue = tf.FIFOQueue(shapes=[[32,32,3], []], 56 | dtypes=[tf.float32, tf.int32], 57 | capacity=capacity) 58 | # self.queue = tf.RandomShuffleQueue(shapes=[[32,32,3], []], 59 | # dtypes=[tf.float32, tf.int32], 60 | # capacity=capacity, 61 | # min_after_dequeue=min_after_dequeue) 62 | self.dataX = tf.placeholder(dtype=tf.float32, shape=[None,32,32,3]) 63 | self.dataY = tf.placeholder(dtype=tf.int32, shape=[None,]) 64 | self.enqueue_op = self.queue.enqueue_many([self.dataX, self.dataY]) 65 | self.image_per_thread = image_per_thread 66 | 67 | self._image_summary_added = False 68 | 69 | 70 | def _preprocess_image(self, input_image): 71 | """Preprocess a single image by crop and whitening(and augmenting if needed). 72 | 73 | Args: 74 | input_image: An image. 3D tensor of [height, width, channel] size. 75 | 76 | Returns: 77 | output_image: Preprocessed image. 3D tensor of size same as input_image.gj 78 | """ 79 | # Crop 80 | image = input_image 81 | if self._distort: 82 | image = skimage.util.pad(image, ((4,4), (4,4), (0,0)), 'reflect') 83 | crop_h = image.shape[0] - 32 84 | crop_h_before = random.choice(range(crop_h)) 85 | crop_h_after = crop_h - crop_h_before 86 | crop_w = image.shape[1] - 32 87 | crop_w_before = random.choice(range(crop_w)) 88 | crop_w_after = crop_w - crop_w_before 89 | image = skimage.util.crop(image, ((crop_h_before, crop_h_after), (crop_w_before, crop_w_after), (0, 0))) 90 | else: 91 | crop_h = image.shape[0] - 32 92 | crop_w = image.shape[1] - 32 93 | if crop_w != 0 or crop_h != 0: 94 | image = skimage.util.crop(image, ((crop_h/2, (crop_h+1)/2), (crop_w/2, (crop_w+1)/2), (0, 0))) 95 | 96 | # Random horizontal flip 97 | if self._distort: 98 | if random.choice(range(2)) == 1: 99 | for i in range(image.shape[2]): 100 | image[:,:,i] = np.fliplr(image[:,:,i]) 101 | 102 | # Image whitening 103 | mean = np.mean(image, axis=(0,1), dtype=np.float32) 104 | std = np.std(image, axis=(0,1), dtype=np.float32) 105 | output_image = (image - mean) / std 106 | 107 | return output_image 108 | 109 | def _preprocess_images(self, input_images): 110 | output_images = np.zeros_like(input_images, dtype=np.float32) 111 | for i in range(output_images.shape[0]): 112 | output_images[i] = self._preprocess_image(input_images[i]) 113 | 114 | return output_images 115 | 116 | def get_inputs(self, batch_size): 117 | """ 118 | Return's tensors containing a batch of images and labels 119 | """ 120 | images_batch, labels_batch = self.queue.dequeue_many(batch_size) 121 | if not CIFAR100Runner._image_summary_added: 122 | tf.summary.image('images', images_batch) 123 | CIFAR100Runner._image_summary_added = True 124 | 125 | return images_batch, labels_batch 126 | 127 | def data_iterator(self): 128 | idxs_idx = 0 129 | idxs = np.arange(0, self.size) 130 | if self._shuffle: 131 | random.shuffle(idxs) 132 | 133 | while True: 134 | images_batch = [] 135 | labels_batch = [] 136 | batch_cnt = 0 137 | while True: 138 | if idxs_idx + (self.image_per_thread - batch_cnt) < self.size: 139 | temp_cnt = self.image_per_thread - batch_cnt 140 | else: 141 | temp_cnt = self.size - idxs_idx 142 | 143 | images_batch.extend(self._images[idxs[idxs_idx:idxs_idx+temp_cnt]]) 144 | labels_batch.extend(self._labels[idxs[idxs_idx:idxs_idx+temp_cnt]]) 145 | idxs_idx += temp_cnt 146 | batch_cnt += temp_cnt 147 | 148 | if idxs_idx == self.size: 149 | idxs_idx = 0 150 | if self._shuffle: 151 | random.shuffle(idxs) 152 | 153 | if batch_cnt == self.image_per_thread: 154 | break 155 | yield images_batch, labels_batch 156 | 157 | def thread_main(self, sess, iterator): 158 | """ 159 | Function run on alternate thread. Basically, keep adding data to the queue. 160 | """ 161 | while True: 162 | images_val, labels_val = iterator.next() 163 | process_images_val = self._preprocess_images(images_val) 164 | sess.run(self.enqueue_op, feed_dict={self.dataX:process_images_val, self.dataY:labels_val}) 165 | 166 | def start_threads(self, sess, n_threads=1): 167 | """ Start background threads to feed queue """ 168 | iterator = ThreadsafeIter(self.data_iterator()) 169 | threads = [] 170 | for n in range(n_threads): 171 | t = threading.Thread(target=self.thread_main, args=(sess,iterator,)) 172 | t.daemon = True # thread will close when parent quits 173 | t.start() 174 | threads.append(t) 175 | return threads 176 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | ## TensorFlow helper functions 5 | 6 | WEIGHT_DECAY_KEY = 'WEIGHT_DECAY' 7 | 8 | def _relu(x, leakness=0.0, name=None): 9 | if leakness > 0.0: 10 | name = 'lrelu' if name is None else name 11 | return tf.maximum(x, x*leakness, name='lrelu') 12 | else: 13 | name = 'relu' if name is None else name 14 | return tf.nn.relu(x, name='relu') 15 | 16 | 17 | def _dropout(x, keep_prob=1.0, name=None): 18 | assert keep_prob >= 0.0 and keep_prob <= 1.0 19 | if keep_prob == 1.0: 20 | return x 21 | else: 22 | return tf.nn.dropout(x, keep_prob, name=name) 23 | 24 | 25 | def _conv(x, filter_size, out_channel, strides, pad='SAME', input_q=None, output_q=None, name='conv'): 26 | if (input_q == None)^(output_q == None): 27 | raise ValueError('Input/Output splits are not correctly given.') 28 | 29 | in_shape = x.get_shape().as_list() 30 | with tf.variable_scope(name): 31 | # Main operation: conv2d 32 | kernel = tf.get_variable('kernel', [filter_size, filter_size, in_shape[3], out_channel], 33 | tf.float32, initializer=tf.random_normal_initializer( 34 | stddev=np.sqrt(1.0/filter_size/filter_size/in_shape[3]))) 35 | if kernel not in tf.get_collection(WEIGHT_DECAY_KEY): 36 | tf.add_to_collection(WEIGHT_DECAY_KEY, kernel) 37 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (kernel.name, str(kernel.get_shape().as_list()))) 38 | conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1], pad) 39 | 40 | # Split and split loss 41 | if (input_q is not None) and (output_q is not None): 42 | _add_split_loss(kernel, input_q, output_q) 43 | 44 | return conv 45 | 46 | 47 | def _conv_with_init(x, filter_size, out_channel, strides, pad='SAME', init_k=None, name='conv'): 48 | in_shape = x.get_shape().as_list() 49 | with tf.variable_scope(name): 50 | # Main operation: conv2d 51 | if init_k is not None: 52 | initializer_k = tf.constant_initializer(init_k) 53 | else: 54 | initializer_k =tf.random_normal_initializer(stddev=np.sqrt(1.0/filter_size/filter_size/in_shape[3])) 55 | kernel = tf.get_variable('kernel', [filter_size, filter_size, in_shape[3], out_channel], 56 | tf.float32, initializer=initializer_k) 57 | if kernel not in tf.get_collection(WEIGHT_DECAY_KEY): 58 | tf.add_to_collection(WEIGHT_DECAY_KEY, kernel) 59 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (kernel.name, str(kernel.get_shape().as_list()))) 60 | conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1], pad) 61 | 62 | return conv 63 | 64 | 65 | def _fc(x, out_dim, input_q=None, output_q=None, name='fc'): 66 | if (input_q == None)^(output_q == None): 67 | raise ValueError('Input/Output splits are not correctly given.') 68 | 69 | with tf.variable_scope(name): 70 | # Main operation: fc 71 | w = tf.get_variable('weights', [x.get_shape()[1], out_dim], 72 | tf.float32, initializer=tf.random_normal_initializer( 73 | stddev=np.sqrt(1.0/x.get_shape().as_list()[1]))) 74 | if w not in tf.get_collection(WEIGHT_DECAY_KEY): 75 | tf.add_to_collection(WEIGHT_DECAY_KEY, w) 76 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (w.name, str(w.get_shape().as_list()))) 77 | b = tf.get_variable('biases', [out_dim], tf.float32, 78 | initializer=tf.constant_initializer(0.0)) 79 | fc = tf.nn.bias_add(tf.matmul(x, w), b) 80 | 81 | # Split loss 82 | if (input_q is not None) and (output_q is not None): 83 | _add_split_loss(w, input_q, output_q) 84 | 85 | return fc 86 | 87 | 88 | def _fc_with_init(x, out_dim, init_w=None, init_b=None, name='fc'): 89 | with tf.variable_scope(name): 90 | # Main operation: fc 91 | if init_w is not None: 92 | initializer_w = tf.constant_initializer(init_w) 93 | else: 94 | initializer_w = tf.random_normal_initializer(stddev=np.sqrt(1.0/x.get_shape().as_list()[1])) 95 | if init_b is not None: 96 | initializer_b = tf.constant_initializer(init_b) 97 | else: 98 | initializer_b = tf.constant_initializer(0.0) 99 | 100 | w = tf.get_variable('weights', [x.get_shape()[1], out_dim], 101 | tf.float32, initializer=initializer_w) 102 | b = tf.get_variable('biases', [out_dim], tf.float32, 103 | initializer=initializer_b) 104 | if w not in tf.get_collection(WEIGHT_DECAY_KEY): 105 | tf.add_to_collection(WEIGHT_DECAY_KEY, w) 106 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (w.name, str(w.get_shape().as_list()))) 107 | fc = tf.nn.bias_add(tf.matmul(x, w), b) 108 | 109 | return fc 110 | 111 | 112 | def _get_split_q(ngroups, dim, name='split'): 113 | with tf.variable_scope(name): 114 | alpha = tf.get_variable('alpha', shape=[ngroups, dim], dtype=tf.float32, 115 | initializer=tf.random_normal_initializer(stddev=0.01)) 116 | q = tf.nn.softmax(alpha, dim=0, name='q') 117 | 118 | return q 119 | 120 | def _merge_split_q(q, merge_idxs, name='merge'): 121 | assert len(q.get_shape()) == 2 122 | ngroups, dim = q.get_shape().as_list() 123 | assert ngroups == len(merge_idxs) 124 | 125 | with tf.variable_scope(name): 126 | max_idx = np.max(merge_idxs) 127 | temp_list = [] 128 | for i in range(max_idx + 1): 129 | temp = [] 130 | for j in range(ngroups): 131 | if merge_idxs[j] == i: 132 | temp.append(tf.slice(q, [j, 0], [1, dim])) 133 | temp_list.append(tf.add_n(temp)) 134 | ret = tf.concat(temp_list, 0) 135 | 136 | return ret 137 | 138 | 139 | def _get_even_merge_idxs(N, split): 140 | assert N >= split 141 | num_elems = [(N + split - i - 1)/split for i in range(split)] 142 | expand_split = [[i] * n for i, n in enumerate(num_elems)] 143 | return [t for l in expand_split for t in l] 144 | 145 | 146 | def _add_split_loss(w, input_q, output_q): 147 | # Check input tensors' measurements 148 | assert len(w.get_shape()) == 2 or len(w.get_shape()) == 4 149 | in_dim, out_dim = w.get_shape().as_list()[-2:] 150 | assert len(input_q.get_shape()) == 2 151 | assert len(output_q.get_shape()) == 2 152 | assert in_dim == input_q.get_shape().as_list()[1] 153 | assert out_dim == output_q.get_shape().as_list()[1] 154 | assert input_q.get_shape().as_list()[0] == output_q.get_shape().as_list()[0] # ngroups 155 | ngroups = input_q.get_shape().as_list()[0] 156 | assert ngroups > 1 157 | 158 | # Add split losses to collections 159 | T_list = [] 160 | U_list = [] 161 | if input_q not in tf.get_collection('OVERLAP_LOSS_WEIGHTS') \ 162 | and not "concat" in input_q.op.name: 163 | tf.add_to_collection('OVERLAP_LOSS_WEIGHTS', input_q) 164 | print('\t\tAdd overlap & split loss for %s' % input_q.name) 165 | T_temp, U_temp = ([], []) 166 | for i in range(ngroups): 167 | for j in range(ngroups): 168 | if i <= j: 169 | continue 170 | T_temp.append(tf.reduce_sum(input_q[i,:] * input_q[j,:])) 171 | U_temp.append(tf.square(tf.reduce_sum(input_q[i,:]))) 172 | T_list.append(tf.reduce_sum(T_temp)/(float(in_dim*(ngroups-1))/float(2*ngroups))) 173 | U_list.append(tf.reduce_sum(U_temp)/(float(in_dim*in_dim)/float(ngroups))) 174 | if output_q not in tf.get_collection('OVERLAP_LOSS_WEIGHTS') \ 175 | and not "concat" in output_q.op.name: 176 | print('\t\tAdd overlap & split loss for %s' % output_q.name) 177 | tf.add_to_collection('OVERLAP_LOSS_WEIGHTS', output_q) 178 | T_temp, U_temp = ([], []) 179 | for i in range(ngroups): 180 | for j in range(ngroups): 181 | if i <= j: 182 | continue 183 | T_temp.append(tf.reduce_sum(output_q[i,:] * output_q[j,:])) 184 | U_temp.append(tf.square(tf.reduce_sum(output_q[i,:]))) 185 | T_list.append(tf.reduce_sum(T_temp)/(float(out_dim*(ngroups-1))/float(2*ngroups))) 186 | U_list.append(tf.reduce_sum(U_temp)/(float(out_dim*out_dim)/float(ngroups))) 187 | if T_list: 188 | tf.add_to_collection('OVERLAP_LOSS', tf.add_n(T_list)/len(T_list)) 189 | if U_list: 190 | tf.add_to_collection('UNIFORM_LOSS', tf.add_n(U_list)/len(U_list)) 191 | 192 | S_list = [] 193 | if w not in tf.get_collection('WEIGHT_SPLIT_WEIGHTS'): 194 | tf.add_to_collection('WEIGHT_SPLIT_WEIGHTS', w) 195 | 196 | ones_col = tf.ones((in_dim,), dtype=tf.float32) 197 | ones_row = tf.ones((out_dim,), dtype=tf.float32) 198 | if len(w.get_shape()) == 4: 199 | w_reduce = tf.reduce_mean(tf.square(w), [0, 1]) 200 | w_norm = w_reduce 201 | std_dev = np.sqrt(1.0/float(w.get_shape().as_list()[0])**2/in_dim) 202 | # w_norm = w_reduce / tf.reduce_sum(w_reduce) 203 | else: 204 | w_norm = w 205 | std_dev = np.sqrt(1.0/float(in_dim)) 206 | # w_norm = w / tf.sqrt(tf.reduce_sum(tf.square(w))) 207 | 208 | for i in range(ngroups): 209 | if len(w.get_shape()) == 4: 210 | wg_row = tf.transpose(tf.transpose(w_norm * tf.square(output_q[i,:])) * tf.square(ones_col - input_q[i,:])) 211 | wg_row_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_row, 1))) / (in_dim*np.sqrt(out_dim)) 212 | wg_col = tf.transpose(tf.transpose(w_norm * tf.square(ones_row - output_q[i,:])) * tf.square(input_q[i,:])) 213 | wg_col_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_col, 0))) / (np.sqrt(in_dim)*out_dim) 214 | else: # len(w.get_shape()) == 2 215 | wg_row = tf.transpose(tf.transpose(w_norm * output_q[i,:]) * (ones_col - input_q[i,:])) 216 | wg_row_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_row * wg_row, 1))) / (in_dim*np.sqrt(out_dim)) 217 | wg_col = tf.transpose(tf.transpose(w_norm * (ones_row - output_q[i,:])) * input_q[i,:]) 218 | wg_col_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_col * wg_col, 0))) / (np.sqrt(in_dim)*out_dim) 219 | S_list.append(wg_row_l2 + wg_col_l2) 220 | # S = tf.add_n(S_list)/((ngroups-1)/ngroups) 221 | S = tf.add_n(S_list)/(2*(ngroups-1)*std_dev/ngroups) 222 | tf.add_to_collection('WEIGHT_SPLIT', S) 223 | 224 | # Add histogram for w if split losses are added 225 | scope_name = tf.get_variable_scope().name 226 | tf.summary.histogram("%s/" % scope_name, w) 227 | print('\t\tAdd split loss for %s(%dx%d, %d groups)' \ 228 | % (tf.get_variable_scope().name, in_dim, out_dim, ngroups)) 229 | 230 | return 231 | 232 | 233 | def _bn(x, is_train, global_step=None, name='bn', no_scale=False): 234 | moving_average_decay = 0.9 235 | # moving_average_decay = 0.99 236 | # moving_average_decay_init = 0.99 237 | with tf.variable_scope(name): 238 | decay = moving_average_decay 239 | # if global_step is None: 240 | # decay = moving_average_decay 241 | # else: 242 | # decay = tf.cond(tf.greater(global_step, 100) 243 | # , lambda: tf.constant(moving_average_decay, tf.float32) 244 | # , lambda: tf.constant(moving_average_decay_init, tf.float32)) 245 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 246 | mu = tf.get_variable('mu', batch_mean.get_shape(), tf.float32, 247 | initializer=tf.zeros_initializer(), trainable=False) 248 | sigma = tf.get_variable('sigma', batch_var.get_shape(), tf.float32, 249 | initializer=tf.ones_initializer(), trainable=False) 250 | beta = tf.get_variable('beta', batch_mean.get_shape(), tf.float32, 251 | initializer=tf.zeros_initializer()) 252 | gamma = tf.get_variable('gamma', batch_var.get_shape(), tf.float32, 253 | initializer=tf.ones_initializer(), trainable=(not no_scale)) 254 | # BN when training 255 | update = 1.0 - decay 256 | # with tf.control_dependencies([tf.Print(decay, [decay])]): 257 | # update_mu = mu.assign_sub(update*(mu - batch_mean)) 258 | update_mu = mu.assign_sub(update*(mu - batch_mean)) 259 | update_sigma = sigma.assign_sub(update*(sigma - batch_var)) 260 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mu) 261 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_sigma) 262 | 263 | mean, var = tf.cond(is_train, lambda: (batch_mean, batch_var), 264 | lambda: (mu, sigma)) 265 | bn = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5) 266 | 267 | # bn = tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, 1e-5) 268 | 269 | # bn = tf.contrib.layers.batch_norm(inputs=x, decay=decay, 270 | # updates_collections=[tf.GraphKeys.UPDATE_OPS], center=True, 271 | # scale=True, epsilon=1e-5, is_training=is_train, 272 | # trainable=True) 273 | return bn 274 | 275 | 276 | ## Other helper functions 277 | 278 | 279 | 280 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import namedtuple 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | import utils 9 | 10 | 11 | HParams = namedtuple('HParams', 12 | 'batch_size, num_classes, num_residual_units, k, weight_decay, momentum, finetune, ' 13 | 'ngroups1, ngroups2, ngroups3, gamma1, gamma2, gamma3, ' 14 | 'dropout_keep_prob, bn_no_scale, weighted_group_loss') 15 | 16 | class ResNet(object): 17 | def __init__(self, hp, images, labels, global_step): 18 | self._hp = hp # Hyperparameters 19 | self._images = images # Input image 20 | self._labels = labels 21 | self._global_step = global_step 22 | self.lr = tf.placeholder(tf.float32) 23 | self.is_train = tf.placeholder(tf.bool) 24 | self._counted_scope = [] 25 | self._flops = 0 26 | self._weights = 0 27 | 28 | def build_model(self): 29 | print('Building model') 30 | filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] 31 | strides = [1, 2, 2] 32 | 33 | with tf.variable_scope("group"): 34 | if self._hp.ngroups1 > 1: 35 | self.split_q1 = utils._get_split_q(self._hp.ngroups1, self._hp.num_classes, name='split_q1') 36 | self.split_p1 = utils._get_split_q(self._hp.ngroups1, filters[3], name='split_p1') 37 | tf.summary.histogram("group/split_p1/", self.split_p1) 38 | tf.summary.histogram("group/split_q1/", self.split_q1) 39 | else: 40 | self.split_q1 = None 41 | self.split_p1 = None 42 | 43 | if self._hp.ngroups2 > 1: 44 | self.split_q2 = utils._merge_split_q(self.split_p1, utils._get_even_merge_idxs(self._hp.ngroups1, self._hp.ngroups2), name='split_q2') 45 | self.split_p2 = utils._get_split_q(self._hp.ngroups2, filters[2], name='split_p2') 46 | self.split_r21 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r21') 47 | self.split_r22 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r22') 48 | tf.summary.histogram("group/split_q2/", self.split_q2) 49 | tf.summary.histogram("group/split_p2/", self.split_p2) 50 | tf.summary.histogram("group/split_r21/", self.split_r21) 51 | tf.summary.histogram("group/split_r22/", self.split_r22) 52 | else: 53 | self.split_p2 = None 54 | self.split_q2 = None 55 | self.split_r21 = None 56 | self.split_r22 = None 57 | 58 | if self._hp.ngroups3 > 1: 59 | self.split_q3 = utils._merge_split_q(self.split_p2, utils._get_even_merge_idxs(self._hp.ngroups2, self._hp.ngroups3), name='split_q3') 60 | self.split_p3 = utils._get_split_q(self._hp.ngroups3, filters[1], name='split_p3') 61 | self.split_r31 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r31') 62 | self.split_r32 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r32') 63 | tf.summary.histogram("group/split_q3/", self.split_q3) 64 | tf.summary.histogram("group/split_p3/", self.split_p3) 65 | tf.summary.histogram("group/split_r31/", self.split_r31) 66 | tf.summary.histogram("group/split_r32/", self.split_r32) 67 | else: 68 | self.split_p3 = None 69 | self.split_q3 = None 70 | self.split_r31 = None 71 | self.split_r32 = None 72 | 73 | # Init. conv. 74 | print('\tBuilding unit: init_conv') 75 | x = utils._conv(self._images, 3, filters[0], 1, name='init_conv') 76 | 77 | x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0') 78 | x = self._residual_block(x, name='unit_1_1') 79 | 80 | x = self._residual_block_first(x, filters[2], strides[1], input_q=self.split_p3, output_q=self.split_q3, split_r=self.split_r31, name='unit_2_0') 81 | x = self._residual_block(x, split_q=self.split_q3, split_r=self.split_r32, name='unit_2_1') 82 | 83 | x = self._residual_block_first(x, filters[3], strides[2], input_q=self.split_p2, output_q=self.split_q2, split_r=self.split_r21, name='unit_3_0') 84 | x = self._residual_block(x, split_q=self.split_q2, split_r=self.split_r22, name='unit_3_1') 85 | 86 | # Last unit 87 | with tf.variable_scope('unit_last') as scope: 88 | print('\tBuilding unit: %s' % scope.name) 89 | x = utils._bn(x, self.is_train, self._global_step) 90 | x = utils._relu(x) 91 | x = tf.reduce_mean(x, [1, 2]) 92 | 93 | # Logit 94 | with tf.variable_scope('logits') as scope: 95 | print('\tBuilding unit: %s' % scope.name) 96 | x_shape = x.get_shape().as_list() 97 | x = tf.reshape(x, [-1, x_shape[1]]) 98 | if self.split_p1 is not None and self.split_q1 is not None: 99 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout') 100 | x = self._fc(x, self._hp.num_classes, input_q=self.split_p1, output_q=self.split_q1) 101 | 102 | self._logits = x 103 | 104 | # Probs & preds & acc 105 | self.probs = tf.nn.softmax(x, name='probs') 106 | self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) 107 | ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) 108 | zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) 109 | correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) 110 | self.acc = tf.reduce_mean(correct, name='acc') 111 | tf.summary.scalar('accuracy', self.acc) 112 | 113 | # Loss & acc 114 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=x, labels=self._labels) 115 | self.loss = tf.reduce_mean(loss) 116 | tf.summary.scalar('cross_entropy', self.loss) 117 | 118 | 119 | def _residual_block_first(self, x, out_channel, strides, input_q=None, output_q=None, split_r=None, name="unit"): 120 | in_channel = x.get_shape().as_list()[-1] 121 | with tf.variable_scope(name) as scope: 122 | print('\tBuilding residual unit: %s' % scope.name) 123 | x = self._bn(x, name='bn_1', no_scale=self._hp.bn_no_scale) 124 | x = self._relu(x, name='relu_1') 125 | if input_q is not None and output_q is not None and split_r is not None: 126 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1') 127 | # Shortcut connection 128 | if in_channel == out_channel: 129 | if strides == 1: 130 | shortcut = tf.identity(x) 131 | else: 132 | shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID') 133 | else: 134 | shortcut = self._conv(x, 1, out_channel, strides, input_q=input_q, output_q=output_q, name='shortcut') 135 | # Residual 136 | x = self._conv(x, 3, out_channel, strides, input_q=input_q, output_q=split_r, name='conv_1') 137 | x = self._bn(x, name='bn_2', no_scale=self._hp.bn_no_scale) 138 | x = self._relu(x, name='relu_2') 139 | if input_q is not None and output_q is not None and split_r is not None: 140 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2') 141 | x = self._conv(x, 3, out_channel, 1, input_q=split_r, output_q=output_q, name='conv_2') 142 | # Merge 143 | x = x + shortcut 144 | return x 145 | 146 | def _residual_block(self, x, split_q=None, split_r=None, name="unit"): 147 | num_channel = x.get_shape().as_list()[-1] 148 | with tf.variable_scope(name) as scope: 149 | print('\tBuilding residual unit: %s' % scope.name) 150 | # Shortcut connection 151 | shortcut = x 152 | # Residual 153 | x = self._bn(x, name='bn_1', no_scale=self._hp.bn_no_scale) 154 | x = self._relu(x, name='relu_1') 155 | if split_q is not None and split_r is not None: 156 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1') 157 | x = self._conv(x, 3, num_channel, 1, input_q=split_q, output_q=split_r, name='conv_1') 158 | x = self._bn(x, name='bn_2', no_scale=self._hp.bn_no_scale) 159 | x = self._relu(x, name='relu_2') 160 | if split_q is not None and split_r is not None: 161 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2') 162 | x = self._conv(x, 3, num_channel, 1, input_q=split_r, output_q=split_q, name='conv_2') 163 | # Merge 164 | x = x + shortcut 165 | return x 166 | 167 | def build_train_op(self): 168 | # Learning rate 169 | tf.summary.scalar('learing_rate', self.lr) 170 | 171 | losses = [self.loss] 172 | 173 | # Add l2 loss 174 | with tf.variable_scope('l2_loss'): 175 | costs = [tf.nn.l2_loss(var) for var in tf.get_collection(utils.WEIGHT_DECAY_KEY)] 176 | l2_loss = tf.multiply(self._hp.weight_decay, tf.add_n(costs)) 177 | losses.append(l2_loss) 178 | 179 | # Add group split loss 180 | with tf.variable_scope('group/'): 181 | if tf.get_collection('OVERLAP_LOSS') and self._hp.gamma1 > 0: 182 | cost1 = tf.reduce_mean(tf.get_collection('OVERLAP_LOSS')) 183 | cost1 = cost1 * self._hp.gamma1 184 | tf.summary.scalar('group/overlap_loss/', cost1) 185 | losses.append(cost1) 186 | 187 | if tf.get_collection('WEIGHT_SPLIT') and self._hp.gamma2 > 0: 188 | if self._hp.weighted_group_loss: 189 | reg_weights = [tf.stop_gradient(x) for x in tf.get_collection('WEIGHT_SPLIT')] 190 | regs = [tf.stop_gradient(x) * x for x in tf.get_collection('WEIGHT_SPLIT')] 191 | cost2 = tf.reduce_sum(regs) / tf.reduce_sum(reg_weights) 192 | else: 193 | cost2 = tf.reduce_mean(tf.get_collection('WEIGHT_SPLIT')) 194 | cost2 = cost2 * self._hp.gamma2 195 | tf.summary.scalar('group/weight_split_loss/', cost2) 196 | losses.append(cost2) 197 | 198 | if tf.get_collection('UNIFORM_LOSS') and self._hp.gamma3 > 0: 199 | cost3 = tf.reduce_mean(tf.get_collection('UNIFORM_LOSS')) 200 | cost3 = cost3 * self._hp.gamma3 201 | tf.summary.scalar('group/group_uniform_loss/', cost3) 202 | losses.append(cost3) 203 | 204 | self._total_loss = tf.add_n(losses) 205 | 206 | # Gradient descent step 207 | opt = tf.train.MomentumOptimizer(self.lr, self._hp.momentum) 208 | grads_and_vars = opt.compute_gradients(self._total_loss, tf.trainable_variables()) 209 | if self._hp.finetune: 210 | for idx, (grad, var) in enumerate(grads_and_vars): 211 | if "unit3" in var.op.name or \ 212 | "unit_last" in var.op.name or \ 213 | "logits" in var.op.name: 214 | print('Scale up learning rate of % s by 10.0' % var.op.name) 215 | grad = 10.0 * grad 216 | grads_and_vars[idx] = (grad,var) 217 | 218 | apply_grad_op = opt.apply_gradients(grads_and_vars, global_step=self._global_step) 219 | 220 | 221 | # Batch normalization moving average update 222 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 223 | if update_ops: 224 | with tf.control_dependencies(update_ops+[apply_grad_op]): 225 | self.train_op = tf.no_op() 226 | else: 227 | self.train_op = apply_grad_op 228 | 229 | # Helper functions(counts FLOPs and number of weights) 230 | def _conv(self, x, filter_size, out_channel, stride, pad="SAME", input_q=None, output_q=None, name="conv"): 231 | b, h, w, in_channel = x.get_shape().as_list() 232 | x = utils._conv(x, filter_size, out_channel, stride, pad, input_q, output_q, name) 233 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size 234 | w = in_channel * out_channel * filter_size * filter_size 235 | scope_name = tf.get_variable_scope().name + "/" + name 236 | self._add_flops_weights(scope_name, f, w) 237 | return x 238 | 239 | def _fc(self, x, out_dim, input_q=None, output_q=None, name="fc"): 240 | b, in_dim = x.get_shape().as_list() 241 | x = utils._fc(x, out_dim, input_q, output_q, name) 242 | f = 2 * (in_dim + 1) * out_dim 243 | w = (in_dim + 1) * out_dim 244 | scope_name = tf.get_variable_scope().name + "/" + name 245 | self._add_flops_weights(scope_name, f, w) 246 | return x 247 | 248 | def _bn(self, x, name="bn", no_scale=False): 249 | x = utils._bn(x, self.is_train, self._global_step, name, no_scale=no_scale) 250 | # f = 8 * self._get_data_size(x) 251 | # w = 4 * x.get_shape().as_list()[-1] 252 | # scope_name = tf.get_variable_scope().name + "/" + name 253 | # self._add_flops_weights(scope_name, f, w) 254 | return x 255 | 256 | def _relu(self, x, name="relu"): 257 | x = utils._relu(x, 0.0, name) 258 | # f = self._get_data_size(x) 259 | # scope_name = tf.get_variable_scope().name + "/" + name 260 | # self._add_flops_weights(scope_name, f, 0) 261 | return x 262 | 263 | def _dropout(self, x, keep_prob, name="dropout"): 264 | x = utils._dropout(x, keep_prob, name) 265 | return x 266 | 267 | def _get_data_size(self, x): 268 | return np.prod(x.get_shape().as_list()[1:]) 269 | 270 | def _add_flops_weights(self, scope_name, f, w): 271 | if scope_name not in self._counted_scope: 272 | self._flops += f 273 | self._weights += w 274 | self._counted_scope.append(scope_name) 275 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | from datetime import datetime 6 | import time 7 | import tensorflow as tf 8 | import numpy as np 9 | import cPickle as pickle 10 | 11 | import cifar100 12 | import resnet_split as resnet 13 | 14 | 15 | 16 | # Dataset Configuration 17 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""") 18 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""") 19 | tf.app.flags.DEFINE_integer('num_test_instance', 10000, """Number of test images.""") 20 | 21 | # Network Configuration 22 | tf.app.flags.DEFINE_integer('batch_size', 100, """Number of images to process in a batch.""") 23 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group. 24 | Total number of conv layers will be 6n+4""") 25 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""") 26 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""") 27 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""") 28 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""") 29 | 30 | # Optimization Configuration 31 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""") 32 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""") 33 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""") 34 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""") 35 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""") 36 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""") 37 | 38 | # Evaluation Configuration 39 | tf.app.flags.DEFINE_string('basemodel', './group/model.ckpt-199999', """Base model to load paramters""") 40 | tf.app.flags.DEFINE_string('checkpoint', './split/model.ckpt-149999', """Path to the model checkpoint file""") 41 | tf.app.flags.DEFINE_string('output_file', './split/eval.pkl', """Path to the result pkl file""") 42 | tf.app.flags.DEFINE_integer('test_iter', 100, """Number of test batches during the evaluation""") 43 | tf.app.flags.DEFINE_integer('display', 10, """Number of iterations to display training info.""") 44 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""") 45 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""") 46 | 47 | FLAGS = tf.app.flags.FLAGS 48 | 49 | 50 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step): 51 | lr = initial_lr 52 | for s in lr_decay_steps: 53 | if global_step >= s: 54 | lr *= lr_decay 55 | return lr 56 | 57 | 58 | def train(): 59 | print('[Dataset Configuration]') 60 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir) 61 | print('\tNumber of classes: %d' % FLAGS.num_classes) 62 | print('\tNumber of test images: %d' % FLAGS.num_test_instance) 63 | 64 | print('[Network Configuration]') 65 | print('\tBatch size: %d' % FLAGS.batch_size) 66 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units) 67 | print('\tNetwork width multiplier: %d' % FLAGS.k) 68 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1)) 69 | print('\tBasemodel file: %s' % FLAGS.basemodel) 70 | 71 | print('[Evaluation Configuration]') 72 | print('\tCheckpoint file: %s' % FLAGS.checkpoint) 73 | print('\tOutput file path: %s' % FLAGS.output_file) 74 | print('\tTest iterations: %d' % FLAGS.test_iter) 75 | print('\tSteps per displaying info: %d' % FLAGS.display) 76 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) 77 | print('\tLog device placement: %d' % FLAGS.log_device_placement) 78 | 79 | 80 | with tf.Graph().as_default(): 81 | global_step = tf.Variable(0, trainable=False, name='global_step') 82 | 83 | # Get images and labels of CIFAR-100 84 | print('Load CIFAR-100 dataset') 85 | test_dataset_path = os.path.join(FLAGS.data_dir, 'test') 86 | with tf.variable_scope('test_image'): 87 | cifar100_test = cifar100.CIFAR100Runner(test_dataset_path, image_per_thread=1, 88 | shuffle=False, distort=False, capacity=5000) 89 | test_images, test_labels = cifar100_test.get_inputs(FLAGS.batch_size) 90 | 91 | # Build a Graph that computes the predictions from the inference model. 92 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3]) 93 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) 94 | 95 | # Get splitted params 96 | if not FLAGS.basemodel: 97 | print('No basemodel found to load split params') 98 | sys.exit(-1) 99 | else: 100 | print('Load split params from %s' % FLAGS.basemodel) 101 | 102 | def get_perms(q_name, ngroups): 103 | split_alpha = reader.get_tensor(q_name+'/alpha') 104 | q_amax = np.argmax(split_alpha, axis=0) 105 | return [np.where(q_amax == i)[0] for i in range(ngroups)] 106 | 107 | reader = tf.train.NewCheckpointReader(FLAGS.basemodel) 108 | split_params = {} 109 | 110 | print('\tlogits...') 111 | base_logits_w = reader.get_tensor('logits/fc/weights') 112 | base_logits_b = reader.get_tensor('logits/fc/biases') 113 | split_p1_idxs = get_perms('group/split_p1', FLAGS.ngroups1) 114 | split_q1_idxs = get_perms('group/split_q1', FLAGS.ngroups1) 115 | 116 | logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]} 117 | for i in range(FLAGS.ngroups1): 118 | logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]]) 119 | logits_params['biases'].append(base_logits_b[split_q1_idxs[i]]) 120 | logits_params['input_perms'] = split_p1_idxs 121 | logits_params['output_perms'] = split_q1_idxs 122 | split_params['logits'] = logits_params 123 | 124 | if FLAGS.ngroups2 > 1: 125 | print('\tunit_3_x...') 126 | base_unit_3_0_shortcut_k = reader.get_tensor('unit_3_0/shortcut/kernel') 127 | base_unit_3_0_conv1_k = reader.get_tensor('unit_3_0/conv_1/kernel') 128 | base_unit_3_0_conv2_k = reader.get_tensor('unit_3_0/conv_2/kernel') 129 | base_unit_3_1_conv1_k = reader.get_tensor('unit_3_1/conv_1/kernel') 130 | base_unit_3_1_conv2_k = reader.get_tensor('unit_3_1/conv_2/kernel') 131 | split_p2_idxs = get_perms('group/split_p2', FLAGS.ngroups2) 132 | split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2)) 133 | split_r21_idxs = get_perms('group/split_r21', FLAGS.ngroups2) 134 | split_r22_idxs = get_perms('group/split_r22', FLAGS.ngroups2) 135 | 136 | unit_3_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} 137 | for i in range(FLAGS.ngroups2): 138 | unit_3_0_params['shortcut'].append(base_unit_3_0_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]]) 139 | unit_3_0_params['conv1'].append(base_unit_3_0_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]]) 140 | unit_3_0_params['conv2'].append(base_unit_3_0_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]]) 141 | unit_3_0_params['p_perms'] = split_p2_idxs 142 | unit_3_0_params['q_perms'] = split_q2_idxs 143 | unit_3_0_params['r_perms'] = split_r21_idxs 144 | split_params['unit_3_0'] = unit_3_0_params 145 | 146 | unit_3_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} 147 | for i in range(FLAGS.ngroups2): 148 | unit_3_1_params['conv1'].append(base_unit_3_1_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]]) 149 | unit_3_1_params['conv2'].append(base_unit_3_1_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]]) 150 | unit_3_1_params['p_perms'] = split_q2_idxs 151 | unit_3_1_params['r_perms'] = split_r22_idxs 152 | split_params['unit_3_1'] = unit_3_1_params 153 | 154 | if FLAGS.ngroups3 > 1: 155 | print('\tconv4_x...') 156 | base_unit_2_0_shortcut_k = reader.get_tensor('unit_2_0/shortcut/kernel') 157 | base_unit_2_0_conv1_k = reader.get_tensor('unit_2_0/conv_1/kernel') 158 | base_unit_2_0_conv2_k = reader.get_tensor('unit_2_0/conv_2/kernel') 159 | base_unit_2_1_conv1_k = reader.get_tensor('unit_2_1/conv_1/kernel') 160 | base_unit_2_1_conv2_k = reader.get_tensor('unit_2_1/conv_2/kernel') 161 | split_p3_idxs = get_perms('group/split_p3', FLAGS.ngroups3) 162 | split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3)) 163 | split_r31_idxs = get_perms('group/split_r31', FLAGS.ngroups3) 164 | split_r32_idxs = get_perms('group/split_r32', FLAGS.ngroups3) 165 | 166 | unit_2_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} 167 | for i in range(FLAGS.ngroups3): 168 | unit_2_0_params['shortcut'].append(base_unit_2_0_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]]) 169 | unit_2_0_params['conv1'].append(base_unit_2_0_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]]) 170 | unit_2_0_params['conv2'].append(base_unit_2_0_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]]) 171 | unit_2_0_params['p_perms'] = split_p3_idxs 172 | unit_2_0_params['q_perms'] = split_q3_idxs 173 | unit_2_0_params['r_perms'] = split_r31_idxs 174 | split_params['unit_2_0'] = unit_2_0_params 175 | 176 | unit_2_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} 177 | for i in range(FLAGS.ngroups3): 178 | unit_2_1_params['conv1'].append(base_unit_2_1_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]]) 179 | unit_2_1_params['conv2'].append(base_unit_2_1_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]]) 180 | unit_2_1_params['p_perms'] = split_q3_idxs 181 | unit_2_1_params['r_perms'] = split_r32_idxs 182 | split_params['unit_2_1'] = unit_2_1_params 183 | 184 | 185 | # Build model 186 | hp = resnet.HParams(batch_size=FLAGS.batch_size, 187 | num_classes=FLAGS.num_classes, 188 | num_residual_units=FLAGS.num_residual_units, 189 | k=FLAGS.k, 190 | weight_decay=FLAGS.l2_weight, 191 | ngroups1=FLAGS.ngroups1, 192 | ngroups2=FLAGS.ngroups2, 193 | ngroups3=FLAGS.ngroups3, 194 | split_params=split_params, 195 | momentum=FLAGS.momentum, 196 | finetune=FLAGS.finetune) 197 | network = resnet.ResNet(hp, images, labels, global_step) 198 | network.build_model() 199 | print('Number of Weights: %d' % network._weights) 200 | print('FLOPs: %d' % network._flops) 201 | 202 | # Build an initialization operation to run below. 203 | init = tf.global_variables_initializer() 204 | 205 | # Start running operations on the Graph. 206 | sess = tf.Session(config=tf.ConfigProto( 207 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction), 208 | allow_soft_placement=True, 209 | log_device_placement=FLAGS.log_device_placement)) 210 | 211 | '''debugging attempt 212 | from tensorflow.python import debug as tf_debug 213 | sess = tf_debug.LocalCLIDebugWrapperSession(sess) 214 | def _get_data(datum, tensor): 215 | return tensor == train_images 216 | sess.add_tensor_filter("get_data", _get_data) 217 | ''' 218 | 219 | sess.run(init) 220 | 221 | # Create a saver. 222 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) 223 | if FLAGS.checkpoint is not None: 224 | saver.restore(sess, FLAGS.checkpoint) 225 | print('Load checkpoint %s' % FLAGS.checkpoint) 226 | else: 227 | print('No checkpoint file found.') 228 | sys.exit(1) 229 | 230 | # Start queue runners & summary_writer 231 | cifar100_test.start_threads(sess, n_threads=1) 232 | 233 | # Test! 234 | test_loss = 0.0 235 | test_acc = 0.0 236 | test_time = 0.0 237 | confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.int32) 238 | for i in range(FLAGS.test_iter): 239 | test_images_val, test_labels_val = sess.run([test_images, test_labels]) 240 | start_time = time.time() 241 | loss_value, acc_value, pred_value = sess.run([network.loss, network.acc, network.preds], 242 | feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val}) 243 | duration = time.time() - start_time 244 | test_loss += loss_value 245 | test_acc += acc_value 246 | test_time += duration 247 | for l, p in zip(test_labels_val, pred_value): 248 | confusion_matrix[l, p] += 1 249 | 250 | if i % FLAGS.display == 0: 251 | num_examples_per_step = FLAGS.batch_size 252 | examples_per_sec = num_examples_per_step / duration 253 | sec_per_batch = float(duration) 254 | format_str = ('%s: iter %d, loss=%.4f, acc=%.4f (%.1f examples/sec; %.3f sec/batch)') 255 | print (format_str % (datetime.now(), i, loss_value, acc_value, 256 | examples_per_sec, sec_per_batch)) 257 | test_loss /= FLAGS.test_iter 258 | test_acc /= FLAGS.test_iter 259 | 260 | # Print and save results 261 | sec_per_image = test_time/FLAGS.test_iter/FLAGS.batch_size 262 | print ('Done! Acc: %.6f, Test time: %.3f sec, %.7f sec/example' % (test_acc, test_time, sec_per_image)) 263 | print ('Saving result... ') 264 | result = {'accuracy': test_acc, 'confusion_matrix': confusion_matrix, 265 | 'test_time': test_time, 'sec_per_image': sec_per_image} 266 | with open(FLAGS.output_file, 'wb') as fd: 267 | pickle.dump(result, fd) 268 | print ('done!') 269 | 270 | 271 | def _merge_split_q(q, merge_idxs, name='merge'): 272 | ngroups, dim = q.shape 273 | max_idx = np.max(merge_idxs) 274 | temp_list = [] 275 | for i in range(max_idx + 1): 276 | temp = [] 277 | for j in range(ngroups): 278 | if merge_idxs[j] == i: 279 | temp.append(q[j,:]) 280 | temp_list.append(np.sum(temp, axis=0)) 281 | ret = np.array(temp_list) 282 | 283 | return ret 284 | 285 | def _merge_split_idxs(split_idxs, merge_idxs, name='merge'): 286 | ngroups = len(split_idxs) 287 | max_idx = np.max(merge_idxs) 288 | ret = [] 289 | for i in range(max_idx + 1): 290 | temp = [] 291 | for j in range(ngroups): 292 | if merge_idxs[j] == i: 293 | temp.append(split_idxs[j]) 294 | ret.append(np.concatenate(temp)) 295 | 296 | return ret 297 | 298 | def _get_even_merge_idxs(N, split): 299 | assert N >= split 300 | num_elems = [(N + split - i - 1)/split for i in range(split)] 301 | expand_split = [[i] * n for i, n in enumerate(num_elems)] 302 | return [t for l in expand_split for t in l] 303 | 304 | 305 | def main(argv=None): # pylint: disable=unused-argument 306 | train() 307 | 308 | 309 | if __name__ == '__main__': 310 | tf.app.run() 311 | -------------------------------------------------------------------------------- /resnet_split.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | import utils 7 | 8 | 9 | HParams = namedtuple('HParams', 10 | 'batch_size, num_classes, num_residual_units, k, weight_decay, momentum, finetune, ' 11 | 'ngroups1, ngroups2, ngroups3, split_params') 12 | 13 | class ResNet(object): 14 | def __init__(self, hp, images, labels, global_step, name=None, reuse_weights=False): 15 | self._hp = hp # Hyperparameters 16 | self._images = images # Input image 17 | self._labels = labels # Input labels 18 | self._global_step = global_step 19 | self.lr = tf.placeholder(tf.float32) 20 | self.is_train = tf.placeholder(tf.bool) 21 | self._counted_scope = [] 22 | self._flops = 0 23 | self._weights = 0 24 | 25 | def build_model(self): 26 | print('Building model') 27 | filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k] 28 | strides = [1, 2, 2] 29 | 30 | # Init. conv. 31 | print('\tBuilding unit: init_conv') 32 | x = utils._conv(self._images, 3, filters[0], 1, name='init_conv') 33 | 34 | # unit_1_x 35 | x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0') 36 | x = self._residual_block(x, name='unit_1_1') 37 | 38 | # unit_2_x 39 | if self._hp.ngroups3 == 1: 40 | x = self._residual_block_first(x, filters[2], strides[1], name='unit_2_0') 41 | x = self._residual_block(x, name='unit_2_1') 42 | else: 43 | unit_2_0_shortcut_kernel = self._hp.split_params['unit_2_0']['shortcut'] 44 | unit_2_0_conv1_kernel = self._hp.split_params['unit_2_0']['conv1'] 45 | unit_2_0_conv2_kernel = self._hp.split_params['unit_2_0']['conv2'] 46 | unit_2_0_p_perms = self._hp.split_params['unit_2_0']['p_perms'] 47 | unit_2_0_q_perms = self._hp.split_params['unit_2_0']['q_perms'] 48 | unit_2_0_r_perms = self._hp.split_params['unit_2_0']['r_perms'] 49 | 50 | with tf.variable_scope('unit_2_0'): 51 | shortcut = self._conv_split(x, filters[2], strides[1], unit_2_0_shortcut_kernel, unit_2_0_p_perms, unit_2_0_q_perms, name='shortcut') 52 | x = self._conv_split(x, filters[2], strides[1], unit_2_0_conv1_kernel, unit_2_0_p_perms, unit_2_0_r_perms, name='conv_1') 53 | x = self._bn(x, name='bn_1') 54 | x = self._relu(x, name='relu_1') 55 | x = self._conv_split(x, filters[2], 1, unit_2_0_conv2_kernel, unit_2_0_r_perms, unit_2_0_q_perms, name='conv_2') 56 | x = self._bn(x, name='bn_2') 57 | x = x + shortcut 58 | x = self._relu(x, name='relu_2') 59 | 60 | unit_2_1_conv1_kernel = self._hp.split_params['unit_2_1']['conv1'] 61 | unit_2_1_conv2_kernel = self._hp.split_params['unit_2_1']['conv2'] 62 | unit_2_1_p_perms = self._hp.split_params['unit_2_1']['p_perms'] 63 | unit_2_1_r_perms = self._hp.split_params['unit_2_1']['r_perms'] 64 | 65 | with tf.variable_scope('unit_2_1'): 66 | shortcut = x 67 | x = self._conv_split(x, filters[2], 1, unit_2_1_conv1_kernel, unit_2_1_p_perms, unit_2_1_r_perms, name='conv_1') 68 | x = self._bn(x, name='bn_1') 69 | x = self._relu(x, name='relu_1') 70 | x = self._conv_split(x, filters[2], 1, unit_2_1_conv2_kernel, unit_2_1_r_perms, unit_2_1_p_perms, name='conv_2') 71 | x = self._bn(x, name='bn_2') 72 | x = x + shortcut 73 | x = self._relu(x, name='relu_2') 74 | 75 | # unit_3_x 76 | if self._hp.ngroups2 == 1: 77 | x = self._residual_block_first(x, filters[3], strides[2], name='unit_3_0') 78 | x = self._residual_block(x, name='unit_3_1') 79 | else: 80 | unit_3_0_shortcut_kernel = self._hp.split_params['unit_3_0']['shortcut'] 81 | unit_3_0_conv1_kernel = self._hp.split_params['unit_3_0']['conv1'] 82 | unit_3_0_conv2_kernel = self._hp.split_params['unit_3_0']['conv2'] 83 | unit_3_0_p_perms = self._hp.split_params['unit_3_0']['p_perms'] 84 | unit_3_0_q_perms = self._hp.split_params['unit_3_0']['q_perms'] 85 | unit_3_0_r_perms = self._hp.split_params['unit_3_0']['r_perms'] 86 | 87 | with tf.variable_scope('unit_3_0'): 88 | shortcut = self._conv_split(x, filters[3], strides[2], unit_3_0_shortcut_kernel, unit_3_0_p_perms, unit_3_0_q_perms, name='shortcut') 89 | x = self._conv_split(x, filters[3], strides[2], unit_3_0_conv1_kernel, unit_3_0_p_perms, unit_3_0_r_perms, name='conv_1') 90 | x = self._bn(x, name='bn_1') 91 | x = self._relu(x, name='relu_1') 92 | x = self._conv_split(x, filters[3], 1, unit_3_0_conv2_kernel, unit_3_0_r_perms, unit_3_0_q_perms, name='conv_2') 93 | x = self._bn(x, name='bn_2') 94 | x = x + shortcut 95 | x = self._relu(x, name='relu_2') 96 | 97 | unit_3_1_conv1_kernel = self._hp.split_params['unit_3_1']['conv1'] 98 | unit_3_1_conv2_kernel = self._hp.split_params['unit_3_1']['conv2'] 99 | unit_3_1_p_perms = self._hp.split_params['unit_3_1']['p_perms'] 100 | unit_3_1_r_perms = self._hp.split_params['unit_3_1']['r_perms'] 101 | 102 | with tf.variable_scope('unit_3_1'): 103 | shortcut = x 104 | x = self._conv_split(x, filters[3], 1, unit_3_1_conv1_kernel, unit_3_1_p_perms, unit_3_1_r_perms, name='conv_1') 105 | x = self._bn(x, name='bn_1') 106 | x = self._relu(x, name='relu_1') 107 | x = self._conv_split(x, filters[3], 1, unit_3_1_conv2_kernel, unit_3_1_r_perms, unit_3_1_p_perms, name='conv_2') 108 | x = self._bn(x, name='bn_2') 109 | x = x + shortcut 110 | x = self._relu(x, name='relu_2') 111 | 112 | # Last unit 113 | with tf.variable_scope('unit_last') as scope: 114 | print('\tBuilding unit: %s' % scope.name) 115 | x = utils._bn(x, self.is_train, self._global_step) 116 | x = utils._relu(x) 117 | x = tf.reduce_mean(x, [1, 2]) 118 | 119 | # Logit 120 | logits_weights = self._hp.split_params['logits']['weights'] 121 | logits_biases = self._hp.split_params['logits']['biases'] 122 | logits_input_perms = self._hp.split_params['logits']['input_perms'] 123 | logits_output_perms = self._hp.split_params['logits']['output_perms'] 124 | with tf.variable_scope('logits') as scope: 125 | print('\tBuilding unit: %s - %d split' % (scope.name, len(logits_weights))) 126 | x_offset = 0 127 | x_list = [] 128 | for i, (w, b, p) in enumerate(zip(logits_weights, logits_biases, logits_input_perms)): 129 | in_dim, out_dim = w.shape 130 | x_split = tf.transpose(tf.gather(tf.transpose(x), p)) 131 | x_split = self._fc_with_init(x_split, out_dim, init_w=w, init_b=b, name='split%d' % (i+1)) 132 | x_list.append(x_split) 133 | x_offset += in_dim 134 | x = tf.concat(x_list, 1) 135 | output_forward_idx = list(np.concatenate(logits_output_perms)) 136 | output_inverse_idx = [output_forward_idx.index(i) for i in range(self._hp.num_classes)] 137 | x = tf.transpose(tf.gather(tf.transpose(x), output_inverse_idx)) 138 | 139 | self._logits = x 140 | 141 | # Probs & preds & acc 142 | self.probs = tf.nn.softmax(x, name='probs') 143 | self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds')) 144 | ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32) 145 | zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32) 146 | correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros) 147 | self.acc = tf.reduce_mean(correct, name='acc') 148 | tf.summary.scalar('accuracy', self.acc) 149 | 150 | # Loss & acc 151 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=x, labels=self._labels) 152 | self.loss = tf.reduce_mean(loss) 153 | tf.summary.scalar('cross_entropy', self.loss) 154 | 155 | 156 | def build_train_op(self): 157 | print('Building train ops') 158 | 159 | # Learning rate 160 | tf.summary.scalar('learing_rate', self.lr) 161 | 162 | losses = [self.loss] 163 | 164 | # Add l2 loss 165 | with tf.variable_scope('l2_loss'): 166 | costs = [tf.nn.l2_loss(var) for var in tf.get_collection(utils.WEIGHT_DECAY_KEY)] 167 | l2_loss = tf.multiply(self._hp.weight_decay, tf.add_n(costs)) 168 | losses.append(l2_loss) 169 | 170 | self._total_loss = tf.add_n(losses) 171 | 172 | # Gradient descent step 173 | opt = tf.train.MomentumOptimizer(self.lr, self._hp.momentum) 174 | grads_and_vars = opt.compute_gradients(self._total_loss, tf.trainable_variables()) 175 | if self._hp.finetune: 176 | for idx, (grad, var) in enumerate(grads_and_vars): 177 | if "group" in var.op.name or \ 178 | (("unit_1_0" in var.op.name or "unit_1_1" in var.op.name) and self._hp.ngroups3 > 1) or \ 179 | (("unit_2_0" in var.op.name or "unit_2_1" in var.op.name) and self._hp.ngroups2 > 1) or \ 180 | ("unit_3_0" in var.op.name or "unit_3_1" in var.op.name) or \ 181 | "logits" in var.op.name: 182 | print('\tScale up learning rate of % s by 10.0' % var.op.name) 183 | grad = 10.0 * grad 184 | grads_and_vars[idx] = (grad,var) 185 | 186 | # Apply gradient 187 | apply_grad_op = opt.apply_gradients(grads_and_vars, global_step=self._global_step) 188 | 189 | # Batch normalization moving average update 190 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 191 | if update_ops: 192 | with tf.control_dependencies(update_ops+[apply_grad_op]): 193 | self.train_op = tf.no_op() 194 | else: 195 | self.train_op = apply_grad_op 196 | 197 | 198 | def _residual_block_first(self, x, out_channel, strides, input_q=None, output_q=None, split_r=None, name="unit"): 199 | in_channel = x.get_shape().as_list()[-1] 200 | with tf.variable_scope(name) as scope: 201 | print('\tBuilding residual unit: %s' % scope.name) 202 | x = self._bn(x, name='bn_1') 203 | x = self._relu(x, name='relu_1') 204 | if input_q is not None and output_q is not None and split_r is not None: 205 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1') 206 | # Shortcut connection 207 | if in_channel == out_channel: 208 | if strides == 1: 209 | shortcut = tf.identity(x) 210 | else: 211 | shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID') 212 | else: 213 | shortcut = self._conv(x, 1, out_channel, strides, input_q=input_q, output_q=output_q, name='shortcut') 214 | # Residual 215 | x = self._conv(x, 3, out_channel, strides, input_q=input_q, output_q=split_r, name='conv_1') 216 | x = self._bn(x, name='bn_2') 217 | x = self._relu(x, name='relu_2') 218 | if input_q is not None and output_q is not None and split_r is not None: 219 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2') 220 | x = self._conv(x, 3, out_channel, 1, input_q=split_r, output_q=output_q, name='conv_2') 221 | # Merge 222 | x = x + shortcut 223 | return x 224 | 225 | def _residual_block(self, x, split_q=None, split_r=None, name="unit"): 226 | num_channel = x.get_shape().as_list()[-1] 227 | with tf.variable_scope(name) as scope: 228 | print('\tBuilding residual unit: %s' % scope.name) 229 | # Shortcut connection 230 | shortcut = x 231 | # Residual 232 | x = self._bn(x, name='bn_1') 233 | x = self._relu(x, name='relu_1') 234 | if split_q is not None and split_r is not None: 235 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1') 236 | x = self._conv(x, 3, num_channel, 1, input_q=split_q, output_q=split_r, name='conv_1') 237 | x = self._bn(x, name='bn_2') 238 | x = self._relu(x, name='relu_2') 239 | if split_q is not None and split_r is not None: 240 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2') 241 | x = self._conv(x, 3, num_channel, 1, input_q=split_r, output_q=split_q, name='conv_2') 242 | # Merge 243 | x = x + shortcut 244 | return x 245 | 246 | 247 | def _conv_split(self, x, out_channel, strides, kernels, input_perms, output_perms, name="unit"): 248 | b, w, h, in_channel = x.get_shape().as_list() 249 | x_list = [] 250 | with tf.variable_scope(name) as scope: 251 | print('\tBuilding residual unit: %s - %d split' % (scope.name, len(kernels))) 252 | for i, (k, p) in enumerate(zip(kernels, input_perms)): 253 | kernel_size, in_dim, out_dim = k.shape[-3:] 254 | x_split = tf.transpose(tf.gather(tf.transpose(x, (3, 0, 1, 2)), p), (1, 2, 3, 0)) 255 | x_split = self._conv_with_init(x_split, kernel_size, out_dim, strides, init_k=k, name="split%d"%(i+1)) 256 | x_list.append(x_split) 257 | x = tf.concat(x_list, 3) 258 | output_forward_idx = list(np.concatenate(output_perms)) 259 | output_inverse_idx = [output_forward_idx.index(i) for i in range(out_channel)] 260 | x = tf.transpose(tf.gather(tf.transpose(x, (3, 0, 1, 2)), output_inverse_idx), (1, 2, 3, 0)) 261 | return x 262 | 263 | 264 | # Helper functions(counts FLOPs and number of weights) 265 | def _conv(self, x, filter_size, out_channel, stride, pad="SAME", input_q=None, output_q=None, name="conv"): 266 | b, h, w, in_channel = x.get_shape().as_list() 267 | x = utils._conv(x, filter_size, out_channel, stride, pad, input_q, output_q, name) 268 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size 269 | w = in_channel * out_channel * filter_size * filter_size 270 | scope_name = tf.get_variable_scope().name + "/" + name 271 | self._add_flops_weights(scope_name, f, w) 272 | return x 273 | 274 | def _conv_with_init(self, x, filter_size, out_channel, stride, pad="SAME", init_k=None, name="conv"): 275 | b, h, w, in_channel = x.get_shape().as_list() 276 | x = utils._conv_with_init(x, filter_size, out_channel, stride, pad, init_k, name) 277 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size 278 | w = in_channel * out_channel * filter_size * filter_size 279 | scope_name = tf.get_variable_scope().name + "/" + name 280 | self._add_flops_weights(scope_name, f, w) 281 | return x 282 | 283 | def _fc(self, x, out_dim, input_q=None, output_q=None, name="fc"): 284 | b, in_dim = x.get_shape().as_list() 285 | x = utils._fc(x, out_dim, input_q, output_q, name) 286 | f = 2 * (in_dim + 1) * out_dim 287 | w = (in_dim + 1) * out_dim 288 | scope_name = tf.get_variable_scope().name + "/" + name 289 | self._add_flops_weights(scope_name, f, w) 290 | return x 291 | 292 | def _fc_with_init(self, x, out_dim, init_w=None, init_b=None, name="fc"): 293 | b, in_dim = x.get_shape().as_list() 294 | x = utils._fc_with_init(x, out_dim, init_w, init_b, name) 295 | f = 2*(in_dim + 1) * out_dim 296 | w = (in_dim + 1) * out_dim 297 | scope_name = tf.get_variable_scope().name + "/" + name 298 | self._add_flops_weights(scope_name, f, w) 299 | return x 300 | 301 | def _bn(self, x, name="bn"): 302 | x = utils._bn(x, self.is_train, self._global_step, name) 303 | # f = 8 * self._get_data_size(x) 304 | # w = 4 * x.get_shape().as_list()[-1] 305 | # scope_name = tf.get_variable_scope().name + "/" + name 306 | # self._add_flops_weights(scope_name, f, w) 307 | return x 308 | 309 | def _relu(self, x, name="relu"): 310 | x = utils._relu(x, 0.0, name) 311 | # f = self._get_data_size(x) 312 | # scope_name = tf.get_variable_scope().name + "/" + name 313 | # self._add_flops_weights(scope_name, f, 0) 314 | return x 315 | 316 | def _dropout(self, x, keep_prob, name="dropout"): 317 | x = utils._dropout(x, keep_prob, name) 318 | return x 319 | 320 | def _get_data_size(self, x): 321 | return np.prod(x.get_shape().as_list()[1:]) 322 | 323 | def _add_flops_weights(self, scope_name, f, w): 324 | if scope_name not in self._counted_scope: 325 | self._flops += f 326 | self._weights += w 327 | self._counted_scope.append(scope_name) 328 | -------------------------------------------------------------------------------- /train_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from datetime import datetime 5 | import time 6 | import tensorflow as tf 7 | import numpy as np 8 | import sys 9 | import select 10 | from IPython import embed 11 | from StringIO import StringIO 12 | import matplotlib.pyplot as plt 13 | 14 | import cifar100 15 | import resnet_split as resnet 16 | 17 | 18 | 19 | # Dataset Configuration 20 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""") 21 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""") 22 | tf.app.flags.DEFINE_integer('num_train_instance', 45000, """Number of training images.""") 23 | tf.app.flags.DEFINE_integer('num_val_instance', 5000, """Number of val images.""") 24 | 25 | # Network Configuration 26 | tf.app.flags.DEFINE_integer('batch_size', 90, """Number of images to process in a batch.""") 27 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group. 28 | Total number of conv layers will be 6n+4""") 29 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""") 30 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""") 31 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""") 32 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""") 33 | 34 | # Optimization Configuration 35 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""") 36 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""") 37 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""") 38 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""") 39 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""") 40 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""") 41 | 42 | # Training Configuration 43 | tf.app.flags.DEFINE_string('train_dir', './train', """Directory where to write log and checkpoint.""") 44 | tf.app.flags.DEFINE_integer('max_steps', 100000, """Number of batches to run.""") 45 | tf.app.flags.DEFINE_integer('display', 100, """Number of iterations to display training info.""") 46 | tf.app.flags.DEFINE_integer('val_interval', 1000, """Number of iterations to run a val""") 47 | tf.app.flags.DEFINE_integer('val_iter', 100, """Number of iterations during a val""") 48 | tf.app.flags.DEFINE_integer('checkpoint_interval', 10000, """Number of iterations to save parameters as a checkpoint""") 49 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""") 50 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""") 51 | tf.app.flags.DEFINE_string('basemodel', None, """Base model to load paramters""") 52 | tf.app.flags.DEFINE_string('checkpoint', None, """Model checkpoint to load""") 53 | 54 | FLAGS = tf.app.flags.FLAGS 55 | 56 | 57 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step): 58 | lr = initial_lr 59 | for s in lr_decay_steps: 60 | if global_step >= s: 61 | lr *= lr_decay 62 | return lr 63 | 64 | 65 | def train(): 66 | print('[Dataset Configuration]') 67 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir) 68 | print('\tNumber of classes: %d' % FLAGS.num_classes) 69 | print('\tNumber of training images: %d' % FLAGS.num_train_instance) 70 | print('\tNumber of val images: %d' % FLAGS.num_val_instance) 71 | 72 | print('[Network Configuration]') 73 | print('\tBatch size: %d' % FLAGS.batch_size) 74 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units) 75 | print('\tNetwork width multiplier: %d' % FLAGS.k) 76 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1)) 77 | print('\tBasemodel file: %s' % FLAGS.basemodel) 78 | 79 | print('[Optimization Configuration]') 80 | print('\tL2 loss weight: %f' % FLAGS.l2_weight) 81 | print('\tThe momentum optimizer: %f' % FLAGS.momentum) 82 | print('\tInitial learning rate: %f' % FLAGS.initial_lr) 83 | print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) 84 | print('\tLearning rate decay: %f' % FLAGS.lr_decay) 85 | print('\tFinetune: %d' % FLAGS.finetune) 86 | 87 | print('[Training Configuration]') 88 | print('\tTrain dir: %s' % FLAGS.train_dir) 89 | print('\tTraining max steps: %d' % FLAGS.max_steps) 90 | print('\tSteps per displaying info: %d' % FLAGS.display) 91 | print('\tSteps per validation: %d' % FLAGS.val_interval) 92 | print('\tSteps during validation: %d' % FLAGS.val_iter) 93 | print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) 94 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) 95 | print('\tLog device placement: %d' % FLAGS.log_device_placement) 96 | 97 | 98 | with tf.Graph().as_default(): 99 | init_step = 0 100 | global_step = tf.Variable(0, trainable=False, name='global_step') 101 | 102 | # Get images and labels of CIFAR-100 103 | print('Load CIFAR-100 dataset') 104 | train_dataset_path = os.path.join(FLAGS.data_dir, 'train') 105 | val_dataset_path = os.path.join(FLAGS.data_dir, 'val') 106 | print('\tLoading training data from %s' % train_dataset_path) 107 | with tf.variable_scope('train_image'): 108 | cifar100_train = cifar100.CIFAR100Runner(train_dataset_path, image_per_thread=32, 109 | shuffle=True, distort=True, capacity=10000) 110 | train_images, train_labels = cifar100_train.get_inputs(FLAGS.batch_size) 111 | print('\tLoading validation data from %s' % val_dataset_path) 112 | with tf.variable_scope('val_image'): 113 | cifar100_val = cifar100.CIFAR100Runner(val_dataset_path, image_per_thread=32, 114 | shuffle=False, distort=False, capacity=5000) 115 | # shuffle=False, distort=False, capacity=10000) 116 | val_images, val_labels = cifar100_val.get_inputs(FLAGS.batch_size) 117 | 118 | # Build a Graph that computes the predictions from the inference model. 119 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3]) 120 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) 121 | 122 | # Get splitted params 123 | if not FLAGS.basemodel: 124 | print('No basemodel found to load split params') 125 | sys.exit(-1) 126 | else: 127 | print('Load split params from %s' % FLAGS.basemodel) 128 | 129 | def get_perms(q_name, ngroups): 130 | split_alpha = reader.get_tensor(q_name+'/alpha') 131 | q_amax = np.argmax(split_alpha, axis=0) 132 | return [np.where(q_amax == i)[0] for i in range(ngroups)] 133 | 134 | reader = tf.train.NewCheckpointReader(FLAGS.basemodel) 135 | split_params = {} 136 | 137 | print('\tlogits...') 138 | base_logits_w = reader.get_tensor('logits/fc/weights') 139 | base_logits_b = reader.get_tensor('logits/fc/biases') 140 | split_p1_idxs = get_perms('group/split_p1', FLAGS.ngroups1) 141 | split_q1_idxs = get_perms('group/split_q1', FLAGS.ngroups1) 142 | 143 | logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]} 144 | for i in range(FLAGS.ngroups1): 145 | logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]]) 146 | logits_params['biases'].append(base_logits_b[split_q1_idxs[i]]) 147 | logits_params['input_perms'] = split_p1_idxs 148 | logits_params['output_perms'] = split_q1_idxs 149 | split_params['logits'] = logits_params 150 | 151 | if FLAGS.ngroups2 > 1: 152 | print('\tunit_3_x...') 153 | base_unit_3_0_shortcut_k = reader.get_tensor('unit_3_0/shortcut/kernel') 154 | base_unit_3_0_conv1_k = reader.get_tensor('unit_3_0/conv_1/kernel') 155 | base_unit_3_0_conv2_k = reader.get_tensor('unit_3_0/conv_2/kernel') 156 | base_unit_3_1_conv1_k = reader.get_tensor('unit_3_1/conv_1/kernel') 157 | base_unit_3_1_conv2_k = reader.get_tensor('unit_3_1/conv_2/kernel') 158 | split_p2_idxs = get_perms('group/split_p2', FLAGS.ngroups2) 159 | split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2)) 160 | split_r21_idxs = get_perms('group/split_r21', FLAGS.ngroups2) 161 | split_r22_idxs = get_perms('group/split_r22', FLAGS.ngroups2) 162 | 163 | unit_3_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} 164 | for i in range(FLAGS.ngroups2): 165 | unit_3_0_params['shortcut'].append(base_unit_3_0_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]]) 166 | unit_3_0_params['conv1'].append(base_unit_3_0_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]]) 167 | unit_3_0_params['conv2'].append(base_unit_3_0_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]]) 168 | unit_3_0_params['p_perms'] = split_p2_idxs 169 | unit_3_0_params['q_perms'] = split_q2_idxs 170 | unit_3_0_params['r_perms'] = split_r21_idxs 171 | split_params['unit_3_0'] = unit_3_0_params 172 | 173 | unit_3_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} 174 | for i in range(FLAGS.ngroups2): 175 | unit_3_1_params['conv1'].append(base_unit_3_1_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]]) 176 | unit_3_1_params['conv2'].append(base_unit_3_1_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]]) 177 | unit_3_1_params['p_perms'] = split_q2_idxs 178 | unit_3_1_params['r_perms'] = split_r22_idxs 179 | split_params['unit_3_1'] = unit_3_1_params 180 | 181 | if FLAGS.ngroups3 > 1: 182 | print('\tconv4_x...') 183 | base_unit_2_0_shortcut_k = reader.get_tensor('unit_2_0/shortcut/kernel') 184 | base_unit_2_0_conv1_k = reader.get_tensor('unit_2_0/conv_1/kernel') 185 | base_unit_2_0_conv2_k = reader.get_tensor('unit_2_0/conv_2/kernel') 186 | base_unit_2_1_conv1_k = reader.get_tensor('unit_2_1/conv_1/kernel') 187 | base_unit_2_1_conv2_k = reader.get_tensor('unit_2_1/conv_2/kernel') 188 | split_p3_idxs = get_perms('group/split_p3', FLAGS.ngroups3) 189 | split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3)) 190 | split_r31_idxs = get_perms('group/split_r31', FLAGS.ngroups3) 191 | split_r32_idxs = get_perms('group/split_r32', FLAGS.ngroups3) 192 | 193 | unit_2_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]} 194 | for i in range(FLAGS.ngroups3): 195 | unit_2_0_params['shortcut'].append(base_unit_2_0_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]]) 196 | unit_2_0_params['conv1'].append(base_unit_2_0_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]]) 197 | unit_2_0_params['conv2'].append(base_unit_2_0_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]]) 198 | unit_2_0_params['p_perms'] = split_p3_idxs 199 | unit_2_0_params['q_perms'] = split_q3_idxs 200 | unit_2_0_params['r_perms'] = split_r31_idxs 201 | split_params['unit_2_0'] = unit_2_0_params 202 | 203 | unit_2_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]} 204 | for i in range(FLAGS.ngroups3): 205 | unit_2_1_params['conv1'].append(base_unit_2_1_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]]) 206 | unit_2_1_params['conv2'].append(base_unit_2_1_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]]) 207 | unit_2_1_params['p_perms'] = split_q3_idxs 208 | unit_2_1_params['r_perms'] = split_r32_idxs 209 | split_params['unit_2_1'] = unit_2_1_params 210 | 211 | 212 | # Build model 213 | lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(',')) 214 | lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size for s in lr_decay_steps]) 215 | hp = resnet.HParams(batch_size=FLAGS.batch_size, 216 | num_classes=FLAGS.num_classes, 217 | num_residual_units=FLAGS.num_residual_units, 218 | k=FLAGS.k, 219 | weight_decay=FLAGS.l2_weight, 220 | ngroups1=FLAGS.ngroups1, 221 | ngroups2=FLAGS.ngroups2, 222 | ngroups3=FLAGS.ngroups3, 223 | split_params=split_params, 224 | momentum=FLAGS.momentum, 225 | finetune=FLAGS.finetune) 226 | network = resnet.ResNet(hp, images, labels, global_step) 227 | network.build_model() 228 | network.build_train_op() 229 | print('Number of Weights: %d' % network._weights) 230 | print('FLOPs: %d' % network._flops) 231 | 232 | train_summary_op = tf.summary.merge_all() # Summaries(training) 233 | 234 | # Build an initialization operation to run below. 235 | init = tf.global_variables_initializer() 236 | 237 | # Start running operations on the Graph. 238 | sess = tf.Session(config=tf.ConfigProto( 239 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction), 240 | # allow_soft_placement=True, 241 | log_device_placement=FLAGS.log_device_placement)) 242 | 243 | '''debugging attempt 244 | from tensorflow.python import debug as tf_debug 245 | sess = tf_debug.LocalCLIDebugWrapperSession(sess) 246 | def _get_data(datum, tensor): 247 | return tensor == train_images 248 | sess.add_tensor_filter("get_data", _get_data) 249 | ''' 250 | 251 | sess.run(init) 252 | 253 | # Create a saver. 254 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) 255 | if FLAGS.checkpoint is not None: 256 | saver.restore(sess, FLAGS.checkpoint) 257 | init_step = global_step.eval(session=sess) 258 | print('Load checkpoint %s' % FLAGS.checkpoint) 259 | else: 260 | # Define a different saver to load model checkpoints 261 | # Select only base variables (exclude split layers) 262 | print('Load parameters from basemodel %s' % FLAGS.basemodel) 263 | variables = tf.global_variables() 264 | vars_restore = [var for var in variables 265 | if not "Momentum" in var.name and 266 | not "logits" in var.name and 267 | not "global_step" in var.name] 268 | if FLAGS.ngroups2 > 1: 269 | vars_restore = [var for var in vars_restore 270 | if not "unit_3_" in var.name] 271 | if FLAGS.ngroups3 > 1: 272 | vars_restore = [var for var in vars_restore 273 | if not "unit_2_" in var.name] 274 | saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000) 275 | saver_restore.restore(sess, FLAGS.basemodel) 276 | 277 | # Start queue runners & summary_writer 278 | cifar100_train.start_threads(sess, n_threads=20) 279 | cifar100_val.start_threads(sess, n_threads=1) 280 | 281 | if not os.path.exists(FLAGS.train_dir): 282 | os.mkdir(FLAGS.train_dir) 283 | summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))), 284 | sess.graph) 285 | 286 | # Training! 287 | val_best_acc = 0.0 288 | for step in xrange(init_step, FLAGS.max_steps): 289 | # val 290 | if step % FLAGS.val_interval == 0: 291 | val_loss, val_acc = 0.0, 0.0 292 | for i in range(FLAGS.val_iter): 293 | val_images_val, val_labels_val = sess.run([val_images, val_labels]) 294 | loss_value, acc_value = sess.run([network.loss, network.acc], 295 | feed_dict={network.is_train:False, images:val_images_val, labels:val_labels_val}) 296 | val_loss += loss_value 297 | val_acc += acc_value 298 | val_loss /= FLAGS.val_iter 299 | val_acc /= FLAGS.val_iter 300 | val_best_acc = max(val_best_acc, val_acc) 301 | format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f') 302 | print (format_str % (datetime.now(), step, val_loss, val_acc)) 303 | 304 | val_summary = tf.Summary() 305 | val_summary.value.add(tag='val/loss', simple_value=val_loss) 306 | val_summary.value.add(tag='val/acc', simple_value=val_acc) 307 | val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc) 308 | summary_writer.add_summary(val_summary, step) 309 | summary_writer.flush() 310 | 311 | # Train 312 | lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) 313 | start_time = time.time() 314 | train_images_val, train_labels_val = sess.run([train_images, train_labels]) 315 | _, loss_value, acc_value, train_summary_str = \ 316 | sess.run([network.train_op, network.loss, network.acc, train_summary_op], 317 | feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val}) 318 | duration = time.time() - start_time 319 | 320 | assert not np.isnan(loss_value) 321 | 322 | # Display & Summary(training) 323 | if step % FLAGS.display == 0: 324 | num_examples_per_step = FLAGS.batch_size 325 | examples_per_sec = num_examples_per_step / duration 326 | sec_per_batch = float(duration) 327 | format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 328 | 'sec/batch)') 329 | print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value, 330 | examples_per_sec, sec_per_batch)) 331 | summary_writer.add_summary(train_summary_str, step) 332 | 333 | # Save the model checkpoint periodically. 334 | if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: 335 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 336 | saver.save(sess, checkpoint_path, global_step=step) 337 | 338 | if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: 339 | char = sys.stdin.read(1) 340 | if char == 'b': 341 | embed() 342 | 343 | 344 | def _merge_split_q(q, merge_idxs, name='merge'): 345 | ngroups, dim = q.shape 346 | max_idx = np.max(merge_idxs) 347 | temp_list = [] 348 | for i in range(max_idx + 1): 349 | temp = [] 350 | for j in range(ngroups): 351 | if merge_idxs[j] == i: 352 | temp.append(q[j,:]) 353 | temp_list.append(np.sum(temp, axis=0)) 354 | ret = np.array(temp_list) 355 | 356 | return ret 357 | 358 | def _merge_split_idxs(split_idxs, merge_idxs, name='merge'): 359 | ngroups = len(split_idxs) 360 | max_idx = np.max(merge_idxs) 361 | ret = [] 362 | for i in range(max_idx + 1): 363 | temp = [] 364 | for j in range(ngroups): 365 | if merge_idxs[j] == i: 366 | temp.append(split_idxs[j]) 367 | ret.append(np.concatenate(temp)) 368 | 369 | return ret 370 | 371 | def _get_even_merge_idxs(N, split): 372 | assert N >= split 373 | num_elems = [(N + split - i - 1)/split for i in range(split)] 374 | expand_split = [[i] * n for i, n in enumerate(num_elems)] 375 | return [t for l in expand_split for t in l] 376 | 377 | 378 | def main(argv=None): # pylint: disable=unused-argument 379 | train() 380 | 381 | 382 | if __name__ == '__main__': 383 | tf.app.run() 384 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from datetime import datetime 5 | import time 6 | import tensorflow as tf 7 | import numpy as np 8 | import sys 9 | import select 10 | from IPython import embed 11 | from StringIO import StringIO 12 | import matplotlib.pyplot as plt 13 | 14 | import cifar100 15 | import resnet 16 | 17 | 18 | 19 | # Dataset Configuration 20 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""") 21 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""") 22 | tf.app.flags.DEFINE_integer('num_train_instance', 45000, """Number of training images.""") 23 | tf.app.flags.DEFINE_integer('num_val_instance', 5000, """Number of val images.""") 24 | 25 | # Network Configuration 26 | tf.app.flags.DEFINE_integer('batch_size', 90, """Number of images to process in a batch.""") 27 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group. 28 | Total number of conv layers will be 6n+4""") 29 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""") 30 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""") 31 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""") 32 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""") 33 | 34 | # Optimization Configuration 35 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""") 36 | tf.app.flags.DEFINE_float('gamma1', 0.0, """split loss regularization paramter""") 37 | tf.app.flags.DEFINE_float('gamma2', 0.0, """overlap loss regularization parameter""") 38 | tf.app.flags.DEFINE_float('gamma3', 0.0, """uniform loss regularization parameter""") 39 | tf.app.flags.DEFINE_float('dropout_keep_prob', 1.0, """probability of dropouts on the split layers(1.0 not to use dropout)""") 40 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""") 41 | tf.app.flags.DEFINE_boolean('bn_no_scale', False, """Whether not to use trainable gamma in BN layers.""") 42 | tf.app.flags.DEFINE_boolean('weighted_group_loss', False, """Whether to normalize weight split loss where coeffs are propotional to its values.""") 43 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""") 44 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""") 45 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""") 46 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""") 47 | 48 | # Training Configuration 49 | tf.app.flags.DEFINE_string('train_dir', './train', """Directory where to write log and checkpoint.""") 50 | tf.app.flags.DEFINE_integer('max_steps', 100000, """Number of batches to run.""") 51 | tf.app.flags.DEFINE_integer('display', 100, """Number of iterations to display training info.""") 52 | tf.app.flags.DEFINE_integer('val_interval', 1000, """Number of iterations to run a val""") 53 | tf.app.flags.DEFINE_integer('val_iter', 100, """Number of iterations during a val""") 54 | tf.app.flags.DEFINE_integer('checkpoint_interval', 10000, """Number of iterations to save parameters as a checkpoint""") 55 | tf.app.flags.DEFINE_integer('group_summary_interval', 2500, """Number of iterations to plot grouping variables and weights""") 56 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""") 57 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""") 58 | tf.app.flags.DEFINE_string('basemodel', None, """Base model to load paramters""") 59 | tf.app.flags.DEFINE_string('checkpoint', None, """Model checkpoint to load""") 60 | 61 | FLAGS = tf.app.flags.FLAGS 62 | 63 | 64 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step): 65 | lr = initial_lr 66 | for s in lr_decay_steps: 67 | if global_step >= s: 68 | lr *= lr_decay 69 | return lr 70 | 71 | 72 | def train(): 73 | print('[Dataset Configuration]') 74 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir) 75 | print('\tNumber of classes: %d' % FLAGS.num_classes) 76 | print('\tNumber of training images: %d' % FLAGS.num_train_instance) 77 | print('\tNumber of val images: %d' % FLAGS.num_val_instance) 78 | 79 | print('[Network Configuration]') 80 | print('\tBatch size: %d' % FLAGS.batch_size) 81 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units) 82 | print('\tNetwork width multiplier: %d' % FLAGS.k) 83 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1)) 84 | print('\tBasemodel file: %s' % FLAGS.basemodel) 85 | 86 | print('[Optimization Configuration]') 87 | print('\tL2 loss weight: %f' % FLAGS.l2_weight) 88 | print('\tOverlap loss weight: %f' % FLAGS.gamma1) 89 | print('\tWeight split loss weight: %f' % FLAGS.gamma2) 90 | print('\tUniform loss weight: %f' % FLAGS.gamma3) 91 | print('\tDropout keep probability: %f' % FLAGS.dropout_keep_prob) 92 | print('\tThe momentum optimizer: %f' % FLAGS.momentum) 93 | print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale) 94 | print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss) 95 | print('\tInitial learning rate: %f' % FLAGS.initial_lr) 96 | print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch) 97 | print('\tLearning rate decay: %f' % FLAGS.lr_decay) 98 | print('\tFinetune: %d' % FLAGS.finetune) 99 | 100 | print('[Training Configuration]') 101 | print('\tTrain dir: %s' % FLAGS.train_dir) 102 | print('\tTraining max steps: %d' % FLAGS.max_steps) 103 | print('\tSteps per displaying info: %d' % FLAGS.display) 104 | print('\tSteps per validation: %d' % FLAGS.val_interval) 105 | print('\tSteps during validation: %d' % FLAGS.val_iter) 106 | print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval) 107 | print('\tSteps per plot splits: %d' % FLAGS.group_summary_interval) 108 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction) 109 | print('\tLog device placement: %d' % FLAGS.log_device_placement) 110 | 111 | 112 | with tf.Graph().as_default(): 113 | init_step = 0 114 | global_step = tf.Variable(0, trainable=False, name='global_step') 115 | 116 | # Get images and labels of CIFAR-100 117 | print('Load CIFAR-100 dataset') 118 | train_dataset_path = os.path.join(FLAGS.data_dir, 'train') 119 | val_dataset_path = os.path.join(FLAGS.data_dir, 'val') 120 | print('\tLoading training data from %s' % train_dataset_path) 121 | with tf.variable_scope('train_image'): 122 | cifar100_train = cifar100.CIFAR100Runner(train_dataset_path, image_per_thread=32, 123 | shuffle=True, distort=True, capacity=10000) 124 | train_images, train_labels = cifar100_train.get_inputs(FLAGS.batch_size) 125 | print('\tLoading validation data from %s' % val_dataset_path) 126 | with tf.variable_scope('val_image'): 127 | cifar100_val = cifar100.CIFAR100Runner(val_dataset_path, image_per_thread=32, 128 | shuffle=False, distort=False, capacity=5000) 129 | # shuffle=False, distort=False, capacity=10000) 130 | val_images, val_labels = cifar100_val.get_inputs(FLAGS.batch_size) 131 | 132 | # Build a Graph that computes the predictions from the inference model. 133 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3]) 134 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size]) 135 | 136 | # Build model 137 | lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(',')) 138 | lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size for s in lr_decay_steps]) 139 | with tf.device('/GPU:0'): 140 | hp = resnet.HParams(batch_size=FLAGS.batch_size, 141 | num_classes=FLAGS.num_classes, 142 | num_residual_units=FLAGS.num_residual_units, 143 | k=FLAGS.k, 144 | weight_decay=FLAGS.l2_weight, 145 | ngroups1=FLAGS.ngroups1, 146 | ngroups2=FLAGS.ngroups2, 147 | ngroups3=FLAGS.ngroups3, 148 | gamma1=FLAGS.gamma1, 149 | gamma2=FLAGS.gamma2, 150 | gamma3=FLAGS.gamma3, 151 | dropout_keep_prob=FLAGS.dropout_keep_prob, 152 | momentum=FLAGS.momentum, 153 | bn_no_scale=FLAGS.bn_no_scale, 154 | weighted_group_loss=FLAGS.weighted_group_loss, 155 | finetune=FLAGS.finetune) 156 | network = resnet.ResNet(hp, images, labels, global_step) 157 | network.build_model() 158 | network.build_train_op() 159 | print('Number of Weights: %d' % network._weights) 160 | print('FLOPs: %d' % network._flops) 161 | 162 | train_summary_op = tf.summary.merge_all() # Summaries(training) 163 | 164 | # Build an initialization operation to run below. 165 | init = tf.global_variables_initializer() 166 | 167 | # Start running operations on the Graph. 168 | sess = tf.Session(config=tf.ConfigProto( 169 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction), 170 | allow_soft_placement=True, 171 | log_device_placement=FLAGS.log_device_placement)) 172 | sess.run(init) 173 | 174 | # Create a saver. 175 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) 176 | if FLAGS.checkpoint is not None: 177 | saver.restore(sess, FLAGS.checkpoint) 178 | init_step = global_step.eval(session=sess) 179 | print('Load checkpoint %s' % FLAGS.checkpoint) 180 | elif FLAGS.basemodel: 181 | # Define a different saver to load model checkpoints 182 | # Select only base variables (exclude split layers) 183 | print('Load parameters from basemodel %s' % FLAGS.basemodel) 184 | variables = tf.global_variables() 185 | vars_restore = [var for var in variables 186 | if not "Momentum" in var.name and 187 | not "group" in var.name and 188 | not "global_step" in var.name] 189 | # vars_restore = [var for var in variables 190 | # if not "alpha" in var.name and 191 | # not "fc_beta" in var.name and 192 | # not "unit_3" in var.name and 193 | # not "unit_last" in var.name and 194 | # not "logits" in var.name and 195 | # not "Momentum" in var.name and 196 | # not "global_step" in var.name] 197 | saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000) 198 | saver_restore.restore(sess, FLAGS.basemodel) 199 | else: 200 | print('No checkpoint file of basemodel found. Start from the scratch.') 201 | 202 | # Start queue runners & summary_writer 203 | cifar100_train.start_threads(sess, n_threads=20) 204 | cifar100_val.start_threads(sess, n_threads=1) 205 | 206 | if not os.path.exists(FLAGS.train_dir): 207 | os.mkdir(FLAGS.train_dir) 208 | summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))), 209 | sess.graph) 210 | 211 | # Training! 212 | val_best_acc = 0.0 213 | for step in xrange(init_step, FLAGS.max_steps): 214 | # val 215 | if step % FLAGS.val_interval == 0: 216 | val_loss, val_acc = 0.0, 0.0 217 | for i in range(FLAGS.val_iter): 218 | val_images_val, val_labels_val = sess.run([val_images, val_labels]) 219 | loss_value, acc_value = sess.run([network.loss, network.acc], 220 | feed_dict={network.is_train:False, images:val_images_val, labels:val_labels_val}) 221 | val_loss += loss_value 222 | val_acc += acc_value 223 | val_loss /= FLAGS.val_iter 224 | val_acc /= FLAGS.val_iter 225 | val_best_acc = max(val_best_acc, val_acc) 226 | format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f') 227 | print (format_str % (datetime.now(), step, val_loss, val_acc)) 228 | 229 | val_summary = tf.Summary() 230 | val_summary.value.add(tag='val/loss', simple_value=val_loss) 231 | val_summary.value.add(tag='val/acc', simple_value=val_acc) 232 | val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc) 233 | summary_writer.add_summary(val_summary, step) 234 | summary_writer.flush() 235 | 236 | # Train 237 | lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step) 238 | start_time = time.time() 239 | train_images_val, train_labels_val = sess.run([train_images, train_labels]) 240 | _, loss_value, acc_value, train_summary_str = \ 241 | sess.run([network.train_op, network.loss, network.acc, train_summary_op], 242 | feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val}) 243 | duration = time.time() - start_time 244 | 245 | assert not np.isnan(loss_value) 246 | 247 | # Display & Summary(training) 248 | if step % FLAGS.display == 0: 249 | num_examples_per_step = FLAGS.batch_size 250 | examples_per_sec = num_examples_per_step / duration 251 | sec_per_batch = float(duration) 252 | format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f ' 253 | 'sec/batch)') 254 | print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value, 255 | examples_per_sec, sec_per_batch)) 256 | summary_writer.add_summary(train_summary_str, step) 257 | 258 | # Save the model checkpoint periodically. 259 | if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps: 260 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 261 | saver.save(sess, checkpoint_path, global_step=step) 262 | 263 | if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: 264 | char = sys.stdin.read(1) 265 | if char == 'b': 266 | embed() 267 | 268 | # Plot grouped weight matrices as image summary 269 | filters = [16, 16 * FLAGS.k, 32 * FLAGS.k, 64 * FLAGS.k] 270 | if FLAGS.group_summary_interval is not None: 271 | if step % FLAGS.group_summary_interval == 0: 272 | img_summaries = [] 273 | 274 | if FLAGS.ngroups1 > 1: 275 | logits_weights = get_var_value('logits/fc/weights', sess) 276 | split_p1 = softmax(get_var_value('group/split_p1/alpha', sess)) 277 | split_q1 = softmax(get_var_value('group/split_q1/alpha', sess)) 278 | feature_indices = np.argsort(np.argmax(split_p1, axis=0)) 279 | class_indices = np.argsort(np.argmax(split_q1, axis=0)) 280 | 281 | img_summaries.append(img_to_summary(np.repeat(split_p1[:, feature_indices], 20, axis=0), 'split_p1')) 282 | img_summaries.append(img_to_summary(np.repeat(split_q1[:, class_indices], 20, axis=0), 'split_q1')) 283 | img_summaries.append(img_to_summary(np.abs(logits_weights[feature_indices, :][:, class_indices]), 'logits')) 284 | 285 | if FLAGS.ngroups2 > 1: 286 | unit_3_0_shortcut = get_var_value('unit_3_0/shortcut/kernel', sess) 287 | unit_3_0_conv_1 = get_var_value('unit_3_0/conv_1/kernel', sess) 288 | unit_3_0_conv_2 = get_var_value('unit_3_0/conv_2/kernel', sess) 289 | unit_3_1_conv_1 = get_var_value('unit_3_1/conv_1/kernel', sess) 290 | unit_3_1_conv_2 = get_var_value('unit_3_1/conv_2/kernel', sess) 291 | split_p2 = softmax(get_var_value('group/split_p2/alpha', sess)) 292 | split_q2 = _merge_split_q(split_p1, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2)) 293 | split_r21 = softmax(get_var_value('group/split_r21/alpha', sess)) 294 | split_r22 = softmax(get_var_value('group/split_r22/alpha', sess)) 295 | feature_indices1 = np.argsort(np.argmax(split_p2, axis=0)) 296 | feature_indices2 = np.argsort(np.argmax(split_q2, axis=0)) 297 | feature_indices3 = np.argsort(np.argmax(split_r21, axis=0)) 298 | feature_indices4 = np.argsort(np.argmax(split_r22, axis=0)) 299 | unit_3_0_shortcut_img = np.abs(unit_3_0_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2], filters[3])) 300 | unit_3_0_conv_1_img = np.abs(unit_3_0_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[3] * 3)) 301 | unit_3_0_conv_2_img = np.abs(unit_3_0_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) 302 | unit_3_1_conv_1_img = np.abs(unit_3_1_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) 303 | unit_3_1_conv_2_img = np.abs(unit_3_1_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3)) 304 | img_summaries.append(img_to_summary(np.repeat(split_p2[:, feature_indices1], 20, axis=0), 'split_p2')) 305 | img_summaries.append(img_to_summary(np.repeat(split_r21[:, feature_indices3], 20, axis=0), 'split_r21')) 306 | img_summaries.append(img_to_summary(np.repeat(split_r22[:, feature_indices4], 20, axis=0), 'split_r22')) 307 | img_summaries.append(img_to_summary(unit_3_0_shortcut_img, 'unit_3_0/shortcut_kernel')) 308 | img_summaries.append(img_to_summary(unit_3_0_conv_1_img, 'unit_3_0/conv_1_kernel')) 309 | img_summaries.append(img_to_summary(unit_3_0_conv_2_img, 'unit_3_0/conv_2_kernel')) 310 | img_summaries.append(img_to_summary(unit_3_1_conv_1_img, 'unit_3_1/conv_1_kernel')) 311 | img_summaries.append(img_to_summary(unit_3_1_conv_2_img, 'unit_3_1/conv_2_kernel')) 312 | 313 | if FLAGS.ngroups3 > 1: 314 | unit_2_0_shortcut = get_var_value('unit_2_0/shortcut/kernel', sess) 315 | unit_2_0_conv_1 = get_var_value('unit_2_0/conv_1/kernel', sess) 316 | unit_2_0_conv_2 = get_var_value('unit_2_0/conv_2/kernel', sess) 317 | unit_2_1_conv_1 = get_var_value('unit_2_1/conv_1/kernel', sess) 318 | unit_2_1_conv_2 = get_var_value('unit_2_1/conv_2/kernel', sess) 319 | split_p3 = softmax(get_var_value('group/split_p3/alpha', sess)) 320 | split_q3 = _merge_split_q(split_p2, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3)) 321 | split_r31 = softmax(get_var_value('group/split_r31/alpha', sess)) 322 | split_r32 = softmax(get_var_value('group/split_r32/alpha', sess)) 323 | feature_indices1 = np.argsort(np.argmax(split_p3, axis=0)) 324 | feature_indices2 = np.argsort(np.argmax(split_q3, axis=0)) 325 | feature_indices3 = np.argsort(np.argmax(split_r31, axis=0)) 326 | feature_indices4 = np.argsort(np.argmax(split_r32, axis=0)) 327 | unit_2_0_shortcut_img = np.abs(unit_2_0_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[1], filters[2])) 328 | unit_2_0_conv_1_img = np.abs(unit_2_0_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[1] * 3, filters[2] * 3)) 329 | unit_2_0_conv_2_img = np.abs(unit_2_0_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3)) 330 | unit_2_1_conv_1_img = np.abs(unit_2_1_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3)) 331 | unit_2_1_conv_2_img = np.abs(unit_2_1_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3)) 332 | img_summaries.append(img_to_summary(np.repeat(split_p3[:, feature_indices1], 20, axis=0), 'split_p3')) 333 | img_summaries.append(img_to_summary(np.repeat(split_r31[:, feature_indices3], 20, axis=0), 'split_r31')) 334 | img_summaries.append(img_to_summary(np.repeat(split_r32[:, feature_indices4], 20, axis=0), 'split_r32')) 335 | img_summaries.append(img_to_summary(unit_2_0_shortcut_img, 'unit_2_0/shortcut_kernel')) 336 | img_summaries.append(img_to_summary(unit_2_0_conv_1_img, 'unit_2_0/conv_1_kernel')) 337 | img_summaries.append(img_to_summary(unit_2_0_conv_2_img, 'unit_2_0/conv_2_kernel')) 338 | img_summaries.append(img_to_summary(unit_2_1_conv_1_img, 'unit_2_1/conv_1_kernel')) 339 | img_summaries.append(img_to_summary(unit_2_1_conv_2_img, 'unit_2_1/conv_2_kernel')) 340 | 341 | if img_summaries: # If not empty 342 | img_summary = tf.Summary(value=img_summaries) 343 | summary_writer.add_summary(img_summary, step) 344 | summary_writer.flush() 345 | 346 | 347 | def get_var_value(var_name, sess): 348 | return [var for var in tf.global_variables() if var_name in var.name][0].eval(session=sess) 349 | 350 | 351 | def softmax(logits, axis=0): 352 | logits_diff = logits - np.max(logits, axis=axis, keepdims=True) 353 | exps = np.exp(logits_diff) 354 | return exps / np.sum(exps, axis=axis, keepdims=True) 355 | 356 | 357 | def img_to_summary(img, tag="img"): 358 | s = StringIO() 359 | plt.imsave(s, img, cmap='bone', format='png') 360 | summary = tf.Summary.Value(tag=tag, 361 | image=tf.Summary.Image(encoded_image_string=s.getvalue(), 362 | height=img.shape[0], 363 | width=img.shape[1])) 364 | return summary 365 | 366 | def _merge_split_q(q, merge_idxs, name='merge'): 367 | ngroups, dim = q.shape 368 | max_idx = np.max(merge_idxs) 369 | temp_list = [] 370 | for i in range(max_idx + 1): 371 | temp = [] 372 | for j in range(ngroups): 373 | if merge_idxs[j] == i: 374 | temp.append(q[i,:]) 375 | temp_list.append(np.sum(temp, axis=0)) 376 | ret = np.array(temp_list) 377 | 378 | return ret 379 | 380 | 381 | def _get_even_merge_idxs(N, split): 382 | assert N >= split 383 | num_elems = [(N + split - i - 1)/split for i in range(split)] 384 | expand_split = [[i] * n for i, n in enumerate(num_elems)] 385 | return [t for l in expand_split for t in l] 386 | 387 | 388 | def main(argv=None): # pylint: disable=unused-argument 389 | train() 390 | 391 | 392 | if __name__ == '__main__': 393 | tf.app.run() 394 | --------------------------------------------------------------------------------