├── eval.sh
├── split.sh
├── .gitignore
├── group.sh
├── README.md
├── download_cifar100.py
├── cifar100.py
├── utils.py
├── resnet.py
├── eval.py
├── resnet_split.py
├── train_split.py
└── train.py
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | export CUDA_VISIBLE_DEVICES=0
3 | export LD_PRELOAD="/usr/lib/libtcmalloc.so"
4 | checkpoint="./split_2-2-2/model.ckpt-119999"
5 | basemodel="./group_2-2-2/model.ckpt-199999"
6 | output_file="./split_2-2-2/eval-119999.pkl"
7 | #data_dir="./cifar100/train_val_split"
8 | data_dir="/data1/dalgu/cifar100/train_val_split"
9 |
10 | python eval.py --checkpoint $checkpoint \
11 | --basemodel $basemodel \
12 | --output_file $output_file \
13 | --data_dir $data_dir \
14 | --batch_size 100 \
15 | --test_iter 100 \
16 | --num_residual_units 2 \
17 | --k 8 \
18 | --ngroups1 2 \
19 | --ngroups2 2 \
20 | --ngroups3 2 \
21 | --gpu_fraction 0.96 \
22 | --display 10 \
23 | #--finetune True \
24 | #--load_last_layers True \
25 |
--------------------------------------------------------------------------------
/split.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | export CUDA_VISIBLE_DEVICES=0
3 | export LD_PRELOAD="/usr/lib/libtcmalloc.so"
4 | train_dir="./split_2-2-2"
5 | #data_dir="./cifar100/train_val_split"
6 | data_dir="/data1/dalgu/cifar100/train_val_split"
7 |
8 | python train_split.py --train_dir $train_dir \
9 | --data_dir $data_dir \
10 | --batch_size 90 \
11 | --test_interval 500 \
12 | --test_iter 50 \
13 | --num_residual_units 2 \
14 | --k 8 \
15 | --ngroups1 2 \
16 | --ngroups2 2 \
17 | --ngroups3 2 \
18 | --l2_weight 0.0005 \
19 | --initial_lr 0.1 \
20 | --lr_step_epoch 100.0,140.0 \
21 | --lr_decay 0.1 \
22 | --max_steps 120000 \
23 | --checkpoint_interval 5000 \
24 | --gpu_fraction 0.96 \
25 | --display 100 \
26 | --basemodel "./group_2-2-2/model.ckpt-199999" \
27 | #--checkpoint "./split_2-2-2/model.ckpt-40000" \
28 | #--finetune True \
29 | #--load_last_layers True \
30 |
31 |
32 | # Finetune with Deep split
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## General
2 |
3 | # Compiled Object files
4 | *.slo
5 | *.lo
6 | *.o
7 | *.cuo
8 |
9 | # Compiled Dynamic libraries
10 | *.so
11 | *.dylib
12 |
13 | # Compiled Static libraries
14 | *.lai
15 | *.la
16 | *.a
17 |
18 | # Compiled protocol buffers
19 | *.pb.h
20 | *.pb.cc
21 | *_pb2.py
22 |
23 | # Compiled python
24 | *.pyc
25 |
26 | # Compiled MATLAB
27 | *.mex*
28 |
29 | # IPython notebook checkpoints
30 | .ipynb_checkpoints
31 |
32 | # Editor temporaries
33 | *.swp
34 | *~
35 |
36 | # Sublime Text settings
37 | *.sublime-workspace
38 | *.sublime-project
39 |
40 | # Eclipse Project settings
41 | *.*project
42 | .settings
43 |
44 | # QtCreator files
45 | *.user
46 |
47 | # PyCharm files
48 | .idea
49 |
50 | # OSX dir files
51 | .DS_Store
52 |
53 | ## Project specific
54 | # CIFAR-100 dataset
55 | cifar100*
56 | cifar-100-binary*
57 | !cifar100.py
58 |
59 | # Training & evaluation shell files
60 | *.sh
61 | !group.sh
62 | !split.sh
63 | !eval.sh
64 | scripts/
65 |
66 | # Trained checkpoints
67 | baseline*/
68 | group*/
69 | split*/
70 |
--------------------------------------------------------------------------------
/group.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | export CUDA_VISIBLE_DEVICES=0
3 | # Your tcmalloc .so path
4 | export LD_PRELOAD="/usr/lib/libtcmalloc.so"
5 | train_dir="./group_2-2-2"
6 | # Our train/val split dataset
7 | # Run 'python download_cifar100.py' before training
8 | data_dir="./cifar100/train_val_split"
9 |
10 | python train.py --train_dir $train_dir \
11 | --data_dir $data_dir \
12 | --batch_size 90 \
13 | --test_interval 500 \
14 | --test_iter 50 \
15 | --num_residual_units 2 \
16 | --k 8 \
17 | --ngroups1 2 \
18 | --ngroups2 2 \
19 | --ngroups3 2 \
20 | --l2_weight 0.0001 \
21 | --gamma1 1.0 \
22 | --gamma2 1.0 \
23 | --gamma3 10.0 \
24 | --dropout_keep_prob 0.5 \
25 | --initial_lr 0.1 \
26 | --lr_step_epoch 240.0,300.0 \
27 | --lr_decay 0.1 \
28 | --bn_no_scale True \
29 | --weighted_group_loss True \
30 | --max_steps 200000 \
31 | --checkpoint_interval 5000 \
32 | --group_summary_interval 5000 \
33 | --gpu_fraction 0.96 \
34 | --display 100 \
35 | #--checkpoint "./group_2-2-2/model.ckpt-149999" \
36 | #--finetune True \
37 | #--basemodel "./baseline/model.ckpt-449999" \
38 | #--load_last_layers True \
39 |
40 |
41 | # Deep split(2-2-2)
42 | # Dropout with prob 0.5
43 | # Softmax reparametrization & Training from scratch
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # splitnet-wrn
2 |
3 | TensorFlow implementation of SplitNet: Learning to Semantically Split Deep Networks for Parameter Reduction and Model Parallelization, ICML 2017
4 |
5 | 
6 |
7 | - Juyong Kim\*(SNU), Yookoon Park\*(SNU), Gunhee Kim(SNU), and Sung Ju Hwang(UNIST) (\*: Equal contributions)
8 |
9 | We propose a novel deep neural network that is both lightweight and effectively structured for model parallelization. Our network, which we name as *SplitNet*, automatically learns to split the network weights into either a set or a hierarchy of multiple groups that use disjoint sets of features, by learning both the class-to-group and feature-to-group assignment matrices along with the network weights. This produces a tree-structured network that involves no connection between branched subtrees of semantically disparate class groups. SplitNet thus greatly reduces the number of parameters and required computations, and is also embarrassingly model-parallelizable at test time, since the evaluation for each subnetwork is completely independent except for the shared lower layer weights that can be duplicated over multiple processors, or assigned to a separate processor. We validate our method with two different deep network models (ResNet and AlexNet) on two datasets (CIFAR-100 and ILSVRC 2012) for image classification, on which our method obtains networks with significantly reduced number of parameters while achieving comparable or superior accuracies over original full deep networks, and accelerated test speed with multiple GPUs.
10 |
11 | ## Prerequisite
12 |
13 | 1. TensorFlow
14 | 2. Train/val/test split of CIFAR-100 dataset(please run `python download_cifar100.py`)
15 |
16 | ## How To Run
17 |
18 | ```shell
19 | # Clone the repo.
20 | git clone https://github.com/dalgu90/splitnet-wrn.git
21 | cd splitnet-wrn
22 |
23 | # Download CIFAR-100 dataset and split train set into train/val
24 | python download_cifar100.py
25 |
26 | # Find grouping of deep(2-2-2) split of WRN-16-8
27 | ./group.sh
28 |
29 | # Split and finetune
30 | ./split.sh
31 |
32 | # To evaluate
33 | ./eval.sh
34 | ```
35 |
36 | ## Acknowledgement
37 |
38 | This work was supported by Samsung Research Funding Center of Samsung Electronics under Project Number SRFC-IT150203.
39 |
40 |
41 | ## Authors
42 |
43 | [Juyong Kim](http://juyongkim.com/)*1, Yookoon Park*1, [Gunhee Kim](http://www.cs.cmu.edu/~gunhee/)1, and [Sung Ju Hwang](http://www.sungjuhwang.com/)2
44 |
45 | 1: [Vision and Learning Lab](http://vision.snu.ac.kr/) @ Computer Science and Engineering, Seoul National University, Seoul, Korea
46 | 2: [MLVR Lab](http://ml.unist.ac.kr/) @ School of Electrical and Computer Engineering, UNIST, Ulsan, South Korea
47 | \*: Equal contribution
48 |
49 |
50 | ## License
51 | ```
52 | MIT license
53 | ```
54 | If you find any problem, please feel free to contact to the authors. :^)
55 |
--------------------------------------------------------------------------------
/download_cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import tarfile
5 | import cPickle as pickle
6 | import numpy as np
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--dataset_path", help="The directory to save splited CIFAR-100 dataset")
11 | args = parser.parse_args()
12 |
13 | # CIFAR-100 download parameters
14 | cifar_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
15 | cifar_dpath = 'cifar100'
16 | cifar_py_name = 'cifar-100-python'
17 | cifar_fname = cifar_py_name + '.tar.gz'
18 |
19 | # CIFAR-100 dataset train/val split parameters
20 | dataset_path = 'cifar100/train_val_split' if not args.dataset_path else args.dataset_path
21 | dataset_path = os.path.expanduser(dataset_path)
22 | num_train_instance = 45000
23 | num_val_instance = 50000 - num_train_instance
24 | num_test_instance = 10000
25 |
26 | def download_file(url, path):
27 | import urllib2
28 | file_name = url.split('/')[-1]
29 | u = urllib2.urlopen(url)
30 | f = open(os.path.join(path, file_name), 'wb')
31 | meta = u.info()
32 | file_size = int(meta.getheaders("Content-Length")[0])
33 | print "Downloading: %s Bytes: %s" % (file_name, file_size)
34 |
35 | download_size = 0
36 | block_size = 8192
37 | while True:
38 | buf = u.read(block_size)
39 | if not buf:
40 | break
41 | download_size += len(buf)
42 | f.write(buf)
43 | status = "\r%12d [%3.2f%%]" % (download_size, download_size * 100. / file_size)
44 | print status,
45 | sys.stdout.flush()
46 | f.close()
47 |
48 | # Check if the dataset split already exists
49 | if os.path.exists(dataset_path) and os.path.exists(os.path.join(dataset_path, 'train')):
50 | print('CIFAR-100 train/val split exists\nNothing to be done... Quit!')
51 | sys.exit(0)
52 |
53 | # Download and extract CIFAR-100
54 | if not os.path.exists(os.path.join(cifar_dpath, cifar_py_name)) \
55 | or not os.path.exists(os.path.join(cifar_dpath, cifar_py_name, 'train')):
56 | print('Downloading CIFAR-100')
57 | if not os.path.exists(cifar_dpath):
58 | os.makedirs(cifar_dpath)
59 | tar_fpath = os.path.join(cifar_dpath, cifar_fname)
60 | if not os.path.exists(tar_fpath) or os.path.getsize(tar_fpath) != 169001437:
61 | download_file(cifar_url, cifar_dpath)
62 | print('Extracting CIFAR-100')
63 | with tarfile.open(tar_fpath) as tar:
64 | tar.extractall(path=cifar_dpath)
65 |
66 | # Load the dataset and split
67 | print('Load CIFAR-100 dataset')
68 | with open(os.path.join(cifar_dpath, cifar_py_name, 'train')) as fd:
69 | train_orig = pickle.load(fd)
70 | train_orig_data = train_orig['data']
71 | train_orig_label = np.array(train_orig['fine_labels'], dtype=np.uint8)
72 | with open(os.path.join(cifar_dpath, cifar_py_name, 'test')) as fd:
73 | test_orig = pickle.load(fd)
74 | test_orig_data = test_orig['data']
75 | test_orig_label = np.array(test_orig['fine_labels'], dtype=np.uint8)
76 |
77 | # Split the dataset
78 | print('Split the dataset')
79 | train_val_idxs = range(50000)
80 | random.shuffle(train_val_idxs)
81 | train = {'data': train_orig_data[train_val_idxs[:num_train_instance], :],
82 | 'labels': train_orig_label[train_val_idxs[:num_train_instance]]}
83 | val = {'data': train_orig_data[train_val_idxs[num_train_instance:] ,:],
84 | 'labels': train_orig_label[train_val_idxs[num_train_instance:]]}
85 | train_val = {'data':train_orig_data, 'labels':train_orig_label}
86 | test = {'data':test_orig_data, 'labels':test_orig_label}
87 |
88 | # Save the dataset split
89 | print('Save the dataset split')
90 | if not os.path.exists(dataset_path):
91 | os.makedirs(dataset_path)
92 | for name, data in zip(['train', 'val', 'train_val', 'test'], [train, val, train_val, test]):
93 | print('[%s] ' % name + ', '.join(['%s: %s' % (k, str(v.shape)) for k, v in data.iteritems()]))
94 | with open(os.path.join(dataset_path, name), 'wb') as fd:
95 | pickle.dump(data, fd)
96 |
97 | print 'done'
98 |
--------------------------------------------------------------------------------
/cifar100.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import random
7 | import threading
8 | import cPickle as pickle
9 | import numpy as np
10 | import skimage.util
11 |
12 | import tensorflow as tf
13 |
14 | # Process images of this size. Note that this differs from the original CIFAR
15 | # image size of 32 x 32. If one alters this number, then the entire model
16 | # architecture will change and any model would need to be retrained.
17 | IMAGE_SIZE = 32
18 |
19 | # Global constants describing the CIFAR-100 data set.
20 | NUM_CLASSES = 100
21 |
22 |
23 | class ThreadsafeIter:
24 | """Takes an iterator/generator and makes it thread-safe by
25 | serializing call to the `next` method of given iterator/generator.
26 | """
27 | def __init__(self, it):
28 | self.it = it
29 | self.lock = threading.Lock()
30 |
31 | def __iter__(self):
32 | return self
33 |
34 | def next(self):
35 | with self.lock:
36 | return self.it.next()
37 |
38 |
39 | class CIFAR100Runner(object):
40 | _image_summary_added = False
41 | """
42 | This class manages the the background threads needed to fill
43 | a queue full of data.
44 | """
45 | def __init__(self, pkl_path, shuffle=False, distort=False,
46 | capacity=2000, image_per_thread=16):
47 | self._shuffle = shuffle
48 | self._distort = distort
49 | with open(pkl_path, 'rb') as fd:
50 | data = pickle.load(fd)
51 | self._images = data['data'].reshape([-1, 3, 32, 32]).transpose((0, 2, 3, 1)).copy(order='C')
52 | self._labels = data['labels'] # numpy 1-D array
53 | self.size = len(self._labels)
54 |
55 | self.queue = tf.FIFOQueue(shapes=[[32,32,3], []],
56 | dtypes=[tf.float32, tf.int32],
57 | capacity=capacity)
58 | # self.queue = tf.RandomShuffleQueue(shapes=[[32,32,3], []],
59 | # dtypes=[tf.float32, tf.int32],
60 | # capacity=capacity,
61 | # min_after_dequeue=min_after_dequeue)
62 | self.dataX = tf.placeholder(dtype=tf.float32, shape=[None,32,32,3])
63 | self.dataY = tf.placeholder(dtype=tf.int32, shape=[None,])
64 | self.enqueue_op = self.queue.enqueue_many([self.dataX, self.dataY])
65 | self.image_per_thread = image_per_thread
66 |
67 | self._image_summary_added = False
68 |
69 |
70 | def _preprocess_image(self, input_image):
71 | """Preprocess a single image by crop and whitening(and augmenting if needed).
72 |
73 | Args:
74 | input_image: An image. 3D tensor of [height, width, channel] size.
75 |
76 | Returns:
77 | output_image: Preprocessed image. 3D tensor of size same as input_image.gj
78 | """
79 | # Crop
80 | image = input_image
81 | if self._distort:
82 | image = skimage.util.pad(image, ((4,4), (4,4), (0,0)), 'reflect')
83 | crop_h = image.shape[0] - 32
84 | crop_h_before = random.choice(range(crop_h))
85 | crop_h_after = crop_h - crop_h_before
86 | crop_w = image.shape[1] - 32
87 | crop_w_before = random.choice(range(crop_w))
88 | crop_w_after = crop_w - crop_w_before
89 | image = skimage.util.crop(image, ((crop_h_before, crop_h_after), (crop_w_before, crop_w_after), (0, 0)))
90 | else:
91 | crop_h = image.shape[0] - 32
92 | crop_w = image.shape[1] - 32
93 | if crop_w != 0 or crop_h != 0:
94 | image = skimage.util.crop(image, ((crop_h/2, (crop_h+1)/2), (crop_w/2, (crop_w+1)/2), (0, 0)))
95 |
96 | # Random horizontal flip
97 | if self._distort:
98 | if random.choice(range(2)) == 1:
99 | for i in range(image.shape[2]):
100 | image[:,:,i] = np.fliplr(image[:,:,i])
101 |
102 | # Image whitening
103 | mean = np.mean(image, axis=(0,1), dtype=np.float32)
104 | std = np.std(image, axis=(0,1), dtype=np.float32)
105 | output_image = (image - mean) / std
106 |
107 | return output_image
108 |
109 | def _preprocess_images(self, input_images):
110 | output_images = np.zeros_like(input_images, dtype=np.float32)
111 | for i in range(output_images.shape[0]):
112 | output_images[i] = self._preprocess_image(input_images[i])
113 |
114 | return output_images
115 |
116 | def get_inputs(self, batch_size):
117 | """
118 | Return's tensors containing a batch of images and labels
119 | """
120 | images_batch, labels_batch = self.queue.dequeue_many(batch_size)
121 | if not CIFAR100Runner._image_summary_added:
122 | tf.summary.image('images', images_batch)
123 | CIFAR100Runner._image_summary_added = True
124 |
125 | return images_batch, labels_batch
126 |
127 | def data_iterator(self):
128 | idxs_idx = 0
129 | idxs = np.arange(0, self.size)
130 | if self._shuffle:
131 | random.shuffle(idxs)
132 |
133 | while True:
134 | images_batch = []
135 | labels_batch = []
136 | batch_cnt = 0
137 | while True:
138 | if idxs_idx + (self.image_per_thread - batch_cnt) < self.size:
139 | temp_cnt = self.image_per_thread - batch_cnt
140 | else:
141 | temp_cnt = self.size - idxs_idx
142 |
143 | images_batch.extend(self._images[idxs[idxs_idx:idxs_idx+temp_cnt]])
144 | labels_batch.extend(self._labels[idxs[idxs_idx:idxs_idx+temp_cnt]])
145 | idxs_idx += temp_cnt
146 | batch_cnt += temp_cnt
147 |
148 | if idxs_idx == self.size:
149 | idxs_idx = 0
150 | if self._shuffle:
151 | random.shuffle(idxs)
152 |
153 | if batch_cnt == self.image_per_thread:
154 | break
155 | yield images_batch, labels_batch
156 |
157 | def thread_main(self, sess, iterator):
158 | """
159 | Function run on alternate thread. Basically, keep adding data to the queue.
160 | """
161 | while True:
162 | images_val, labels_val = iterator.next()
163 | process_images_val = self._preprocess_images(images_val)
164 | sess.run(self.enqueue_op, feed_dict={self.dataX:process_images_val, self.dataY:labels_val})
165 |
166 | def start_threads(self, sess, n_threads=1):
167 | """ Start background threads to feed queue """
168 | iterator = ThreadsafeIter(self.data_iterator())
169 | threads = []
170 | for n in range(n_threads):
171 | t = threading.Thread(target=self.thread_main, args=(sess,iterator,))
172 | t.daemon = True # thread will close when parent quits
173 | t.start()
174 | threads.append(t)
175 | return threads
176 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | ## TensorFlow helper functions
5 |
6 | WEIGHT_DECAY_KEY = 'WEIGHT_DECAY'
7 |
8 | def _relu(x, leakness=0.0, name=None):
9 | if leakness > 0.0:
10 | name = 'lrelu' if name is None else name
11 | return tf.maximum(x, x*leakness, name='lrelu')
12 | else:
13 | name = 'relu' if name is None else name
14 | return tf.nn.relu(x, name='relu')
15 |
16 |
17 | def _dropout(x, keep_prob=1.0, name=None):
18 | assert keep_prob >= 0.0 and keep_prob <= 1.0
19 | if keep_prob == 1.0:
20 | return x
21 | else:
22 | return tf.nn.dropout(x, keep_prob, name=name)
23 |
24 |
25 | def _conv(x, filter_size, out_channel, strides, pad='SAME', input_q=None, output_q=None, name='conv'):
26 | if (input_q == None)^(output_q == None):
27 | raise ValueError('Input/Output splits are not correctly given.')
28 |
29 | in_shape = x.get_shape().as_list()
30 | with tf.variable_scope(name):
31 | # Main operation: conv2d
32 | kernel = tf.get_variable('kernel', [filter_size, filter_size, in_shape[3], out_channel],
33 | tf.float32, initializer=tf.random_normal_initializer(
34 | stddev=np.sqrt(1.0/filter_size/filter_size/in_shape[3])))
35 | if kernel not in tf.get_collection(WEIGHT_DECAY_KEY):
36 | tf.add_to_collection(WEIGHT_DECAY_KEY, kernel)
37 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (kernel.name, str(kernel.get_shape().as_list())))
38 | conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1], pad)
39 |
40 | # Split and split loss
41 | if (input_q is not None) and (output_q is not None):
42 | _add_split_loss(kernel, input_q, output_q)
43 |
44 | return conv
45 |
46 |
47 | def _conv_with_init(x, filter_size, out_channel, strides, pad='SAME', init_k=None, name='conv'):
48 | in_shape = x.get_shape().as_list()
49 | with tf.variable_scope(name):
50 | # Main operation: conv2d
51 | if init_k is not None:
52 | initializer_k = tf.constant_initializer(init_k)
53 | else:
54 | initializer_k =tf.random_normal_initializer(stddev=np.sqrt(1.0/filter_size/filter_size/in_shape[3]))
55 | kernel = tf.get_variable('kernel', [filter_size, filter_size, in_shape[3], out_channel],
56 | tf.float32, initializer=initializer_k)
57 | if kernel not in tf.get_collection(WEIGHT_DECAY_KEY):
58 | tf.add_to_collection(WEIGHT_DECAY_KEY, kernel)
59 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (kernel.name, str(kernel.get_shape().as_list())))
60 | conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1], pad)
61 |
62 | return conv
63 |
64 |
65 | def _fc(x, out_dim, input_q=None, output_q=None, name='fc'):
66 | if (input_q == None)^(output_q == None):
67 | raise ValueError('Input/Output splits are not correctly given.')
68 |
69 | with tf.variable_scope(name):
70 | # Main operation: fc
71 | w = tf.get_variable('weights', [x.get_shape()[1], out_dim],
72 | tf.float32, initializer=tf.random_normal_initializer(
73 | stddev=np.sqrt(1.0/x.get_shape().as_list()[1])))
74 | if w not in tf.get_collection(WEIGHT_DECAY_KEY):
75 | tf.add_to_collection(WEIGHT_DECAY_KEY, w)
76 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (w.name, str(w.get_shape().as_list())))
77 | b = tf.get_variable('biases', [out_dim], tf.float32,
78 | initializer=tf.constant_initializer(0.0))
79 | fc = tf.nn.bias_add(tf.matmul(x, w), b)
80 |
81 | # Split loss
82 | if (input_q is not None) and (output_q is not None):
83 | _add_split_loss(w, input_q, output_q)
84 |
85 | return fc
86 |
87 |
88 | def _fc_with_init(x, out_dim, init_w=None, init_b=None, name='fc'):
89 | with tf.variable_scope(name):
90 | # Main operation: fc
91 | if init_w is not None:
92 | initializer_w = tf.constant_initializer(init_w)
93 | else:
94 | initializer_w = tf.random_normal_initializer(stddev=np.sqrt(1.0/x.get_shape().as_list()[1]))
95 | if init_b is not None:
96 | initializer_b = tf.constant_initializer(init_b)
97 | else:
98 | initializer_b = tf.constant_initializer(0.0)
99 |
100 | w = tf.get_variable('weights', [x.get_shape()[1], out_dim],
101 | tf.float32, initializer=initializer_w)
102 | b = tf.get_variable('biases', [out_dim], tf.float32,
103 | initializer=initializer_b)
104 | if w not in tf.get_collection(WEIGHT_DECAY_KEY):
105 | tf.add_to_collection(WEIGHT_DECAY_KEY, w)
106 | # print('\tadded to WEIGHT_DECAY_KEY: %s(%s)' % (w.name, str(w.get_shape().as_list())))
107 | fc = tf.nn.bias_add(tf.matmul(x, w), b)
108 |
109 | return fc
110 |
111 |
112 | def _get_split_q(ngroups, dim, name='split'):
113 | with tf.variable_scope(name):
114 | alpha = tf.get_variable('alpha', shape=[ngroups, dim], dtype=tf.float32,
115 | initializer=tf.random_normal_initializer(stddev=0.01))
116 | q = tf.nn.softmax(alpha, dim=0, name='q')
117 |
118 | return q
119 |
120 | def _merge_split_q(q, merge_idxs, name='merge'):
121 | assert len(q.get_shape()) == 2
122 | ngroups, dim = q.get_shape().as_list()
123 | assert ngroups == len(merge_idxs)
124 |
125 | with tf.variable_scope(name):
126 | max_idx = np.max(merge_idxs)
127 | temp_list = []
128 | for i in range(max_idx + 1):
129 | temp = []
130 | for j in range(ngroups):
131 | if merge_idxs[j] == i:
132 | temp.append(tf.slice(q, [j, 0], [1, dim]))
133 | temp_list.append(tf.add_n(temp))
134 | ret = tf.concat(temp_list, 0)
135 |
136 | return ret
137 |
138 |
139 | def _get_even_merge_idxs(N, split):
140 | assert N >= split
141 | num_elems = [(N + split - i - 1)/split for i in range(split)]
142 | expand_split = [[i] * n for i, n in enumerate(num_elems)]
143 | return [t for l in expand_split for t in l]
144 |
145 |
146 | def _add_split_loss(w, input_q, output_q):
147 | # Check input tensors' measurements
148 | assert len(w.get_shape()) == 2 or len(w.get_shape()) == 4
149 | in_dim, out_dim = w.get_shape().as_list()[-2:]
150 | assert len(input_q.get_shape()) == 2
151 | assert len(output_q.get_shape()) == 2
152 | assert in_dim == input_q.get_shape().as_list()[1]
153 | assert out_dim == output_q.get_shape().as_list()[1]
154 | assert input_q.get_shape().as_list()[0] == output_q.get_shape().as_list()[0] # ngroups
155 | ngroups = input_q.get_shape().as_list()[0]
156 | assert ngroups > 1
157 |
158 | # Add split losses to collections
159 | T_list = []
160 | U_list = []
161 | if input_q not in tf.get_collection('OVERLAP_LOSS_WEIGHTS') \
162 | and not "concat" in input_q.op.name:
163 | tf.add_to_collection('OVERLAP_LOSS_WEIGHTS', input_q)
164 | print('\t\tAdd overlap & split loss for %s' % input_q.name)
165 | T_temp, U_temp = ([], [])
166 | for i in range(ngroups):
167 | for j in range(ngroups):
168 | if i <= j:
169 | continue
170 | T_temp.append(tf.reduce_sum(input_q[i,:] * input_q[j,:]))
171 | U_temp.append(tf.square(tf.reduce_sum(input_q[i,:])))
172 | T_list.append(tf.reduce_sum(T_temp)/(float(in_dim*(ngroups-1))/float(2*ngroups)))
173 | U_list.append(tf.reduce_sum(U_temp)/(float(in_dim*in_dim)/float(ngroups)))
174 | if output_q not in tf.get_collection('OVERLAP_LOSS_WEIGHTS') \
175 | and not "concat" in output_q.op.name:
176 | print('\t\tAdd overlap & split loss for %s' % output_q.name)
177 | tf.add_to_collection('OVERLAP_LOSS_WEIGHTS', output_q)
178 | T_temp, U_temp = ([], [])
179 | for i in range(ngroups):
180 | for j in range(ngroups):
181 | if i <= j:
182 | continue
183 | T_temp.append(tf.reduce_sum(output_q[i,:] * output_q[j,:]))
184 | U_temp.append(tf.square(tf.reduce_sum(output_q[i,:])))
185 | T_list.append(tf.reduce_sum(T_temp)/(float(out_dim*(ngroups-1))/float(2*ngroups)))
186 | U_list.append(tf.reduce_sum(U_temp)/(float(out_dim*out_dim)/float(ngroups)))
187 | if T_list:
188 | tf.add_to_collection('OVERLAP_LOSS', tf.add_n(T_list)/len(T_list))
189 | if U_list:
190 | tf.add_to_collection('UNIFORM_LOSS', tf.add_n(U_list)/len(U_list))
191 |
192 | S_list = []
193 | if w not in tf.get_collection('WEIGHT_SPLIT_WEIGHTS'):
194 | tf.add_to_collection('WEIGHT_SPLIT_WEIGHTS', w)
195 |
196 | ones_col = tf.ones((in_dim,), dtype=tf.float32)
197 | ones_row = tf.ones((out_dim,), dtype=tf.float32)
198 | if len(w.get_shape()) == 4:
199 | w_reduce = tf.reduce_mean(tf.square(w), [0, 1])
200 | w_norm = w_reduce
201 | std_dev = np.sqrt(1.0/float(w.get_shape().as_list()[0])**2/in_dim)
202 | # w_norm = w_reduce / tf.reduce_sum(w_reduce)
203 | else:
204 | w_norm = w
205 | std_dev = np.sqrt(1.0/float(in_dim))
206 | # w_norm = w / tf.sqrt(tf.reduce_sum(tf.square(w)))
207 |
208 | for i in range(ngroups):
209 | if len(w.get_shape()) == 4:
210 | wg_row = tf.transpose(tf.transpose(w_norm * tf.square(output_q[i,:])) * tf.square(ones_col - input_q[i,:]))
211 | wg_row_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_row, 1))) / (in_dim*np.sqrt(out_dim))
212 | wg_col = tf.transpose(tf.transpose(w_norm * tf.square(ones_row - output_q[i,:])) * tf.square(input_q[i,:]))
213 | wg_col_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_col, 0))) / (np.sqrt(in_dim)*out_dim)
214 | else: # len(w.get_shape()) == 2
215 | wg_row = tf.transpose(tf.transpose(w_norm * output_q[i,:]) * (ones_col - input_q[i,:]))
216 | wg_row_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_row * wg_row, 1))) / (in_dim*np.sqrt(out_dim))
217 | wg_col = tf.transpose(tf.transpose(w_norm * (ones_row - output_q[i,:])) * input_q[i,:])
218 | wg_col_l2 = tf.reduce_sum(tf.sqrt(tf.reduce_sum(wg_col * wg_col, 0))) / (np.sqrt(in_dim)*out_dim)
219 | S_list.append(wg_row_l2 + wg_col_l2)
220 | # S = tf.add_n(S_list)/((ngroups-1)/ngroups)
221 | S = tf.add_n(S_list)/(2*(ngroups-1)*std_dev/ngroups)
222 | tf.add_to_collection('WEIGHT_SPLIT', S)
223 |
224 | # Add histogram for w if split losses are added
225 | scope_name = tf.get_variable_scope().name
226 | tf.summary.histogram("%s/" % scope_name, w)
227 | print('\t\tAdd split loss for %s(%dx%d, %d groups)' \
228 | % (tf.get_variable_scope().name, in_dim, out_dim, ngroups))
229 |
230 | return
231 |
232 |
233 | def _bn(x, is_train, global_step=None, name='bn', no_scale=False):
234 | moving_average_decay = 0.9
235 | # moving_average_decay = 0.99
236 | # moving_average_decay_init = 0.99
237 | with tf.variable_scope(name):
238 | decay = moving_average_decay
239 | # if global_step is None:
240 | # decay = moving_average_decay
241 | # else:
242 | # decay = tf.cond(tf.greater(global_step, 100)
243 | # , lambda: tf.constant(moving_average_decay, tf.float32)
244 | # , lambda: tf.constant(moving_average_decay_init, tf.float32))
245 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
246 | mu = tf.get_variable('mu', batch_mean.get_shape(), tf.float32,
247 | initializer=tf.zeros_initializer(), trainable=False)
248 | sigma = tf.get_variable('sigma', batch_var.get_shape(), tf.float32,
249 | initializer=tf.ones_initializer(), trainable=False)
250 | beta = tf.get_variable('beta', batch_mean.get_shape(), tf.float32,
251 | initializer=tf.zeros_initializer())
252 | gamma = tf.get_variable('gamma', batch_var.get_shape(), tf.float32,
253 | initializer=tf.ones_initializer(), trainable=(not no_scale))
254 | # BN when training
255 | update = 1.0 - decay
256 | # with tf.control_dependencies([tf.Print(decay, [decay])]):
257 | # update_mu = mu.assign_sub(update*(mu - batch_mean))
258 | update_mu = mu.assign_sub(update*(mu - batch_mean))
259 | update_sigma = sigma.assign_sub(update*(sigma - batch_var))
260 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mu)
261 | tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_sigma)
262 |
263 | mean, var = tf.cond(is_train, lambda: (batch_mean, batch_var),
264 | lambda: (mu, sigma))
265 | bn = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
266 |
267 | # bn = tf.nn.batch_normalization(x, batch_mean, batch_var, beta, gamma, 1e-5)
268 |
269 | # bn = tf.contrib.layers.batch_norm(inputs=x, decay=decay,
270 | # updates_collections=[tf.GraphKeys.UPDATE_OPS], center=True,
271 | # scale=True, epsilon=1e-5, is_training=is_train,
272 | # trainable=True)
273 | return bn
274 |
275 |
276 | ## Other helper functions
277 |
278 |
279 |
280 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from collections import namedtuple
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 | import utils
9 |
10 |
11 | HParams = namedtuple('HParams',
12 | 'batch_size, num_classes, num_residual_units, k, weight_decay, momentum, finetune, '
13 | 'ngroups1, ngroups2, ngroups3, gamma1, gamma2, gamma3, '
14 | 'dropout_keep_prob, bn_no_scale, weighted_group_loss')
15 |
16 | class ResNet(object):
17 | def __init__(self, hp, images, labels, global_step):
18 | self._hp = hp # Hyperparameters
19 | self._images = images # Input image
20 | self._labels = labels
21 | self._global_step = global_step
22 | self.lr = tf.placeholder(tf.float32)
23 | self.is_train = tf.placeholder(tf.bool)
24 | self._counted_scope = []
25 | self._flops = 0
26 | self._weights = 0
27 |
28 | def build_model(self):
29 | print('Building model')
30 | filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k]
31 | strides = [1, 2, 2]
32 |
33 | with tf.variable_scope("group"):
34 | if self._hp.ngroups1 > 1:
35 | self.split_q1 = utils._get_split_q(self._hp.ngroups1, self._hp.num_classes, name='split_q1')
36 | self.split_p1 = utils._get_split_q(self._hp.ngroups1, filters[3], name='split_p1')
37 | tf.summary.histogram("group/split_p1/", self.split_p1)
38 | tf.summary.histogram("group/split_q1/", self.split_q1)
39 | else:
40 | self.split_q1 = None
41 | self.split_p1 = None
42 |
43 | if self._hp.ngroups2 > 1:
44 | self.split_q2 = utils._merge_split_q(self.split_p1, utils._get_even_merge_idxs(self._hp.ngroups1, self._hp.ngroups2), name='split_q2')
45 | self.split_p2 = utils._get_split_q(self._hp.ngroups2, filters[2], name='split_p2')
46 | self.split_r21 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r21')
47 | self.split_r22 = utils._get_split_q(self._hp.ngroups2, filters[3], name='split_r22')
48 | tf.summary.histogram("group/split_q2/", self.split_q2)
49 | tf.summary.histogram("group/split_p2/", self.split_p2)
50 | tf.summary.histogram("group/split_r21/", self.split_r21)
51 | tf.summary.histogram("group/split_r22/", self.split_r22)
52 | else:
53 | self.split_p2 = None
54 | self.split_q2 = None
55 | self.split_r21 = None
56 | self.split_r22 = None
57 |
58 | if self._hp.ngroups3 > 1:
59 | self.split_q3 = utils._merge_split_q(self.split_p2, utils._get_even_merge_idxs(self._hp.ngroups2, self._hp.ngroups3), name='split_q3')
60 | self.split_p3 = utils._get_split_q(self._hp.ngroups3, filters[1], name='split_p3')
61 | self.split_r31 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r31')
62 | self.split_r32 = utils._get_split_q(self._hp.ngroups3, filters[2], name='split_r32')
63 | tf.summary.histogram("group/split_q3/", self.split_q3)
64 | tf.summary.histogram("group/split_p3/", self.split_p3)
65 | tf.summary.histogram("group/split_r31/", self.split_r31)
66 | tf.summary.histogram("group/split_r32/", self.split_r32)
67 | else:
68 | self.split_p3 = None
69 | self.split_q3 = None
70 | self.split_r31 = None
71 | self.split_r32 = None
72 |
73 | # Init. conv.
74 | print('\tBuilding unit: init_conv')
75 | x = utils._conv(self._images, 3, filters[0], 1, name='init_conv')
76 |
77 | x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0')
78 | x = self._residual_block(x, name='unit_1_1')
79 |
80 | x = self._residual_block_first(x, filters[2], strides[1], input_q=self.split_p3, output_q=self.split_q3, split_r=self.split_r31, name='unit_2_0')
81 | x = self._residual_block(x, split_q=self.split_q3, split_r=self.split_r32, name='unit_2_1')
82 |
83 | x = self._residual_block_first(x, filters[3], strides[2], input_q=self.split_p2, output_q=self.split_q2, split_r=self.split_r21, name='unit_3_0')
84 | x = self._residual_block(x, split_q=self.split_q2, split_r=self.split_r22, name='unit_3_1')
85 |
86 | # Last unit
87 | with tf.variable_scope('unit_last') as scope:
88 | print('\tBuilding unit: %s' % scope.name)
89 | x = utils._bn(x, self.is_train, self._global_step)
90 | x = utils._relu(x)
91 | x = tf.reduce_mean(x, [1, 2])
92 |
93 | # Logit
94 | with tf.variable_scope('logits') as scope:
95 | print('\tBuilding unit: %s' % scope.name)
96 | x_shape = x.get_shape().as_list()
97 | x = tf.reshape(x, [-1, x_shape[1]])
98 | if self.split_p1 is not None and self.split_q1 is not None:
99 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout')
100 | x = self._fc(x, self._hp.num_classes, input_q=self.split_p1, output_q=self.split_q1)
101 |
102 | self._logits = x
103 |
104 | # Probs & preds & acc
105 | self.probs = tf.nn.softmax(x, name='probs')
106 | self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds'))
107 | ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32)
108 | zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32)
109 | correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros)
110 | self.acc = tf.reduce_mean(correct, name='acc')
111 | tf.summary.scalar('accuracy', self.acc)
112 |
113 | # Loss & acc
114 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=x, labels=self._labels)
115 | self.loss = tf.reduce_mean(loss)
116 | tf.summary.scalar('cross_entropy', self.loss)
117 |
118 |
119 | def _residual_block_first(self, x, out_channel, strides, input_q=None, output_q=None, split_r=None, name="unit"):
120 | in_channel = x.get_shape().as_list()[-1]
121 | with tf.variable_scope(name) as scope:
122 | print('\tBuilding residual unit: %s' % scope.name)
123 | x = self._bn(x, name='bn_1', no_scale=self._hp.bn_no_scale)
124 | x = self._relu(x, name='relu_1')
125 | if input_q is not None and output_q is not None and split_r is not None:
126 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1')
127 | # Shortcut connection
128 | if in_channel == out_channel:
129 | if strides == 1:
130 | shortcut = tf.identity(x)
131 | else:
132 | shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID')
133 | else:
134 | shortcut = self._conv(x, 1, out_channel, strides, input_q=input_q, output_q=output_q, name='shortcut')
135 | # Residual
136 | x = self._conv(x, 3, out_channel, strides, input_q=input_q, output_q=split_r, name='conv_1')
137 | x = self._bn(x, name='bn_2', no_scale=self._hp.bn_no_scale)
138 | x = self._relu(x, name='relu_2')
139 | if input_q is not None and output_q is not None and split_r is not None:
140 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2')
141 | x = self._conv(x, 3, out_channel, 1, input_q=split_r, output_q=output_q, name='conv_2')
142 | # Merge
143 | x = x + shortcut
144 | return x
145 |
146 | def _residual_block(self, x, split_q=None, split_r=None, name="unit"):
147 | num_channel = x.get_shape().as_list()[-1]
148 | with tf.variable_scope(name) as scope:
149 | print('\tBuilding residual unit: %s' % scope.name)
150 | # Shortcut connection
151 | shortcut = x
152 | # Residual
153 | x = self._bn(x, name='bn_1', no_scale=self._hp.bn_no_scale)
154 | x = self._relu(x, name='relu_1')
155 | if split_q is not None and split_r is not None:
156 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1')
157 | x = self._conv(x, 3, num_channel, 1, input_q=split_q, output_q=split_r, name='conv_1')
158 | x = self._bn(x, name='bn_2', no_scale=self._hp.bn_no_scale)
159 | x = self._relu(x, name='relu_2')
160 | if split_q is not None and split_r is not None:
161 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2')
162 | x = self._conv(x, 3, num_channel, 1, input_q=split_r, output_q=split_q, name='conv_2')
163 | # Merge
164 | x = x + shortcut
165 | return x
166 |
167 | def build_train_op(self):
168 | # Learning rate
169 | tf.summary.scalar('learing_rate', self.lr)
170 |
171 | losses = [self.loss]
172 |
173 | # Add l2 loss
174 | with tf.variable_scope('l2_loss'):
175 | costs = [tf.nn.l2_loss(var) for var in tf.get_collection(utils.WEIGHT_DECAY_KEY)]
176 | l2_loss = tf.multiply(self._hp.weight_decay, tf.add_n(costs))
177 | losses.append(l2_loss)
178 |
179 | # Add group split loss
180 | with tf.variable_scope('group/'):
181 | if tf.get_collection('OVERLAP_LOSS') and self._hp.gamma1 > 0:
182 | cost1 = tf.reduce_mean(tf.get_collection('OVERLAP_LOSS'))
183 | cost1 = cost1 * self._hp.gamma1
184 | tf.summary.scalar('group/overlap_loss/', cost1)
185 | losses.append(cost1)
186 |
187 | if tf.get_collection('WEIGHT_SPLIT') and self._hp.gamma2 > 0:
188 | if self._hp.weighted_group_loss:
189 | reg_weights = [tf.stop_gradient(x) for x in tf.get_collection('WEIGHT_SPLIT')]
190 | regs = [tf.stop_gradient(x) * x for x in tf.get_collection('WEIGHT_SPLIT')]
191 | cost2 = tf.reduce_sum(regs) / tf.reduce_sum(reg_weights)
192 | else:
193 | cost2 = tf.reduce_mean(tf.get_collection('WEIGHT_SPLIT'))
194 | cost2 = cost2 * self._hp.gamma2
195 | tf.summary.scalar('group/weight_split_loss/', cost2)
196 | losses.append(cost2)
197 |
198 | if tf.get_collection('UNIFORM_LOSS') and self._hp.gamma3 > 0:
199 | cost3 = tf.reduce_mean(tf.get_collection('UNIFORM_LOSS'))
200 | cost3 = cost3 * self._hp.gamma3
201 | tf.summary.scalar('group/group_uniform_loss/', cost3)
202 | losses.append(cost3)
203 |
204 | self._total_loss = tf.add_n(losses)
205 |
206 | # Gradient descent step
207 | opt = tf.train.MomentumOptimizer(self.lr, self._hp.momentum)
208 | grads_and_vars = opt.compute_gradients(self._total_loss, tf.trainable_variables())
209 | if self._hp.finetune:
210 | for idx, (grad, var) in enumerate(grads_and_vars):
211 | if "unit3" in var.op.name or \
212 | "unit_last" in var.op.name or \
213 | "logits" in var.op.name:
214 | print('Scale up learning rate of % s by 10.0' % var.op.name)
215 | grad = 10.0 * grad
216 | grads_and_vars[idx] = (grad,var)
217 |
218 | apply_grad_op = opt.apply_gradients(grads_and_vars, global_step=self._global_step)
219 |
220 |
221 | # Batch normalization moving average update
222 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
223 | if update_ops:
224 | with tf.control_dependencies(update_ops+[apply_grad_op]):
225 | self.train_op = tf.no_op()
226 | else:
227 | self.train_op = apply_grad_op
228 |
229 | # Helper functions(counts FLOPs and number of weights)
230 | def _conv(self, x, filter_size, out_channel, stride, pad="SAME", input_q=None, output_q=None, name="conv"):
231 | b, h, w, in_channel = x.get_shape().as_list()
232 | x = utils._conv(x, filter_size, out_channel, stride, pad, input_q, output_q, name)
233 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size
234 | w = in_channel * out_channel * filter_size * filter_size
235 | scope_name = tf.get_variable_scope().name + "/" + name
236 | self._add_flops_weights(scope_name, f, w)
237 | return x
238 |
239 | def _fc(self, x, out_dim, input_q=None, output_q=None, name="fc"):
240 | b, in_dim = x.get_shape().as_list()
241 | x = utils._fc(x, out_dim, input_q, output_q, name)
242 | f = 2 * (in_dim + 1) * out_dim
243 | w = (in_dim + 1) * out_dim
244 | scope_name = tf.get_variable_scope().name + "/" + name
245 | self._add_flops_weights(scope_name, f, w)
246 | return x
247 |
248 | def _bn(self, x, name="bn", no_scale=False):
249 | x = utils._bn(x, self.is_train, self._global_step, name, no_scale=no_scale)
250 | # f = 8 * self._get_data_size(x)
251 | # w = 4 * x.get_shape().as_list()[-1]
252 | # scope_name = tf.get_variable_scope().name + "/" + name
253 | # self._add_flops_weights(scope_name, f, w)
254 | return x
255 |
256 | def _relu(self, x, name="relu"):
257 | x = utils._relu(x, 0.0, name)
258 | # f = self._get_data_size(x)
259 | # scope_name = tf.get_variable_scope().name + "/" + name
260 | # self._add_flops_weights(scope_name, f, 0)
261 | return x
262 |
263 | def _dropout(self, x, keep_prob, name="dropout"):
264 | x = utils._dropout(x, keep_prob, name)
265 | return x
266 |
267 | def _get_data_size(self, x):
268 | return np.prod(x.get_shape().as_list()[1:])
269 |
270 | def _add_flops_weights(self, scope_name, f, w):
271 | if scope_name not in self._counted_scope:
272 | self._flops += f
273 | self._weights += w
274 | self._counted_scope.append(scope_name)
275 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 | from datetime import datetime
6 | import time
7 | import tensorflow as tf
8 | import numpy as np
9 | import cPickle as pickle
10 |
11 | import cifar100
12 | import resnet_split as resnet
13 |
14 |
15 |
16 | # Dataset Configuration
17 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""")
18 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""")
19 | tf.app.flags.DEFINE_integer('num_test_instance', 10000, """Number of test images.""")
20 |
21 | # Network Configuration
22 | tf.app.flags.DEFINE_integer('batch_size', 100, """Number of images to process in a batch.""")
23 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group.
24 | Total number of conv layers will be 6n+4""")
25 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""")
26 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""")
27 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""")
28 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""")
29 |
30 | # Optimization Configuration
31 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""")
32 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""")
33 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""")
34 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""")
35 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""")
36 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""")
37 |
38 | # Evaluation Configuration
39 | tf.app.flags.DEFINE_string('basemodel', './group/model.ckpt-199999', """Base model to load paramters""")
40 | tf.app.flags.DEFINE_string('checkpoint', './split/model.ckpt-149999', """Path to the model checkpoint file""")
41 | tf.app.flags.DEFINE_string('output_file', './split/eval.pkl', """Path to the result pkl file""")
42 | tf.app.flags.DEFINE_integer('test_iter', 100, """Number of test batches during the evaluation""")
43 | tf.app.flags.DEFINE_integer('display', 10, """Number of iterations to display training info.""")
44 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""")
45 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""")
46 |
47 | FLAGS = tf.app.flags.FLAGS
48 |
49 |
50 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step):
51 | lr = initial_lr
52 | for s in lr_decay_steps:
53 | if global_step >= s:
54 | lr *= lr_decay
55 | return lr
56 |
57 |
58 | def train():
59 | print('[Dataset Configuration]')
60 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
61 | print('\tNumber of classes: %d' % FLAGS.num_classes)
62 | print('\tNumber of test images: %d' % FLAGS.num_test_instance)
63 |
64 | print('[Network Configuration]')
65 | print('\tBatch size: %d' % FLAGS.batch_size)
66 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
67 | print('\tNetwork width multiplier: %d' % FLAGS.k)
68 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
69 | print('\tBasemodel file: %s' % FLAGS.basemodel)
70 |
71 | print('[Evaluation Configuration]')
72 | print('\tCheckpoint file: %s' % FLAGS.checkpoint)
73 | print('\tOutput file path: %s' % FLAGS.output_file)
74 | print('\tTest iterations: %d' % FLAGS.test_iter)
75 | print('\tSteps per displaying info: %d' % FLAGS.display)
76 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
77 | print('\tLog device placement: %d' % FLAGS.log_device_placement)
78 |
79 |
80 | with tf.Graph().as_default():
81 | global_step = tf.Variable(0, trainable=False, name='global_step')
82 |
83 | # Get images and labels of CIFAR-100
84 | print('Load CIFAR-100 dataset')
85 | test_dataset_path = os.path.join(FLAGS.data_dir, 'test')
86 | with tf.variable_scope('test_image'):
87 | cifar100_test = cifar100.CIFAR100Runner(test_dataset_path, image_per_thread=1,
88 | shuffle=False, distort=False, capacity=5000)
89 | test_images, test_labels = cifar100_test.get_inputs(FLAGS.batch_size)
90 |
91 | # Build a Graph that computes the predictions from the inference model.
92 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3])
93 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size])
94 |
95 | # Get splitted params
96 | if not FLAGS.basemodel:
97 | print('No basemodel found to load split params')
98 | sys.exit(-1)
99 | else:
100 | print('Load split params from %s' % FLAGS.basemodel)
101 |
102 | def get_perms(q_name, ngroups):
103 | split_alpha = reader.get_tensor(q_name+'/alpha')
104 | q_amax = np.argmax(split_alpha, axis=0)
105 | return [np.where(q_amax == i)[0] for i in range(ngroups)]
106 |
107 | reader = tf.train.NewCheckpointReader(FLAGS.basemodel)
108 | split_params = {}
109 |
110 | print('\tlogits...')
111 | base_logits_w = reader.get_tensor('logits/fc/weights')
112 | base_logits_b = reader.get_tensor('logits/fc/biases')
113 | split_p1_idxs = get_perms('group/split_p1', FLAGS.ngroups1)
114 | split_q1_idxs = get_perms('group/split_q1', FLAGS.ngroups1)
115 |
116 | logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]}
117 | for i in range(FLAGS.ngroups1):
118 | logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]])
119 | logits_params['biases'].append(base_logits_b[split_q1_idxs[i]])
120 | logits_params['input_perms'] = split_p1_idxs
121 | logits_params['output_perms'] = split_q1_idxs
122 | split_params['logits'] = logits_params
123 |
124 | if FLAGS.ngroups2 > 1:
125 | print('\tunit_3_x...')
126 | base_unit_3_0_shortcut_k = reader.get_tensor('unit_3_0/shortcut/kernel')
127 | base_unit_3_0_conv1_k = reader.get_tensor('unit_3_0/conv_1/kernel')
128 | base_unit_3_0_conv2_k = reader.get_tensor('unit_3_0/conv_2/kernel')
129 | base_unit_3_1_conv1_k = reader.get_tensor('unit_3_1/conv_1/kernel')
130 | base_unit_3_1_conv2_k = reader.get_tensor('unit_3_1/conv_2/kernel')
131 | split_p2_idxs = get_perms('group/split_p2', FLAGS.ngroups2)
132 | split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2))
133 | split_r21_idxs = get_perms('group/split_r21', FLAGS.ngroups2)
134 | split_r22_idxs = get_perms('group/split_r22', FLAGS.ngroups2)
135 |
136 | unit_3_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
137 | for i in range(FLAGS.ngroups2):
138 | unit_3_0_params['shortcut'].append(base_unit_3_0_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]])
139 | unit_3_0_params['conv1'].append(base_unit_3_0_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]])
140 | unit_3_0_params['conv2'].append(base_unit_3_0_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]])
141 | unit_3_0_params['p_perms'] = split_p2_idxs
142 | unit_3_0_params['q_perms'] = split_q2_idxs
143 | unit_3_0_params['r_perms'] = split_r21_idxs
144 | split_params['unit_3_0'] = unit_3_0_params
145 |
146 | unit_3_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
147 | for i in range(FLAGS.ngroups2):
148 | unit_3_1_params['conv1'].append(base_unit_3_1_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]])
149 | unit_3_1_params['conv2'].append(base_unit_3_1_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]])
150 | unit_3_1_params['p_perms'] = split_q2_idxs
151 | unit_3_1_params['r_perms'] = split_r22_idxs
152 | split_params['unit_3_1'] = unit_3_1_params
153 |
154 | if FLAGS.ngroups3 > 1:
155 | print('\tconv4_x...')
156 | base_unit_2_0_shortcut_k = reader.get_tensor('unit_2_0/shortcut/kernel')
157 | base_unit_2_0_conv1_k = reader.get_tensor('unit_2_0/conv_1/kernel')
158 | base_unit_2_0_conv2_k = reader.get_tensor('unit_2_0/conv_2/kernel')
159 | base_unit_2_1_conv1_k = reader.get_tensor('unit_2_1/conv_1/kernel')
160 | base_unit_2_1_conv2_k = reader.get_tensor('unit_2_1/conv_2/kernel')
161 | split_p3_idxs = get_perms('group/split_p3', FLAGS.ngroups3)
162 | split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
163 | split_r31_idxs = get_perms('group/split_r31', FLAGS.ngroups3)
164 | split_r32_idxs = get_perms('group/split_r32', FLAGS.ngroups3)
165 |
166 | unit_2_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
167 | for i in range(FLAGS.ngroups3):
168 | unit_2_0_params['shortcut'].append(base_unit_2_0_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]])
169 | unit_2_0_params['conv1'].append(base_unit_2_0_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]])
170 | unit_2_0_params['conv2'].append(base_unit_2_0_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]])
171 | unit_2_0_params['p_perms'] = split_p3_idxs
172 | unit_2_0_params['q_perms'] = split_q3_idxs
173 | unit_2_0_params['r_perms'] = split_r31_idxs
174 | split_params['unit_2_0'] = unit_2_0_params
175 |
176 | unit_2_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
177 | for i in range(FLAGS.ngroups3):
178 | unit_2_1_params['conv1'].append(base_unit_2_1_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]])
179 | unit_2_1_params['conv2'].append(base_unit_2_1_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]])
180 | unit_2_1_params['p_perms'] = split_q3_idxs
181 | unit_2_1_params['r_perms'] = split_r32_idxs
182 | split_params['unit_2_1'] = unit_2_1_params
183 |
184 |
185 | # Build model
186 | hp = resnet.HParams(batch_size=FLAGS.batch_size,
187 | num_classes=FLAGS.num_classes,
188 | num_residual_units=FLAGS.num_residual_units,
189 | k=FLAGS.k,
190 | weight_decay=FLAGS.l2_weight,
191 | ngroups1=FLAGS.ngroups1,
192 | ngroups2=FLAGS.ngroups2,
193 | ngroups3=FLAGS.ngroups3,
194 | split_params=split_params,
195 | momentum=FLAGS.momentum,
196 | finetune=FLAGS.finetune)
197 | network = resnet.ResNet(hp, images, labels, global_step)
198 | network.build_model()
199 | print('Number of Weights: %d' % network._weights)
200 | print('FLOPs: %d' % network._flops)
201 |
202 | # Build an initialization operation to run below.
203 | init = tf.global_variables_initializer()
204 |
205 | # Start running operations on the Graph.
206 | sess = tf.Session(config=tf.ConfigProto(
207 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
208 | allow_soft_placement=True,
209 | log_device_placement=FLAGS.log_device_placement))
210 |
211 | '''debugging attempt
212 | from tensorflow.python import debug as tf_debug
213 | sess = tf_debug.LocalCLIDebugWrapperSession(sess)
214 | def _get_data(datum, tensor):
215 | return tensor == train_images
216 | sess.add_tensor_filter("get_data", _get_data)
217 | '''
218 |
219 | sess.run(init)
220 |
221 | # Create a saver.
222 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
223 | if FLAGS.checkpoint is not None:
224 | saver.restore(sess, FLAGS.checkpoint)
225 | print('Load checkpoint %s' % FLAGS.checkpoint)
226 | else:
227 | print('No checkpoint file found.')
228 | sys.exit(1)
229 |
230 | # Start queue runners & summary_writer
231 | cifar100_test.start_threads(sess, n_threads=1)
232 |
233 | # Test!
234 | test_loss = 0.0
235 | test_acc = 0.0
236 | test_time = 0.0
237 | confusion_matrix = np.zeros((FLAGS.num_classes, FLAGS.num_classes), dtype=np.int32)
238 | for i in range(FLAGS.test_iter):
239 | test_images_val, test_labels_val = sess.run([test_images, test_labels])
240 | start_time = time.time()
241 | loss_value, acc_value, pred_value = sess.run([network.loss, network.acc, network.preds],
242 | feed_dict={network.is_train:False, images:test_images_val, labels:test_labels_val})
243 | duration = time.time() - start_time
244 | test_loss += loss_value
245 | test_acc += acc_value
246 | test_time += duration
247 | for l, p in zip(test_labels_val, pred_value):
248 | confusion_matrix[l, p] += 1
249 |
250 | if i % FLAGS.display == 0:
251 | num_examples_per_step = FLAGS.batch_size
252 | examples_per_sec = num_examples_per_step / duration
253 | sec_per_batch = float(duration)
254 | format_str = ('%s: iter %d, loss=%.4f, acc=%.4f (%.1f examples/sec; %.3f sec/batch)')
255 | print (format_str % (datetime.now(), i, loss_value, acc_value,
256 | examples_per_sec, sec_per_batch))
257 | test_loss /= FLAGS.test_iter
258 | test_acc /= FLAGS.test_iter
259 |
260 | # Print and save results
261 | sec_per_image = test_time/FLAGS.test_iter/FLAGS.batch_size
262 | print ('Done! Acc: %.6f, Test time: %.3f sec, %.7f sec/example' % (test_acc, test_time, sec_per_image))
263 | print ('Saving result... ')
264 | result = {'accuracy': test_acc, 'confusion_matrix': confusion_matrix,
265 | 'test_time': test_time, 'sec_per_image': sec_per_image}
266 | with open(FLAGS.output_file, 'wb') as fd:
267 | pickle.dump(result, fd)
268 | print ('done!')
269 |
270 |
271 | def _merge_split_q(q, merge_idxs, name='merge'):
272 | ngroups, dim = q.shape
273 | max_idx = np.max(merge_idxs)
274 | temp_list = []
275 | for i in range(max_idx + 1):
276 | temp = []
277 | for j in range(ngroups):
278 | if merge_idxs[j] == i:
279 | temp.append(q[j,:])
280 | temp_list.append(np.sum(temp, axis=0))
281 | ret = np.array(temp_list)
282 |
283 | return ret
284 |
285 | def _merge_split_idxs(split_idxs, merge_idxs, name='merge'):
286 | ngroups = len(split_idxs)
287 | max_idx = np.max(merge_idxs)
288 | ret = []
289 | for i in range(max_idx + 1):
290 | temp = []
291 | for j in range(ngroups):
292 | if merge_idxs[j] == i:
293 | temp.append(split_idxs[j])
294 | ret.append(np.concatenate(temp))
295 |
296 | return ret
297 |
298 | def _get_even_merge_idxs(N, split):
299 | assert N >= split
300 | num_elems = [(N + split - i - 1)/split for i in range(split)]
301 | expand_split = [[i] * n for i, n in enumerate(num_elems)]
302 | return [t for l in expand_split for t in l]
303 |
304 |
305 | def main(argv=None): # pylint: disable=unused-argument
306 | train()
307 |
308 |
309 | if __name__ == '__main__':
310 | tf.app.run()
311 |
--------------------------------------------------------------------------------
/resnet_split.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | import utils
7 |
8 |
9 | HParams = namedtuple('HParams',
10 | 'batch_size, num_classes, num_residual_units, k, weight_decay, momentum, finetune, '
11 | 'ngroups1, ngroups2, ngroups3, split_params')
12 |
13 | class ResNet(object):
14 | def __init__(self, hp, images, labels, global_step, name=None, reuse_weights=False):
15 | self._hp = hp # Hyperparameters
16 | self._images = images # Input image
17 | self._labels = labels # Input labels
18 | self._global_step = global_step
19 | self.lr = tf.placeholder(tf.float32)
20 | self.is_train = tf.placeholder(tf.bool)
21 | self._counted_scope = []
22 | self._flops = 0
23 | self._weights = 0
24 |
25 | def build_model(self):
26 | print('Building model')
27 | filters = [16, 16 * self._hp.k, 32 * self._hp.k, 64 * self._hp.k]
28 | strides = [1, 2, 2]
29 |
30 | # Init. conv.
31 | print('\tBuilding unit: init_conv')
32 | x = utils._conv(self._images, 3, filters[0], 1, name='init_conv')
33 |
34 | # unit_1_x
35 | x = self._residual_block_first(x, filters[1], strides[0], name='unit_1_0')
36 | x = self._residual_block(x, name='unit_1_1')
37 |
38 | # unit_2_x
39 | if self._hp.ngroups3 == 1:
40 | x = self._residual_block_first(x, filters[2], strides[1], name='unit_2_0')
41 | x = self._residual_block(x, name='unit_2_1')
42 | else:
43 | unit_2_0_shortcut_kernel = self._hp.split_params['unit_2_0']['shortcut']
44 | unit_2_0_conv1_kernel = self._hp.split_params['unit_2_0']['conv1']
45 | unit_2_0_conv2_kernel = self._hp.split_params['unit_2_0']['conv2']
46 | unit_2_0_p_perms = self._hp.split_params['unit_2_0']['p_perms']
47 | unit_2_0_q_perms = self._hp.split_params['unit_2_0']['q_perms']
48 | unit_2_0_r_perms = self._hp.split_params['unit_2_0']['r_perms']
49 |
50 | with tf.variable_scope('unit_2_0'):
51 | shortcut = self._conv_split(x, filters[2], strides[1], unit_2_0_shortcut_kernel, unit_2_0_p_perms, unit_2_0_q_perms, name='shortcut')
52 | x = self._conv_split(x, filters[2], strides[1], unit_2_0_conv1_kernel, unit_2_0_p_perms, unit_2_0_r_perms, name='conv_1')
53 | x = self._bn(x, name='bn_1')
54 | x = self._relu(x, name='relu_1')
55 | x = self._conv_split(x, filters[2], 1, unit_2_0_conv2_kernel, unit_2_0_r_perms, unit_2_0_q_perms, name='conv_2')
56 | x = self._bn(x, name='bn_2')
57 | x = x + shortcut
58 | x = self._relu(x, name='relu_2')
59 |
60 | unit_2_1_conv1_kernel = self._hp.split_params['unit_2_1']['conv1']
61 | unit_2_1_conv2_kernel = self._hp.split_params['unit_2_1']['conv2']
62 | unit_2_1_p_perms = self._hp.split_params['unit_2_1']['p_perms']
63 | unit_2_1_r_perms = self._hp.split_params['unit_2_1']['r_perms']
64 |
65 | with tf.variable_scope('unit_2_1'):
66 | shortcut = x
67 | x = self._conv_split(x, filters[2], 1, unit_2_1_conv1_kernel, unit_2_1_p_perms, unit_2_1_r_perms, name='conv_1')
68 | x = self._bn(x, name='bn_1')
69 | x = self._relu(x, name='relu_1')
70 | x = self._conv_split(x, filters[2], 1, unit_2_1_conv2_kernel, unit_2_1_r_perms, unit_2_1_p_perms, name='conv_2')
71 | x = self._bn(x, name='bn_2')
72 | x = x + shortcut
73 | x = self._relu(x, name='relu_2')
74 |
75 | # unit_3_x
76 | if self._hp.ngroups2 == 1:
77 | x = self._residual_block_first(x, filters[3], strides[2], name='unit_3_0')
78 | x = self._residual_block(x, name='unit_3_1')
79 | else:
80 | unit_3_0_shortcut_kernel = self._hp.split_params['unit_3_0']['shortcut']
81 | unit_3_0_conv1_kernel = self._hp.split_params['unit_3_0']['conv1']
82 | unit_3_0_conv2_kernel = self._hp.split_params['unit_3_0']['conv2']
83 | unit_3_0_p_perms = self._hp.split_params['unit_3_0']['p_perms']
84 | unit_3_0_q_perms = self._hp.split_params['unit_3_0']['q_perms']
85 | unit_3_0_r_perms = self._hp.split_params['unit_3_0']['r_perms']
86 |
87 | with tf.variable_scope('unit_3_0'):
88 | shortcut = self._conv_split(x, filters[3], strides[2], unit_3_0_shortcut_kernel, unit_3_0_p_perms, unit_3_0_q_perms, name='shortcut')
89 | x = self._conv_split(x, filters[3], strides[2], unit_3_0_conv1_kernel, unit_3_0_p_perms, unit_3_0_r_perms, name='conv_1')
90 | x = self._bn(x, name='bn_1')
91 | x = self._relu(x, name='relu_1')
92 | x = self._conv_split(x, filters[3], 1, unit_3_0_conv2_kernel, unit_3_0_r_perms, unit_3_0_q_perms, name='conv_2')
93 | x = self._bn(x, name='bn_2')
94 | x = x + shortcut
95 | x = self._relu(x, name='relu_2')
96 |
97 | unit_3_1_conv1_kernel = self._hp.split_params['unit_3_1']['conv1']
98 | unit_3_1_conv2_kernel = self._hp.split_params['unit_3_1']['conv2']
99 | unit_3_1_p_perms = self._hp.split_params['unit_3_1']['p_perms']
100 | unit_3_1_r_perms = self._hp.split_params['unit_3_1']['r_perms']
101 |
102 | with tf.variable_scope('unit_3_1'):
103 | shortcut = x
104 | x = self._conv_split(x, filters[3], 1, unit_3_1_conv1_kernel, unit_3_1_p_perms, unit_3_1_r_perms, name='conv_1')
105 | x = self._bn(x, name='bn_1')
106 | x = self._relu(x, name='relu_1')
107 | x = self._conv_split(x, filters[3], 1, unit_3_1_conv2_kernel, unit_3_1_r_perms, unit_3_1_p_perms, name='conv_2')
108 | x = self._bn(x, name='bn_2')
109 | x = x + shortcut
110 | x = self._relu(x, name='relu_2')
111 |
112 | # Last unit
113 | with tf.variable_scope('unit_last') as scope:
114 | print('\tBuilding unit: %s' % scope.name)
115 | x = utils._bn(x, self.is_train, self._global_step)
116 | x = utils._relu(x)
117 | x = tf.reduce_mean(x, [1, 2])
118 |
119 | # Logit
120 | logits_weights = self._hp.split_params['logits']['weights']
121 | logits_biases = self._hp.split_params['logits']['biases']
122 | logits_input_perms = self._hp.split_params['logits']['input_perms']
123 | logits_output_perms = self._hp.split_params['logits']['output_perms']
124 | with tf.variable_scope('logits') as scope:
125 | print('\tBuilding unit: %s - %d split' % (scope.name, len(logits_weights)))
126 | x_offset = 0
127 | x_list = []
128 | for i, (w, b, p) in enumerate(zip(logits_weights, logits_biases, logits_input_perms)):
129 | in_dim, out_dim = w.shape
130 | x_split = tf.transpose(tf.gather(tf.transpose(x), p))
131 | x_split = self._fc_with_init(x_split, out_dim, init_w=w, init_b=b, name='split%d' % (i+1))
132 | x_list.append(x_split)
133 | x_offset += in_dim
134 | x = tf.concat(x_list, 1)
135 | output_forward_idx = list(np.concatenate(logits_output_perms))
136 | output_inverse_idx = [output_forward_idx.index(i) for i in range(self._hp.num_classes)]
137 | x = tf.transpose(tf.gather(tf.transpose(x), output_inverse_idx))
138 |
139 | self._logits = x
140 |
141 | # Probs & preds & acc
142 | self.probs = tf.nn.softmax(x, name='probs')
143 | self.preds = tf.to_int32(tf.argmax(self._logits, 1, name='preds'))
144 | ones = tf.constant(np.ones([self._hp.batch_size]), dtype=tf.float32)
145 | zeros = tf.constant(np.zeros([self._hp.batch_size]), dtype=tf.float32)
146 | correct = tf.where(tf.equal(self.preds, self._labels), ones, zeros)
147 | self.acc = tf.reduce_mean(correct, name='acc')
148 | tf.summary.scalar('accuracy', self.acc)
149 |
150 | # Loss & acc
151 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=x, labels=self._labels)
152 | self.loss = tf.reduce_mean(loss)
153 | tf.summary.scalar('cross_entropy', self.loss)
154 |
155 |
156 | def build_train_op(self):
157 | print('Building train ops')
158 |
159 | # Learning rate
160 | tf.summary.scalar('learing_rate', self.lr)
161 |
162 | losses = [self.loss]
163 |
164 | # Add l2 loss
165 | with tf.variable_scope('l2_loss'):
166 | costs = [tf.nn.l2_loss(var) for var in tf.get_collection(utils.WEIGHT_DECAY_KEY)]
167 | l2_loss = tf.multiply(self._hp.weight_decay, tf.add_n(costs))
168 | losses.append(l2_loss)
169 |
170 | self._total_loss = tf.add_n(losses)
171 |
172 | # Gradient descent step
173 | opt = tf.train.MomentumOptimizer(self.lr, self._hp.momentum)
174 | grads_and_vars = opt.compute_gradients(self._total_loss, tf.trainable_variables())
175 | if self._hp.finetune:
176 | for idx, (grad, var) in enumerate(grads_and_vars):
177 | if "group" in var.op.name or \
178 | (("unit_1_0" in var.op.name or "unit_1_1" in var.op.name) and self._hp.ngroups3 > 1) or \
179 | (("unit_2_0" in var.op.name or "unit_2_1" in var.op.name) and self._hp.ngroups2 > 1) or \
180 | ("unit_3_0" in var.op.name or "unit_3_1" in var.op.name) or \
181 | "logits" in var.op.name:
182 | print('\tScale up learning rate of % s by 10.0' % var.op.name)
183 | grad = 10.0 * grad
184 | grads_and_vars[idx] = (grad,var)
185 |
186 | # Apply gradient
187 | apply_grad_op = opt.apply_gradients(grads_and_vars, global_step=self._global_step)
188 |
189 | # Batch normalization moving average update
190 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
191 | if update_ops:
192 | with tf.control_dependencies(update_ops+[apply_grad_op]):
193 | self.train_op = tf.no_op()
194 | else:
195 | self.train_op = apply_grad_op
196 |
197 |
198 | def _residual_block_first(self, x, out_channel, strides, input_q=None, output_q=None, split_r=None, name="unit"):
199 | in_channel = x.get_shape().as_list()[-1]
200 | with tf.variable_scope(name) as scope:
201 | print('\tBuilding residual unit: %s' % scope.name)
202 | x = self._bn(x, name='bn_1')
203 | x = self._relu(x, name='relu_1')
204 | if input_q is not None and output_q is not None and split_r is not None:
205 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1')
206 | # Shortcut connection
207 | if in_channel == out_channel:
208 | if strides == 1:
209 | shortcut = tf.identity(x)
210 | else:
211 | shortcut = tf.nn.max_pool(x, [1, strides, strides, 1], [1, strides, strides, 1], 'VALID')
212 | else:
213 | shortcut = self._conv(x, 1, out_channel, strides, input_q=input_q, output_q=output_q, name='shortcut')
214 | # Residual
215 | x = self._conv(x, 3, out_channel, strides, input_q=input_q, output_q=split_r, name='conv_1')
216 | x = self._bn(x, name='bn_2')
217 | x = self._relu(x, name='relu_2')
218 | if input_q is not None and output_q is not None and split_r is not None:
219 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2')
220 | x = self._conv(x, 3, out_channel, 1, input_q=split_r, output_q=output_q, name='conv_2')
221 | # Merge
222 | x = x + shortcut
223 | return x
224 |
225 | def _residual_block(self, x, split_q=None, split_r=None, name="unit"):
226 | num_channel = x.get_shape().as_list()[-1]
227 | with tf.variable_scope(name) as scope:
228 | print('\tBuilding residual unit: %s' % scope.name)
229 | # Shortcut connection
230 | shortcut = x
231 | # Residual
232 | x = self._bn(x, name='bn_1')
233 | x = self._relu(x, name='relu_1')
234 | if split_q is not None and split_r is not None:
235 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_1')
236 | x = self._conv(x, 3, num_channel, 1, input_q=split_q, output_q=split_r, name='conv_1')
237 | x = self._bn(x, name='bn_2')
238 | x = self._relu(x, name='relu_2')
239 | if split_q is not None and split_r is not None:
240 | x = self._dropout(x, self._hp.dropout_keep_prob, name='dropout_2')
241 | x = self._conv(x, 3, num_channel, 1, input_q=split_r, output_q=split_q, name='conv_2')
242 | # Merge
243 | x = x + shortcut
244 | return x
245 |
246 |
247 | def _conv_split(self, x, out_channel, strides, kernels, input_perms, output_perms, name="unit"):
248 | b, w, h, in_channel = x.get_shape().as_list()
249 | x_list = []
250 | with tf.variable_scope(name) as scope:
251 | print('\tBuilding residual unit: %s - %d split' % (scope.name, len(kernels)))
252 | for i, (k, p) in enumerate(zip(kernels, input_perms)):
253 | kernel_size, in_dim, out_dim = k.shape[-3:]
254 | x_split = tf.transpose(tf.gather(tf.transpose(x, (3, 0, 1, 2)), p), (1, 2, 3, 0))
255 | x_split = self._conv_with_init(x_split, kernel_size, out_dim, strides, init_k=k, name="split%d"%(i+1))
256 | x_list.append(x_split)
257 | x = tf.concat(x_list, 3)
258 | output_forward_idx = list(np.concatenate(output_perms))
259 | output_inverse_idx = [output_forward_idx.index(i) for i in range(out_channel)]
260 | x = tf.transpose(tf.gather(tf.transpose(x, (3, 0, 1, 2)), output_inverse_idx), (1, 2, 3, 0))
261 | return x
262 |
263 |
264 | # Helper functions(counts FLOPs and number of weights)
265 | def _conv(self, x, filter_size, out_channel, stride, pad="SAME", input_q=None, output_q=None, name="conv"):
266 | b, h, w, in_channel = x.get_shape().as_list()
267 | x = utils._conv(x, filter_size, out_channel, stride, pad, input_q, output_q, name)
268 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size
269 | w = in_channel * out_channel * filter_size * filter_size
270 | scope_name = tf.get_variable_scope().name + "/" + name
271 | self._add_flops_weights(scope_name, f, w)
272 | return x
273 |
274 | def _conv_with_init(self, x, filter_size, out_channel, stride, pad="SAME", init_k=None, name="conv"):
275 | b, h, w, in_channel = x.get_shape().as_list()
276 | x = utils._conv_with_init(x, filter_size, out_channel, stride, pad, init_k, name)
277 | f = 2 * (h/stride) * (w/stride) * in_channel * out_channel * filter_size * filter_size
278 | w = in_channel * out_channel * filter_size * filter_size
279 | scope_name = tf.get_variable_scope().name + "/" + name
280 | self._add_flops_weights(scope_name, f, w)
281 | return x
282 |
283 | def _fc(self, x, out_dim, input_q=None, output_q=None, name="fc"):
284 | b, in_dim = x.get_shape().as_list()
285 | x = utils._fc(x, out_dim, input_q, output_q, name)
286 | f = 2 * (in_dim + 1) * out_dim
287 | w = (in_dim + 1) * out_dim
288 | scope_name = tf.get_variable_scope().name + "/" + name
289 | self._add_flops_weights(scope_name, f, w)
290 | return x
291 |
292 | def _fc_with_init(self, x, out_dim, init_w=None, init_b=None, name="fc"):
293 | b, in_dim = x.get_shape().as_list()
294 | x = utils._fc_with_init(x, out_dim, init_w, init_b, name)
295 | f = 2*(in_dim + 1) * out_dim
296 | w = (in_dim + 1) * out_dim
297 | scope_name = tf.get_variable_scope().name + "/" + name
298 | self._add_flops_weights(scope_name, f, w)
299 | return x
300 |
301 | def _bn(self, x, name="bn"):
302 | x = utils._bn(x, self.is_train, self._global_step, name)
303 | # f = 8 * self._get_data_size(x)
304 | # w = 4 * x.get_shape().as_list()[-1]
305 | # scope_name = tf.get_variable_scope().name + "/" + name
306 | # self._add_flops_weights(scope_name, f, w)
307 | return x
308 |
309 | def _relu(self, x, name="relu"):
310 | x = utils._relu(x, 0.0, name)
311 | # f = self._get_data_size(x)
312 | # scope_name = tf.get_variable_scope().name + "/" + name
313 | # self._add_flops_weights(scope_name, f, 0)
314 | return x
315 |
316 | def _dropout(self, x, keep_prob, name="dropout"):
317 | x = utils._dropout(x, keep_prob, name)
318 | return x
319 |
320 | def _get_data_size(self, x):
321 | return np.prod(x.get_shape().as_list()[1:])
322 |
323 | def _add_flops_weights(self, scope_name, f, w):
324 | if scope_name not in self._counted_scope:
325 | self._flops += f
326 | self._weights += w
327 | self._counted_scope.append(scope_name)
328 |
--------------------------------------------------------------------------------
/train_split.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | from datetime import datetime
5 | import time
6 | import tensorflow as tf
7 | import numpy as np
8 | import sys
9 | import select
10 | from IPython import embed
11 | from StringIO import StringIO
12 | import matplotlib.pyplot as plt
13 |
14 | import cifar100
15 | import resnet_split as resnet
16 |
17 |
18 |
19 | # Dataset Configuration
20 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""")
21 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""")
22 | tf.app.flags.DEFINE_integer('num_train_instance', 45000, """Number of training images.""")
23 | tf.app.flags.DEFINE_integer('num_val_instance', 5000, """Number of val images.""")
24 |
25 | # Network Configuration
26 | tf.app.flags.DEFINE_integer('batch_size', 90, """Number of images to process in a batch.""")
27 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group.
28 | Total number of conv layers will be 6n+4""")
29 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""")
30 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""")
31 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""")
32 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""")
33 |
34 | # Optimization Configuration
35 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""")
36 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""")
37 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""")
38 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""")
39 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""")
40 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""")
41 |
42 | # Training Configuration
43 | tf.app.flags.DEFINE_string('train_dir', './train', """Directory where to write log and checkpoint.""")
44 | tf.app.flags.DEFINE_integer('max_steps', 100000, """Number of batches to run.""")
45 | tf.app.flags.DEFINE_integer('display', 100, """Number of iterations to display training info.""")
46 | tf.app.flags.DEFINE_integer('val_interval', 1000, """Number of iterations to run a val""")
47 | tf.app.flags.DEFINE_integer('val_iter', 100, """Number of iterations during a val""")
48 | tf.app.flags.DEFINE_integer('checkpoint_interval', 10000, """Number of iterations to save parameters as a checkpoint""")
49 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""")
50 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""")
51 | tf.app.flags.DEFINE_string('basemodel', None, """Base model to load paramters""")
52 | tf.app.flags.DEFINE_string('checkpoint', None, """Model checkpoint to load""")
53 |
54 | FLAGS = tf.app.flags.FLAGS
55 |
56 |
57 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step):
58 | lr = initial_lr
59 | for s in lr_decay_steps:
60 | if global_step >= s:
61 | lr *= lr_decay
62 | return lr
63 |
64 |
65 | def train():
66 | print('[Dataset Configuration]')
67 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
68 | print('\tNumber of classes: %d' % FLAGS.num_classes)
69 | print('\tNumber of training images: %d' % FLAGS.num_train_instance)
70 | print('\tNumber of val images: %d' % FLAGS.num_val_instance)
71 |
72 | print('[Network Configuration]')
73 | print('\tBatch size: %d' % FLAGS.batch_size)
74 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
75 | print('\tNetwork width multiplier: %d' % FLAGS.k)
76 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
77 | print('\tBasemodel file: %s' % FLAGS.basemodel)
78 |
79 | print('[Optimization Configuration]')
80 | print('\tL2 loss weight: %f' % FLAGS.l2_weight)
81 | print('\tThe momentum optimizer: %f' % FLAGS.momentum)
82 | print('\tInitial learning rate: %f' % FLAGS.initial_lr)
83 | print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
84 | print('\tLearning rate decay: %f' % FLAGS.lr_decay)
85 | print('\tFinetune: %d' % FLAGS.finetune)
86 |
87 | print('[Training Configuration]')
88 | print('\tTrain dir: %s' % FLAGS.train_dir)
89 | print('\tTraining max steps: %d' % FLAGS.max_steps)
90 | print('\tSteps per displaying info: %d' % FLAGS.display)
91 | print('\tSteps per validation: %d' % FLAGS.val_interval)
92 | print('\tSteps during validation: %d' % FLAGS.val_iter)
93 | print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
94 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
95 | print('\tLog device placement: %d' % FLAGS.log_device_placement)
96 |
97 |
98 | with tf.Graph().as_default():
99 | init_step = 0
100 | global_step = tf.Variable(0, trainable=False, name='global_step')
101 |
102 | # Get images and labels of CIFAR-100
103 | print('Load CIFAR-100 dataset')
104 | train_dataset_path = os.path.join(FLAGS.data_dir, 'train')
105 | val_dataset_path = os.path.join(FLAGS.data_dir, 'val')
106 | print('\tLoading training data from %s' % train_dataset_path)
107 | with tf.variable_scope('train_image'):
108 | cifar100_train = cifar100.CIFAR100Runner(train_dataset_path, image_per_thread=32,
109 | shuffle=True, distort=True, capacity=10000)
110 | train_images, train_labels = cifar100_train.get_inputs(FLAGS.batch_size)
111 | print('\tLoading validation data from %s' % val_dataset_path)
112 | with tf.variable_scope('val_image'):
113 | cifar100_val = cifar100.CIFAR100Runner(val_dataset_path, image_per_thread=32,
114 | shuffle=False, distort=False, capacity=5000)
115 | # shuffle=False, distort=False, capacity=10000)
116 | val_images, val_labels = cifar100_val.get_inputs(FLAGS.batch_size)
117 |
118 | # Build a Graph that computes the predictions from the inference model.
119 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3])
120 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size])
121 |
122 | # Get splitted params
123 | if not FLAGS.basemodel:
124 | print('No basemodel found to load split params')
125 | sys.exit(-1)
126 | else:
127 | print('Load split params from %s' % FLAGS.basemodel)
128 |
129 | def get_perms(q_name, ngroups):
130 | split_alpha = reader.get_tensor(q_name+'/alpha')
131 | q_amax = np.argmax(split_alpha, axis=0)
132 | return [np.where(q_amax == i)[0] for i in range(ngroups)]
133 |
134 | reader = tf.train.NewCheckpointReader(FLAGS.basemodel)
135 | split_params = {}
136 |
137 | print('\tlogits...')
138 | base_logits_w = reader.get_tensor('logits/fc/weights')
139 | base_logits_b = reader.get_tensor('logits/fc/biases')
140 | split_p1_idxs = get_perms('group/split_p1', FLAGS.ngroups1)
141 | split_q1_idxs = get_perms('group/split_q1', FLAGS.ngroups1)
142 |
143 | logits_params = {'weights':[], 'biases':[], 'input_perms':[], 'output_perms':[]}
144 | for i in range(FLAGS.ngroups1):
145 | logits_params['weights'].append(base_logits_w[split_p1_idxs[i], :][:, split_q1_idxs[i]])
146 | logits_params['biases'].append(base_logits_b[split_q1_idxs[i]])
147 | logits_params['input_perms'] = split_p1_idxs
148 | logits_params['output_perms'] = split_q1_idxs
149 | split_params['logits'] = logits_params
150 |
151 | if FLAGS.ngroups2 > 1:
152 | print('\tunit_3_x...')
153 | base_unit_3_0_shortcut_k = reader.get_tensor('unit_3_0/shortcut/kernel')
154 | base_unit_3_0_conv1_k = reader.get_tensor('unit_3_0/conv_1/kernel')
155 | base_unit_3_0_conv2_k = reader.get_tensor('unit_3_0/conv_2/kernel')
156 | base_unit_3_1_conv1_k = reader.get_tensor('unit_3_1/conv_1/kernel')
157 | base_unit_3_1_conv2_k = reader.get_tensor('unit_3_1/conv_2/kernel')
158 | split_p2_idxs = get_perms('group/split_p2', FLAGS.ngroups2)
159 | split_q2_idxs = _merge_split_idxs(split_p1_idxs, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2))
160 | split_r21_idxs = get_perms('group/split_r21', FLAGS.ngroups2)
161 | split_r22_idxs = get_perms('group/split_r22', FLAGS.ngroups2)
162 |
163 | unit_3_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
164 | for i in range(FLAGS.ngroups2):
165 | unit_3_0_params['shortcut'].append(base_unit_3_0_shortcut_k[:,:,split_p2_idxs[i],:][:,:,:,split_q2_idxs[i]])
166 | unit_3_0_params['conv1'].append(base_unit_3_0_conv1_k[:,:,split_p2_idxs[i],:][:,:,:,split_r21_idxs[i]])
167 | unit_3_0_params['conv2'].append(base_unit_3_0_conv2_k[:,:,split_r21_idxs[i],:][:,:,:,split_q2_idxs[i]])
168 | unit_3_0_params['p_perms'] = split_p2_idxs
169 | unit_3_0_params['q_perms'] = split_q2_idxs
170 | unit_3_0_params['r_perms'] = split_r21_idxs
171 | split_params['unit_3_0'] = unit_3_0_params
172 |
173 | unit_3_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
174 | for i in range(FLAGS.ngroups2):
175 | unit_3_1_params['conv1'].append(base_unit_3_1_conv1_k[:,:,split_q2_idxs[i],:][:,:,:,split_r22_idxs[i]])
176 | unit_3_1_params['conv2'].append(base_unit_3_1_conv2_k[:,:,split_r22_idxs[i],:][:,:,:,split_q2_idxs[i]])
177 | unit_3_1_params['p_perms'] = split_q2_idxs
178 | unit_3_1_params['r_perms'] = split_r22_idxs
179 | split_params['unit_3_1'] = unit_3_1_params
180 |
181 | if FLAGS.ngroups3 > 1:
182 | print('\tconv4_x...')
183 | base_unit_2_0_shortcut_k = reader.get_tensor('unit_2_0/shortcut/kernel')
184 | base_unit_2_0_conv1_k = reader.get_tensor('unit_2_0/conv_1/kernel')
185 | base_unit_2_0_conv2_k = reader.get_tensor('unit_2_0/conv_2/kernel')
186 | base_unit_2_1_conv1_k = reader.get_tensor('unit_2_1/conv_1/kernel')
187 | base_unit_2_1_conv2_k = reader.get_tensor('unit_2_1/conv_2/kernel')
188 | split_p3_idxs = get_perms('group/split_p3', FLAGS.ngroups3)
189 | split_q3_idxs = _merge_split_idxs(split_p2_idxs, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
190 | split_r31_idxs = get_perms('group/split_r31', FLAGS.ngroups3)
191 | split_r32_idxs = get_perms('group/split_r32', FLAGS.ngroups3)
192 |
193 | unit_2_0_params = {'shortcut':[], 'conv1':[], 'conv2':[], 'p_perms':[], 'q_perms':[], 'r_perms':[]}
194 | for i in range(FLAGS.ngroups3):
195 | unit_2_0_params['shortcut'].append(base_unit_2_0_shortcut_k[:,:,split_p3_idxs[i],:][:,:,:,split_q3_idxs[i]])
196 | unit_2_0_params['conv1'].append(base_unit_2_0_conv1_k[:,:,split_p3_idxs[i],:][:,:,:,split_r31_idxs[i]])
197 | unit_2_0_params['conv2'].append(base_unit_2_0_conv2_k[:,:,split_r31_idxs[i],:][:,:,:,split_q3_idxs[i]])
198 | unit_2_0_params['p_perms'] = split_p3_idxs
199 | unit_2_0_params['q_perms'] = split_q3_idxs
200 | unit_2_0_params['r_perms'] = split_r31_idxs
201 | split_params['unit_2_0'] = unit_2_0_params
202 |
203 | unit_2_1_params = {'conv1':[], 'conv2':[], 'p_perms':[], 'r_perms':[]}
204 | for i in range(FLAGS.ngroups3):
205 | unit_2_1_params['conv1'].append(base_unit_2_1_conv1_k[:,:,split_q3_idxs[i],:][:,:,:,split_r32_idxs[i]])
206 | unit_2_1_params['conv2'].append(base_unit_2_1_conv2_k[:,:,split_r32_idxs[i],:][:,:,:,split_q3_idxs[i]])
207 | unit_2_1_params['p_perms'] = split_q3_idxs
208 | unit_2_1_params['r_perms'] = split_r32_idxs
209 | split_params['unit_2_1'] = unit_2_1_params
210 |
211 |
212 | # Build model
213 | lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(','))
214 | lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size for s in lr_decay_steps])
215 | hp = resnet.HParams(batch_size=FLAGS.batch_size,
216 | num_classes=FLAGS.num_classes,
217 | num_residual_units=FLAGS.num_residual_units,
218 | k=FLAGS.k,
219 | weight_decay=FLAGS.l2_weight,
220 | ngroups1=FLAGS.ngroups1,
221 | ngroups2=FLAGS.ngroups2,
222 | ngroups3=FLAGS.ngroups3,
223 | split_params=split_params,
224 | momentum=FLAGS.momentum,
225 | finetune=FLAGS.finetune)
226 | network = resnet.ResNet(hp, images, labels, global_step)
227 | network.build_model()
228 | network.build_train_op()
229 | print('Number of Weights: %d' % network._weights)
230 | print('FLOPs: %d' % network._flops)
231 |
232 | train_summary_op = tf.summary.merge_all() # Summaries(training)
233 |
234 | # Build an initialization operation to run below.
235 | init = tf.global_variables_initializer()
236 |
237 | # Start running operations on the Graph.
238 | sess = tf.Session(config=tf.ConfigProto(
239 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
240 | # allow_soft_placement=True,
241 | log_device_placement=FLAGS.log_device_placement))
242 |
243 | '''debugging attempt
244 | from tensorflow.python import debug as tf_debug
245 | sess = tf_debug.LocalCLIDebugWrapperSession(sess)
246 | def _get_data(datum, tensor):
247 | return tensor == train_images
248 | sess.add_tensor_filter("get_data", _get_data)
249 | '''
250 |
251 | sess.run(init)
252 |
253 | # Create a saver.
254 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
255 | if FLAGS.checkpoint is not None:
256 | saver.restore(sess, FLAGS.checkpoint)
257 | init_step = global_step.eval(session=sess)
258 | print('Load checkpoint %s' % FLAGS.checkpoint)
259 | else:
260 | # Define a different saver to load model checkpoints
261 | # Select only base variables (exclude split layers)
262 | print('Load parameters from basemodel %s' % FLAGS.basemodel)
263 | variables = tf.global_variables()
264 | vars_restore = [var for var in variables
265 | if not "Momentum" in var.name and
266 | not "logits" in var.name and
267 | not "global_step" in var.name]
268 | if FLAGS.ngroups2 > 1:
269 | vars_restore = [var for var in vars_restore
270 | if not "unit_3_" in var.name]
271 | if FLAGS.ngroups3 > 1:
272 | vars_restore = [var for var in vars_restore
273 | if not "unit_2_" in var.name]
274 | saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
275 | saver_restore.restore(sess, FLAGS.basemodel)
276 |
277 | # Start queue runners & summary_writer
278 | cifar100_train.start_threads(sess, n_threads=20)
279 | cifar100_val.start_threads(sess, n_threads=1)
280 |
281 | if not os.path.exists(FLAGS.train_dir):
282 | os.mkdir(FLAGS.train_dir)
283 | summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
284 | sess.graph)
285 |
286 | # Training!
287 | val_best_acc = 0.0
288 | for step in xrange(init_step, FLAGS.max_steps):
289 | # val
290 | if step % FLAGS.val_interval == 0:
291 | val_loss, val_acc = 0.0, 0.0
292 | for i in range(FLAGS.val_iter):
293 | val_images_val, val_labels_val = sess.run([val_images, val_labels])
294 | loss_value, acc_value = sess.run([network.loss, network.acc],
295 | feed_dict={network.is_train:False, images:val_images_val, labels:val_labels_val})
296 | val_loss += loss_value
297 | val_acc += acc_value
298 | val_loss /= FLAGS.val_iter
299 | val_acc /= FLAGS.val_iter
300 | val_best_acc = max(val_best_acc, val_acc)
301 | format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f')
302 | print (format_str % (datetime.now(), step, val_loss, val_acc))
303 |
304 | val_summary = tf.Summary()
305 | val_summary.value.add(tag='val/loss', simple_value=val_loss)
306 | val_summary.value.add(tag='val/acc', simple_value=val_acc)
307 | val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc)
308 | summary_writer.add_summary(val_summary, step)
309 | summary_writer.flush()
310 |
311 | # Train
312 | lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step)
313 | start_time = time.time()
314 | train_images_val, train_labels_val = sess.run([train_images, train_labels])
315 | _, loss_value, acc_value, train_summary_str = \
316 | sess.run([network.train_op, network.loss, network.acc, train_summary_op],
317 | feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
318 | duration = time.time() - start_time
319 |
320 | assert not np.isnan(loss_value)
321 |
322 | # Display & Summary(training)
323 | if step % FLAGS.display == 0:
324 | num_examples_per_step = FLAGS.batch_size
325 | examples_per_sec = num_examples_per_step / duration
326 | sec_per_batch = float(duration)
327 | format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
328 | 'sec/batch)')
329 | print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value,
330 | examples_per_sec, sec_per_batch))
331 | summary_writer.add_summary(train_summary_str, step)
332 |
333 | # Save the model checkpoint periodically.
334 | if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps:
335 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
336 | saver.save(sess, checkpoint_path, global_step=step)
337 |
338 | if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
339 | char = sys.stdin.read(1)
340 | if char == 'b':
341 | embed()
342 |
343 |
344 | def _merge_split_q(q, merge_idxs, name='merge'):
345 | ngroups, dim = q.shape
346 | max_idx = np.max(merge_idxs)
347 | temp_list = []
348 | for i in range(max_idx + 1):
349 | temp = []
350 | for j in range(ngroups):
351 | if merge_idxs[j] == i:
352 | temp.append(q[j,:])
353 | temp_list.append(np.sum(temp, axis=0))
354 | ret = np.array(temp_list)
355 |
356 | return ret
357 |
358 | def _merge_split_idxs(split_idxs, merge_idxs, name='merge'):
359 | ngroups = len(split_idxs)
360 | max_idx = np.max(merge_idxs)
361 | ret = []
362 | for i in range(max_idx + 1):
363 | temp = []
364 | for j in range(ngroups):
365 | if merge_idxs[j] == i:
366 | temp.append(split_idxs[j])
367 | ret.append(np.concatenate(temp))
368 |
369 | return ret
370 |
371 | def _get_even_merge_idxs(N, split):
372 | assert N >= split
373 | num_elems = [(N + split - i - 1)/split for i in range(split)]
374 | expand_split = [[i] * n for i, n in enumerate(num_elems)]
375 | return [t for l in expand_split for t in l]
376 |
377 |
378 | def main(argv=None): # pylint: disable=unused-argument
379 | train()
380 |
381 |
382 | if __name__ == '__main__':
383 | tf.app.run()
384 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | from datetime import datetime
5 | import time
6 | import tensorflow as tf
7 | import numpy as np
8 | import sys
9 | import select
10 | from IPython import embed
11 | from StringIO import StringIO
12 | import matplotlib.pyplot as plt
13 |
14 | import cifar100
15 | import resnet
16 |
17 |
18 |
19 | # Dataset Configuration
20 | tf.app.flags.DEFINE_string('data_dir', './cifar100/train_val_split', """Path to the CIFAR-100 data.""")
21 | tf.app.flags.DEFINE_integer('num_classes', 100, """Number of classes in the dataset.""")
22 | tf.app.flags.DEFINE_integer('num_train_instance', 45000, """Number of training images.""")
23 | tf.app.flags.DEFINE_integer('num_val_instance', 5000, """Number of val images.""")
24 |
25 | # Network Configuration
26 | tf.app.flags.DEFINE_integer('batch_size', 90, """Number of images to process in a batch.""")
27 | tf.app.flags.DEFINE_integer('num_residual_units', 2, """Number of residual block per group.
28 | Total number of conv layers will be 6n+4""")
29 | tf.app.flags.DEFINE_integer('k', 8, """Network width multiplier""")
30 | tf.app.flags.DEFINE_integer('ngroups1', 1, """Grouping number on logits""")
31 | tf.app.flags.DEFINE_integer('ngroups2', 1, """Grouping number on unit_3_x""")
32 | tf.app.flags.DEFINE_integer('ngroups3', 1, """Grouping number on unit_2_x""")
33 |
34 | # Optimization Configuration
35 | tf.app.flags.DEFINE_float('l2_weight', 0.0001, """L2 loss weight applied all the weights""")
36 | tf.app.flags.DEFINE_float('gamma1', 0.0, """split loss regularization paramter""")
37 | tf.app.flags.DEFINE_float('gamma2', 0.0, """overlap loss regularization parameter""")
38 | tf.app.flags.DEFINE_float('gamma3', 0.0, """uniform loss regularization parameter""")
39 | tf.app.flags.DEFINE_float('dropout_keep_prob', 1.0, """probability of dropouts on the split layers(1.0 not to use dropout)""")
40 | tf.app.flags.DEFINE_float('momentum', 0.9, """The momentum of MomentumOptimizer""")
41 | tf.app.flags.DEFINE_boolean('bn_no_scale', False, """Whether not to use trainable gamma in BN layers.""")
42 | tf.app.flags.DEFINE_boolean('weighted_group_loss', False, """Whether to normalize weight split loss where coeffs are propotional to its values.""")
43 | tf.app.flags.DEFINE_float('initial_lr', 0.1, """Initial learning rate""")
44 | tf.app.flags.DEFINE_string('lr_step_epoch', "80.0,120.0,160.0", """Epochs after which learing rate decays""")
45 | tf.app.flags.DEFINE_float('lr_decay', 0.1, """Learning rate decay factor""")
46 | tf.app.flags.DEFINE_boolean('finetune', False, """Whether to finetune.""")
47 |
48 | # Training Configuration
49 | tf.app.flags.DEFINE_string('train_dir', './train', """Directory where to write log and checkpoint.""")
50 | tf.app.flags.DEFINE_integer('max_steps', 100000, """Number of batches to run.""")
51 | tf.app.flags.DEFINE_integer('display', 100, """Number of iterations to display training info.""")
52 | tf.app.flags.DEFINE_integer('val_interval', 1000, """Number of iterations to run a val""")
53 | tf.app.flags.DEFINE_integer('val_iter', 100, """Number of iterations during a val""")
54 | tf.app.flags.DEFINE_integer('checkpoint_interval', 10000, """Number of iterations to save parameters as a checkpoint""")
55 | tf.app.flags.DEFINE_integer('group_summary_interval', 2500, """Number of iterations to plot grouping variables and weights""")
56 | tf.app.flags.DEFINE_float('gpu_fraction', 0.95, """The fraction of GPU memory to be allocated""")
57 | tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""")
58 | tf.app.flags.DEFINE_string('basemodel', None, """Base model to load paramters""")
59 | tf.app.flags.DEFINE_string('checkpoint', None, """Model checkpoint to load""")
60 |
61 | FLAGS = tf.app.flags.FLAGS
62 |
63 |
64 | def get_lr(initial_lr, lr_decay, lr_decay_steps, global_step):
65 | lr = initial_lr
66 | for s in lr_decay_steps:
67 | if global_step >= s:
68 | lr *= lr_decay
69 | return lr
70 |
71 |
72 | def train():
73 | print('[Dataset Configuration]')
74 | print('\tCIFAR-100 dir: %s' % FLAGS.data_dir)
75 | print('\tNumber of classes: %d' % FLAGS.num_classes)
76 | print('\tNumber of training images: %d' % FLAGS.num_train_instance)
77 | print('\tNumber of val images: %d' % FLAGS.num_val_instance)
78 |
79 | print('[Network Configuration]')
80 | print('\tBatch size: %d' % FLAGS.batch_size)
81 | print('\tResidual blocks per group: %d' % FLAGS.num_residual_units)
82 | print('\tNetwork width multiplier: %d' % FLAGS.k)
83 | print('\tNumber of Groups: %d-%d-%d' % (FLAGS.ngroups3, FLAGS.ngroups2, FLAGS.ngroups1))
84 | print('\tBasemodel file: %s' % FLAGS.basemodel)
85 |
86 | print('[Optimization Configuration]')
87 | print('\tL2 loss weight: %f' % FLAGS.l2_weight)
88 | print('\tOverlap loss weight: %f' % FLAGS.gamma1)
89 | print('\tWeight split loss weight: %f' % FLAGS.gamma2)
90 | print('\tUniform loss weight: %f' % FLAGS.gamma3)
91 | print('\tDropout keep probability: %f' % FLAGS.dropout_keep_prob)
92 | print('\tThe momentum optimizer: %f' % FLAGS.momentum)
93 | print('\tNo update on BN scale parameter: %d' % FLAGS.bn_no_scale)
94 | print('\tWeighted split loss: %d' % FLAGS.weighted_group_loss)
95 | print('\tInitial learning rate: %f' % FLAGS.initial_lr)
96 | print('\tEpochs per lr step: %s' % FLAGS.lr_step_epoch)
97 | print('\tLearning rate decay: %f' % FLAGS.lr_decay)
98 | print('\tFinetune: %d' % FLAGS.finetune)
99 |
100 | print('[Training Configuration]')
101 | print('\tTrain dir: %s' % FLAGS.train_dir)
102 | print('\tTraining max steps: %d' % FLAGS.max_steps)
103 | print('\tSteps per displaying info: %d' % FLAGS.display)
104 | print('\tSteps per validation: %d' % FLAGS.val_interval)
105 | print('\tSteps during validation: %d' % FLAGS.val_iter)
106 | print('\tSteps per saving checkpoints: %d' % FLAGS.checkpoint_interval)
107 | print('\tSteps per plot splits: %d' % FLAGS.group_summary_interval)
108 | print('\tGPU memory fraction: %f' % FLAGS.gpu_fraction)
109 | print('\tLog device placement: %d' % FLAGS.log_device_placement)
110 |
111 |
112 | with tf.Graph().as_default():
113 | init_step = 0
114 | global_step = tf.Variable(0, trainable=False, name='global_step')
115 |
116 | # Get images and labels of CIFAR-100
117 | print('Load CIFAR-100 dataset')
118 | train_dataset_path = os.path.join(FLAGS.data_dir, 'train')
119 | val_dataset_path = os.path.join(FLAGS.data_dir, 'val')
120 | print('\tLoading training data from %s' % train_dataset_path)
121 | with tf.variable_scope('train_image'):
122 | cifar100_train = cifar100.CIFAR100Runner(train_dataset_path, image_per_thread=32,
123 | shuffle=True, distort=True, capacity=10000)
124 | train_images, train_labels = cifar100_train.get_inputs(FLAGS.batch_size)
125 | print('\tLoading validation data from %s' % val_dataset_path)
126 | with tf.variable_scope('val_image'):
127 | cifar100_val = cifar100.CIFAR100Runner(val_dataset_path, image_per_thread=32,
128 | shuffle=False, distort=False, capacity=5000)
129 | # shuffle=False, distort=False, capacity=10000)
130 | val_images, val_labels = cifar100_val.get_inputs(FLAGS.batch_size)
131 |
132 | # Build a Graph that computes the predictions from the inference model.
133 | images = tf.placeholder(tf.float32, [FLAGS.batch_size, cifar100.IMAGE_SIZE, cifar100.IMAGE_SIZE, 3])
134 | labels = tf.placeholder(tf.int32, [FLAGS.batch_size])
135 |
136 | # Build model
137 | lr_decay_steps = map(float,FLAGS.lr_step_epoch.split(','))
138 | lr_decay_steps = map(int,[s*FLAGS.num_train_instance/FLAGS.batch_size for s in lr_decay_steps])
139 | with tf.device('/GPU:0'):
140 | hp = resnet.HParams(batch_size=FLAGS.batch_size,
141 | num_classes=FLAGS.num_classes,
142 | num_residual_units=FLAGS.num_residual_units,
143 | k=FLAGS.k,
144 | weight_decay=FLAGS.l2_weight,
145 | ngroups1=FLAGS.ngroups1,
146 | ngroups2=FLAGS.ngroups2,
147 | ngroups3=FLAGS.ngroups3,
148 | gamma1=FLAGS.gamma1,
149 | gamma2=FLAGS.gamma2,
150 | gamma3=FLAGS.gamma3,
151 | dropout_keep_prob=FLAGS.dropout_keep_prob,
152 | momentum=FLAGS.momentum,
153 | bn_no_scale=FLAGS.bn_no_scale,
154 | weighted_group_loss=FLAGS.weighted_group_loss,
155 | finetune=FLAGS.finetune)
156 | network = resnet.ResNet(hp, images, labels, global_step)
157 | network.build_model()
158 | network.build_train_op()
159 | print('Number of Weights: %d' % network._weights)
160 | print('FLOPs: %d' % network._flops)
161 |
162 | train_summary_op = tf.summary.merge_all() # Summaries(training)
163 |
164 | # Build an initialization operation to run below.
165 | init = tf.global_variables_initializer()
166 |
167 | # Start running operations on the Graph.
168 | sess = tf.Session(config=tf.ConfigProto(
169 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_fraction),
170 | allow_soft_placement=True,
171 | log_device_placement=FLAGS.log_device_placement))
172 | sess.run(init)
173 |
174 | # Create a saver.
175 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000)
176 | if FLAGS.checkpoint is not None:
177 | saver.restore(sess, FLAGS.checkpoint)
178 | init_step = global_step.eval(session=sess)
179 | print('Load checkpoint %s' % FLAGS.checkpoint)
180 | elif FLAGS.basemodel:
181 | # Define a different saver to load model checkpoints
182 | # Select only base variables (exclude split layers)
183 | print('Load parameters from basemodel %s' % FLAGS.basemodel)
184 | variables = tf.global_variables()
185 | vars_restore = [var for var in variables
186 | if not "Momentum" in var.name and
187 | not "group" in var.name and
188 | not "global_step" in var.name]
189 | # vars_restore = [var for var in variables
190 | # if not "alpha" in var.name and
191 | # not "fc_beta" in var.name and
192 | # not "unit_3" in var.name and
193 | # not "unit_last" in var.name and
194 | # not "logits" in var.name and
195 | # not "Momentum" in var.name and
196 | # not "global_step" in var.name]
197 | saver_restore = tf.train.Saver(vars_restore, max_to_keep=10000)
198 | saver_restore.restore(sess, FLAGS.basemodel)
199 | else:
200 | print('No checkpoint file of basemodel found. Start from the scratch.')
201 |
202 | # Start queue runners & summary_writer
203 | cifar100_train.start_threads(sess, n_threads=20)
204 | cifar100_val.start_threads(sess, n_threads=1)
205 |
206 | if not os.path.exists(FLAGS.train_dir):
207 | os.mkdir(FLAGS.train_dir)
208 | summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_dir, str(global_step.eval(session=sess))),
209 | sess.graph)
210 |
211 | # Training!
212 | val_best_acc = 0.0
213 | for step in xrange(init_step, FLAGS.max_steps):
214 | # val
215 | if step % FLAGS.val_interval == 0:
216 | val_loss, val_acc = 0.0, 0.0
217 | for i in range(FLAGS.val_iter):
218 | val_images_val, val_labels_val = sess.run([val_images, val_labels])
219 | loss_value, acc_value = sess.run([network.loss, network.acc],
220 | feed_dict={network.is_train:False, images:val_images_val, labels:val_labels_val})
221 | val_loss += loss_value
222 | val_acc += acc_value
223 | val_loss /= FLAGS.val_iter
224 | val_acc /= FLAGS.val_iter
225 | val_best_acc = max(val_best_acc, val_acc)
226 | format_str = ('%s: (val) step %d, loss=%.4f, acc=%.4f')
227 | print (format_str % (datetime.now(), step, val_loss, val_acc))
228 |
229 | val_summary = tf.Summary()
230 | val_summary.value.add(tag='val/loss', simple_value=val_loss)
231 | val_summary.value.add(tag='val/acc', simple_value=val_acc)
232 | val_summary.value.add(tag='val/best_acc', simple_value=val_best_acc)
233 | summary_writer.add_summary(val_summary, step)
234 | summary_writer.flush()
235 |
236 | # Train
237 | lr_value = get_lr(FLAGS.initial_lr, FLAGS.lr_decay, lr_decay_steps, step)
238 | start_time = time.time()
239 | train_images_val, train_labels_val = sess.run([train_images, train_labels])
240 | _, loss_value, acc_value, train_summary_str = \
241 | sess.run([network.train_op, network.loss, network.acc, train_summary_op],
242 | feed_dict={network.is_train:True, network.lr:lr_value, images:train_images_val, labels:train_labels_val})
243 | duration = time.time() - start_time
244 |
245 | assert not np.isnan(loss_value)
246 |
247 | # Display & Summary(training)
248 | if step % FLAGS.display == 0:
249 | num_examples_per_step = FLAGS.batch_size
250 | examples_per_sec = num_examples_per_step / duration
251 | sec_per_batch = float(duration)
252 | format_str = ('%s: (Training) step %d, loss=%.4f, acc=%.4f, lr=%f (%.1f examples/sec; %.3f '
253 | 'sec/batch)')
254 | print (format_str % (datetime.now(), step, loss_value, acc_value, lr_value,
255 | examples_per_sec, sec_per_batch))
256 | summary_writer.add_summary(train_summary_str, step)
257 |
258 | # Save the model checkpoint periodically.
259 | if (step > init_step and step % FLAGS.checkpoint_interval == 0) or (step + 1) == FLAGS.max_steps:
260 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
261 | saver.save(sess, checkpoint_path, global_step=step)
262 |
263 | if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
264 | char = sys.stdin.read(1)
265 | if char == 'b':
266 | embed()
267 |
268 | # Plot grouped weight matrices as image summary
269 | filters = [16, 16 * FLAGS.k, 32 * FLAGS.k, 64 * FLAGS.k]
270 | if FLAGS.group_summary_interval is not None:
271 | if step % FLAGS.group_summary_interval == 0:
272 | img_summaries = []
273 |
274 | if FLAGS.ngroups1 > 1:
275 | logits_weights = get_var_value('logits/fc/weights', sess)
276 | split_p1 = softmax(get_var_value('group/split_p1/alpha', sess))
277 | split_q1 = softmax(get_var_value('group/split_q1/alpha', sess))
278 | feature_indices = np.argsort(np.argmax(split_p1, axis=0))
279 | class_indices = np.argsort(np.argmax(split_q1, axis=0))
280 |
281 | img_summaries.append(img_to_summary(np.repeat(split_p1[:, feature_indices], 20, axis=0), 'split_p1'))
282 | img_summaries.append(img_to_summary(np.repeat(split_q1[:, class_indices], 20, axis=0), 'split_q1'))
283 | img_summaries.append(img_to_summary(np.abs(logits_weights[feature_indices, :][:, class_indices]), 'logits'))
284 |
285 | if FLAGS.ngroups2 > 1:
286 | unit_3_0_shortcut = get_var_value('unit_3_0/shortcut/kernel', sess)
287 | unit_3_0_conv_1 = get_var_value('unit_3_0/conv_1/kernel', sess)
288 | unit_3_0_conv_2 = get_var_value('unit_3_0/conv_2/kernel', sess)
289 | unit_3_1_conv_1 = get_var_value('unit_3_1/conv_1/kernel', sess)
290 | unit_3_1_conv_2 = get_var_value('unit_3_1/conv_2/kernel', sess)
291 | split_p2 = softmax(get_var_value('group/split_p2/alpha', sess))
292 | split_q2 = _merge_split_q(split_p1, _get_even_merge_idxs(FLAGS.ngroups1, FLAGS.ngroups2))
293 | split_r21 = softmax(get_var_value('group/split_r21/alpha', sess))
294 | split_r22 = softmax(get_var_value('group/split_r22/alpha', sess))
295 | feature_indices1 = np.argsort(np.argmax(split_p2, axis=0))
296 | feature_indices2 = np.argsort(np.argmax(split_q2, axis=0))
297 | feature_indices3 = np.argsort(np.argmax(split_r21, axis=0))
298 | feature_indices4 = np.argsort(np.argmax(split_r22, axis=0))
299 | unit_3_0_shortcut_img = np.abs(unit_3_0_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2], filters[3]))
300 | unit_3_0_conv_1_img = np.abs(unit_3_0_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[3] * 3))
301 | unit_3_0_conv_2_img = np.abs(unit_3_0_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
302 | unit_3_1_conv_1_img = np.abs(unit_3_1_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
303 | unit_3_1_conv_2_img = np.abs(unit_3_1_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[3] * 3, filters[3] * 3))
304 | img_summaries.append(img_to_summary(np.repeat(split_p2[:, feature_indices1], 20, axis=0), 'split_p2'))
305 | img_summaries.append(img_to_summary(np.repeat(split_r21[:, feature_indices3], 20, axis=0), 'split_r21'))
306 | img_summaries.append(img_to_summary(np.repeat(split_r22[:, feature_indices4], 20, axis=0), 'split_r22'))
307 | img_summaries.append(img_to_summary(unit_3_0_shortcut_img, 'unit_3_0/shortcut_kernel'))
308 | img_summaries.append(img_to_summary(unit_3_0_conv_1_img, 'unit_3_0/conv_1_kernel'))
309 | img_summaries.append(img_to_summary(unit_3_0_conv_2_img, 'unit_3_0/conv_2_kernel'))
310 | img_summaries.append(img_to_summary(unit_3_1_conv_1_img, 'unit_3_1/conv_1_kernel'))
311 | img_summaries.append(img_to_summary(unit_3_1_conv_2_img, 'unit_3_1/conv_2_kernel'))
312 |
313 | if FLAGS.ngroups3 > 1:
314 | unit_2_0_shortcut = get_var_value('unit_2_0/shortcut/kernel', sess)
315 | unit_2_0_conv_1 = get_var_value('unit_2_0/conv_1/kernel', sess)
316 | unit_2_0_conv_2 = get_var_value('unit_2_0/conv_2/kernel', sess)
317 | unit_2_1_conv_1 = get_var_value('unit_2_1/conv_1/kernel', sess)
318 | unit_2_1_conv_2 = get_var_value('unit_2_1/conv_2/kernel', sess)
319 | split_p3 = softmax(get_var_value('group/split_p3/alpha', sess))
320 | split_q3 = _merge_split_q(split_p2, _get_even_merge_idxs(FLAGS.ngroups2, FLAGS.ngroups3))
321 | split_r31 = softmax(get_var_value('group/split_r31/alpha', sess))
322 | split_r32 = softmax(get_var_value('group/split_r32/alpha', sess))
323 | feature_indices1 = np.argsort(np.argmax(split_p3, axis=0))
324 | feature_indices2 = np.argsort(np.argmax(split_q3, axis=0))
325 | feature_indices3 = np.argsort(np.argmax(split_r31, axis=0))
326 | feature_indices4 = np.argsort(np.argmax(split_r32, axis=0))
327 | unit_2_0_shortcut_img = np.abs(unit_2_0_shortcut[:,:,feature_indices1,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[1], filters[2]))
328 | unit_2_0_conv_1_img = np.abs(unit_2_0_conv_1[:,:,feature_indices1,:][:,:,:,feature_indices3].transpose([2,0,3,1]).reshape(filters[1] * 3, filters[2] * 3))
329 | unit_2_0_conv_2_img = np.abs(unit_2_0_conv_2[:,:,feature_indices3,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3))
330 | unit_2_1_conv_1_img = np.abs(unit_2_1_conv_1[:,:,feature_indices2,:][:,:,:,feature_indices4].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3))
331 | unit_2_1_conv_2_img = np.abs(unit_2_1_conv_2[:,:,feature_indices4,:][:,:,:,feature_indices2].transpose([2,0,3,1]).reshape(filters[2] * 3, filters[2] * 3))
332 | img_summaries.append(img_to_summary(np.repeat(split_p3[:, feature_indices1], 20, axis=0), 'split_p3'))
333 | img_summaries.append(img_to_summary(np.repeat(split_r31[:, feature_indices3], 20, axis=0), 'split_r31'))
334 | img_summaries.append(img_to_summary(np.repeat(split_r32[:, feature_indices4], 20, axis=0), 'split_r32'))
335 | img_summaries.append(img_to_summary(unit_2_0_shortcut_img, 'unit_2_0/shortcut_kernel'))
336 | img_summaries.append(img_to_summary(unit_2_0_conv_1_img, 'unit_2_0/conv_1_kernel'))
337 | img_summaries.append(img_to_summary(unit_2_0_conv_2_img, 'unit_2_0/conv_2_kernel'))
338 | img_summaries.append(img_to_summary(unit_2_1_conv_1_img, 'unit_2_1/conv_1_kernel'))
339 | img_summaries.append(img_to_summary(unit_2_1_conv_2_img, 'unit_2_1/conv_2_kernel'))
340 |
341 | if img_summaries: # If not empty
342 | img_summary = tf.Summary(value=img_summaries)
343 | summary_writer.add_summary(img_summary, step)
344 | summary_writer.flush()
345 |
346 |
347 | def get_var_value(var_name, sess):
348 | return [var for var in tf.global_variables() if var_name in var.name][0].eval(session=sess)
349 |
350 |
351 | def softmax(logits, axis=0):
352 | logits_diff = logits - np.max(logits, axis=axis, keepdims=True)
353 | exps = np.exp(logits_diff)
354 | return exps / np.sum(exps, axis=axis, keepdims=True)
355 |
356 |
357 | def img_to_summary(img, tag="img"):
358 | s = StringIO()
359 | plt.imsave(s, img, cmap='bone', format='png')
360 | summary = tf.Summary.Value(tag=tag,
361 | image=tf.Summary.Image(encoded_image_string=s.getvalue(),
362 | height=img.shape[0],
363 | width=img.shape[1]))
364 | return summary
365 |
366 | def _merge_split_q(q, merge_idxs, name='merge'):
367 | ngroups, dim = q.shape
368 | max_idx = np.max(merge_idxs)
369 | temp_list = []
370 | for i in range(max_idx + 1):
371 | temp = []
372 | for j in range(ngroups):
373 | if merge_idxs[j] == i:
374 | temp.append(q[i,:])
375 | temp_list.append(np.sum(temp, axis=0))
376 | ret = np.array(temp_list)
377 |
378 | return ret
379 |
380 |
381 | def _get_even_merge_idxs(N, split):
382 | assert N >= split
383 | num_elems = [(N + split - i - 1)/split for i in range(split)]
384 | expand_split = [[i] * n for i, n in enumerate(num_elems)]
385 | return [t for l in expand_split for t in l]
386 |
387 |
388 | def main(argv=None): # pylint: disable=unused-argument
389 | train()
390 |
391 |
392 | if __name__ == '__main__':
393 | tf.app.run()
394 |
--------------------------------------------------------------------------------