├── README.md ├── cifar_deep.py ├── network_base.py └── vecgen_tf.py /README.md: -------------------------------------------------------------------------------- 1 | ## orthogonal-convolution 2 | Github repo for my experiments with the orthogonal convolution idea 3 | 4 | ### Run orthogonal convolution experiment 5 | ```bash 6 | python cifar_deep.py --result-path result/ 7 | ``` 8 | ### To access all command line arguments 9 | ```bash 10 | python cifar_deep.py -h 11 | ``` 12 | -------------------------------------------------------------------------------- /cifar_deep.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """A deep CIFAR-10 classifier using convolutional layers. 18 | See extensive documentation at 19 | https://www.tensorflow.org/get_started/mnist/pros 20 | """ 21 | # Disable linter warnings to maintain consistency with tutorial. 22 | # pylint: disable=invalid-name 23 | # pylint: disable=g-bad-import-order 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import argparse 30 | import pandas as pd 31 | import sys, os 32 | import time 33 | import tempfile 34 | from collections import defaultdict 35 | from network_base import Network 36 | 37 | from tensorflow.examples.tutorials.mnist import input_data 38 | 39 | import tensorflow as tf 40 | 41 | EVAL_BATCH_COUNT = 100 42 | 43 | class deepnn(Network): 44 | def __init__(self, ortho_conv=False, nonlin='relu'): 45 | tf.set_random_seed(99) 46 | assert nonlin in ['relu', 'selu'] 47 | if nonlin=='relu': 48 | self._nonlin = self._relu 49 | elif nonlin=='selu': 50 | self._nonlin = self._selu 51 | 52 | if ortho_conv: 53 | self._conv_weights_fn = self._ortho_weight_bias_variable 54 | else: 55 | self._conv_weights_fn = self._weight_bias_variable 56 | 57 | def forward(self, images, labels): 58 | num_classes = labels.get_shape().as_list()[-1] 59 | self.ortho_loss = 0 60 | with tf.variable_scope('conv1'): 61 | W_conv, b_conv = self._conv_weights_fn([5, 5, 3, 32]) 62 | self.ortho_loss += self._ortho_loss(W_conv, b_conv) 63 | h = self._nonlin(self._conv2d(images, W_conv, b_conv)) 64 | 65 | with tf.variable_scope('conv2'): 66 | W_conv, b_conv = self._conv_weights_fn([3, 3, 32, 32]) 67 | self.ortho_loss += self._ortho_loss(W_conv, b_conv) 68 | h = self._nonlin(self._conv2d(h, W_conv, b_conv, padding='VALID')) 69 | 70 | h = self._max_pool_2x2(h) 71 | 72 | with tf.variable_scope('conv3'): 73 | W_conv, b_conv = self._conv_weights_fn([3, 3, 32, 64]) 74 | self.ortho_loss += self._ortho_loss(W_conv, b_conv) 75 | h = self._nonlin(self._conv2d(h, W_conv, b_conv)) 76 | 77 | with tf.variable_scope('conv4'): 78 | W_conv, b_conv = self._conv_weights_fn([3, 3, 64, 64]) 79 | self.ortho_loss += self._ortho_loss(W_conv, b_conv) 80 | h = self._nonlin(self._conv2d(h, W_conv, b_conv, padding='VALID')) 81 | 82 | h = self._max_pool_2x2(h, padding='VALID') 83 | 84 | h = self._flatten(h) 85 | 86 | with tf.variable_scope('fc1'): 87 | W, b = self._weight_bias_variable([6 * 6 * 64, 512]) 88 | h = self._nonlin(self._dense(h, W, b)) 89 | 90 | with tf.variable_scope('fc2'): 91 | W, b = self._weight_bias_variable([512, num_classes]) 92 | logits = self._dense(h, W, b) 93 | 94 | cross_entropy_vector = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) 95 | self.cross_entropy = tf.reduce_mean(cross_entropy_vector) 96 | 97 | correct_prediction = tf.equal(tf.argmax(labels, 1), tf.argmax(logits, 1)) 98 | correct_prediction = tf.cast(correct_prediction, tf.float32) 99 | self.accuracy = tf.reduce_mean(correct_prediction) 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--dataset', choices=['cifar10', 'cifar100'], default='cifar10', 104 | type=str, help='Dataset to train the model on (default %(default)s)') 105 | parser.add_argument('--result-path', default='result', type=str, 106 | help='Directory for storing training and eval logs') 107 | parser.add_argument('--ortho-conv', action='store_true', default=False, 108 | help='use orthogonal convolution') 109 | parser.add_argument('--nonlin', choices=['relu', 'selu'], default='relu', 110 | type=str, help='nonlinearity to use (default %(default)s)') 111 | parser.add_argument('--ortho-weight', default=0, type=float, 112 | help='weight given to ortho loss (default %(default)s)') 113 | parser.add_argument('--num-epochs', default=50, type=int, 114 | help='number of epochs (default %(default)s)') 115 | 116 | options = parser.parse_args() 117 | assert not(os.path.exists(options.result_path)), "result dir already exists!" 118 | 119 | if options.dataset=='cifar10': 120 | (x_train, y_train), (x_test, y_test) = tf.contrib.keras.datasets.cifar10.load_data() 121 | num_classes = 10 122 | elif options.dataset=='cifar100': 123 | (x_train, y_train), (x_test, y_test) = tf.contrib.keras.datasets.cifar100.load_data() 124 | num_classes = 100 125 | else: 126 | raise ValueError('Invalid dataset name') 127 | 128 | x_train = x_train/255. 129 | y_train = tf.contrib.keras.utils.to_categorical(y_train, num_classes) 130 | x_test = x_test/255. 131 | y_test = tf.contrib.keras.utils.to_categorical(y_test, num_classes) 132 | 133 | result_path = options.result_path 134 | 135 | net = deepnn(ortho_conv=options.ortho_conv, nonlin=options.nonlin) 136 | images = tf.placeholder(tf.float32, [None, 32, 32, 3]) 137 | labels = tf.placeholder(tf.float32, [None, num_classes]) 138 | 139 | logits = net.forward(images, labels) 140 | cross_entropy = net.cross_entropy 141 | accuracy = net.accuracy 142 | ortho_loss = net.ortho_loss 143 | wtd_ortho_loss = ortho_loss*options.ortho_weight 144 | total_loss = cross_entropy + wtd_ortho_loss 145 | 146 | optimizer = tf.train.AdamOptimizer(learning_rate=3e-4) 147 | train_step = optimizer.minimize(total_loss) 148 | 149 | config = tf.ConfigProto() 150 | config.gpu_options.allow_growth = True 151 | sess = tf.Session(config=config) 152 | train_metrics = defaultdict(list) 153 | eval_metrics = defaultdict(list) 154 | os.makedirs(result_path) 155 | models_path = os.path.join(result_path, 'models') 156 | 157 | sess.run(tf.global_variables_initializer()) 158 | train_start_time = time.time() 159 | iters = 0 160 | for i in range(options.num_epochs): 161 | for j in range(0, len(x_train), 128): 162 | iters = iters + 1 163 | _, train_cross_entropy, train_accuracy, train_ortho_loss, train_wtd_ortho_loss, train_total_loss = sess.run([train_step, cross_entropy, accuracy, 164 | ortho_loss, wtd_ortho_loss, total_loss], 165 | feed_dict={images:x_train[j:j+128], labels:y_train[j:j+128]}) 166 | train_metrics['time_per_iter'].append((time.time() - train_start_time)/iters) 167 | train_metrics['iteration'].append(iters) 168 | train_metrics['cross_entropy'].append(train_cross_entropy) 169 | train_metrics['accuracy'].append(train_accuracy) 170 | train_metrics['ortho_loss'].append(train_ortho_loss) 171 | train_metrics['wtd_ortho_loss'].append(train_wtd_ortho_loss) 172 | 173 | if (iters-1) % 100 == 0: 174 | eval_start_time = time.time() 175 | eval_cross_entropy, eval_accuracy = sess.run([cross_entropy, accuracy], feed_dict={images:x_test, labels:y_test}) 176 | eval_metrics['time_per_iter'].append(time.time() - eval_start_time) 177 | eval_metrics['iteration'].append(iters) 178 | eval_metrics['cross_entropy'].append(eval_cross_entropy) 179 | eval_metrics['accuracy'].append(eval_accuracy) 180 | print('step %d, train accuracy %g, train cross entropy %g, train ortho loss %g, train weighted ortho loss %g, train total loss %g' % (iters, 181 | train_accuracy, 182 | train_cross_entropy, 183 | train_ortho_loss, 184 | train_wtd_ortho_loss, 185 | train_total_loss 186 | )) 187 | print('eval accuracy %g, eval loss %g' % (eval_accuracy, eval_cross_entropy)) 188 | 189 | saver = tf.train.Saver() 190 | model_name = 'iter_{}'.format(iters) 191 | model_filepath = os.path.join(models_path, model_name) 192 | saver.save(sess, model_filepath, write_meta_graph=False, write_state=False) 193 | 194 | pd_train_metrics = pd.DataFrame(train_metrics) 195 | pd_eval_metrics = pd.DataFrame(eval_metrics) 196 | 197 | pd_train_metrics.to_csv(os.path.join(result_path, 'train_metrics.csv')) 198 | pd_eval_metrics.to_csv(os.path.join(result_path, 'eval_metrics.csv')) 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /network_base.py: -------------------------------------------------------------------------------- 1 | from vecgen_tf import orthoconv_filter 2 | from functools import reduce 3 | from operator import mul 4 | import tensorflow as tf 5 | 6 | class Network(object): 7 | def _conv2d(self, x, weight, bias, padding='SAME'): 8 | """conv2d returns a 2d convolution layer with full stride.""" 9 | return tf.nn.conv2d(x, weight, strides=[1, 1, 1, 1], padding=padding) + bias 10 | 11 | def _max_pool_2x2(self, x, padding='SAME'): 12 | """max_pool_2x2 downsamples a feature map by 2X.""" 13 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 14 | strides=[1, 2, 2, 1], padding=padding) 15 | 16 | def _flatten(self, x): 17 | shape = x.get_shape().as_list() 18 | shape_product = reduce((lambda x, y: x * y), shape[1:]) 19 | return tf.reshape(x, [-1, shape_product]) 20 | 21 | def _dense(self, x, weight, bias): 22 | return tf.matmul(x, weight) + bias 23 | 24 | def _relu(self, x): 25 | return tf.nn.relu(x) 26 | 27 | def _selu(self, x): 28 | """selu, self normalizing activation function""" 29 | alpha = 1.6732632423543772848170429916717 30 | scale = 1.0507009873554804934193349852946 31 | return scale * tf.where(tf.less(x, 0.0), alpha * tf.nn.elu(x), x) 32 | 33 | def _ortho_loss(self, weight, bias): 34 | shape = weight.get_shape().as_list() 35 | m = reduce(mul, shape[:-1]) 36 | n = shape[-1] 37 | weight = tf.reshape(weight, shape=[m, n]) 38 | tensor_var = tf.concat([weight, bias], 0) 39 | tensor_mul = tf.matmul(tf.transpose(tensor_var), tensor_var) 40 | tensor_norm = tf.norm(tensor_var, axis=0) 41 | tensor_norm_mul = tf.matmul(tf.expand_dims(tensor_norm, axis=1), tf.expand_dims(tensor_norm, axis=0)) 42 | cosine_tensor = tf.divide(tensor_mul, tensor_norm_mul) 43 | cosine_tensor_sq = tf.square(cosine_tensor) 44 | loss = tf.reduce_sum(cosine_tensor_sq) - tf.reduce_sum(tf.trace(cosine_tensor_sq)) 45 | return loss 46 | 47 | def _ortho_weight_bias_variable(self, shape): 48 | """generates an orthogonal weight and bias variable of a given shape.""" 49 | assert len(shape)==4 50 | [in_height, in_width, in_channels, out_channels] = shape 51 | weight, bias = orthoconv_filter(in_height, in_width, in_channels, out_channels, bias=True) 52 | scale_weight = tf.Variable(tf.constant(1., shape=weight.get_shape().as_list())) 53 | scale_bias = tf.Variable(tf.constant(1., shape=bias.get_shape().as_list())) 54 | return scale_weight*weight, scale_bias*bias 55 | 56 | def _weight_bias_variable(self, shape): 57 | """generates a weight and bias variable of a given shape.""" 58 | m = reduce(mul, shape[:-1]) + 1 59 | n = shape[-1] 60 | limit = tf.sqrt(6./(m + n)) 61 | tensor_var = tf.Variable(tf.random_uniform([m, n], minval=-limit, maxval=limit)) 62 | weight, bias = tf.slice(tensor_var, [0, 0], [m-1, n]), tf.slice(tensor_var, [m-1, 0], [1, n]) 63 | weight = tf.reshape(weight, shape) 64 | return weight, bias 65 | -------------------------------------------------------------------------------- /vecgen_tf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import tensorflow as tf 4 | 5 | def check(w,b): 6 | w = np.reshape(w, [-1, w.shape[-1]]) 7 | e = np.concatenate([w,b], axis=0) 8 | print(np.linalg.norm(e, axis=0)) 9 | print(e.shape) 10 | min_dim = min(np.shape(e)) 11 | ce = np.eye(min_dim, min_dim, dtype=np.float64) 12 | re = np.matmul(np.transpose(e), e) 13 | print(np.linalg.norm(ce-re)) 14 | 15 | def orthoconv_filter(in_height, in_width, in_channels, out_channels, 16 | num_ortho=None, bias=False, mode='train'): 17 | if(mode=='eval'): 18 | conv_filter = tf.get_variable(name='conv_filter') 19 | conv_bias = tf.get_variable(name='conv_bias') 20 | return conv_filter, conv_bias 21 | m = in_height * in_width * in_channels 22 | n = out_channels 23 | if num_ortho is None: num_ortho = n 24 | if(bias): 25 | m = m + 1 26 | num_free_weights = n*m-((n*n+n)/2) 27 | free_weights = tf.get_variable(name='free', shape=[num_free_weights + n], 28 | initializer=tf.random_uniform_initializer(-1, 1)) 29 | id_mat = tf.eye(m, dtype=tf.float32) 30 | initial_mat = tf.eye(m, n) 31 | vec_sizes = tf.range(m-num_ortho+1, m+1) 32 | start_indices = tf.cumsum(vec_sizes, exclusive=True) 33 | 34 | def find_HHvecs(start_idx, vec_size): 35 | free_weights_slice = tf.slice(free_weights, [start_idx], [vec_size]) 36 | paddings = [[m-vec_size, 0]] 37 | 38 | weights_vec = tf.pad(free_weights_slice, paddings, "CONSTANT") 39 | weights_vec = tf.nn.l2_normalize(weights_vec, dim=0) 40 | return weights_vec 41 | 42 | weights_vecs = tf.map_fn(lambda x: find_HHvecs(x[0], x[1]), (start_indices, vec_sizes), dtype=tf.float32) 43 | ortho_matrices = tf.expand_dims(id_mat, axis=0) - 2*(tf.expand_dims(weights_vecs, axis=2)*tf.expand_dims(weights_vecs, axis=1)) 44 | 45 | weights_matrix = tf.foldl(tf.matmul, ortho_matrices, parallel_iterations=min(n/2, 32)) 46 | filter_tensor = tf.matmul(weights_matrix, initial_mat) 47 | bias_tensor = tf.zeros([1, n]) 48 | if(bias): 49 | filter_tensor, bias_tensor = tf.slice(filter_tensor, [0, 0], [m-1, n]), tf.slice(filter_tensor, [m-1, 0], [1, n]) 50 | filter_tensor = tf.reshape(filter_tensor, shape=[in_height, in_width, in_channels, out_channels]) 51 | conv_filter = tf.get_variable(name='conv_filter', 52 | shape=[in_height, in_width, in_channels, out_channels], 53 | trainable=False, initializer=tf.zeros_initializer) 54 | conv_bias = tf.get_variable(name='conv_bias', 55 | shape=[1, out_channels], trainable=False, 56 | initializer=tf.zeros_initializer) 57 | 58 | conv_filter = conv_filter.assign(filter_tensor) 59 | conv_bias = conv_bias.assign(bias_tensor) 60 | tf.add_to_collection('ortho_compute', filter_tensor) 61 | tf.add_to_collection('ortho_compute', bias_tensor) 62 | tf.add_to_collection('ortho_assign', conv_filter) 63 | tf.add_to_collection('ortho_assign', conv_bias) 64 | return filter_tensor, bias_tensor 65 | 66 | 67 | def main(): 68 | # my code here 69 | a = orthoconv_filter(3, 3, 16, 32, bias=True) 70 | init_op = tf.global_variables_initializer() 71 | with tf.Session() as sess: 72 | sess.run(init_op) 73 | x,y, = sess.run([a[0], a[1]]) 74 | check(x,y) 75 | 76 | if __name__ == "__main__": 77 | main() 78 | --------------------------------------------------------------------------------