├── README.md
├── op_utils.py
├── test.py
├── dataloader
├── CIFAR.py
└── ILSVRC.py
├── nets
├── ResNet.py
└── tcl.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Tensorflow2 with jit (xla) compiling on multi-gpu training.
2 | - CIFAR and ILSVRC training code with **jit compiling** and **distributed learning** on the multi-GPU system.
3 | - I highly recommend using Jit compiling because most of the algorithm is static and can be compiled, which gives memory usage reduction and training speed improvement.
4 | - This repository is built by **custom layers** and **custom training loop** for my project, but if you only want to check how to use jit compiling with distributed learning, check 'train.py' and 'op_util.py'.
5 |
6 | ## Requirement
7 | - **Tensorflow >= 2.5**
8 | - Pillow
9 |
10 | ## How to run
11 | - ILSVRC
12 | ```
13 | python train.py --compile --gpu_id {} --dataset ILSVRC --data_path /path/to/your/ILSVRC/home --train_path /path/to/log
14 | ```
15 |
16 | - CIFAR{10,100}
17 | ```
18 | python train.py --compile --gpu_id {} --dataset CIFAR{10,100} --train_path /path/to/log
19 | ```
20 |
21 | ## Experimental results
22 | - I used four 1080ti.
23 | - Jit compiling gives a 40% speedup for training time.
24 |
25 | | | Accuracy | Training time
26 | |------------| ------------- | -------------
27 | |Distributed only | 75.83 | 94.61
28 | |Distributed with Jit | 75.57 | 56.98
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/op_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def Optimizer(args, model, strategy):
4 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = tf.keras.losses.Reduction.SUM)
5 | optimizer = tf.keras.optimizers.SGD(args.learning_rate, .9, nesterov=True)
6 |
7 | train_loss = tf.keras.metrics.Mean(name='train_loss')
8 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
9 |
10 | @tf.function(jit_compile = args.compile)
11 | def compiled_step(images, labels):
12 | with tf.GradientTape() as tape:
13 | pred = model(images, training = True)
14 | total_loss = loss_object(labels, pred)/args.batch_size
15 | gradients = tape.gradient(total_loss, model.trainable_variables)
16 | update_vars = [model.Layers[k].update_var if hasattr(model.Layers[k], 'update_var') else None for k in model.Layers ]
17 | return total_loss, pred, gradients, update_vars
18 |
19 | def train_step(images, labels):
20 | total_loss, pred, gradients, update_vars = compiled_step(images, labels)
21 | if args.weight_decay > 0.:
22 | gradients = [g+v*args.weight_decay for g,v in zip(gradients, model.trainable_variables)]
23 | optimizer.apply_gradients(zip(gradients, model.trainable_variables))
24 | for k, v in zip(model.Layers, update_vars):
25 | if hasattr(model.Layers[k], 'update'):
26 | model.Layers[k].update(v)
27 |
28 | train_loss.update_state(total_loss)
29 | train_accuracy.update_state(labels, pred)
30 |
31 | @tf.function
32 | def train_step_dist(image, labels):
33 | strategy.run(train_step, args= (image, labels))
34 |
35 | return train_step_dist, train_loss, train_accuracy, optimizer
36 |
37 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import argparse
5 | import tensorflow as tf
6 |
7 | from dataloader import ILSVRC, CIFAR
8 | import utils
9 |
10 | home_path = os.path.dirname(os.path.abspath(__file__))
11 | parser = argparse.ArgumentParser(description='')
12 |
13 | parser.add_argument("--arch", default='ResNet-50', type=str)
14 | parser.add_argument("--dataset", default='ILSVRC', type=str)
15 |
16 | parser.add_argument("--val_batch_size", default=256, type=int)
17 | parser.add_argument("--trained_param", default = 'res50_ilsvrc.pkl',type=str)
18 | parser.add_argument("--data_path", default = '/home/cvip/nas/ssd/ILSVRC2012',type=str)
19 |
20 | parser.add_argument("--gpu_id", default= [0], type=int, nargs = '+')
21 | parser.add_argument("--compile", default = False, action = 'store_true')
22 |
23 | args = parser.parse_args()
24 |
25 | args.home_path = os.path.dirname(os.path.abspath(__file__))
26 | args.input_size = [224,224,3]
27 |
28 | if __name__ == '__main__':
29 | gpus = tf.config.list_physical_devices('GPU')
30 | tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[i] for i in args.gpu_id], 'GPU')
31 | for gpu_id in args.gpu_id:
32 | tf.config.experimental.set_memory_growth(gpus[gpu_id], True)
33 | devices = ['/gpu:{}'.format(i) for i in args.gpu_id]
34 | strategy = tf.distribute.MirroredStrategy(devices, cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
35 |
36 | with strategy.scope():
37 | if args.dataset == 'ILSVRC':
38 | datasets = ILSVRC.build_dataset_providers(args, strategy, test_only = True)
39 | elif 'CIFAR' in args.dataset:
40 | datasets = CIFAR.build_dataset_providers(args, strategy, test_only = True)
41 |
42 | model = utils.load_model(args, datasets['num_classes'], args.trained_param)
43 |
44 | top1_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='top1_accuracy')
45 | top5_accuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5_accuracy')
46 |
47 | @tf.function(experimental_compile = args.compile)
48 | def compiled_step(images):
49 | return model(images, training = False)
50 |
51 | def test_step(images, labels):
52 | pred = compiled_step(images, labels)
53 | top1_accuracy.update_state(labels, pred)
54 | top5_accuracy.update_state(labels, pred)
55 |
56 | @tf.function
57 | def test_step_dist(images, labels):
58 | strategy.run(test_step, args=(images, labels))
59 |
60 | for i, (test_images, test_labels) in enumerate(datasets['test']):
61 | test_step_dist(test_images, test_labels)
62 |
63 | top1_acc = top1_accuracy.result().numpy()
64 | top5_acc = top5_accuracy.result().numpy()
65 | top1_accuracy.reset_states()
66 | top5_accuracy.reset_states()
67 | print ('Test ACC. Top-1: %.4f, Top-5: %.4f'%(top1_acc, top5_acc))
--------------------------------------------------------------------------------
/dataloader/CIFAR.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | def build_dataset_providers(args, strategy, test_only = False):
5 | if args.dataset == 'CIFAR10':
6 | train_images, train_labels, test_images, test_labels, pre_processing = Cifar10(args)
7 | if args.dataset == 'CIFAR100':
8 | train_images, train_labels, test_images, test_labels, pre_processing = Cifar100(args)
9 |
10 | options = tf.data.Options()
11 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
12 |
13 | test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
14 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE)
15 | test_ds = test_ds.batch(args.val_batch_size).cache()
16 | test_ds = test_ds.with_options(options)
17 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
18 |
19 | if test_only:
20 | return {'test': test_ds, 'num_classes' : int(args.dataset[5:])}
21 |
22 | train_ds = [train_images, train_labels]
23 |
24 | train_ds = tf.data.Dataset.from_tensor_slices(tuple(train_ds)).cache()
25 | train_ds = train_ds.map(pre_processing(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
26 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size)
27 | train_ds = train_ds.with_options(options)
28 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
29 |
30 | datasets = {
31 | 'train': train_ds.repeat( args.train_epoch ),
32 | 'test': test_ds,
33 | }
34 |
35 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets}
36 | datasets['train_len'] = train_ds.cardinality().numpy()
37 | datasets['num_classes'] = int(args.dataset[5:])
38 | args.input_size = [32,32,3]
39 |
40 | print('Datasets are built')
41 | return datasets
42 |
43 | def Cifar10(args):
44 | from tensorflow.keras.datasets.cifar10 import load_data
45 | (train_images, train_labels), (val_images, val_labels) = load_data()
46 |
47 | def pre_processing(is_training = False):
48 | def training(image, *argv):
49 | sz = tf.shape(image)
50 |
51 | image = tf.cast(image, tf.float32)
52 |
53 | image0 = image
54 |
55 | image0 = (image0-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
56 | image0 = tf.image.random_flip_left_right(image0)
57 | image0 = tf.pad(image0, [[4,4],[4,4],[0,0]], 'REFLECT')
58 | image0 = tf.image.random_crop(image0,sz)
59 |
60 | return [image0] + [arg for arg in argv]
61 |
62 | def inference(image, label):
63 | image = tf.cast(image, tf.float32)
64 | image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
65 | return image, label
66 |
67 | return training if is_training else inference
68 | return train_images, train_labels, val_images, val_labels, pre_processing
69 |
70 | def Cifar100(args):
71 | from tensorflow.keras.datasets.cifar100 import load_data
72 | (train_images, train_labels), (val_images, val_labels) = load_data()
73 |
74 | def pre_processing(is_training = False):
75 | @tf.function
76 | def training(image, *argv):
77 | sz = tf.shape(image)
78 |
79 | image = tf.cast(image, tf.float32)
80 | image0 = image
81 |
82 | image0 = (image0-np.array([112,124,129]))/np.array([70,65,68])
83 | image0 = tf.image.random_flip_left_right(image0)
84 | image0 = tf.pad(image0, [[4,4],[4,4],[0,0]], 'REFLECT')
85 | image0 = tf.image.random_crop(image0,sz)
86 |
87 | return [image0] + [arg for arg in argv]
88 |
89 | @tf.function
90 | def inference(image, label):
91 | image = tf.cast(image, tf.float32)
92 | image = (image-np.array([112,124,129]))/np.array([70,65,68])
93 | return image, label
94 |
95 | return training if is_training else inference
96 | return train_images, train_labels, val_images, val_labels, pre_processing
97 |
--------------------------------------------------------------------------------
/nets/ResNet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from nets import tcl
3 |
4 | class Model(tf.keras.Model):
5 | def __init__(self, num_layers, num_class, name = 'ResNet', trainable = True, **kwargs):
6 | super(Model, self).__init__(name = name, **kwargs)
7 | def kwargs(**kwargs):
8 | return kwargs
9 | setattr(tcl.Conv2d, 'pre_defined', kwargs(use_biases = False, activation_fn = None, trainable = trainable))
10 | setattr(tcl.BatchNorm, 'pre_defined', kwargs(trainable = trainable))
11 | setattr(tcl.FC, 'pre_defined', kwargs(trainable = trainable))
12 |
13 | self.num_layers = num_layers
14 |
15 | self.Layers = {}
16 | network_argments = {
17 | ## ILSVRC
18 | 18 : {'blocks' : [2,2,2,2],'depth' : [64,128,256,512], 'strides' : [1,2,2,2]},
19 | 50 : {'blocks' : [3,4,6,3],'depth' : [64,128,256,512], 'strides' : [1,2,2,2]},
20 |
21 | ## CIFAR
22 | 56 : {'blocks' : [9,9,9],'depth' : [16,32,64], 'strides' : [1,2,2]},
23 | }
24 | self.net_args = network_argments[self.num_layers]
25 |
26 |
27 | if num_class == 1000:
28 | self.Layers['conv'] = tcl.Conv2d([7,7], self.net_args['depth'][0], strides = 2, name = 'conv')
29 | self.Layers['bn'] = tcl.BatchNorm(name = 'bn')
30 | self.maxpool_3x3 = tf.keras.layers.MaxPool2D((3,3), strides = 2, padding = 'SAME')
31 |
32 | else:
33 | self.Layers['conv'] = tcl.Conv2d([3,3], self.net_args['depth'][0], name = 'conv')
34 | self.Layers['bn'] = tcl.BatchNorm(name = 'bn')
35 |
36 | self.expansion = 1 if self.num_layers in {18, 56} else 4
37 | in_depth = self.net_args['depth'][0]
38 | for i, (nb_resnet_layers, depth, strides) in enumerate(zip(self.net_args['blocks'], self.net_args['depth'], self.net_args['strides'])):
39 | for j in range(nb_resnet_layers):
40 | name = 'BasicBlock%d.%d/'%(i,j)
41 | if j != 0:
42 | strides = 1
43 |
44 | if strides > 1 or depth * self.expansion != in_depth:
45 | self.Layers[name + 'conv3'] = tcl.Conv2d([1,1], depth * self.expansion, strides = strides, name = name +'conv3')
46 | self.Layers[name + 'bn3'] = tcl.BatchNorm(name = name + 'bn3')
47 |
48 | if self.num_layers in {18, 56}:
49 | self.Layers[name + 'conv1'] = tcl.Conv2d([3,3], depth, strides = strides, name = name + 'conv1')
50 | self.Layers[name + 'bn1'] = tcl.BatchNorm( name = name + 'bn1')
51 | self.Layers[name + 'conv2'] = tcl.Conv2d([3,3], depth * self.expansion, name = name + 'conv2')
52 | self.Layers[name + 'bn2'] = tcl.BatchNorm( name = name + 'bn2')
53 |
54 | else:
55 | self.Layers[name + 'conv0'] = tcl.Conv2d([1,1], depth, name = name + 'conv0')
56 | self.Layers[name + 'bn0'] = tcl.BatchNorm( name = name + 'bn0')
57 | self.Layers[name + 'conv1'] = tcl.Conv2d([3,3], depth, strides = strides, name = name + 'conv1')
58 | self.Layers[name + 'bn1'] = tcl.BatchNorm( name = name + 'bn1')
59 | self.Layers[name + 'conv2'] = tcl.Conv2d([1,1], depth * self.expansion, name = name + 'conv2')
60 | self.Layers[name + 'bn2'] = tcl.BatchNorm( name = name + 'bn2',)
61 | #param_initializers = {'gamma': tf.keras.initializers.Zeros()})
62 |
63 | in_depth = depth * self.expansion
64 |
65 | self.Layers['fc'] = tcl.FC(num_class, name = 'fc')
66 |
67 | def call(self, x, training=None):
68 | x = self.Layers['conv'](x)
69 | x = self.Layers['bn'](x)
70 | x = tf.nn.relu(x)
71 | if hasattr(self, 'maxpool_3x3'):
72 | x = self.maxpool_3x3(x)
73 |
74 | in_depth = self.net_args['depth'][0]
75 |
76 | for i, (nb_resnet_layers, depth, strides) in enumerate(zip(self.net_args['blocks'], self.net_args['depth'], self.net_args['strides'])):
77 | for j in range(nb_resnet_layers):
78 | name = 'BasicBlock%d.%d/'%(i,j)
79 | if j != 0:
80 | strides = 1
81 |
82 | if strides > 1 or depth * self.expansion != in_depth:
83 | residual = self.Layers[name + 'conv3'](x)
84 | residual = self.Layers[name + 'bn3'](residual)
85 | else:
86 | residual = x
87 |
88 | if self.num_layers not in {18, 56}:
89 | x = self.Layers[name + 'conv0'](x)
90 | x = self.Layers[name + 'bn0'](x)
91 | x = tf.nn.relu(x)
92 |
93 | x = self.Layers[name + 'conv1'](x)
94 | x = self.Layers[name + 'bn1'](x)
95 | x = tf.nn.relu(x)
96 |
97 | x = self.Layers[name + 'conv2'](x)
98 | x = self.Layers[name + 'bn2'](x)
99 | x = tf.nn.relu(x + residual)
100 | in_depth = depth * self.expansion
101 |
102 | x = tf.reduce_mean(x, [1,2])
103 | x = self.Layers['fc'](x)
104 |
105 | return x
106 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, time, argparse
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | tf.debugging.set_log_device_placement(False)
7 |
8 | from dataloader import ILSVRC, CIFAR
9 | import op_utils, utils
10 |
11 | parser = argparse.ArgumentParser(description='')
12 | parser.add_argument("--train_path", default="test", type=str, help = 'path to log')
13 | parser.add_argument("--data_path", default="E:/ILSVRC2012", type=str, help = 'home path for ILSVRC dataset')
14 | parser.add_argument("--arch", default='ResNet-50', type=str, help = 'network architecture. currently ResNet is only available')
15 | parser.add_argument("--dataset", default='ILSVRC', type=str, help = 'ILSVRC or CIFAR{10,100}')
16 |
17 | parser.add_argument("--learning_rate", default = 1e-1, type=float, help = 'initial learning rate')
18 | parser.add_argument("--decay_points", default = [.3, .6, .9], type=float, nargs = '+', help = 'learning rate decay point')
19 | parser.add_argument("--decay_rate", default=.1, type=float, help = 'rate to decay at each decay points')
20 | parser.add_argument("--weight_decay", default=1e-4, type=float, help = 'decay parameter for l2 regularizer')
21 | parser.add_argument("--batch_size", default = 256, type=int, help = 'training batch size')
22 | parser.add_argument("--val_batch_size", default=256, type=int, help = 'validation batch size')
23 | parser.add_argument("--train_epoch", default=100, type=int, help = 'total training epoch')
24 |
25 | parser.add_argument("--gpu_id", default= [0], type=int, nargs = '+', help = 'denote which gpus are used')
26 | parser.add_argument("--do_log", default=200, type=int, help = 'logging period')
27 | parser.add_argument("--compile", default=False, action = 'store_true', help = 'denote use compile or not. True is recommended in this repo')
28 | args = parser.parse_args()
29 |
30 | args.home_path = os.path.dirname(os.path.abspath(__file__))
31 | args.decay_points = [int(dp*args.train_epoch) if dp < 1 else int(dp) for dp in args.decay_points]
32 |
33 | if args.dataset == 'ILSVRC':
34 | args.weight_decay /= len(args.gpu_id)
35 | args.learning_rate *= args.batch_size/256
36 |
37 | if __name__ == '__main__':
38 | gpus = tf.config.list_physical_devices('GPU')
39 | tf.config.set_visible_devices([tf.config.list_physical_devices('GPU')[i] for i in args.gpu_id], 'GPU')
40 | for gpu_id in args.gpu_id:
41 | tf.config.experimental.set_memory_growth(gpus[gpu_id], True)
42 | devices = ['/gpu:{}'.format(i) for i in args.gpu_id]
43 | strategy = tf.distribute.MirroredStrategy(devices, cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
44 |
45 | with strategy.scope():
46 | if args.dataset == 'ILSVRC':
47 | datasets = ILSVRC.build_dataset_providers(args, strategy)
48 | elif 'CIFAR' in args.dataset:
49 | datasets = CIFAR.build_dataset_providers(args, strategy)
50 | model = utils.load_model(args, datasets['num_classes'])
51 |
52 | summary_writer = tf.summary.create_file_writer(args.train_path, flush_millis = 30000)
53 | with summary_writer.as_default():
54 | utils.save_code_and_augments(args)
55 | total_train_time = 0
56 |
57 | train_step, train_loss, train_accuracy, optimizer = op_utils.Optimizer(args, model, strategy )
58 |
59 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = tf.keras.losses.Reduction.SUM)
60 | Eval = utils.Evaluation(args, model, strategy, datasets['test'], loss_object)
61 |
62 | print ('Training starts')
63 | fine_tuning_time = 0
64 | tic = time.time()
65 | for step, data in enumerate(datasets['train']):
66 | epoch = step//datasets['train_len']
67 | lr = utils.scheduler(args, optimizer, epoch)
68 | train_step(*data)
69 |
70 | step += 1
71 | if step % args.do_log == 0:
72 | template = 'Global step {0:5d}: loss = {1:0.4f} ({2:1.3f} sec/step)'
73 | train_time = time.time() - tic
74 | print (template.format(step, train_loss.result()*len(args.gpu_id), train_time/args.do_log))
75 | fine_tuning_time += train_time
76 | tic = time.time()
77 |
78 | if step % datasets['train_len'] == 0:
79 | tic_ = time.time()
80 | test_acc, test_loss = Eval.run(False)
81 |
82 | tf.summary.scalar('Categorical_loss/train', train_loss.result()*len(args.gpu_id), step=epoch+1)
83 | tf.summary.scalar('Categorical_loss/test', test_loss*len(args.gpu_id), step=epoch+1)
84 | tf.summary.scalar('Accuracy/train', train_accuracy.result()*100, step=epoch+1)
85 | tf.summary.scalar('Accuracy/test', test_acc*100, step=epoch+1)
86 | tf.summary.scalar('learning_rate', lr, step=epoch)
87 | summary_writer.flush()
88 |
89 | template = 'Epoch: {0:3d}, train_loss: {1:0.4f}, train_Acc.: {2:2.2f}, val_loss: {3:0.4f}, val_Acc.: {4:2.2f}'
90 | print (template.format(epoch+1, train_loss.result()*len(args.gpu_id), train_accuracy.result()*100,
91 | test_loss*len(args.gpu_id), test_acc*100))
92 |
93 | train_loss.reset_states()
94 | train_accuracy.reset_states()
95 | tic += time.time() - tic_
96 |
97 | utils.save_model(args, model, 'trained_params')
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os, shutil, glob, pickle, json
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | tf.debugging.set_log_device_placement(False)
7 |
8 | from nets import ResNet
9 |
10 | def scheduler(args, optimizer, epoch):
11 | lr = args.learning_rate
12 | for dp in args.decay_points:
13 | if epoch >= dp:
14 | lr *= args.decay_rate
15 | if epoch in args.decay_points:
16 | optimizer.learning_rate = lr
17 | return lr
18 |
19 | def save_code_and_augments(args):
20 | if os.path.isdir(os.path.join(args.train_path,'codes')):
21 | print ('============================================')
22 | print ('The folder already is. It will be overwrited')
23 | print ('============================================')
24 | else:
25 | os.mkdir(os.path.join(args.train_path,'codes'))
26 |
27 | for code in glob.glob(args.home_path + '/*.py'):
28 | shutil.copyfile(code, os.path.join(args.train_path, 'codes', os.path.split(code)[-1]))
29 |
30 | with open(os.path.join(args.train_path, 'arguments.txt'), 'w') as f:
31 | json.dump(args.__dict__, f, indent=2)
32 |
33 | class Evaluation:
34 | def __init__(self, args, model, strategy, dataset, loss_object):
35 | self.test_loss = tf.keras.metrics.Mean(name='test_loss')
36 | self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
37 |
38 | @tf.function(jit_compile=args.compile)
39 | def compiled_step(images, labels, training):
40 | pred = model(images, training = training)
41 | loss = loss_object(labels, pred)/args.val_batch_size
42 | return pred, loss
43 |
44 | def eval_step(images, labels, training):
45 | pred, loss = compiled_step(images, labels, training)
46 | self.test_loss.update_state(loss)
47 | self.test_accuracy.update_state(labels, pred)
48 |
49 |
50 | @tf.function
51 | def eval_step_dist(images, labels, training):
52 | strategy.run(eval_step, args=(images, labels, training))
53 |
54 | self.dataset = dataset
55 | self.step = eval_step_dist
56 |
57 | def run(self, training):
58 | for images, labels in self.dataset:
59 | self.step(images, labels, training)
60 | loss = self.test_loss.result().numpy()
61 | acc = self.test_accuracy.result().numpy()
62 | self.test_loss.reset_states()
63 | self.test_accuracy.reset_states()
64 | return acc, loss
65 |
66 | def load_model(args, num_class, trained_param = None):
67 | if 'ResNet' in args.arch:
68 | arch = int(args.arch.split('-')[1])
69 | model = ResNet.Model(num_layers = arch, num_class = num_class, name = 'ResNet', trainable = True)
70 |
71 | if trained_param is not None:
72 | with open(trained_param, 'rb') as f:
73 | trained = pickle.load(f)
74 | n = 0
75 | for k in model.Layers.keys():
76 | layer = model.Layers[k]
77 | if 'conv' in k or 'fc' in k:
78 | kernel = trained[layer.name + '/kernel:0']
79 | layer.kernel_initializer = tf.constant_initializer(kernel)
80 | n += 1
81 | if layer.use_biases:
82 | layer.biases_initializer = tf.constant_initializer(trained[layer.name + '/biases:0'])
83 | n += 1
84 | layer.num_outputs = kernel.shape[-1]
85 |
86 | elif 'bn' in k:
87 | moving_mean = trained[layer.name + '/moving_mean:0']
88 | moving_variance = trained[layer.name + '/moving_variance:0']
89 | param_initializers = {'moving_mean' : tf.constant_initializer(moving_mean),
90 | 'moving_variance': tf.constant_initializer(moving_variance)}
91 | n += 2
92 | if layer.scale:
93 | param_initializers['gamma'] = tf.constant_initializer(trained[layer.name + '/gamma:0'])
94 | n += 1
95 | if layer.center:
96 | param_initializers['beta'] = tf.constant_initializer(trained[layer.name + '/beta:0'])
97 | n += 1
98 | layer.param_initializers = param_initializers
99 | print (n, 'params loaded')
100 | return model
101 |
102 | def build_dataset_providers(args, strategy):
103 | options = tf.data.Options()
104 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
105 |
106 | train_ds = ILSVRC(args, 'train', shuffle = True)
107 | train_ds = train_ds.map(pre_processing(is_training = True, contrastive = args.Knowledge), num_parallel_calls=tf.data.experimental.AUTOTUNE)
108 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size).map(pre_processing_batched(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
109 | train_ds = train_ds.with_options(options)
110 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
111 |
112 | test_ds = ILSVRC(args, 'val', shuffle = False)
113 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE)
114 | test_ds = test_ds.batch(args.val_batch_size).map(pre_processing_batched(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE)
115 | test_ds = test_ds.with_options(options)
116 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
117 |
118 | datasets = {
119 | 'train': train_ds.repeat( args.train_epoch ),
120 | 'test': test_ds
121 | }
122 |
123 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets}
124 | datasets['train_len'] = train_ds.cardinality().numpy()
125 |
126 | print('Datasets are built')
127 | return datasets
128 |
129 | def save_model(args, model, name):
130 | params = {}
131 | for v in model.variables:
132 | if model.name in v.name:
133 | params[v.name[len(model.name)+1:]] = v.numpy()
134 | with open(os.path.join(args.train_path, name + '.pkl'), 'wb') as f:
135 | pickle.dump(params, f)
136 |
--------------------------------------------------------------------------------
/dataloader/ILSVRC.py:
--------------------------------------------------------------------------------
1 | import glob, os
2 | import tensorflow as tf
3 | import numpy as np
4 | from PIL import Image
5 |
6 | JPEG_OPT = {'fancy_upscaling': True, 'dct_method': 'INTEGER_ACCURATE'}
7 |
8 | def build_dataset_providers(args, strategy, test_only = False):
9 | options = tf.data.Options()
10 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
11 |
12 | test_ds = ILSVRC(args, 'val', shuffle = False)
13 | test_ds = test_ds.map(pre_processing(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE)
14 | test_ds = test_ds.batch(args.val_batch_size).map(pre_processing_batched(is_training = False), num_parallel_calls=tf.data.experimental.AUTOTUNE)
15 | test_ds = test_ds.with_options(options)
16 | test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
17 |
18 | if test_only:
19 | return {'test': test_ds, 'num_classes': 1000}
20 |
21 |
22 | train_ds = ILSVRC(args, 'train', shuffle = True)
23 | train_ds = train_ds.map(pre_processing(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
24 | train_ds = train_ds.shuffle(100*args.batch_size).batch(args.batch_size).map(pre_processing_batched(is_training = True), num_parallel_calls=tf.data.experimental.AUTOTUNE)
25 | train_ds = train_ds.with_options(options)
26 | train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
27 |
28 | datasets = {
29 | 'train': train_ds.repeat( args.train_epoch ),
30 | 'test': test_ds
31 | }
32 |
33 | datasets = {k:strategy.experimental_distribute_dataset(datasets[k]) for k in datasets}
34 | datasets['train_len'] = train_ds.cardinality().numpy()
35 | datasets['num_classes'] = 1000
36 | args.input_size = [224,224,3]
37 |
38 | print('Datasets are built')
39 | return datasets
40 |
41 | def ILSVRC(args, split = 'train', sample_rate = None, shuffle = False, seed = None, sub_ds = None, saved = False):
42 | if split == 'train':
43 | with open(os.path.join(args.data_path, 'class_to_label.txt'),'r') as f:
44 | CLASSES = f.readlines()
45 | CLASSES = { name.replace('\n','') : l for l, name in enumerate(CLASSES) }
46 |
47 | label_pathes = glob.glob(os.path.join(args.data_path, split, '*'))
48 |
49 | if sample_rate is not None:
50 | if abs(sample_rate) < 1:
51 | min_num_label = min([len(glob.glob(os.path.join(l, '*'))) for l in label_pathes])
52 | sampled_data_len = int(abs(sample_rate) * min_num_label)
53 | else:
54 | sampled_data_len = abs(sample_rate)
55 |
56 | image_paths = []
57 | labels = []
58 | class_names = []
59 | for name in label_pathes:
60 | image_path = glob.glob(os.path.join(name, '*'))
61 | image_path = [p for p in image_path if 'n02105855_2933.JPEG' not in p]
62 |
63 | if sample_rate is not None:
64 | np.random.seed(seed)
65 | np.random.shuffle(image_path)
66 | if sample_rate < 0:
67 | image_path = image_path[::-1]
68 | image_path = image_path[:sampled_data_len]
69 | image_paths += image_path
70 | labels += [CLASSES[os.path.split(name)[1]]] * len(image_path)
71 | class_names.append(os.path.split(name)[1])
72 |
73 | if shuffle:
74 | np.random.seed(seed)
75 | idx = np.arange(len(image_paths))
76 | np.random.shuffle(idx)
77 | image_paths = [image_paths[i] for i in idx]
78 | labels = [labels[i] for i in idx]
79 |
80 | elif split == 'val':
81 | image_paths = glob.glob(os.path.join(args.data_path, split, '*'))
82 | image_paths.sort()
83 |
84 | with open(os.path.join(args.home_path, 'val_gt.txt'),'r') as f:
85 | labels = f.readlines()
86 |
87 | print (split + ' dataset length :', len(image_paths))
88 | label_arry = np.int64(labels)
89 | img_ds = tf.data.Dataset.from_tensor_slices(image_paths)
90 | label_ds = tf.data.Dataset.from_tensor_slices(label_arry)
91 | dataset = tf.data.Dataset.zip((img_ds, label_ds))
92 |
93 | return dataset
94 |
95 | def get_size(path):
96 | img = Image.open(path.decode("utf-8"))
97 | w,h = img.size
98 | return np.int32(h), np.int32(w)
99 |
100 | def random_resize_crop(image, height, width):
101 | shape = tf.stack([height, width, 3])
102 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box(
103 | shape,
104 | bounding_boxes=tf.zeros(shape=[0, 0, 4]),
105 | min_object_covered=0,
106 | aspect_ratio_range=[0.75, 1.33],
107 | area_range=[0.08, 1.0],
108 | max_attempts=10,
109 | use_image_if_no_bounding_boxes=True)
110 |
111 | is_bad = tf.reduce_sum(tf.cast(tf.equal(bbox_size, shape), tf.int32)) >= 2
112 |
113 | if is_bad:
114 | image = tf.image.decode_jpeg(image, channels = 3, **JPEG_OPT)
115 | newh, neww = tf.numpy_function(resizeshortest, [tf.shape(image, tf.int32), 256], [tf.int32, tf.int32])
116 | image = tf.image.resize(image, (newh,neww), method='bicubic')
117 | image = tf.slice(image, [newh//2-112,neww//2-112,0],[224,224,-1])
118 | else:
119 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
120 | target_height, target_width, _ = tf.unstack(bbox_size)
121 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
122 |
123 | image = tf.image.decode_and_crop_jpeg(image, crop_window, channels = 3, **JPEG_OPT)
124 | image = tf.image.resize(image, (224,224), method='bicubic')
125 | image.set_shape((224,224,3))
126 |
127 | return image
128 |
129 | def resizeshortest(shape, size):
130 | h, w = shape[:2]
131 | scale = size / min(h, w)
132 | if h < w:
133 | newh, neww = size, int(scale * w + 0.5)
134 | else:
135 | newh, neww = int(scale * h + 0.5), size
136 | return np.int32(newh), np.int32(neww)
137 |
138 | def lighting(image, std, eigval, eigvec):
139 | v = tf.random.normal(shape=[3], stddev=std) * eigval
140 | inc = tf.matmul(eigvec, tf.reshape(v, [3, 1]))
141 | image = image + tf.reshape(inc, [3])
142 | return image
143 |
144 | def pre_processing(is_training = False, contrastive = False):
145 | @tf.function
146 | def training(path, label):
147 | height, width = tf.numpy_function(get_size, [path], [tf.int32, tf.int32])
148 | image = tf.io.read_file(path)
149 |
150 | height, width = tf.numpy_function(get_size, [path], [tf.int32, tf.int32])
151 | image = tf.io.read_file(path)
152 | image = random_resize_crop(image, height, width)
153 | return image, label
154 |
155 | @tf.function
156 | def test(path, label):
157 | image = tf.io.read_file(path)
158 | image = tf.io.decode_jpeg(image, channels = 3, **JPEG_OPT)
159 |
160 | newh,neww = tf.numpy_function(resizeshortest, [tf.shape(image, tf.int32), 256], [tf.int32, tf.int32])
161 | image = tf.image.resize(image, (newh,neww), method='bicubic')
162 | image = tf.slice(image, [newh//2-112,neww//2-112,0],[224,224,-1])
163 | return image, label
164 | return training if is_training else test
165 |
166 | def pre_processing_batched(is_training = False, contrastive = False, mode = 0):
167 | @tf.function
168 | def training(image, *argv):
169 | image = tf.image.random_flip_left_right(image)
170 | image = (image-np.array([123.675, 116.28 , 103.53 ]))/np.array([58.395, 57.12, 57.375])
171 | if len(argv) == 2:
172 | image = tf.reshape(image, [B,N,H,W,D])
173 | return [image] + [arg for arg in argv]
174 |
175 | @tf.function
176 | def test(image, *argv):
177 | shape = image.shape
178 | image = (image-np.array([123.675, 116.28 , 103.53 ]))/np.array([58.395, 57.12, 57.375])
179 | return [image] + [arg for arg in argv]
180 | return training if is_training else test
181 |
182 |
--------------------------------------------------------------------------------
/nets/tcl.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def arg_scope(func):
4 | def func_with_args(self, *args, **kwargs):
5 | if hasattr(self, 'pre_defined'):
6 | for k in self.pre_defined.keys():
7 | if k not in kwargs.keys():
8 | kwargs[k] = self.pre_defined[k]
9 | return func(self, *args, **kwargs)
10 | return func_with_args
11 |
12 | class Conv2d(tf.keras.layers.Layer):
13 | @arg_scope
14 | def __init__(self, kernel_size, num_outputs, strides = 1, dilations = 1, padding = 'SAME',
15 | kernel_initializer = tf.keras.initializers.VarianceScaling(scale = 2., mode='fan_out'),
16 | use_biases = True,
17 | biases_initializer = tf.keras.initializers.Zeros(),
18 | activation_fn = None,
19 | name = 'conv',
20 | trainable = True,
21 | **kwargs):
22 | super(Conv2d, self).__init__(name = name, trainable = trainable, **kwargs)
23 |
24 | self.kernel_size = kernel_size
25 | self.num_outputs = num_outputs
26 | self.strides = strides
27 | self.padding = padding
28 | self.dilations = dilations
29 | self.kernel_initializer = kernel_initializer
30 |
31 | self.use_biases = use_biases
32 | self.biases_initializer = biases_initializer
33 |
34 | self.activation_fn = activation_fn
35 |
36 | def build(self, input_shape):
37 | super(Conv2d, self).build(input_shape)
38 | self.kernel = self.add_weight(name = 'kernel',
39 | shape = self.kernel_size + [input_shape[-1], self.num_outputs],
40 | initializer=self.kernel_initializer,
41 | trainable = self.trainable)
42 | if self.use_biases:
43 | self.biases = self.add_weight(name = "biases",
44 | shape=[1,1,1,self.num_outputs],
45 | initializer = self.biases_initializer,
46 | trainable = self.trainable)
47 | self.ori_shape = self.kernel.shape
48 |
49 | def call(self, input):
50 | kernel = self.kernel
51 |
52 | conv = tf.nn.conv2d(input, kernel, self.strides, self.padding,
53 | dilations=self.dilations, name=None)
54 | if self.use_biases:
55 | conv += self.biases
56 |
57 | if self.activation_fn:
58 | conv = self.activation_fn(conv)
59 |
60 | return conv
61 |
62 | class DepthwiseConv2d(tf.keras.layers.Layer):
63 | @arg_scope
64 | def __init__(self, kernel_size, multiplier = 1, strides = [1,1,1,1], dilations = [1,1], padding = 'SAME',
65 | kernel_initializer = tf.keras.initializers.VarianceScaling(scale = 2., mode='fan_in'),
66 | use_biases = True,
67 | biases_initializer = tf.keras.initializers.Zeros(),
68 | activation_fn = None,
69 | name = 'conv',
70 | trainable = True,
71 | **kwargs):
72 | super(DepthwiseConv2d, self).__init__(name = name, trainable = trainable, **kwargs)
73 |
74 | self.kernel_size = kernel_size
75 | self.strides = strides if isinstance(strides, list) else [1, strides, strides, 1]
76 | self.padding = padding
77 | self.dilations = dilations if isinstance(dilations, list) else [dilations, dilations]
78 | self.multiplier = multiplier
79 | self.kernel_initializer = kernel_initializer
80 |
81 | self.use_biases = use_biases
82 | self.biases_initializer = biases_initializer
83 |
84 | self.activation_fn = activation_fn
85 |
86 | def build(self, input_shape):
87 | super(DepthwiseConv2d, self).build(input_shape)
88 | self.kernel = self.add_weight(name = 'kernel',
89 | shape = self.kernel_size + [input_shape[-1], self.multiplier],
90 | initializer=self.kernel_initializer,
91 | trainable = self.trainable)
92 | if self.use_biases:
93 | self.biases = self.add_weight(name = "biases",
94 | shape=[1,1,1, input_shape[-1]*self.multiplier],
95 | initializer = self.biases_initializer,
96 | trainable = self.trainable)
97 | self.ori_shape = self.kernel.shape
98 |
99 | def call(self, input):
100 | kernel = self.kernel
101 | conv = tf.nn.depthwise_conv2d(input, kernel, strides = self.strides, padding = self.padding, dilations=self.dilations)
102 | if self.use_biases:
103 | conv += self.biases
104 | if self.activation_fn:
105 | conv = self.activation_fn(conv)
106 | return conv
107 |
108 | class FC(tf.keras.layers.Layer):
109 | @arg_scope
110 | def __init__(self, num_outputs,
111 | kernel_initializer = tf.keras.initializers.random_normal(stddev = 1e-2),
112 | use_biases = True,
113 | biases_initializer = tf.keras.initializers.Zeros(),
114 | activation_fn = None,
115 | name = 'fc',
116 | trainable = True, **kwargs):
117 | super(FC, self).__init__(name = name, trainable = trainable, **kwargs)
118 | self.num_outputs = num_outputs
119 | self.kernel_initializer = kernel_initializer
120 |
121 | self.use_biases = use_biases
122 | self.biases_initializer = biases_initializer
123 |
124 | self.activation_fn = activation_fn
125 |
126 | def build(self, input_shape):
127 | super(FC, self).build(input_shape)
128 | self.kernel = self.add_weight(name = 'kernel',
129 | shape = [int(input_shape[-1]), self.num_outputs],
130 | initializer=self.kernel_initializer,
131 | trainable = self.trainable)
132 | if self.use_biases:
133 | self.biases = self.add_weight(name = "biases",
134 | shape=[1,self.num_outputs],
135 | initializer = self.biases_initializer,
136 | trainable = self.trainable)
137 | self.ori_shape = self.kernel.shape
138 | def call(self, input):
139 | kernel = self.kernel
140 |
141 | fc = tf.matmul(input, kernel)
142 | if self.use_biases:
143 | fc += self.biases
144 | if self.activation_fn:
145 | fc = self.activation_fn(fc)
146 |
147 | return fc
148 |
149 | class BatchNorm(tf.keras.layers.Layer):
150 | @arg_scope
151 | def __init__(self, param_initializers = None,
152 | scale = True,
153 | center = True,
154 | alpha = 0.9,
155 | epsilon = 1e-5,
156 | activation_fn = None,
157 | name = 'bn',
158 | trainable = True,
159 | **kwargs):
160 | super(BatchNorm, self).__init__(name = name, trainable = trainable, **kwargs)
161 | if param_initializers == None:
162 | param_initializers = {}
163 | if not(param_initializers.get('moving_mean')):
164 | param_initializers['moving_mean'] = tf.keras.initializers.Zeros()
165 | if not(param_initializers.get('moving_variance')):
166 | param_initializers['moving_variance'] = tf.keras.initializers.Ones()
167 | if not(param_initializers.get('gamma')) and scale:
168 | param_initializers['gamma'] = tf.keras.initializers.Ones()
169 | if not(param_initializers.get('beta')) and center:
170 | param_initializers['beta'] = tf.keras.initializers.Zeros()
171 |
172 | self.param_initializers = param_initializers
173 | self.scale = scale
174 | self.center = center
175 | self.alpha = alpha
176 | self.epsilon = epsilon
177 | self.activation_fn = activation_fn
178 |
179 | def build(self, input_shape):
180 | super(BatchNorm, self).build(input_shape)
181 | self.moving_mean = self.add_weight(name = 'moving_mean', trainable = False,
182 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])],
183 | initializer=self.param_initializers['moving_mean'],
184 | aggregation=tf.VariableAggregation.MEAN,
185 | )
186 | self.moving_variance = self.add_weight(name = 'moving_variance', trainable = False,
187 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])],
188 | initializer=self.param_initializers['moving_variance'],
189 | aggregation=tf.VariableAggregation.MEAN,
190 | )
191 | if self.scale:
192 | self.gamma = self.add_weight(name = 'gamma',
193 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])],
194 | initializer=self.param_initializers['gamma'],
195 | trainable = self.trainable)
196 | else:
197 | self.gamma = 1.
198 | if self.center:
199 | self.beta = self.add_weight(name = 'beta',
200 | shape = [1]*(len(input_shape)-1)+[int(input_shape[-1])],
201 | initializer=self.param_initializers['beta'],
202 | trainable = self.trainable)
203 | else:
204 | self.beta = 0.
205 | self.ori_shape = self.moving_mean.shape[-1]
206 |
207 | def EMA(self, variable, value):
208 | update_delta = (variable - value) * (1-self.alpha)
209 | variable.assign_sub(update_delta)
210 |
211 | def update(self, update_var):
212 | mean, var = update_var
213 | self.EMA(self.moving_mean, mean)
214 | self.EMA(self.moving_variance, var)
215 |
216 | def call(self, input, training=None):
217 | if training:
218 | mean, var = tf.nn.moments(input, list(range(len(input.shape)-1)), keepdims=True)
219 | self.update_var = [mean, var]
220 | else:
221 | mean = self.moving_mean
222 | var = self.moving_variance
223 |
224 | gamma, beta = self.gamma, self.beta
225 | bn = tf.nn.batch_normalization(input, mean, var, offset = beta, scale = gamma, variance_epsilon = self.epsilon)
226 |
227 | if self.activation_fn:
228 | bn = self.activation_fn(bn)
229 |
230 | return bn
231 |
232 |
--------------------------------------------------------------------------------