├── LICENSE ├── README.md ├── config.py ├── dataset.py ├── main.py ├── model.py ├── network ├── __init__.py ├── cond_resnet.py └── condconv.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 prstrive 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CondConv-tensorflow 2 | Conditional convolution (Dynamic convolution) in tensorflow2.2.0. This depository implements the method described in the paper: 3 | 4 | >CondConv: Conditionally Parameterized Convolutions for Efficient Inference 5 | >Brandon Yang, Gabriel Bender, Quoc V.Le, Jiquan Ngiam 6 | >[Source PDF](https://papers.nips.cc/paper/8412-condconv-conditionally-parameterized-convolutions-for-efficient-inference.pdf) 7 | 8 | Meanwhile, the softmax with a large temperature for kernel attention introduced by [Dynamic Convolution: Attention Over Convolution Kernels](https://arxiv.org/pdf/1912.03458.pdf) is adopted. 9 | 10 | Another similar paper: [DyNet: Dynamic Convolution for Accelerating Convolutional Neural Networks](https://arxiv.org/pdf/2004.10694.pdf). 11 | 12 | ### Start 13 | You can start according to the default arguments by `python main.py`. Or specify the arguments: 14 | ```python 15 | python main.py --arch cond_cifar_resnet --num_layers 56 --num_experts 3 --dataset cifar10 --num_classes 10 16 | ``` 17 | 18 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | MEAN = {"imagenet": [0.485, 0.456, 0.406], "cifar": [0.4914, 0.4822, 0.4465]} 2 | STD = {"imagenet": [0.229, 0.224, 0.225], "cifar": [0.2023, 0.1994, 0.2010]} 3 | 4 | WEIGHT_DECAY = 2e-4 5 | 6 | LABEL_SMOOTH = 2e-1 7 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from config import * 3 | import tensorflow as tf 4 | 5 | 6 | def get_cifar_dataset(num_class, train_batch, val_batch): 7 | if num_class == 10: 8 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 9 | else: 10 | (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() 11 | 12 | def _parse_image_train(image, label): 13 | image = tf.image.convert_image_dtype(image, tf.float32) 14 | image = (image - MEAN["cifar"]) / STD["cifar"] 15 | 16 | image = tf.image.random_crop(tf.pad(image, [[4, 4], [4, 4], [0, 0]]), size=[32, 32, 3]) 17 | image = tf.image.random_flip_left_right(image) 18 | 19 | return image, label 20 | 21 | def _parse_image_val(image, label): 22 | image = tf.image.convert_image_dtype(image, tf.float32) 23 | image = (image - MEAN["cifar"]) / STD["cifar"] 24 | return image, label 25 | 26 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(_parse_image_train, 27 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 28 | train_dataset = train_dataset.shuffle(buffer_size=len(y_train)).batch(batch_size=train_batch).prefetch( 29 | buffer_size=tf.data.experimental.AUTOTUNE) 30 | 31 | val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(_parse_image_val, 32 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 33 | val_dataset = val_dataset.batch(batch_size=val_batch).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 34 | 35 | return train_dataset, np.ceil(len(y_train) / train_batch), val_dataset, np.ceil(len(y_test) / val_batch) 36 | 37 | 38 | def get_datasets(name, train_batch, val_batch): 39 | if name == "cifar10": 40 | return get_cifar_dataset(num_class=10, train_batch=train_batch, val_batch=val_batch) 41 | elif name == "cifar100": 42 | return get_cifar_dataset(num_class=100, train_batch=train_batch, val_batch=val_batch) 43 | else: 44 | raise ValueError("Dataset only support cifar10, cifar100 and ILSVRC2012, but get {}!".format(name)) 45 | 46 | 47 | if __name__ == '__main__': 48 | train_date, train_batch_num, val_data, val_batch_num = get_cifar_dataset(num_class=10, train_batch=70, val_batch=1) 49 | 50 | for b, (d, l) in enumerate(train_date): 51 | print(d.shape) 52 | print(l.shape) 53 | break 54 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from model import Model 3 | 4 | parser = argparse.ArgumentParser(description="CondConv args") 5 | parser.add_argument('--dataset', 6 | type=str, 7 | help='name of dataset', 8 | default="cifar10", 9 | choices=["cifar10", "cifar100"] 10 | ) 11 | parser.add_argument('--data_path', 12 | type=str, 13 | help='path to dataset', 14 | default="" 15 | ) 16 | parser.add_argument("--lr", 17 | type=float, 18 | default=0.1 19 | ) 20 | parser.add_argument("--train_batch", 21 | type=int, 22 | default=128 23 | ) 24 | parser.add_argument("--val_batch", 25 | type=int, 26 | default=128 27 | ) 28 | parser.add_argument("--epochs", 29 | type=int, 30 | default=200 31 | ) 32 | parser.add_argument("--arch", 33 | type=str, 34 | default="cond_cifar_resnet", 35 | choices=["cond_cifar_resnet"] 36 | ) 37 | parser.add_argument("--num_layers", 38 | type=int, 39 | default=56, 40 | choices=[20, 32, 44, 56, 110, 1202] 41 | ) 42 | parser.add_argument("--num_experts", 43 | type=int, 44 | default=3 45 | ) 46 | parser.add_argument("--num_classes", 47 | type=int, 48 | default=10 49 | ) 50 | parser.add_argument("--models_path", 51 | type=str, 52 | default="./models" 53 | ) 54 | parser.add_argument("--logs_path", 55 | type=str, 56 | default="./logs" 57 | ) 58 | # list of available gpu 59 | parser.add_argument("--gpu_ids", 60 | nargs='+', 61 | default=['0'] 62 | ) 63 | parser.add_argument("--val", 64 | action="store_true" 65 | ) 66 | parser.add_argument("--pretrained", 67 | action="store_true" 68 | ) 69 | parser.add_argument("--distribute", 70 | help='train model in distributed way', 71 | action="store_true" 72 | ) 73 | parser.add_argument("--resume", 74 | help='whether to load the latest checkpoint', 75 | action="store_true" 76 | ) 77 | args = parser.parse_args() 78 | 79 | if __name__ == '__main__': 80 | model = Model(args) 81 | model.main() 82 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import progressbar 4 | from config import * 5 | import tensorflow as tf 6 | from network import get_net 7 | from dataset import get_datasets 8 | from trainer import Trainer, DisTrainer 9 | 10 | 11 | class Model: 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(self.args.gpu_ids) 16 | 17 | gpus = tf.config.experimental.list_physical_devices('GPU') 18 | if gpus: 19 | for gpu in gpus: 20 | tf.config.experimental.set_memory_growth(gpu, True) 21 | 22 | if self.args.dataset == "cifar10": 23 | self.args.num_classes = 10 24 | elif self.args.dataset == "cifar100": 25 | self.args.num_classes = 100 26 | 27 | self.model_save_path = os.path.join(self.args.models_path, self.args.arch + str(self.args.num_layers)) 28 | os.makedirs(self.model_save_path, exist_ok=True) 29 | 30 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 31 | self.log_dir = os.path.join(self.args.logs_path, self.args.arch + str(self.args.num_layers), current_time) 32 | self.log_dir = os.path.join(self.args.logs_path, self.args.arch + str(self.args.num_layers), current_time) 33 | 34 | def main(self): 35 | if self.args.distribute: 36 | self.distribute_run() 37 | else: 38 | self.run() 39 | 40 | def run(self): 41 | train_date, train_batch_num, val_data, val_batch_num = get_datasets(name=self.args.dataset, train_batch=self.args.train_batch, 42 | val_batch=self.args.val_batch) 43 | model = get_net(arch=self.args.arch, num_layers=self.args.num_layers, num_experts=self.args.num_experts, 44 | num_classes=self.args.num_classes) 45 | model.build(input_shape=(None, 32, 32, 3)) 46 | model.summary() 47 | 48 | optimizer = tf.keras.optimizers.SGD(learning_rate=self.args.lr, momentum=0.9, decay=0.0001, nesterov=True) 49 | 50 | trainer = Trainer(model=model, optimizer=optimizer, epochs=self.args.epochs, val_data=val_data, 51 | train_batch=self.args.train_batch, val_batch=self.args.val_batch, train_data=train_date, 52 | log_dir=self.log_dir, model_save_path=self.model_save_path, train_batch_num=train_batch_num, 53 | val_batch_num=val_batch_num) 54 | 55 | trainer(resume=self.args.resume, val=self.args.val) 56 | 57 | def distribute_run(self): 58 | strategy = tf.distribute.MirroredStrategy() 59 | train_global_batch = self.args.train_batch * strategy.num_replicas_in_sync 60 | val_global_batch = self.args.val_batch * strategy.num_replicas_in_sync 61 | train_date, train_batch_num, val_data, val_batch_num = get_datasets(name=self.args.dataset, train_batch=train_global_batch, 62 | val_batch=val_global_batch) 63 | with strategy.scope(): 64 | model = get_net(arch=self.args.arch, num_layers=self.args.num_layers, num_experts=self.args.num_experts, 65 | num_classes=self.args.num_classes) 66 | model.build(input_shape=(None, 32, 32, 3)) 67 | model.summary() 68 | 69 | optimizer = tf.keras.optimizers.SGD(learning_rate=self.args.lr, momentum=0.9, decay=0.0001, nesterov=True) 70 | 71 | dis_trainer = DisTrainer(strategy=strategy, model=model, optimizer=optimizer, epochs=self.args.epochs, val_data=val_data, 72 | train_batch=self.args.train_batch, val_batch=self.args.val_batch, train_data=train_date, 73 | log_dir=self.log_dir, model_save_path=self.model_save_path, train_batch_num=train_batch_num, 74 | val_batch_num=val_batch_num) 75 | 76 | dis_trainer(resume=self.args.resume, val=self.args.val) 77 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .cond_resnet import CondCifarResNet 2 | 3 | 4 | def get_net(arch, **kwargs): 5 | if arch == "cond_cifar_resnet": 6 | return CondCifarResNet(**kwargs) 7 | -------------------------------------------------------------------------------- /network/cond_resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config import WEIGHT_DECAY 3 | from tensorflow.keras import layers 4 | from tensorflow.keras.models import Sequential, Model 5 | from .condconv import CondConv2D 6 | 7 | 8 | class CondCifarBasicBlock(layers.Layer): 9 | expansion = 1 10 | 11 | def __init__(self, filters, stride=1, option='A', num_experts=3, **kwargs): 12 | super(CondCifarBasicBlock, self).__init__(**kwargs) 13 | self.bn1 = layers.BatchNormalization(name="bn1") 14 | self.conv1 = CondConv2D(kernel_size=3, filters=filters, stride=stride, use_bias=False, num_experts=num_experts, name="conv1") 15 | self.bn2 = layers.BatchNormalization(name="bn2") 16 | self.conv2 = CondConv2D(kernel_size=3, filters=filters, stride=1, use_bias=False, num_experts=num_experts, name="conv2") 17 | 18 | self.shortcut = None 19 | if stride != 1: 20 | if option == 'A': 21 | """ 22 | For CIFAR10 ResNet paper uses option A. 23 | """ 24 | self.shortcut = lambda x: tf.pad(tf.nn.avg_pool2d(x, (2, 2), strides=(1, 2, 2, 1), padding='SAME'), 25 | [[0, 0], [0, 0], [0, 0], [filters // 4, filters // 4]]) 26 | elif option == 'B': 27 | self.shortcut = Sequential([ 28 | CondConv2D(kernel_size=1, filters=self.expansion * filters, stride=stride, use_bias=False, num_experts=num_experts), 29 | layers.BatchNormalization() 30 | ], name="shortcut") 31 | 32 | def call(self, inputs, training=None, **kwargs): 33 | 34 | out = tf.nn.relu(self.bn1(self.conv1(inputs), training=training)) 35 | out = self.bn2(self.conv2(out), training=training) 36 | 37 | if self.shortcut is not None: 38 | inputs = self.shortcut(inputs) 39 | 40 | out += inputs 41 | out = tf.nn.relu(out) 42 | return out 43 | 44 | 45 | class CondCifarResNet(Model): 46 | def __init__(self, num_layers, num_classes=10, num_experts=3): 47 | super(CondCifarResNet, self).__init__() 48 | 49 | block = CondCifarBasicBlock 50 | num_blocks = int((num_layers - 2) / 6) 51 | 52 | self.conv1 = CondConv2D(kernel_size=3, filters=16, stride=1, use_bias=False, num_experts=num_experts, name="conv1") 53 | self.bn1 = layers.BatchNormalization(name="bn1") 54 | self.layer1 = self._make_layer(block, 16, num_blocks, stride=1, num_experts=num_experts, name="layer1") 55 | self.layer2 = self._make_layer(block, 32, num_blocks, stride=2, num_experts=num_experts, name="layer2") 56 | self.layer3 = self._make_layer(block, 64, num_blocks, stride=2, num_experts=num_experts, name="layer3") 57 | self.gavgpool = layers.GlobalAveragePooling2D() 58 | self.fc = layers.Dense(units=num_classes, name="fc") 59 | 60 | def _make_layer(self, block, filters, num_blocks, stride, name, num_experts): 61 | blocks_list = [block(filters, stride, num_experts=num_experts)] 62 | for i in range(1, num_blocks): 63 | blocks_list.append(block(filters, 1, num_experts=num_experts)) 64 | 65 | return Sequential(blocks_list, name=name) 66 | 67 | def call(self, inputs, training=None, mask=None): 68 | out = tf.nn.relu(self.bn1(self.conv1(inputs), training=training)) 69 | out = self.layer1(out, training=training) 70 | out = self.layer2(out, training=training) 71 | out = self.layer3(out, training=training) 72 | out = self.gavgpool(out) 73 | out = self.fc(out) 74 | return out 75 | -------------------------------------------------------------------------------- /network/condconv.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config import WEIGHT_DECAY 3 | from tensorflow.keras import layers 4 | 5 | 6 | def conv2d(kernel_size, stride, filters, kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY), padding="same", use_bias=False, 7 | kernel_initializer="he_normal", **kwargs): 8 | return layers.Conv2D(kernel_size=kernel_size, strides=stride, filters=filters, kernel_regularizer=kernel_regularizer, padding=padding, 9 | use_bias=use_bias, kernel_initializer=kernel_initializer, **kwargs) 10 | 11 | 12 | class Routing(layers.Layer): 13 | def __init__(self, out_channels, dropout_rate, temperature=30, **kwargs): 14 | super(Routing, self).__init__(**kwargs) 15 | self.avgpool = layers.GlobalAveragePooling2D() 16 | self.dropout = layers.Dropout(rate=dropout_rate) 17 | self.fc = layers.Dense(units=out_channels) 18 | self.softmax = layers.Softmax() 19 | self.temperature = temperature 20 | 21 | def call(self, inputs, **kwargs): 22 | """ 23 | :param inputs: (b, c, h, w) 24 | :return: (b, out_features) 25 | """ 26 | out = self.avgpool(inputs) 27 | out = self.dropout(out) 28 | 29 | # refer to paper: https://arxiv.org/pdf/1912.03458.pdf 30 | out = self.softmax(self.fc(out) * 1.0 / self.temperature) 31 | return out 32 | 33 | 34 | class CondConv2D(layers.Layer): 35 | def __init__(self, filters, kernel_size, stride=1, use_bias=True, num_experts=3, padding="same", **kwargs): 36 | super(CondConv2D, self).__init__(**kwargs) 37 | 38 | self.routing = Routing(out_channels=num_experts, dropout_rate=0.2, name="routing_layer") 39 | self.convs = [] 40 | for _ in range(num_experts): 41 | self.convs.append(conv2d(filters=filters, stride=stride, kernel_size=kernel_size, use_bias=use_bias, padding=padding)) 42 | 43 | def call(self, inputs, **kwargs): 44 | """ 45 | :param inputs: (b, h, w, c) 46 | :return: (b, h_out, w_out, filters) 47 | """ 48 | routing_weights = self.routing(inputs) 49 | feature = routing_weights[:, 0] * tf.transpose(self.convs[0](inputs), perm=[1, 2, 3, 0]) 50 | for i in range(1, len(self.convs)): 51 | feature += routing_weights[:, i] * tf.transpose(self.convs[i](inputs), perm=[1, 2, 3, 0]) 52 | feature = tf.transpose(feature, perm=[3, 0, 1, 2]) 53 | return feature 54 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import progressbar 3 | import tensorflow as tf 4 | 5 | 6 | class Trainer: 7 | def __init__(self, model, optimizer, epochs, train_batch, val_batch, train_data, val_data, log_dir, model_save_path, train_batch_num, 8 | val_batch_num): 9 | 10 | self.model = model 11 | self.optimizer = optimizer 12 | self.epochs = epochs 13 | self.train_data = train_data 14 | self.val_data = val_data 15 | self.train_batch = train_batch 16 | self.val_batch = val_batch 17 | self.train_batch_num = train_batch_num 18 | self.val_batch_num = val_batch_num 19 | 20 | self.model_save_path = model_save_path 21 | os.makedirs(self.model_save_path, exist_ok=True) 22 | 23 | self.train_summary_writer = tf.summary.create_file_writer(log_dir + "/train") 24 | self.val_summary_writer = tf.summary.create_file_writer(log_dir + "/val") 25 | 26 | self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 27 | 28 | self.train_loss = tf.keras.metrics.Mean(name='train_loss') 29 | self.train_accuracy_top1 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name="train_accuracy_top1") 30 | self.train_accuracy_top5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="train_accuracy_top5") 31 | 32 | self.val_loss = tf.keras.metrics.Mean(name='val_loss') 33 | self.val_accuracy_top1 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name="val_accuracy_top1") 34 | self.val_accuracy_top5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="val_accuracy_top5") 35 | 36 | def lr_decay(self, epoch): 37 | if epoch < 60: 38 | return 0.1 39 | elif epoch < 100: 40 | return 0.01 41 | elif epoch < 150: 42 | return 0.001 43 | else: 44 | return 0.0001 45 | 46 | def train_epoch(self, curr_epoch): 47 | 48 | pwidgets = [progressbar.Percentage(), " ", progressbar.Counter(format='%(value)02d/%(max_value)d'), " ", progressbar.Bar(), " ", 49 | progressbar.Timer(), ", ", progressbar.Variable('LR', width=1, precision=4), ", ", 50 | progressbar.Variable('Top1', width=2, precision=4), ", ", progressbar.Variable('Top5', width=2, precision=4), ", ", 51 | progressbar.Variable('Loss', width=2, precision=4)] 52 | pbar = progressbar.ProgressBar(widgets=pwidgets, max_value=self.train_batch_num, 53 | prefix="Epoch {}/{}: ".format(curr_epoch, self.epochs)).start() 54 | 55 | self.train_loss.reset_states() 56 | self.train_accuracy_top1.reset_states() 57 | self.train_accuracy_top5.reset_states() 58 | 59 | for batch, (images, labels) in enumerate(self.train_data): 60 | loss = self.train_step(images, labels) 61 | self.train_loss(loss) 62 | pbar.update(batch, LR=self.optimizer.learning_rate.numpy(), Top1=self.train_accuracy_top1.result().numpy(), 63 | Top5=self.train_accuracy_top5.result().numpy(), Loss=self.train_loss.result().numpy()) 64 | pbar.finish() 65 | 66 | @tf.function 67 | def train_step(self, images, labels): 68 | with tf.GradientTape(persistent=True) as tape: 69 | predictions = self.model(images, training=True) 70 | cross_entropy_loss = self.loss_object(labels, predictions) 71 | regularization_losses = self.model.losses 72 | total_loss = tf.add_n(regularization_losses + [cross_entropy_loss]) 73 | gradients = tape.gradient(total_loss, self.model.trainable_variables) 74 | self.optimizer.apply_gradients(grads_and_vars=zip(gradients, self.model.trainable_variables)) 75 | 76 | self.train_accuracy_top1(y_true=labels, y_pred=predictions) 77 | self.train_accuracy_top5(y_true=labels, y_pred=predictions) 78 | 79 | return total_loss 80 | 81 | def validate_epoch(self): 82 | pwidgets = [progressbar.Percentage(), " ", progressbar.Counter(format='%(value)02d/%(max_value)d'), " ", progressbar.Bar(), " ", 83 | progressbar.Timer(), ", ", progressbar.Variable('Top1', width=2, precision=4), ", ", 84 | progressbar.Variable('Top5', width=2, precision=4), ", ", progressbar.Variable('Loss', width=2, precision=4)] 85 | pbar = progressbar.ProgressBar(widgets=pwidgets, max_value=self.val_batch_num, prefix="Val: ").start() 86 | 87 | self.val_loss.reset_states() 88 | self.val_accuracy_top1.reset_states() 89 | self.val_accuracy_top5.reset_states() 90 | 91 | for batch, (images, labels) in enumerate(self.val_data): 92 | self.validate_step(images, labels) 93 | 94 | pbar.update(batch, Top1=self.val_accuracy_top1.result().numpy(), Top5=self.val_accuracy_top5.result().numpy(), 95 | Loss=self.val_loss.result().numpy()) 96 | 97 | pbar.finish() 98 | 99 | @tf.function 100 | def validate_step(self, images, labels): 101 | predictions = self.model(images, training=False) 102 | regularization_losses = self.model.losses 103 | 104 | cross_entropy_loss = self.loss_object(labels, predictions) 105 | total_loss = tf.add_n(regularization_losses + [cross_entropy_loss]) 106 | self.val_loss(total_loss) 107 | self.val_accuracy_top1(y_true=labels, y_pred=predictions) 108 | self.val_accuracy_top5(y_true=labels, y_pred=predictions) 109 | 110 | def __call__(self, resume=False, val=False): 111 | best_top1 = 0 112 | start_epoch = 0 113 | 114 | checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer, best_top1=tf.Variable(0), epoch=tf.Variable(0)) 115 | checkpointManager = tf.train.CheckpointManager(checkpoint, directory=self.model_save_path, max_to_keep=1, 116 | checkpoint_name="model_best.ckpt") 117 | if resume: 118 | checkpoint.restore(checkpointManager.latest_checkpoint) 119 | best_top1 = checkpoint.best_top1.numpy() 120 | start_epoch = checkpoint.epoch.numpy() + 1 # if resume, start from next epoch 121 | 122 | if val: 123 | self.validate_epoch() 124 | return 125 | 126 | for epoch in range(start_epoch, self.epochs): 127 | self.optimizer.learning_rate = self.lr_decay(epoch) 128 | 129 | self.train_epoch(epoch) 130 | 131 | with self.train_summary_writer.as_default(): 132 | tf.summary.scalar('loss', self.train_loss.result(), step=epoch) 133 | tf.summary.scalar('accuracy_top1', self.train_accuracy_top1.result(), step=epoch) 134 | tf.summary.scalar('accuracy_top5', self.train_accuracy_top5.result(), step=epoch) 135 | 136 | self.validate_epoch() 137 | 138 | with self.val_summary_writer.as_default(): 139 | tf.summary.scalar('loss', self.val_loss.result(), step=epoch) 140 | tf.summary.scalar('accuracy_top1', self.val_accuracy_top1.result(), step=epoch) 141 | tf.summary.scalar('accuracy_top5', self.val_accuracy_top5.result(), step=epoch) 142 | 143 | val_top1 = self.val_accuracy_top1.result().numpy() 144 | if val_top1 > best_top1: 145 | best_top1 = val_top1 146 | checkpoint.best_top1.assign(best_top1) 147 | checkpointManager.save() 148 | 149 | checkpoint.epoch.assign_add(1) 150 | 151 | 152 | class DisTrainer(Trainer): 153 | def __init__(self, strategy, *args, **kwargs): 154 | super(DisTrainer, self).__init__(*args, **kwargs) 155 | self.strategy = strategy 156 | self.train_global_batch = self.train_batch * self.strategy.num_replicas_in_sync 157 | self.val_global_batch = self.val_batch * self.strategy.num_replicas_in_sync 158 | self.train_data = self.strategy.experimental_distribute_dataset(self.train_data) 159 | self.val_data = self.strategy.experimental_distribute_dataset(self.val_data) 160 | self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True) 161 | 162 | def compute_loss(self, labels, predictions): 163 | per_example_loss = self.loss_object(labels, predictions) 164 | return tf.nn.compute_average_loss(per_example_loss, global_batch_size=self.train_global_batch) 165 | 166 | # distribute train step 167 | @tf.function 168 | def train_step(self, dis_images, dis_labels): 169 | def step_fn(images, labels): 170 | with tf.GradientTape() as tape: 171 | predictions = self.model(images, training=True) 172 | cross_entropy_loss = self.loss_object(labels, predictions) 173 | regularization_losses = self.model.losses 174 | total_loss = tf.add_n(regularization_losses + [cross_entropy_loss]) 175 | gradients = tape.gradient(total_loss, self.model.trainable_variables) 176 | self.optimizer.apply_gradients(grads_and_vars=zip(gradients, self.model.trainable_variables)) 177 | 178 | self.train_accuracy_top1(y_true=labels, y_pred=predictions) 179 | self.train_accuracy_top5(y_true=labels, y_pred=predictions) 180 | 181 | return total_loss 182 | 183 | per_replica_losses = self.strategy.run(step_fn, args=(dis_images, dis_labels)) 184 | return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) 185 | 186 | # distribute validate step 187 | @tf.function 188 | def validate_step(self, dis_images, dis_labels): 189 | def step_fn(images, labels): 190 | predictions = self.model(images, training=False) 191 | regularization_losses = self.model.losses 192 | 193 | cross_entropy_loss = self.loss_object(labels, predictions) 194 | total_loss = tf.add_n(regularization_losses + [cross_entropy_loss]) 195 | self.val_loss(total_loss) 196 | self.val_accuracy_top1(y_true=labels, y_pred=predictions) 197 | self.val_accuracy_top5(y_true=labels, y_pred=predictions) 198 | 199 | return self.strategy.run(step_fn, args=(dis_images, dis_labels)) 200 | --------------------------------------------------------------------------------