├── README.md ├── op_utils.py ├── test.py ├── dataloader ├── CIFAR.py └── ILSVRC.py ├── nets ├── ResNet.py └── tcl.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Tensorflow2 with jit (xla) compiling on multi-gpu training. 2 | - CIFAR and ILSVRC training code with **jit compiling** and **distributed learning** on the multi-GPU system. 3 | - I highly recommend using Jit compiling because most of the algorithm is static and can be compiled, which gives memory usage reduction and training speed improvement. 4 | - This repository is built by **custom layers** and **custom training loop** for my project, but if you only want to check how to use jit compiling with distributed learning, check 'train.py' and 'op_util.py'. 5 | 6 | ## Requirement 7 | - **Tensorflow >= 2.5** 8 | - Pillow 9 | 10 | ## How to run 11 | - ILSVRC 12 | ``` 13 | python train.py --compile --gpu_id {} --dataset ILSVRC --data_path /path/to/your/ILSVRC/home --train_path /path/to/log 14 | ``` 15 | 16 | - CIFAR{10,100} 17 | ``` 18 | python train.py --compile --gpu_id {} --dataset CIFAR{10,100} --train_path /path/to/log 19 | ``` 20 | 21 | ## Experimental results 22 | - I used four 1080ti. 23 | - Jit compiling gives a 40% speedup for training time. 24 | 25 | | | Accuracy | Training time 26 | |------------| ------------- | ------------- 27 | |Distributed only | 75.83 | 94.61 28 | |Distributed with Jit | 75.57 | 56.98 29 | 30 |

31 | Training plot of ResNet-56 on ILSVRC-2012 32 |

