├── LICENSE ├── README.md ├── imgs ├── Fractal.png ├── FractalBlock.png ├── Seperated.png └── fractalnet.png └── src ├── cifar_demo.py ├── fractal_block.py └── mnist_demo.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 edgelord 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FractalNet 2 | Implementation of FractalNet (https://arxiv.org/pdf/1605.07648v1.pdf) in TensorFlow. 3 | 4 | 5 | 6 | Fractal blocks have components with separated columns, as well as the fractal layout. 7 | For rounds of training with global drop path, the separated layout is used. 8 | In all other cases, the fractal layout is used. 9 | 10 | #Fractal 11 | ![Fractal Visualization](imgs/Fractal.png) 12 | 13 | #Separated 14 | ![Separated Visualization](imgs/Seperated.png) 15 | 16 | #Entire Block 17 | ![Block Visualization](imgs/FractalBlock.png) 18 | -------------------------------------------------------------------------------- /imgs/Fractal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorpro/FractalNet/6ab34b74af4aa2c7897f4b7d36ad87ed3c991047/imgs/Fractal.png -------------------------------------------------------------------------------- /imgs/FractalBlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorpro/FractalNet/6ab34b74af4aa2c7897f4b7d36ad87ed3c991047/imgs/FractalBlock.png -------------------------------------------------------------------------------- /imgs/Seperated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorpro/FractalNet/6ab34b74af4aa2c7897f4b7d36ad87ed3c991047/imgs/Seperated.png -------------------------------------------------------------------------------- /imgs/fractalnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorpro/FractalNet/6ab34b74af4aa2c7897f4b7d36ad87ed3c991047/imgs/fractalnet.png -------------------------------------------------------------------------------- /src/cifar_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ Convolutional network applied to CIFAR-10 dataset classification task. 4 | 5 | References: 6 | Learning Multiple Layers of Features from Tiny Images, A. Krizhevsky, 2009. 7 | 8 | Links: 9 | [CIFAR-10 Dataset](https://www.cs.toronto.edu/~kriz/cifar.html) 10 | 11 | """ 12 | from __future__ import division, print_function, absolute_import 13 | 14 | import tflearn 15 | from tflearn.data_utils import shuffle, to_categorical 16 | from tflearn.layers.core import input_data, dropout, fully_connected 17 | from tflearn.layers.conv import conv_2d, max_pool_2d, global_avg_pool 18 | from tflearn.layers.estimator import regression 19 | from tflearn.data_preprocessing import ImagePreprocessing 20 | from tflearn.data_augmentation import ImageAugmentation 21 | from tflearn.activations import softmax 22 | from fractal_block import fractal_conv2d 23 | from tensorflow.contrib import slim 24 | from tflearn.layers.normalization import batch_normalization 25 | 26 | # Data loading and preprocessing 27 | from tflearn.datasets import cifar10 28 | (X, Y), (X_test, Y_test) = cifar10.load_data() 29 | X, Y = shuffle(X, Y) 30 | Y = to_categorical(Y, 10) 31 | Y_test = to_categorical(Y_test, 10) 32 | 33 | # Real-time data preprocessing 34 | img_prep = ImagePreprocessing() 35 | img_prep.add_featurewise_zero_center() 36 | img_prep.add_featurewise_stdnorm() 37 | 38 | # Real-time data augmentation 39 | img_aug = ImageAugmentation() 40 | img_aug.add_random_flip_leftright() 41 | img_aug.add_random_rotation(max_angle=25.) 42 | 43 | # Convolutional network building 44 | net = input_data(shape=[None, 32, 32, 3], 45 | data_preprocessing=img_prep, 46 | data_augmentation=img_aug) 47 | 48 | filters = [64,128,256,512] 49 | for f in filters: 50 | net = fractal_conv2d(net, 4, f, 3, 51 | normalizer_fn=batch_normalization) 52 | net = slim.max_pool2d(net,2, 2) 53 | 54 | net = fractal_conv2d(net, 4, 512, 2, 55 | normalizer_fn=batch_normalization) 56 | 57 | 58 | net = conv_2d(net, 10, 1) 59 | net = global_avg_pool(net) 60 | net = softmax(net) 61 | 62 | net = regression(net, optimizer='adam', 63 | loss='categorical_crossentropy', 64 | learning_rate=.002) 65 | 66 | # Train using classifier 67 | model = tflearn.DNN(net, tensorboard_verbose=0) 68 | model.fit(X, Y, n_epoch=400, shuffle=True, validation_set=(X_test, Y_test), 69 | show_metric=True, batch_size=32, run_id='cifar10_cnn') 70 | -------------------------------------------------------------------------------- /src/fractal_block.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow import transpose 8 | from tensorflow import mul 9 | from tensorflow import nn 10 | from tensorflow.python.ops import nn 11 | from tensorflow.contrib import slim 12 | from tensorflow.contrib.framework.python.ops import add_arg_scope 13 | from tensorflow.contrib.layers.python.layers import initializers, utils 14 | import tflearn 15 | from tflearn import get_training_mode 16 | 17 | def tensor_shape(tensor): 18 | """Helper function to return shape of tensor.""" 19 | return tensor.get_shape().as_list() 20 | 21 | def apply_mask(mask, 22 | columns): 23 | """Uses a boolean mask to zero out some columns. 24 | 25 | Used instead of boolean mask so that output has same 26 | shape as input. 27 | 28 | Args: 29 | mask:boolean tensor. 30 | columns:columns of fractal block. 31 | """ 32 | tensor = tf.convert_to_tensor(columns) 33 | mask = tf.cast(mask, tensor.dtype) 34 | return transpose(mul(transpose(tensor), mask)) 35 | 36 | def random_column(columns): 37 | """Zeros out all except one of `columns`. 38 | 39 | Used for rounds with global drop path. 40 | 41 | Args: 42 | columns: the columns of a fractal block to be selected from. 43 | """ 44 | num_columns = tensor_shape(columns)[0] 45 | mask = tf.random_shuffle([True]+[False]*(num_columns-1)) 46 | return apply_mask(mask, columns)* num_columns 47 | 48 | def drop_some(columns, 49 | drop_prob=.15): 50 | """Zeros out columns with probability `drop_prob`. 51 | 52 | Used for rounds of local drop path. 53 | """ 54 | num_columns = tensor_shape(columns)[0] 55 | mask = tf.random_uniform([num_columns])>drop_prob 56 | scale = num_columns/tf.reduce_sum(tf.cast(mask, tf.float32)) 57 | 58 | return tf.cond(tf.reduce_any(mask), 59 | lambda : apply_mask(mask, columns) * scale, 60 | lambda : random_column(columns)) 61 | 62 | def coin_flip(prob=.5): 63 | """Random boolean variable, with `prob` chance of being true. 64 | 65 | Used to choose between local and global drop path. 66 | 67 | Args: 68 | prob:float, probability of being True. 69 | """ 70 | with tf.variable_op_scope([],None,"CoinFlip"): 71 | coin = tf.random_uniform([1])[0]>prob 72 | return coin 73 | 74 | def drop_path(columns, 75 | coin): 76 | with tf.variable_op_scope([columns], None, "DropPath"): 77 | out = tf.cond(coin, 78 | lambda : drop_some(columns), 79 | lambda : random_column(columns)) 80 | return out 81 | 82 | def join(columns, 83 | coin): 84 | """Takes mean of the columns, applies drop path if 85 | `tflearn.get_training_mode()` is True. 86 | 87 | Args: 88 | columns: columns of fractal block. 89 | is_training: boolean in tensor form. Determines whether drop path 90 | should be used. 91 | coin: boolean in tensor form. Determines whether drop path is 92 | local or global. 93 | """ 94 | if len(columns)==1: 95 | return columns[0] 96 | with tf.variable_op_scope(columns, None, "Join"): 97 | columns = tf.convert_to_tensor(columns) 98 | columns = tf.cond(tflearn.get_training_mode(), 99 | lambda: drop_path(columns, coin), 100 | lambda: columns) 101 | out = tf.reduce_mean(columns, 0) 102 | return out 103 | 104 | def fractal_template(inputs, 105 | num_columns, 106 | block_fn, 107 | block_asc, 108 | joined=True, 109 | is_training=True, 110 | reuse=False, 111 | scope=None): 112 | """Template for making fractal blocks. 113 | 114 | Given a function and a corresponding arg_scope `fractal_template` 115 | will build a truncated fractal with `num_columns` columns. 116 | 117 | Args: 118 | inputs: a 4-D tensor `[batch_size, height, width, channels]`. 119 | num_columns: integer, the columns in the fractal. 120 | block_fn: function to be called within each fractal. 121 | block_as: A function that returns argscope for `block_fn`. 122 | joined: boolean, whether the output columns should be joined. 123 | reuse: whether or not the layer and its variables should be reused. To be 124 | able to reuse the layer scope must be given. 125 | scope: Optional scope for `variable_scope`. 126 | """ 127 | 128 | def fractal_expand(inputs, num_columns, joined): 129 | '''Recursive Helper Function for making fractal''' 130 | with block_asc(): 131 | output = lambda cols: join(cols, coin) if joined else cols 132 | if num_columns == 1: 133 | return output([block_fn(inputs)]) 134 | left = block_fn(inputs) 135 | right = fractal_expand(inputs, num_columns-1, joined=True) 136 | right = fractal_expand(right, num_columns-1, joined=False) 137 | cols=[left]+right 138 | return output(cols) 139 | 140 | with tf.variable_op_scope([inputs], scope, 'Fractal', 141 | reuse=reuse) as scope: 142 | coin = coin_flip() 143 | net=fractal_expand(inputs, num_columns, joined) 144 | 145 | return net 146 | 147 | def fractal_conv2d(inputs, 148 | num_columns, 149 | num_outputs, 150 | kernel_size, 151 | joined=True, 152 | stride=1, 153 | padding='SAME', 154 | # rate=1, 155 | activation_fn=nn.relu, 156 | normalizer_fn=slim.batch_norm, 157 | normalizer_params=None, 158 | weights_initializer=initializers.xavier_initializer(), 159 | weights_regularizer=None, 160 | biases_initializer=None, 161 | biases_regularizer=None, 162 | reuse=None, 163 | variables_collections=None, 164 | outputs_collections=None, 165 | is_training=True, 166 | trainable=True, 167 | scope=None): 168 | """Builds a fractal block with slim.conv2d. 169 | The fractal will have `num_columns` columns, and have 170 | Args: 171 | inputs: a 4-D tensor `[batch_size, height, width, channels]`. 172 | num_columns: integer, the columns in the fractal. 173 | """ 174 | locs = locals() 175 | fractal_args = ['inputs','num_columns','joined','is_training'] 176 | asc_fn = lambda : slim.arg_scope([slim.conv2d], 177 | **{arg:val for (arg,val) in locs.items() 178 | if arg not in fractal_args}) 179 | return fractal_template(inputs, num_columns, slim.conv2d, asc_fn, 180 | joined, is_training, reuse, scope) 181 | -------------------------------------------------------------------------------- /src/mnist_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import tflearn 4 | from tflearn.layers.core import input_data, dropout, fully_connected 5 | from tflearn.layers.conv import conv_2d, max_pool_2d 6 | from tflearn.layers.normalization import local_response_normalization 7 | from tflearn.layers.estimator import regression 8 | from fractal_block import fractal_conv2d 9 | 10 | 11 | # Data loading and preprocessing 12 | import tflearn.datasets.mnist as mnist 13 | X, Y, testX, testY = mnist.load_data(one_hot=True) 14 | X = X.reshape([-1, 28, 28, 1]) 15 | testX = testX.reshape([-1, 28, 28, 1]) 16 | from tensorflow.contrib import slim 17 | # Building convolutional network 18 | net = input_data(shape=[None, 28, 28, 1], name='input') 19 | 20 | # filters = [32,64,128,256] 21 | filters = [4,8] 22 | for f in filters: 23 | net = fractal_conv2d(net, 4, f,3) 24 | net = slim.max_pool2d(net,2) 25 | 26 | 27 | net = fully_connected(net, 10, activation='softmax') 28 | net = regression(net, optimizer='adam', learning_rate=0.01, 29 | loss='categorical_crossentropy', name='target') 30 | 31 | 32 | # Training 33 | model = tflearn.DNN(net, tensorboard_verbose=0) 34 | model.fit({'input': X}, {'target': Y}, n_epoch=20, 35 | validation_set=({'input': testX}, {'target': testY}), 36 | snapshot_step=100, show_metric=True, run_id='convnet_mnist') 37 | 38 | --------------------------------------------------------------------------------