├── LICENSE ├── README.md ├── ckpt └── Readme.md ├── config.py ├── data └── Readme.md ├── data_loader.py ├── images └── Overview.png ├── logs └── Readme.md ├── main.py ├── resnet.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 XyChen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Mutual-Learning 2 | This is an unofficial implementation of Deep Mutual Learning by Pytorch to do classification on cifar100. 3 | The algorithm was proposed in *《Deep Mutual Learning》 (CVPR 2017)*. 4 | # Dependence 5 | Pytorch 1.0.0 6 | tensorboard 1.14.0 7 | # Overview 8 | Overview of the algorithm: 9 | 10 | # Usage 11 | The default network for DML is ResNet32. 12 | Train 2 models using DML by main.py: 13 | ``` 14 | python train.py --model_num 2 15 | ``` 16 | Use tensorboard to monitor training process on choosing port: 17 | ``` 18 | tensorboard --logdir logs --port 6006 19 | ``` 20 | # Result 21 | | Network | ind_avg_acc | Dml_avg_acc| 22 | |---------|:-----------:|:----------:| 23 | |ResNet32 | 69.83% | **71.03%** | 24 | -------------------------------------------------------------------------------- /ckpt/Readme.md: -------------------------------------------------------------------------------- 1 | Store the checkpoint -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 23 10:32:34 2019 4 | 5 | @author: chxy 6 | """ 7 | 8 | import argparse 9 | 10 | arg_lists = [] 11 | parser = argparse.ArgumentParser(description='mobilenet_classification') 12 | 13 | 14 | def str2bool(v): 15 | return v.lower() in ('true', '1') 16 | 17 | 18 | def add_argument_group(name): 19 | arg = parser.add_argument_group(name) 20 | arg_lists.append(arg) 21 | return arg 22 | 23 | # data params 24 | data_arg = add_argument_group('Data Params') 25 | data_arg.add_argument('--num_classes', type=int, default=100, 26 | help='Number of classes to classify') 27 | data_arg.add_argument('--batch_size', type=int, default=256, 28 | help='# of images in each batch of data') 29 | data_arg.add_argument('--num_workers', type=int, default=4, 30 | help='# of subprocesses to use for data loading') 31 | data_arg.add_argument('--pin_memory', type=str2bool, default=True, 32 | help='whether to copy tensors into CUDA pinned memory') 33 | data_arg.add_argument('--shuffle', type=str2bool, default=True, 34 | help='Whether to shuffle the train indices') 35 | 36 | 37 | # training params 38 | train_arg = add_argument_group('Training Params') 39 | train_arg.add_argument('--is_train', type=str2bool, default=True, 40 | help='Whether to train or test the model') 41 | train_arg.add_argument('--momentum', type=float, default=0.9, 42 | help='Momentum value') 43 | train_arg.add_argument('--epochs', type=int, default=200, 44 | help='# of epochs to train for') 45 | train_arg.add_argument('--init_lr', type=float, default=0.1, 46 | help='Initial learning rate value') 47 | train_arg.add_argument('--weight_decay', type=float, default=5e-4, 48 | help='value of weight dacay for regularization') 49 | train_arg.add_argument('--nesterov', type=str2bool, default=True, 50 | help='Whether to use Nesterov momentum') 51 | train_arg.add_argument('--lr_patience', type=int, default=10, 52 | help='Number of epochs to wait before reducing lr') 53 | train_arg.add_argument('--train_patience', type=int, default=100, 54 | help='Number of epochs to wait before stopping train') 55 | train_arg.add_argument('--gamma', type=float, default=0.1, 56 | help='value of learning rate decay') 57 | 58 | # other params 59 | misc_arg = add_argument_group('Misc.') 60 | misc_arg.add_argument('--use_gpu', type=str2bool, default=True, 61 | help="Whether to run on the GPU") 62 | misc_arg.add_argument('--best', type=str2bool, default=False, 63 | help='Load best model or most recent for testing') 64 | misc_arg.add_argument('--random_seed', type=int, default=1, 65 | help='Seed to ensure reproducibility') 66 | misc_arg.add_argument('--data_dir', type=str, default='./data/cifar100', 67 | help='Directory in which data is stored') 68 | misc_arg.add_argument('--ckpt_dir', type=str, default='./ckpt', 69 | help='Directory in which to save model checkpoints') 70 | misc_arg.add_argument('--logs_dir', type=str, default='./logs/', 71 | help='Directory in which Tensorboard logs wil be stored') 72 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True, 73 | help='Whether to use tensorboard for visualization') 74 | misc_arg.add_argument('--resume', type=str2bool, default=False, 75 | help='Whether to resume training from checkpoint') 76 | misc_arg.add_argument('--print_freq', type=int, default=10, 77 | help='How frequently to print training details') 78 | misc_arg.add_argument('--save_name', type=str, default='model', 79 | help='Name of the model to save as') 80 | misc_arg.add_argument('--model_num', type=int, default=2, 81 | help='Number of models to train for DML') 82 | 83 | def get_config(): 84 | config, unparsed = parser.parse_known_args() 85 | return config, unparsed 86 | -------------------------------------------------------------------------------- /data/Readme.md: -------------------------------------------------------------------------------- 1 | Store the dataset -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jul 10 14:12:10 2019 4 | 5 | @author: chxy 6 | """ 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torchvision import datasets 12 | from torchvision import transforms 13 | 14 | def get_train_loader(data_dir, 15 | batch_size, 16 | random_seed, 17 | shuffle=True, 18 | num_workers=4, 19 | pin_memory=True): 20 | """ 21 | Utility function for loading and returning a multi-process 22 | train iterator over the CIFAR100 dataset. 23 | 24 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 25 | 26 | Args 27 | ---- 28 | - data_dir: path directory to the dataset. 29 | - batch_size: how many samples per batch to load. 30 | - num_workers: number of subprocesses to use when loading the dataset. 31 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 32 | True if using GPU. 33 | 34 | Returns 35 | ------- 36 | - data_loader: train set iterator. 37 | """ 38 | 39 | # define transforms 40 | trans = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), # 将图像转化为32 * 32 42 | transforms.RandomHorizontalFlip(), # 随机水平翻转 43 | transforms.RandomRotation(degrees=15), # 随机旋转 44 | transforms.ToTensor(), # 将numpy数据类型转化为Tensor 45 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化 46 | ]) 47 | 48 | # load dataset 49 | dataset = datasets.CIFAR100(root=data_dir, 50 | transform=trans, 51 | download=False, 52 | train=True) 53 | if shuffle: 54 | np.random.seed(random_seed) 55 | 56 | train_loader = torch.utils.data.DataLoader( 57 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, 58 | ) 59 | 60 | return train_loader 61 | 62 | 63 | 64 | def get_test_loader(data_dir, 65 | batch_size, 66 | num_workers=4, 67 | pin_memory=True): 68 | """ 69 | Utility function for loading and returning a multi-process 70 | test iterator over the CIFAR100 dataset. 71 | 72 | If using CUDA, num_workers should be set to 1 and pin_memory to True. 73 | 74 | Args 75 | ---- 76 | - data_dir: path directory to the dataset. 77 | - batch_size: how many samples per batch to load. 78 | - num_workers: number of subprocesses to use when loading the dataset. 79 | - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to 80 | True if using GPU. 81 | 82 | Returns 83 | ------- 84 | - data_loader: test set iterator. 85 | """ 86 | # define transforms 87 | trans = transforms.Compose([ 88 | transforms.ToTensor(), # 将numpy数据类型转化为Tensor 89 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化 90 | ]) 91 | 92 | # load dataset 93 | dataset = datasets.CIFAR100( 94 | data_dir, train=False, download=False, transform=trans 95 | ) 96 | 97 | data_loader = torch.utils.data.DataLoader( 98 | dataset, batch_size=batch_size, shuffle=False, 99 | num_workers=num_workers, pin_memory=pin_memory, 100 | ) 101 | 102 | return data_loader 103 | -------------------------------------------------------------------------------- /images/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chxy95/Deep-Mutual-Learning/650e2767684ec481841581b2ff74bfd88fce1d12/images/Overview.png -------------------------------------------------------------------------------- /logs/Readme.md: -------------------------------------------------------------------------------- 1 | tensorboard logs -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jul 28 11:07:46 2019 4 | 5 | @author: chxy 6 | """ 7 | 8 | import torch 9 | 10 | from trainer import Trainer 11 | from config import get_config 12 | from utils import prepare_dirs, save_config 13 | from data_loader import get_test_loader, get_train_loader 14 | 15 | 16 | def main(config): 17 | 18 | # ensure directories are setup 19 | prepare_dirs(config) 20 | 21 | # ensure reproducibility 22 | #torch.manual_seed(config.random_seed) 23 | kwargs = {} 24 | if config.use_gpu: 25 | #torch.cuda.manual_seed_all(config.random_seed) 26 | kwargs = {'num_workers': config.num_workers, 'pin_memory': config.pin_memory} 27 | #torch.backends.cudnn.deterministic = True 28 | 29 | # instantiate data loaders 30 | test_data_loader = get_test_loader( 31 | config.data_dir, config.batch_size, **kwargs 32 | ) 33 | 34 | if config.is_train: 35 | train_data_loader = get_train_loader( 36 | config.data_dir, config.batch_size, 37 | config.random_seed, config.shuffle, **kwargs 38 | ) 39 | data_loader = (train_data_loader, test_data_loader) 40 | else: 41 | data_loader = test_data_loader 42 | 43 | # instantiate trainer 44 | trainer = Trainer(config, data_loader) 45 | 46 | # either train 47 | if config.is_train: 48 | save_config(config) 49 | trainer.train() 50 | 51 | # or load a pretrained model and test 52 | else: 53 | trainer.test() 54 | 55 | 56 | if __name__ == '__main__': 57 | config, unparsed = get_config() 58 | main(config) 59 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.nn.init as init 34 | 35 | from torch.autograd import Variable 36 | 37 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 38 | 39 | def _weights_init(m): 40 | classname = m.__class__.__name__ 41 | #print(classname) 42 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight) 44 | 45 | class LambdaLayer(nn.Module): 46 | def __init__(self, lambd): 47 | super(LambdaLayer, self).__init__() 48 | self.lambd = lambd 49 | 50 | def forward(self, x): 51 | return self.lambd(x) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, in_planes, planes, stride=1, option='A'): 58 | super(BasicBlock, self).__init__() 59 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != planes: 66 | if option == 'A': 67 | """ 68 | For CIFAR10 ResNet paper uses option A. 69 | """ 70 | self.shortcut = LambdaLayer(lambda x: 71 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 72 | elif option == 'B': 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion * planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.bn1(self.conv1(x))) 80 | out = self.bn2(self.conv2(out)) 81 | out += self.shortcut(x) 82 | out = F.relu(out) 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=100): 88 | super(ResNet, self).__init__() 89 | self.in_planes = 16 90 | 91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(16) 93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 96 | self.linear = nn.Linear(64, num_classes) 97 | 98 | self.apply(_weights_init) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_planes, planes, stride)) 105 | self.in_planes = planes * block.expansion 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = F.avg_pool2d(out, out.size()[3]) 115 | out = out.view(out.size(0), -1) 116 | out = self.linear(out) 117 | return out 118 | 119 | 120 | def resnet20(): 121 | return ResNet(BasicBlock, [3, 3, 3]) 122 | 123 | 124 | def resnet32(): 125 | return ResNet(BasicBlock, [5, 5, 5]) 126 | 127 | 128 | def resnet44(): 129 | return ResNet(BasicBlock, [7, 7, 7]) 130 | 131 | 132 | def resnet56(): 133 | return ResNet(BasicBlock, [9, 9, 9]) 134 | 135 | 136 | def resnet110(): 137 | return ResNet(BasicBlock, [18, 18, 18]) 138 | 139 | 140 | def resnet1202(): 141 | return ResNet(BasicBlock, [200, 200, 200]) 142 | 143 | 144 | def test(net): 145 | import numpy as np 146 | total_params = 0 147 | 148 | for x in filter(lambda p: p.requires_grad, net.parameters()): 149 | total_params += np.prod(x.data.numpy().shape) 150 | print("Total number of params", total_params) 151 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 152 | 153 | 154 | if __name__ == "__main__": 155 | for net_name in __all__: 156 | if net_name.startswith('resnet'): 157 | #print(net_name) 158 | test(globals()[net_name]()) 159 | print() 160 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 23 10:45:48 2019 4 | 5 | @author: chxy 6 | """ 7 | 8 | import torch 9 | import torchvision 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import torch.optim as optim 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau 14 | import torch.nn.functional as F 15 | 16 | import os 17 | import time 18 | import shutil 19 | 20 | from tqdm import tqdm 21 | from utils import accuracy, AverageMeter 22 | from resnet import resnet32 23 | from tensorboard_logger import configure, log_value 24 | 25 | class Trainer(object): 26 | """ 27 | Trainer encapsulates all the logic necessary for 28 | training the MobileNet Model. 29 | 30 | All hyperparameters are provided by the user in the 31 | config file. 32 | """ 33 | def __init__(self, config, data_loader): 34 | """ 35 | Construct a new Trainer instance. 36 | 37 | Args 38 | ---- 39 | - config: object containing command line arguments. 40 | - data_loader: data iterator 41 | """ 42 | self.config = config 43 | 44 | # data params 45 | if config.is_train: 46 | self.train_loader = data_loader[0] 47 | self.valid_loader = data_loader[1] 48 | self.num_train = len(self.train_loader.dataset) 49 | self.num_valid = len(self.valid_loader.dataset) 50 | else: 51 | self.test_loader = data_loader 52 | self.num_test = len(self.test_loader.dataset) 53 | self.num_classes = config.num_classes 54 | 55 | # training params 56 | self.epochs = config.epochs 57 | self.start_epoch = 0 58 | self.momentum = config.momentum 59 | self.lr = config.init_lr 60 | self.weight_decay = config.weight_decay 61 | self.nesterov = config.nesterov 62 | self.gamma = config.gamma 63 | # misc params 64 | self.use_gpu = config.use_gpu 65 | self.best = config.best 66 | self.ckpt_dir = config.ckpt_dir 67 | self.logs_dir = config.logs_dir 68 | self.counter = 0 69 | self.lr_patience = config.lr_patience 70 | self.train_patience = config.train_patience 71 | self.use_tensorboard = config.use_tensorboard 72 | self.resume = config.resume 73 | self.print_freq = config.print_freq 74 | self.model_name = config.save_name 75 | 76 | self.model_num = config.model_num 77 | self.models = [] 78 | self.optimizers = [] 79 | self.schedulers = [] 80 | 81 | self.loss_kl = nn.KLDivLoss(reduction='batchmean') 82 | self.loss_ce = nn.CrossEntropyLoss() 83 | self.best_valid_accs = [0.] * self.model_num 84 | 85 | # configure tensorboard logging 86 | if self.use_tensorboard: 87 | tensorboard_dir = self.logs_dir + self.model_name 88 | print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir)) 89 | if not os.path.exists(tensorboard_dir): 90 | os.makedirs(tensorboard_dir) 91 | configure(tensorboard_dir) 92 | 93 | for i in range(self.model_num): 94 | # build models 95 | model = resnet32() 96 | if self.use_gpu: 97 | model.cuda() 98 | 99 | self.models.append(model) 100 | 101 | # initialize optimizer and scheduler 102 | optimizer = optim.SGD(model.parameters(), lr=self.lr, momentum=self.momentum, 103 | weight_decay=self.weight_decay, nesterov=self.nesterov) 104 | 105 | self.optimizers.append(optimizer) 106 | 107 | # set learning rate decay 108 | scheduler = optim.lr_scheduler.StepLR(self.optimizers[i], step_size=60, gamma=self.gamma, last_epoch=-1) 109 | self.schedulers.append(scheduler) 110 | 111 | print('[*] Number of parameters of one model: {:,}'.format( 112 | sum([p.data.nelement() for p in self.models[0].parameters()]))) 113 | 114 | def train(self): 115 | """ 116 | Train the model on the training set. 117 | 118 | A checkpoint of the model is saved after each epoch 119 | and if the validation accuracy is improved upon, 120 | a separate ckpt is created for use on the test set. 121 | """ 122 | # load the most recent checkpoint 123 | if self.resume: 124 | self.load_checkpoint(best=False) 125 | 126 | print("\n[*] Train on {} samples, validate on {} samples".format( 127 | self.num_train, self.num_valid) 128 | ) 129 | 130 | for epoch in range(self.start_epoch, self.epochs): 131 | 132 | for scheduler in self.schedulers: 133 | scheduler.step(epoch) 134 | 135 | print( 136 | '\nEpoch: {}/{} - LR: {:.6f}'.format( 137 | epoch+1, self.epochs, self.optimizers[0].param_groups[0]['lr'],) 138 | ) 139 | 140 | # train for 1 epoch 141 | train_losses, train_accs = self.train_one_epoch(epoch) 142 | 143 | # evaluate on validation set 144 | valid_losses, valid_accs = self.validate(epoch) 145 | 146 | for i in range(self.model_num): 147 | is_best = valid_accs[i].avg> self.best_valid_accs[i] 148 | msg1 = "model_{:d}: train loss: {:.3f} - train acc: {:.3f} " 149 | msg2 = "- val loss: {:.3f} - val acc: {:.3f}" 150 | if is_best: 151 | #self.counter = 0 152 | msg2 += " [*]" 153 | msg = msg1 + msg2 154 | print(msg.format(i+1, train_losses[i].avg, train_accs[i].avg, valid_losses[i].avg, valid_accs[i].avg)) 155 | 156 | # check for improvement 157 | #if not is_best: 158 | #self.counter += 1 159 | #if self.counter > self.train_patience: 160 | #print("[!] No improvement in a while, stopping training.") 161 | #return 162 | self.best_valid_accs[i] = max(valid_accs[i].avg, self.best_valid_accs[i]) 163 | self.save_checkpoint(i, 164 | {'epoch': epoch + 1, 165 | 'model_state': self.models[i].state_dict(), 166 | 'optim_state': self.optimizers[i].state_dict(), 167 | 'best_valid_acc': self.best_valid_accs[i], 168 | }, is_best 169 | ) 170 | 171 | def train_one_epoch(self, epoch): 172 | """ 173 | Train the model for 1 epoch of the training set. 174 | 175 | An epoch corresponds to one full pass through the entire 176 | training set in successive mini-batches. 177 | 178 | This is used by train() and should not be called manually. 179 | """ 180 | batch_time = AverageMeter() 181 | losses = [] 182 | accs = [] 183 | 184 | for i in range(self.model_num): 185 | self.models[i].train() 186 | losses.append(AverageMeter()) 187 | accs.append(AverageMeter()) 188 | 189 | 190 | tic = time.time() 191 | with tqdm(total=self.num_train) as pbar: 192 | for i, (images, labels) in enumerate(self.train_loader): 193 | if self.use_gpu: 194 | images, labels = images.cuda(), labels.cuda() 195 | images, labels = Variable(images), Variable(labels) 196 | 197 | #forward pass 198 | outputs=[] 199 | for model in self.models: 200 | outputs.append(model(images)) 201 | for i in range(self.model_num): 202 | ce_loss = self.loss_ce(outputs[i], labels) 203 | kl_loss = 0 204 | for j in range(self.model_num): 205 | if i!=j: 206 | kl_loss += self.loss_kl(F.log_softmax(outputs[i], dim = 1), 207 | F.softmax(Variable(outputs[j]), dim=1)) 208 | loss = ce_loss + kl_loss / (self.model_num - 1) 209 | 210 | # measure accuracy and record loss 211 | prec = accuracy(outputs[i].data, labels.data, topk=(1,))[0] 212 | losses[i].update(loss.item(), images.size()[0]) 213 | accs[i].update(prec.item(), images.size()[0]) 214 | 215 | 216 | # compute gradients and update SGD 217 | self.optimizers[i].zero_grad() 218 | loss.backward() 219 | self.optimizers[i].step() 220 | 221 | # measure elapsed time 222 | toc = time.time() 223 | batch_time.update(toc-tic) 224 | 225 | pbar.set_description( 226 | ( 227 | "{:.1f}s - model1_loss: {:.3f} - model1_acc: {:.3f}".format( 228 | (toc-tic), losses[0].avg, accs[0].avg 229 | ) 230 | ) 231 | ) 232 | self.batch_size = images.shape[0] 233 | pbar.update(self.batch_size) 234 | 235 | # log to tensorboard 236 | if self.use_tensorboard: 237 | iteration = epoch*len(self.train_loader) + i 238 | for i in range(self.model_num): 239 | log_value('train_loss_%d' % (i+1), losses[i].avg, iteration) 240 | log_value('train_acc_%d' % (i+1), accs[i].avg, iteration) 241 | 242 | return losses, accs 243 | 244 | def validate(self, epoch): 245 | """ 246 | Evaluate the model on the validation set. 247 | """ 248 | losses = [] 249 | accs = [] 250 | for i in range(self.model_num): 251 | self.models[i].eval() 252 | losses.append(AverageMeter()) 253 | accs.append(AverageMeter()) 254 | 255 | for i, (images, labels) in enumerate(self.valid_loader): 256 | if self.use_gpu: 257 | images, labels = images.cuda(), labels.cuda() 258 | images, labels = Variable(images), Variable(labels) 259 | 260 | #forward pass 261 | outputs=[] 262 | for model in self.models: 263 | outputs.append(model(images)) 264 | for i in range(self.model_num): 265 | ce_loss = self.loss_ce(outputs[i], labels) 266 | kl_loss = 0 267 | for j in range(self.model_num): 268 | if i!=j: 269 | kl_loss += self.loss_kl(F.log_softmax(outputs[i], dim = 1), 270 | F.softmax(Variable(outputs[j]), dim=1)) 271 | loss = ce_loss + kl_loss / (self.model_num - 1) 272 | 273 | # measure accuracy and record loss 274 | prec = accuracy(outputs[i].data, labels.data, topk=(1,))[0] 275 | losses[i].update(loss.item(), images.size()[0]) 276 | accs[i].update(prec.item(), images.size()[0]) 277 | 278 | # log to tensorboard for every epoch 279 | if self.use_tensorboard: 280 | for i in range(self.model_num): 281 | log_value('valid_loss_%d' % (i+1), losses[i].avg, epoch+1) 282 | log_value('valid_acc_%d' % (i+1), accs[i].avg, epoch+1) 283 | 284 | return losses, accs 285 | 286 | def test(self): 287 | """ 288 | Test the model on the held-out test data. 289 | This function should only be called at the very 290 | end once the model has finished training. 291 | """ 292 | losses = AverageMeter() 293 | top1 = AverageMeter() 294 | top5 = AverageMeter() 295 | 296 | # load the best checkpoint 297 | self.load_checkpoint(best=self.best) 298 | self.model.eval() 299 | for i, (images, labels) in enumerate(self.test_loader): 300 | if self.use_gpu: 301 | images, labels = images.cuda(), labels.cuda() 302 | images, labels = Variable(images), Variable(labels) 303 | 304 | #forward pass 305 | outputs = self.model(images) 306 | loss = self.loss_fn(outputs, labels) 307 | 308 | # measure accuracy and record loss 309 | prec1, prec5 = accuracy(outputs.data, labels.data, topk=(1, 5)) 310 | losses.update(loss.item(), images.size()[0]) 311 | top1.update(prec1.item(), images.size()[0]) 312 | top5.update(prec5.item(), images.size()[0]) 313 | 314 | print( 315 | '[*] Test loss: {:.3f}, top1_acc: {:.3f}%, top5_acc: {:.3f}%'.format( 316 | losses.avg, top1.avg, top5.avg) 317 | ) 318 | 319 | def save_checkpoint(self, i, state, is_best): 320 | """ 321 | Save a copy of the model so that it can be loaded at a future 322 | date. This function is used when the model is being evaluated 323 | on the test data. 324 | 325 | If this model has reached the best validation accuracy thus 326 | far, a seperate file with the suffix `best` is created. 327 | """ 328 | # print("[*] Saving model to {}".format(self.ckpt_dir)) 329 | 330 | filename = self.model_name + str(i+1) + '_ckpt.pth.tar' 331 | ckpt_path = os.path.join(self.ckpt_dir, filename) 332 | torch.save(state, ckpt_path) 333 | 334 | if is_best: 335 | filename = self.model_name + str(i+1) + '_model_best.pth.tar' 336 | shutil.copyfile( 337 | ckpt_path, os.path.join(self.ckpt_dir, filename) 338 | ) 339 | 340 | '''def load_checkpoint(self, best=False): 341 | """ 342 | Load the best copy of a model. This is useful for 2 cases: 343 | 344 | - Resuming training with the most recent model checkpoint. 345 | - Loading the best validation model to evaluate on the test data. 346 | 347 | Params 348 | ------ 349 | - best: if set to True, loads the best model. Use this if you want 350 | to evaluate your model on the test data. Else, set to False in 351 | which case the most recent version of the checkpoint is used. 352 | """ 353 | print("[*] Loading model from {}".format(self.ckpt_dir)) 354 | 355 | filename = self.model_name + '_ckpt.pth.tar' 356 | if best: 357 | filename = self.model_name + '_model_best.pth.tar' 358 | ckpt_path = os.path.join(self.ckpt_dir, filename) 359 | ckpt = torch.load(ckpt_path) 360 | 361 | # load variables from checkpoint 362 | self.start_epoch = ckpt['epoch'] 363 | self.best_valid_acc = ckpt['best_valid_acc'] 364 | self.model.load_state_dict(ckpt['model_state']) 365 | self.optimizer.load_state_dict(ckpt['optim_state']) 366 | 367 | if best: 368 | print( 369 | "[*] Loaded {} checkpoint @ epoch {} " 370 | "with best valid acc of {:.3f}".format( 371 | filename, ckpt['epoch'], ckpt['best_valid_acc']) 372 | ) 373 | else: 374 | print( 375 | "[*] Loaded {} checkpoint @ epoch {}".format( 376 | filename, ckpt['epoch']) 377 | )''' 378 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | 7 | from PIL import Image 8 | 9 | 10 | def denormalize(T, coords): 11 | return (0.5 * ((coords + 1.0) * T)) 12 | 13 | class AverageMeter(object): 14 | """ 15 | Computes and stores the average and 16 | current value. 17 | """ 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | def accuracy(output, target, topk=(1,)): 34 | """Computes the precision@k for the specified values of k""" 35 | maxk = max(topk) 36 | batch_size = target.size(0) 37 | 38 | _, pred = output.topk(maxk, 1, True, True) 39 | pred = pred.t() 40 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 41 | 42 | res = [] 43 | for k in topk: 44 | correct_k = correct[:k].view(-1).float().sum(0) 45 | res.append(correct_k.mul_(100.0 / batch_size)) 46 | return res 47 | 48 | def resize_array(x, size): 49 | # 3D and 4D tensors allowed only 50 | assert x.ndim in [3, 4], "Only 3D and 4D Tensors allowed!" 51 | 52 | # 4D Tensor 53 | if x.ndim == 4: 54 | res = [] 55 | for i in range(x.shape[0]): 56 | img = array2img(x[i]) 57 | img = img.resize((size, size)) 58 | img = np.asarray(img, dtype='float32') 59 | img = np.expand_dims(img, axis=0) 60 | img /= 255.0 61 | res.append(img) 62 | res = np.concatenate(res) 63 | res = np.expand_dims(res, axis=1) 64 | return res 65 | 66 | # 3D Tensor 67 | img = array2img(x) 68 | img = img.resize((size, size)) 69 | res = np.asarray(img, dtype='float32') 70 | res = np.expand_dims(res, axis=0) 71 | res /= 255.0 72 | return res 73 | 74 | 75 | def img2array(data_path, desired_size=None, expand=False, view=False): 76 | """ 77 | Util function for loading RGB image into a numpy array. 78 | 79 | Returns array of shape (1, H, W, C). 80 | """ 81 | img = Image.open(data_path) 82 | img = img.convert('RGB') 83 | if desired_size: 84 | img = img.resize((desired_size[1], desired_size[0])) 85 | if view: 86 | img.show() 87 | x = np.asarray(img, dtype='float32') 88 | if expand: 89 | x = np.expand_dims(x, axis=0) 90 | x /= 255.0 91 | return x 92 | 93 | 94 | def array2img(x): 95 | """ 96 | Util function for converting anumpy array to a PIL img. 97 | 98 | Returns PIL RGB img. 99 | """ 100 | x = np.asarray(x) 101 | x = x + max(-np.min(x), 0) 102 | x_max = np.max(x) 103 | if x_max != 0: 104 | x /= x_max 105 | x *= 255 106 | return Image.fromarray(x.astype('uint8'), 'RGB') 107 | 108 | def prepare_dirs(config): 109 | for path in [config.ckpt_dir, config.logs_dir]: 110 | if not os.path.exists(path): 111 | os.makedirs(path) 112 | 113 | 114 | def save_config(config): 115 | model_name = config.save_name 116 | filename = model_name + '_params.json' 117 | param_path = os.path.join(config.ckpt_dir, filename) 118 | 119 | print("[*] Model Checkpoint Dir: {}".format(config.ckpt_dir)) 120 | print("[*] Param Path: {}".format(param_path)) 121 | 122 | with open(param_path, 'w') as fp: 123 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 124 | --------------------------------------------------------------------------------