33 | -------------------------------------------------------------------------------- /op_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def Optimizer(args, model, strategy): 4 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = tf.keras.losses.Reduction.SUM) 5 | optimizer = tf.keras.optimizers.SGD(args.learning_rate, .9, nesterov=True) 6 | 7 | train_loss = tf.keras.metrics.Mean(name='train_loss') 8 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') 9 | 10 | @tf.function(jit_compile = args.compile) 11 | def compiled_step(images, labels): 12 | with tf.GradientTape() as tape: 13 | pred = model(images, training = True) 14 | total_loss = loss_object(labels, pred)/args.batch_size 15 | gradients = tape.gradient(total_loss, model.trainable_variables) 16 | update_vars = [model.Layers[k].update_var if hasattr(model.Layers[k], 'update_var') else None for k in model.Layers ] 17 | return total_loss, pred, gradients, update_vars 18 | 19 | def train_step(images, labels): 20 | total_loss, pred, gradients, update_vars = compiled_step(images, labels) 21 | if args.weight_decay > 0.: 22 | gradients = [g+v*args.weight_decay for g,v in zip(gradients, model.trainable_variables)] 23 | optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 24 | for k, v in zip(model.Layers, update_vars): 25 | if hasattr(model.Layers[k], 'update'): 26 | model.Layers[k].update(v) 27 | 28 | train_loss.update_state(total_loss) 29 | train_accuracy.update_state(labels, pred) 30 | 31 | @tf.function 32 | def train_step_dist(image, labels): 33 | strategy.run(train_step, args= (image, labels)) 34 | 35 | return train_step_dist, train_loss, train_accuracy, optimizer 36 | 37 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | import argparse 5 | import tensorflow as tf 6 | 7 | from dataloader import ILSVRC, CIFAR 8 | import utils 9 | 10 | home_path = os.path.dirname(os.path.abspath(__file__)) 11 | parser = argparse.ArgumentParser(description='') 12 | 13 | parser.add_argument("--arch", default='ResNet-50', type=str) 14 | parser.add_argument("--dataset", default='ILSVRC', type=str) 15 | 16 | parser.add_argument("--val_batch_size", default=256, type=int) 17 | parser.add_argument("--trained_param", default = 'res50_ilsvrc.pkl',type=str) 18 | parser.add_argument("--data_path", default = '/home/cvip/nas/ssd/ILSVRC2012',type=str) 19 | 20 | parser.add_argument("--gpu_id", default= [0], type=int, nargs = '+') 21 | parser.add_argument("--compile", default = False, action = 'store_true') 22 | 23 | args = parser.parse_args() 24 | 25 | args.home_path = os.path.dirname(os.path.abspath(__file__)) 26 | args.input_size = [224,224,3] 27 | 28 | if __name__ == '__main__': 29 | gpus = tf.config.list_physical_devices('GPU') 30 | tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[i] for i in args.gpu_id], 'GPU') 31 | for gpu_id in args.gpu_id: 32 | tf.config.experimental.set_memory_growth(gpus[gpu_id], True) 33 | devices = ['/gpu:{}'.format(i) for i in args.gpu_id] 34 | strategy = tf.distribute.MirroredStrategy(devices, cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 35 | 36 | with strategy.scope(): 37 | if args.dataset == 'ILSVRC': 38 | datasets = ILSVRC.build_dataset_providers(args, strategy, test_only = True) 39 | elif 'CIFAR' in args.dataset: 40 | datasets = CIFAR.build_dataset_providers(args, strategy, test_only = True) 41 | 42 | model = utils.load_model(args, datasets['num_classes'], args.trained_param) 43 | 44 | top1_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='top1_accuracy') 45 | top5_accuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5_accuracy') 46 | 47 | @tf.function(experimental_compile = args.compile) 48 | def compiled_step(images): 49 | return model(images, training = False) 50 | 51 | def test_step(images, labels): 52 | pred = compiled_step(images, labels) 53 | top1_accuracy.update_state(labels, pred) 54 | top5_accuracy.update_state(labels, pred) 55 | 56 | @tf.function 57 | def test_step_dist(images, labels): 58 | strategy.run(test_step, args=(images, labels)) 59 | 60 | for i, (test_images, test_labels) in enumerate(datasets['test']): 61 | test_step_dist(test_images, test_labels) 62 | 63 | top1_acc = top1_accuracy.result().numpy() 64 | top5_acc = top5_accuracy.result().numpy() 65 | top1_accuracy.reset_states() 66 | top5_accuracy.reset_states() 67 | print ('Test ACC. Top-1: %.4f, Top-5: %.4f'%(top1_acc, top5_acc)) -------------------------------------------------------------------------------- /dataloader/CIFAR.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def build_dataset_providers(args, strategy, test_only = False): 5 | if args.dataset == 'CIFAR10': 6 | train_images, train_labels, test_images, test_labels, pre_processing = Cifar10(args) 7 | if args.dataset == 'CIFAR100': 8 | train_images, train_labels, test_images, test_labels, pre_processing = Cifar100(args) 9 | 10 | options = tf.data.Options() 11 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 12 | 13 | test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)) 14 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE) 15 | test_ds = test_ds.batch(args.val_batch_size).cache() 16 | test_ds = test_ds.with_options(options) 17 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE) 18 | 19 | if test_only: 20 | return {'test': test_ds, 'num_classes' : int(args.dataset[5:])} 21 | 22 | train_ds = [train_images, train_labels] 23 | 24 | train_ds = tf.data.Dataset.from_tensor_slices(tuple(train_ds)).cache() 25 | train_ds = train_ds.map(pre_processing(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE) 26 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size) 27 | train_ds = train_ds.with_options(options) 28 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE) 29 | 30 | datasets = { 31 | 'train': train_ds.repeat( args.train_epoch ), 32 | 'test': test_ds, 33 | } 34 | 35 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets} 36 | datasets['train_len'] = train_ds.cardinality().numpy() 37 | datasets['num_classes'] = int(args.dataset[5:]) 38 | args.input_size = [32,32,3] 39 | 40 | print('Datasets are built') 41 | return datasets 42 | 43 | def Cifar10(args): 44 | from tensorflow.keras.datasets.cifar10 import load_data 45 | (train_images, train_labels), (val_images, val_labels) = load_data() 46 | 47 | def pre_processing(is_training = False): 48 | def training(image, *argv): 49 | sz = tf.shape(image) 50 | 51 | image = tf.cast(image, tf.float32) 52 | 53 | image0 = image 54 | 55 | image0 = (image0-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0]) 56 | image0 = tf.image.random_flip_left_right(image0) 57 | image0 = tf.pad(image0, [[4,4],[4,4],[0,0]], 'REFLECT') 58 | image0 = tf.image.random_crop(image0,sz) 59 | 60 | return [image0] + [arg for arg in argv] 61 | 62 | def inference(image, label): 63 | image = tf.cast(image, tf.float32) 64 | image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0]) 65 | return image, label 66 | 67 | return training if is_training else inference 68 | return train_images, train_labels, val_images, val_labels, pre_processing 69 | 70 | def Cifar100(args): 71 | from tensorflow.keras.datasets.cifar100 import load_data 72 | (train_images, train_labels), (val_images, val_labels) = load_data() 73 | 74 | def pre_processing(is_training = False): 75 | @tf.function 76 | def training(image, *argv): 77 | sz = tf.shape(image) 78 | 79 | image = tf.cast(image, tf.float32) 80 | image0 = image 81 | 82 | image0 = (image0-np.array([112,124,129]))/np.array([70,65,68]) 83 | image0 = tf.image.random_flip_left_right(image0) 84 | image0 = tf.pad(image0, [[4,4],[4,4],[0,0]], 'REFLECT') 85 | image0 = tf.image.random_crop(image0,sz) 86 | 87 | return [image0] + [arg for arg in argv] 88 | 89 | @tf.function 90 | def inference(image, label): 91 | image = tf.cast(image, tf.float32) 92 | image = (image-np.array([112,124,129]))/np.array([70,65,68]) 93 | return image, label 94 | 95 | return training if is_training else inference 96 | return train_images, train_labels, val_images, val_labels, pre_processing 97 | -------------------------------------------------------------------------------- /nets/ResNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from nets import tcl 3 | 4 | class Model(tf.keras.Model): 5 | def __init__(self, num_layers, num_class, name = 'ResNet', trainable = True, **kwargs): 6 | super(Model, self).__init__(name = name, **kwargs) 7 | def kwargs(**kwargs): 8 | return kwargs 9 | setattr(tcl.Conv2d, 'pre_defined', kwargs(use_biases = False, activation_fn = None, trainable = trainable)) 10 | setattr(tcl.BatchNorm, 'pre_defined', kwargs(trainable = trainable)) 11 | setattr(tcl.FC, 'pre_defined', kwargs(trainable = trainable)) 12 | 13 | self.num_layers = num_layers 14 | 15 | self.Layers = {} 16 | network_argments = { 17 | ## ILSVRC 18 | 18 : {'blocks' : [2,2,2,2],'depth' : [64,128,256,512], 'strides' : [1,2,2,2]}, 19 | 50 : {'blocks' : [3,4,6,3],'depth' : [64,128,256,512], 'strides' : [1,2,2,2]}, 20 | 21 | ## CIFAR 22 | 56 : {'blocks' : [9,9,9],'depth' : [16,32,64], 'strides' : [1,2,2]}, 23 | } 24 | self.net_args = network_argments[self.num_layers] 25 | 26 | 27 | if num_class == 1000: 28 | self.Layers['conv'] = tcl.Conv2d([7,7], self.net_args['depth'][0], strides = 2, name = 'conv') 29 | self.Layers['bn'] = tcl.BatchNorm(name = 'bn') 30 | self.maxpool_3x3 = tf.keras.layers.MaxPool2D((3,3), strides = 2, padding = 'SAME') 31 | 32 | else: 33 | self.Layers['conv'] = tcl.Conv2d([3,3], self.net_args['depth'][0], name = 'conv') 34 | self.Layers['bn'] = tcl.BatchNorm(name = 'bn') 35 | 36 | self.expansion = 1 if self.num_layers in {18, 56} else 4 37 | in_depth = self.net_args['depth'][0] 38 | for i, (nb_resnet_layers, depth, strides) in enumerate(zip(self.net_args['blocks'], self.net_args['depth'], self.net_args['strides'])): 39 | for j in range(nb_resnet_layers): 40 | name = 'BasicBlock%d.%d/'%(i,j) 41 | if j != 0: 42 | strides = 1 43 | 44 | if strides > 1 or depth * self.expansion != in_depth: 45 | self.Layers[name + 'conv3'] = tcl.Conv2d([1,1], depth * self.expansion, strides = strides, name = name +'conv3') 46 | self.Layers[name + 'bn3'] = tcl.BatchNorm(name = name + 'bn3') 47 | 48 | if self.num_layers in {18, 56}: 49 | self.Layers[name + 'conv1'] = tcl.Conv2d([3,3], depth, strides = strides, name = name + 'conv1') 50 | self.Layers[name + 'bn1'] = tcl.BatchNorm( name = name + 'bn1') 51 | self.Layers[name + 'conv2'] = tcl.Conv2d([3,3], depth * self.expansion, name = name + 'conv2') 52 | self.Layers[name + 'bn2'] = tcl.BatchNorm( name = name + 'bn2') 53 | 54 | else: 55 | self.Layers[name + 'conv0'] = tcl.Conv2d([1,1], depth, name = name + 'conv0') 56 | self.Layers[name + 'bn0'] = tcl.BatchNorm( name = name + 'bn0') 57 | self.Layers[name + 'conv1'] = tcl.Conv2d([3,3], depth, strides = strides, name = name + 'conv1') 58 | self.Layers[name + 'bn1'] = tcl.BatchNorm( name = name + 'bn1') 59 | self.Layers[name + 'conv2'] = tcl.Conv2d([1,1], depth * self.expansion, name = name + 'conv2') 60 | self.Layers[name + 'bn2'] = tcl.BatchNorm( name = name + 'bn2',) 61 | #param_initializers = {'gamma': tf.keras.initializers.Zeros()}) 62 | 63 | in_depth = depth * self.expansion 64 | 65 | self.Layers['fc'] = tcl.FC(num_class, name = 'fc') 66 | 67 | def call(self, x, training=None): 68 | x = self.Layers['conv'](x) 69 | x = self.Layers['bn'](x) 70 | x = tf.nn.relu(x) 71 | if hasattr(self, 'maxpool_3x3'): 72 | x = self.maxpool_3x3(x) 73 | 74 | in_depth = self.net_args['depth'][0] 75 | 76 | for i, (nb_resnet_layers, depth, strides) in enumerate(zip(self.net_args['blocks'], self.net_args['depth'], self.net_args['strides'])): 77 | for j in range(nb_resnet_layers): 78 | name = 'BasicBlock%d.%d/'%(i,j) 79 | if j != 0: 80 | strides = 1 81 | 82 | if strides > 1 or depth * self.expansion != in_depth: 83 | residual = self.Layers[name + 'conv3'](x) 84 | residual = self.Layers[name + 'bn3'](residual) 85 | else: 86 | residual = x 87 | 88 | if self.num_layers not in {18, 56}: 89 | x = self.Layers[name + 'conv0'](x) 90 | x = self.Layers[name + 'bn0'](x) 91 | x = tf.nn.relu(x) 92 | 93 | x = self.Layers[name + 'conv1'](x) 94 | x = self.Layers[name + 'bn1'](x) 95 | x = tf.nn.relu(x) 96 | 97 | x = self.Layers[name + 'conv2'](x) 98 | x = self.Layers[name + 'bn2'](x) 99 | x = tf.nn.relu(x + residual) 100 | in_depth = depth * self.expansion 101 | 102 | x = tf.reduce_mean(x, [1,2]) 103 | x = self.Layers['fc'](x) 104 | 105 | return x 106 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, time, argparse 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | tf.debugging.set_log_device_placement(False) 7 | 8 | from dataloader import ILSVRC, CIFAR 9 | import op_utils, utils 10 | 11 | parser = argparse.ArgumentParser(description='') 12 | parser.add_argument("--train_path", default="test", type=str, help = 'path to log') 13 | parser.add_argument("--data_path", default="E:/ILSVRC2012", type=str, help = 'home path for ILSVRC dataset') 14 | parser.add_argument("--arch", default='ResNet-50', type=str, help = 'network architecture. currently ResNet is only available') 15 | parser.add_argument("--dataset", default='ILSVRC', type=str, help = 'ILSVRC or CIFAR{10,100}') 16 | 17 | parser.add_argument("--learning_rate", default = 1e-1, type=float, help = 'initial learning rate') 18 | parser.add_argument("--decay_points", default = [.3, .6, .9], type=float, nargs = '+', help = 'learning rate decay point') 19 | parser.add_argument("--decay_rate", default=.1, type=float, help = 'rate to decay at each decay points') 20 | parser.add_argument("--weight_decay", default=1e-4, type=float, help = 'decay parameter for l2 regularizer') 21 | parser.add_argument("--batch_size", default = 256, type=int, help = 'training batch size') 22 | parser.add_argument("--val_batch_size", default=256, type=int, help = 'validation batch size') 23 | parser.add_argument("--train_epoch", default=100, type=int, help = 'total training epoch') 24 | 25 | parser.add_argument("--gpu_id", default= [0], type=int, nargs = '+', help = 'denote which gpus are used') 26 | parser.add_argument("--do_log", default=200, type=int, help = 'logging period') 27 | parser.add_argument("--compile", default=False, action = 'store_true', help = 'denote use compile or not. True is recommended in this repo') 28 | args = parser.parse_args() 29 | 30 | args.home_path = os.path.dirname(os.path.abspath(__file__)) 31 | args.decay_points = [int(dp*args.train_epoch) if dp < 1 else int(dp) for dp in args.decay_points] 32 | 33 | if args.dataset == 'ILSVRC': 34 | args.weight_decay /= len(args.gpu_id) 35 | args.learning_rate *= args.batch_size/256 36 | 37 | if __name__ == '__main__': 38 | gpus = tf.config.list_physical_devices('GPU') 39 | tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[i] for i in args.gpu_id], 'GPU') 40 | for gpu_id in args.gpu_id: 41 | tf.config.experimental.set_memory_growth(gpus[gpu_id], True) 42 | devices = ['/gpu:{}'.format(i) for i in args.gpu_id] 43 | strategy = tf.distribute.MirroredStrategy(devices, cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 44 | 45 | with strategy.scope(): 46 | if args.dataset == 'ILSVRC': 47 | datasets = ILSVRC.build_dataset_providers(args, strategy) 48 | elif 'CIFAR' in args.dataset: 49 | datasets = CIFAR.build_dataset_providers(args, strategy) 50 | model = utils.load_model(args, datasets['num_classes']) 51 | 52 | summary_writer = tf.summary.create_file_writer(args.train_path, flush_millis = 30000) 53 | with summary_writer.as_default(): 54 | utils.save_code_and_augments(args) 55 | total_train_time = 0 56 | 57 | train_step, train_loss, train_accuracy, optimizer = op_utils.Optimizer(args, model, strategy ) 58 | 59 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = tf.keras.losses.Reduction.SUM) 60 | Eval = utils.Evaluation(args, model, strategy, datasets['test'], loss_object) 61 | 62 | print ('Training starts') 63 | fine_tuning_time = 0 64 | tic = time.time() 65 | for step, data in enumerate(datasets['train']): 66 | epoch = step//datasets['train_len'] 67 | lr = utils.scheduler(args, optimizer, epoch) 68 | train_step(*data) 69 | 70 | step += 1 71 | if step % args.do_log == 0: 72 | template = 'Global step {0:5d}: loss = {1:0.4f} ({2:1.3f} sec/step)' 73 | train_time = time.time() - tic 74 | print (template.format(step, train_loss.result()*len(args.gpu_id), train_time/args.do_log)) 75 | fine_tuning_time += train_time 76 | tic = time.time() 77 | 78 | if step % datasets['train_len'] == 0: 79 | tic_ = time.time() 80 | test_acc, test_loss = Eval.run(False) 81 | 82 | tf.summary.scalar('Categorical_loss/train', train_loss.result()*len(args.gpu_id), step=epoch+1) 83 | tf.summary.scalar('Categorical_loss/test', test_loss*len(args.gpu_id), step=epoch+1) 84 | tf.summary.scalar('Accuracy/train', train_accuracy.result()*100, step=epoch+1) 85 | tf.summary.scalar('Accuracy/test', test_acc*100, step=epoch+1) 86 | tf.summary.scalar('learning_rate', lr, step=epoch) 87 | summary_writer.flush() 88 | 89 | template = 'Epoch: {0:3d}, train_loss: {1:0.4f}, train_Acc.: {2:2.2f}, val_loss: {3:0.4f}, val_Acc.: {4:2.2f}' 90 | print (template.format(epoch+1, train_loss.result()*len(args.gpu_id), train_accuracy.result()*100, 91 | test_loss*len(args.gpu_id), test_acc*100)) 92 | 93 | train_loss.reset_states() 94 | train_accuracy.reset_states() 95 | tic += time.time() - tic_ 96 | 97 | utils.save_model(args, model, 'trained_params') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, shutil, glob, pickle, json 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | tf.debugging.set_log_device_placement(False) 7 | 8 | from nets import ResNet 9 | 10 | def scheduler(args, optimizer, epoch): 11 | lr = args.learning_rate 12 | for dp in args.decay_points: 13 | if epoch >= dp: 14 | lr *= args.decay_rate 15 | if epoch in args.decay_points: 16 | optimizer.learning_rate = lr 17 | return lr 18 | 19 | def save_code_and_augments(args): 20 | if os.path.isdir(os.path.join(args.train_path,'codes')): 21 | print ('============================================') 22 | print ('The folder already is. It will be overwrited') 23 | print ('============================================') 24 | else: 25 | os.mkdir(os.path.join(args.train_path,'codes')) 26 | 27 | for code in glob.glob(args.home_path + '/*.py'): 28 | shutil.copyfile(code, os.path.join(args.train_path, 'codes', os.path.split(code)[-1])) 29 | 30 | with open(os.path.join(args.train_path, 'arguments.txt'), 'w') as f: 31 | json.dump(args.__dict__, f, indent=2) 32 | 33 | class Evaluation: 34 | def __init__(self, args, model, strategy, dataset, loss_object): 35 | self.test_loss = tf.keras.metrics.Mean(name='test_loss') 36 | self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') 37 | 38 | @tf.function(jit_compile=args.compile) 39 | def compiled_step(images, labels, training): 40 | pred = model(images, training = training) 41 | loss = loss_object(labels, pred)/args.val_batch_size 42 | return pred, loss 43 | 44 | def eval_step(images, labels, training): 45 | pred, loss = compiled_step(images, labels, training) 46 | self.test_loss.update_state(loss) 47 | self.test_accuracy.update_state(labels, pred) 48 | 49 | 50 | @tf.function 51 | def eval_step_dist(images, labels, training): 52 | strategy.run(eval_step, args=(images, labels, training)) 53 | 54 | self.dataset = dataset 55 | self.step = eval_step_dist 56 | 57 | def run(self, training): 58 | for images, labels in self.dataset: 59 | self.step(images, labels, training) 60 | loss = self.test_loss.result().numpy() 61 | acc = self.test_accuracy.result().numpy() 62 | self.test_loss.reset_states() 63 | self.test_accuracy.reset_states() 64 | return acc, loss 65 | 66 | def load_model(args, num_class, trained_param = None): 67 | if 'ResNet' in args.arch: 68 | arch = int(args.arch.split('-')[1]) 69 | model = ResNet.Model(num_layers = arch, num_class = num_class, name = 'ResNet', trainable = True) 70 | 71 | if trained_param is not None: 72 | with open(trained_param, 'rb') as f: 73 | trained = pickle.load(f) 74 | n = 0 75 | for k in model.Layers.keys(): 76 | layer = model.Layers[k] 77 | if 'conv' in k or 'fc' in k: 78 | kernel = trained[layer.name + '/kernel:0'] 79 | layer.kernel_initializer = tf.constant_initializer(kernel) 80 | n += 1 81 | if layer.use_biases: 82 | layer.biases_initializer = tf.constant_initializer(trained[layer.name + '/biases:0']) 83 | n += 1 84 | layer.num_outputs = kernel.shape[-1] 85 | 86 | elif 'bn' in k: 87 | moving_mean = trained[layer.name + '/moving_mean:0'] 88 | moving_variance = trained[layer.name + '/moving_variance:0'] 89 | param_initializers = {'moving_mean' : tf.constant_initializer(moving_mean), 90 | 'moving_variance': tf.constant_initializer(moving_variance)} 91 | n += 2 92 | if layer.scale: 93 | param_initializers['gamma'] = tf.constant_initializer(trained[layer.name + '/gamma:0']) 94 | n += 1 95 | if layer.center: 96 | param_initializers['beta'] = tf.constant_initializer(trained[layer.name + '/beta:0']) 97 | n += 1 98 | layer.param_initializers = param_initializers 99 | print (n, 'params loaded') 100 | return model 101 | 102 | def build_dataset_providers(args, strategy): 103 | options = tf.data.Options() 104 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 105 | 106 | train_ds = ILSVRC(args, 'train', shuffle = True) 107 | train_ds = train_ds.map(pre_processing(is_training = True, contrastive = args.Knowledge), num_parallel_calls=tf.data.experimental.AUTOTUNE) 108 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size).map(pre_processing_batched(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE) 109 | train_ds = train_ds.with_options(options) 110 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE) 111 | 112 | test_ds = ILSVRC(args, 'val', shuffle = False) 113 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE) 114 | test_ds = test_ds.batch(args.val_batch_size).map(pre_processing_batched(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE) 115 | test_ds = test_ds.with_options(options) 116 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE) 117 | 118 | datasets = { 119 | 'train': train_ds.repeat( args.train_epoch ), 120 | 'test': test_ds 121 | } 122 | 123 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets} 124 | datasets['train_len'] = train_ds.cardinality().numpy() 125 | 126 | print('Datasets are built') 127 | return datasets 128 | 129 | def save_model(args, model, name): 130 | params = {} 131 | for v in model.variables: 132 | if model.name in v.name: 133 | params[v.name[len(model.name)+1:]] = v.numpy() 134 | with open(os.path.join(args.train_path, name + '.pkl'), 'wb') as f: 135 | pickle.dump(params, f) 136 | -------------------------------------------------------------------------------- /dataloader/ILSVRC.py: -------------------------------------------------------------------------------- 1 | import glob, os 2 | import tensorflow as tf 3 | import numpy as np 4 | from PIL import Image 5 | 6 | JPEG_OPT = {'fancy_upscaling': True, 'dct_method': 'INTEGER_ACCURATE'} 7 | 8 | def build_dataset_providers(args, strategy, test_only = False): 9 | options = tf.data.Options() 10 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 11 | 12 | test_ds = ILSVRC(args, 'val', shuffle = False) 13 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE) 14 | test_ds = test_ds.batch(args.val_batch_size).map(pre_processing_batched(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE) 15 | test_ds = test_ds.with_options(options) 16 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE) 17 | 18 | if test_only: 19 | return {'test': test_ds, 'num_classes': 1000} 20 | 21 | 22 | train_ds = ILSVRC(args, 'train', shuffle = True) 23 | train_ds = train_ds.map(pre_processing(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE) 24 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size).map(pre_processing_batched(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE) 25 | train_ds = train_ds.with_options(options) 26 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE) 27 | 28 | datasets = { 29 | 'train': train_ds.repeat( args.train_epoch ), 30 | 'test': test_ds 31 | } 32 | 33 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets} 34 | datasets['train_len'] = train_ds.cardinality().numpy() 35 | datasets['num_classes'] = 1000 36 | args.input_size = [224,224,3] 37 | 38 | print('Datasets are built') 39 | return datasets 40 | 41 | def ILSVRC(args, split = 'train', sample_rate = None, shuffle = False, seed = None, sub_ds = None, saved = False): 42 | if split == 'train': 43 | with open(os.path.join(args.data_path, 'class_to_label.txt'),'r') as f: 44 | CLASSES = f.readlines() 45 | CLASSES = { name.replace('\n','') : l for l, name in enumerate(CLASSES) } 46 | 47 | label_pathes = glob.glob(os.path.join(args.data_path, split, '*')) 48 | 49 | if sample_rate is not None: 50 | if abs(sample_rate) < 1: 51 | min_num_label = min([len(glob.glob(os.path.join(l, '*'))) for l in label_pathes]) 52 | sampled_data_len = int(abs(sample_rate) * min_num_label) 53 | else: 54 | sampled_data_len = abs(sample_rate) 55 | 56 | image_paths = [] 57 | labels = [] 58 | class_names = [] 59 | for name in label_pathes: 60 | image_path = glob.glob(os.path.join(name, '*')) 61 | image_path = [p for p in image_path if 'n02105855_2933.JPEG' not in p] 62 | 63 | if sample_rate is not None: 64 | np.random.seed(seed) 65 | np.random.shuffle(image_path) 66 | if sample_rate < 0: 67 | image_path = image_path[::-1] 68 | image_path = image_path[:sampled_data_len] 69 | image_paths += image_path 70 | labels += [CLASSES[os.path.split(name)[1]]] * len(image_path) 71 | class_names.append(os.path.split(name)[1]) 72 | 73 | if shuffle: 74 | np.random.seed(seed) 75 | idx = np.arange(len(image_paths)) 76 | np.random.shuffle(idx) 77 | image_paths = [image_paths[i] for i in idx] 78 | labels = [labels[i] for i in idx] 79 | 80 | elif split == 'val': 81 | image_paths = glob.glob(os.path.join(args.data_path, split, '*')) 82 | image_paths.sort() 83 | 84 | with open(os.path.join(args.home_path, 'val_gt.txt'),'r') as f: 85 | labels = f.readlines() 86 | 87 | print (split + ' dataset length :', len(image_paths)) 88 | label_arry = np.int64(labels) 89 | img_ds = tf.data.Dataset.from_tensor_slices(image_paths) 90 | label_ds = tf.data.Dataset.from_tensor_slices(label_arry) 91 | dataset = tf.data.Dataset.zip((img_ds, label_ds)) 92 | 93 | return dataset 94 | 95 | def get_size(path): 96 | img = Image.open(path.decode("utf-8")) 97 | w,h = img.size 98 | return np.int32(h), np.int32(w) 99 | 100 | def random_resize_crop(image, height, width): 101 | shape = tf.stack([height, width, 3]) 102 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box( 103 | shape, 104 | bounding_boxes=tf.zeros(shape=[0, 0, 4]), 105 | min_object_covered=0, 106 | aspect_ratio_range=[0.75, 1.33], 107 | area_range=[0.08, 1.0], 108 | max_attempts=10, 109 | use_image_if_no_bounding_boxes=True) 110 | 111 | is_bad = tf.reduce_sum(tf.cast(tf.equal(bbox_size, shape), tf.int32)) >= 2 112 | 113 | if is_bad: 114 | image = tf.image.decode_jpeg(image, channels = 3, **JPEG_OPT) 115 | newh, neww = tf.numpy_function(resizeshortest, [tf.shape(image, tf.int32), 256], [tf.int32, tf.int32]) 116 | image = tf.image.resize(image, (newh,neww), method='bicubic') 117 | image = tf.slice(image, [newh//2-112,neww//2-112,0],[224,224,-1]) 118 | else: 119 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 120 | target_height, target_width, _ = tf.unstack(bbox_size) 121 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 122 | 123 | image = tf.image.decode_and_crop_jpeg(image, crop_window, channels = 3, **JPEG_OPT) 124 | image = tf.image.resize(image, (224,224), method='bicubic') 125 | image.set_shape((224,224,3)) 126 | 127 | return image 128 | 129 | def resizeshortest(shape, size): 130 | h, w = shape[:2] 131 | scale = size / min(h, w) 132 | if h < w: 133 | newh, neww = size, int(scale * w + 0.5) 134 | else: 135 | newh, neww = int(scale * h + 0.5), size 136 | return np.int32(newh), np.int32(neww) 137 | 138 | def lighting(image, std, eigval, eigvec): 139 | v = tf.random.normal(shape=[3], stddev=std) * eigval 140 | inc = tf.matmul(eigvec, tf.reshape(v, [3, 1])) 141 | image = image + tf.reshape(inc, [3]) 142 | return image 143 | 144 | def pre_processing(is_training = False, contrastive = False): 145 | @tf.function 146 | def training(path, label): 147 | height, width = tf.numpy_function(get_size, [path], [tf.int32, tf.int32]) 148 | image = tf.io.read_file(path) 149 | 150 | height, width = tf.numpy_function(get_size, [path], [tf.int32, tf.int32]) 151 | image = tf.io.read_file(path) 152 | image = random_resize_crop(image, height, width) 153 | return image, label 154 | 155 | @tf.function 156 | def test(path, label): 157 | image = tf.io.read_file(path) 158 | image = tf.io.decode_jpeg(image, channels = 3, **JPEG_OPT) 159 | 160 | newh,neww = tf.numpy_function(resizeshortest, [tf.shape(image, tf.int32), 256], [tf.int32, tf.int32]) 161 | image = tf.image.resize(image, (newh,neww), method='bicubic') 162 | image = tf.slice(image, [newh//2-112,neww//2-112,0],[224,224,-1]) 163 | return image, label 164 | return training if is_training else test 165 | 166 | def pre_processing_batched(is_training = False, contrastive = False, mode = 0): 167 | @tf.function 168 | def training(image, *argv): 169 | image = tf.image.random_flip_left_right(image) 170 | image = (image-np.array([123.675, 116.28 , 103.53 ]))/np.array([58.395, 57.12, 57.375]) 171 | if len(argv) == 2: 172 | image = tf.reshape(image, [B,N,H,W,D]) 173 | return [image] + [arg for arg in argv] 174 | 175 | @tf.function 176 | def test(image, *argv): 177 | shape = image.shape 178 | image = (image-np.array([123.675, 116.28 , 103.53 ]))/np.array([58.395, 57.12, 57.375]) 179 | return [image] + [arg for arg in argv] 180 | return training if is_training else test 181 | 182 | -------------------------------------------------------------------------------- /nets/tcl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def arg_scope(func): 4 | def func_with_args(self, *args, **kwargs): 5 | if hasattr(self, 'pre_defined'): 6 | for k in self.pre_defined.keys(): 7 | if k not in kwargs.keys(): 8 | kwargs[k] = self.pre_defined[k] 9 | return func(self, *args, **kwargs) 10 | return func_with_args 11 | 12 | class Conv2d(tf.keras.layers.Layer): 13 | @arg_scope 14 | def __init__(self, kernel_size, num_outputs, strides = 1, dilations = 1, padding = 'SAME', 15 | kernel_initializer = tf.keras.initializers.VarianceScaling(scale = 2., mode='fan_out'), 16 | use_biases = True, 17 | biases_initializer = tf.keras.initializers.Zeros(), 18 | activation_fn = None, 19 | name = 'conv', 20 | trainable = True, 21 | **kwargs): 22 | super(Conv2d, self).__init__(name = name, trainable = trainable, **kwargs) 23 | 24 | self.kernel_size = kernel_size 25 | self.num_outputs = num_outputs 26 | self.strides = strides 27 | self.padding = padding 28 | self.dilations = dilations 29 | self.kernel_initializer = kernel_initializer 30 | 31 | self.use_biases = use_biases 32 | self.biases_initializer = biases_initializer 33 | 34 | self.activation_fn = activation_fn 35 | 36 | def build(self, input_shape): 37 | super(Conv2d, self).build(input_shape) 38 | self.kernel = self.add_weight(name = 'kernel', 39 | shape = self.kernel_size + [input_shape[-1], self.num_outputs], 40 | initializer=self.kernel_initializer, 41 | trainable = self.trainable) 42 | if self.use_biases: 43 | self.biases = self.add_weight(name = "biases", 44 | shape=[1,1,1,self.num_outputs], 45 | initializer = self.biases_initializer, 46 | trainable = self.trainable) 47 | self.ori_shape = self.kernel.shape 48 | 49 | def call(self, input): 50 | kernel = self.kernel 51 | 52 | conv = tf.nn.conv2d(input, kernel, self.strides, self.padding, 53 | dilations=self.dilations, name=None) 54 | if self.use_biases: 55 | conv += self.biases 56 | 57 | if self.activation_fn: 58 | conv = self.activation_fn(conv) 59 | 60 | return conv 61 | 62 | class DepthwiseConv2d(tf.keras.layers.Layer): 63 | @arg_scope 64 | def __init__(self, kernel_size, multiplier = 1, strides = [1,1,1,1], dilations = [1,1], padding = 'SAME', 65 | kernel_initializer = tf.keras.initializers.VarianceScaling(scale = 2., mode='fan_in'), 66 | use_biases = True, 67 | biases_initializer = tf.keras.initializers.Zeros(), 68 | activation_fn = None, 69 | name = 'conv', 70 | trainable = True, 71 | **kwargs): 72 | super(DepthwiseConv2d, self).__init__(name = name, trainable = trainable, **kwargs) 73 | 74 | self.kernel_size = kernel_size 75 | self.strides = strides if isinstance(strides, list) else [1, strides, strides, 1] 76 | self.padding = padding 77 | self.dilations = dilations if isinstance(dilations, list) else [dilations, dilations] 78 | self.multiplier = multiplier 79 | self.kernel_initializer = kernel_initializer 80 | 81 | self.use_biases = use_biases 82 | self.biases_initializer = biases_initializer 83 | 84 | self.activation_fn = activation_fn 85 | 86 | def build(self, input_shape): 87 | super(DepthwiseConv2d, self).build(input_shape) 88 | self.kernel = self.add_weight(name = 'kernel', 89 | shape = self.kernel_size + [input_shape[-1], self.multiplier], 90 | initializer=self.kernel_initializer, 91 | trainable = self.trainable) 92 | if self.use_biases: 93 | self.biases = self.add_weight(name = "biases", 94 | shape=[1,1,1, input_shape[-1]*self.multiplier], 95 | initializer = self.biases_initializer, 96 | trainable = self.trainable) 97 | self.ori_shape = self.kernel.shape 98 | 99 | def call(self, input): 100 | kernel = self.kernel 101 | conv = tf.nn.depthwise_conv2d(input, kernel, strides = self.strides, padding = self.padding, dilations=self.dilations) 102 | if self.use_biases: 103 | conv += self.biases 104 | if self.activation_fn: 105 | conv = self.activation_fn(conv) 106 | return conv 107 | 108 | class FC(tf.keras.layers.Layer): 109 | @arg_scope 110 | def __init__(self, num_outputs, 111 | kernel_initializer = tf.keras.initializers.random_normal(stddev = 1e-2), 112 | use_biases = True, 113 | biases_initializer = tf.keras.initializers.Zeros(), 114 | activation_fn = None, 115 | name = 'fc', 116 | trainable = True, **kwargs): 117 | super(FC, self).__init__(name = name, trainable = trainable, **kwargs) 118 | self.num_outputs = num_outputs 119 | self.kernel_initializer = kernel_initializer 120 | 121 | self.use_biases = use_biases 122 | self.biases_initializer = biases_initializer 123 | 124 | self.activation_fn = activation_fn 125 | 126 | def build(self, input_shape): 127 | super(FC, self).build(input_shape) 128 | self.kernel = self.add_weight(name = 'kernel', 129 | shape = [int(input_shape[-1]), self.num_outputs], 130 | initializer=self.kernel_initializer, 131 | trainable = self.trainable) 132 | if self.use_biases: 133 | self.biases = self.add_weight(name = "biases", 134 | shape=[1,self.num_outputs], 135 | initializer = self.biases_initializer, 136 | trainable = self.trainable) 137 | self.ori_shape = self.kernel.shape 138 | def call(self, input): 139 | kernel = self.kernel 140 | 141 | fc = tf.matmul(input, kernel) 142 | if self.use_biases: 143 | fc += self.biases 144 | if self.activation_fn: 145 | fc = self.activation_fn(fc) 146 | 147 | return fc 148 | 149 | class BatchNorm(tf.keras.layers.Layer): 150 | @arg_scope 151 | def __init__(self, param_initializers = None, 152 | scale = True, 153 | center = True, 154 | alpha = 0.9, 155 | epsilon = 1e-5, 156 | activation_fn = None, 157 | name = 'bn', 158 | trainable = True, 159 | **kwargs): 160 | super(BatchNorm, self).__init__(name = name, trainable = trainable, **kwargs) 161 | if param_initializers == None: 162 | param_initializers = {} 163 | if not(param_initializers.get('moving_mean')): 164 | param_initializers['moving_mean'] = tf.keras.initializers.Zeros() 165 | if not(param_initializers.get('moving_variance')): 166 | param_initializers['moving_variance'] = tf.keras.initializers.Ones() 167 | if not(param_initializers.get('gamma')) and scale: 168 | param_initializers['gamma'] = tf.keras.initializers.Ones() 169 | if not(param_initializers.get('beta')) and center: 170 | param_initializers['beta'] = tf.keras.initializers.Zeros() 171 | 172 | self.param_initializers = param_initializers 173 | self.scale = scale 174 | self.center = center 175 | self.alpha = alpha 176 | self.epsilon = epsilon 177 | self.activation_fn = activation_fn 178 | 179 | def build(self, input_shape): 180 | super(BatchNorm, self).build(input_shape) 181 | self.moving_mean = self.add_weight(name = 'moving_mean', trainable = False, 182 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])], 183 | initializer=self.param_initializers['moving_mean'], 184 | aggregation=tf.VariableAggregation.MEAN, 185 | ) 186 | self.moving_variance = self.add_weight(name = 'moving_variance', trainable = False, 187 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])], 188 | initializer=self.param_initializers['moving_variance'], 189 | aggregation=tf.VariableAggregation.MEAN, 190 | ) 191 | if self.scale: 192 | self.gamma = self.add_weight(name = 'gamma', 193 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])], 194 | initializer=self.param_initializers['gamma'], 195 | trainable = self.trainable) 196 | else: 197 | self.gamma = 1. 198 | if self.center: 199 | self.beta = self.add_weight(name = 'beta', 200 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])], 201 | initializer=self.param_initializers['beta'], 202 | trainable = self.trainable) 203 | else: 204 | self.beta = 0. 205 | self.ori_shape = self.moving_mean.shape[-1] 206 | 207 | def EMA(self, variable, value): 208 | update_delta = (variable - value) * (1-self.alpha) 209 | variable.assign_sub(update_delta) 210 | 211 | def update(self, update_var): 212 | mean, var = update_var 213 | self.EMA(self.moving_mean, mean) 214 | self.EMA(self.moving_variance, var) 215 | 216 | def call(self, input, training=None): 217 | if training: 218 | mean, var = tf.nn.moments(input, list(range(len(input.shape)-1)), keepdims=True) 219 | self.update_var = [mean, var] 220 | else: 221 | mean = self.moving_mean 222 | var = self.moving_variance 223 | 224 | gamma, beta = self.gamma, self.beta 225 | bn = tf.nn.batch_normalization(input, mean, var, offset = beta, scale = gamma, variance_epsilon = self.epsilon) 226 | 227 | if self.activation_fn: 228 | bn = self.activation_fn(bn) 229 | 230 | return bn 231 | 232 | --------------------------------------------------------------------------------