├── README.md ├── conf ├── __init__.py └── global_settings.py ├── models └── MS_ResNet.py ├── test.py ├── train_amp.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Advancing Spiking Neural Networks towards Deep Residual Learning 2 | 3 | This repo **covers the implementation of the following paper:** 4 | 5 | "Advancing Spiking Neural Networks towards Deep Residual Learning". [Paper](https://arxiv.org/abs/2112.08954). 6 | 7 | The most straightforward way of training higher quality models is by increasing their size. In this work, we would like to see that deepening network structures could get rid of the degradation problem and always be a trustworthy way to achieve satisfying accuracy for the direct training of SNNs. 8 | 9 | This repository contains the source code for the training of our MS-ResNet on ImageNet. The models are defined in `models/MS_ResNet.py` . 10 | 11 | ## Running 12 | 13 | 1. Install Python 3.7, PyTorch 1.8 and Tensorboard. 14 | 15 | 2. Change the data paths `vardir,traindir` to the image folders of ImageNet dataset. 16 | 17 | 3. To train the model, please run `CUDA_VISIBLE_DEVICES=GPU_IDs python -m torch.distributed.launch --master_port=1234 --nproc_per_node=NUM_GPU_USED train_amp.py -net resnet34 -b 256 -lr 0.1` . 18 | 19 | `-net` option supports `resnet18/34/104` . 20 | 21 | ## Citation 22 | 23 | If you find this repo useful for your research, please consider citing the paper 24 | 25 | ``` 26 | @misc{hu2023advancing, 27 | title={Advancing Spiking Neural Networks towards Deep Residual Learning}, 28 | author={Yifan Hu and Lei Deng and Yujie Wu and Man Yao and Guoqi Li}, 29 | year={2023}, 30 | eprint={2112.08954}, 31 | archivePrefix={arXiv}, 32 | primaryClass={cs.NE} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /conf/__init__.py: -------------------------------------------------------------------------------- 1 | """ dynamically load settings 2 | 3 | author baiyu 4 | """ 5 | import conf.global_settings as settings 6 | 7 | class Settings: 8 | def __init__(self, settings): 9 | 10 | for attr in dir(settings): 11 | if attr.isupper(): 12 | setattr(self, attr, getattr(settings, attr)) 13 | 14 | settings = Settings(settings) -------------------------------------------------------------------------------- /conf/global_settings.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | CHECKPOINT_PATH = 'checkpoint' 4 | 5 | # total training epoches 6 | EPOCH = 125 7 | 8 | # initial learning rate 9 | # INIT_LR = 0.1 10 | 11 | # time of we run the script 12 | TIME_NOW = datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') 13 | 14 | # tensorboard log dir 15 | LOG_DIR = 'runs' 16 | 17 | # save weights file per SAVE_EPOCH epoch 18 | SAVE_EPOCH = 30 19 | -------------------------------------------------------------------------------- /models/MS_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # Model for RM-ResNet 6 | 7 | thresh = 0.5 # neuronal threshold 8 | lens = 0.5 # hyper-parameters of approximate function 9 | decay = 0.25 # decay constants 10 | num_classes = 1000 11 | time_window = 6 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | # define approximate firing function 16 | class ActFun(torch.autograd.Function): 17 | 18 | @staticmethod 19 | def forward(ctx, input): 20 | ctx.save_for_backward(input) 21 | return input.gt(thresh).float() 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | input, = ctx.saved_tensors 26 | grad_input = grad_output.clone() 27 | temp = abs(input - thresh) < lens 28 | temp = temp / (2 * lens) 29 | return grad_input * temp.float() 30 | 31 | 32 | act_fun = ActFun.apply 33 | # membrane potential update 34 | 35 | 36 | class mem_update(nn.Module): 37 | 38 | def __init__(self): 39 | super(mem_update, self).__init__() 40 | 41 | def forward(self, x): 42 | mem = torch.zeros_like(x[0]).to(device) 43 | spike = torch.zeros_like(x[0]).to(device) 44 | output = torch.zeros_like(x) 45 | mem_old = 0 46 | for i in range(time_window): 47 | if i >= 1: 48 | mem = mem_old * decay * (1 - spike.detach()) + x[i] 49 | else: 50 | mem = x[i] 51 | spike = act_fun(mem) 52 | mem_old = mem.clone() 53 | output[i] = spike 54 | return output 55 | 56 | 57 | class batch_norm_2d(nn.Module): 58 | """TDBN""" 59 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 60 | super(batch_norm_2d, self).__init__() 61 | self.bn = BatchNorm3d1(num_features) 62 | 63 | def forward(self, input): 64 | y = input.transpose(0, 2).contiguous().transpose(0, 1).contiguous() 65 | y = self.bn(y) 66 | return y.contiguous().transpose(0, 1).contiguous().transpose(0, 2) 67 | 68 | 69 | class batch_norm_2d1(nn.Module): 70 | """TDBN-Zero init""" 71 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 72 | super(batch_norm_2d1, self).__init__() 73 | self.bn = BatchNorm3d2(num_features) 74 | 75 | def forward(self, input): 76 | y = input.transpose(0, 2).contiguous().transpose(0, 1).contiguous() 77 | y = self.bn(y) 78 | return y.contiguous().transpose(0, 1).contiguous().transpose(0, 2) 79 | 80 | 81 | class BatchNorm3d1(torch.nn.BatchNorm3d): 82 | 83 | def reset_parameters(self): 84 | self.reset_running_stats() 85 | if self.affine: 86 | nn.init.constant_(self.weight, thresh) 87 | nn.init.zeros_(self.bias) 88 | 89 | 90 | class BatchNorm3d2(torch.nn.BatchNorm3d): 91 | 92 | def reset_parameters(self): 93 | self.reset_running_stats() 94 | if self.affine: 95 | nn.init.constant_(self.weight, 0) 96 | nn.init.zeros_(self.bias) 97 | 98 | 99 | class Snn_Conv2d(nn.Conv2d): 100 | 101 | def __init__(self, 102 | in_channels, 103 | out_channels, 104 | kernel_size, 105 | stride=1, 106 | padding=0, 107 | dilation=1, 108 | groups=1, 109 | bias=True, 110 | padding_mode='zeros', 111 | marker='b'): 112 | super(Snn_Conv2d, 113 | self).__init__(in_channels, out_channels, kernel_size, stride, 114 | padding, dilation, groups, bias, padding_mode) 115 | self.marker = marker 116 | 117 | def forward(self, input): 118 | weight = self.weight 119 | h = (input.size()[3] - self.kernel_size[0] + 120 | 2 * self.padding[0]) // self.stride[0] + 1 121 | w = (input.size()[4] - self.kernel_size[0] + 122 | 2 * self.padding[0]) // self.stride[0] + 1 123 | c1 = torch.zeros(time_window, 124 | input.size()[1], 125 | self.out_channels, 126 | h, 127 | w, 128 | device=input.device) 129 | for i in range(time_window): 130 | c1[i] = F.conv2d(input[i], weight, self.bias, self.stride, 131 | self.padding, self.dilation, self.groups) 132 | return c1 133 | 134 | 135 | ###################################################################################################################### 136 | class BasicBlock_104(nn.Module): 137 | expansion = 1 138 | 139 | def __init__(self, in_channels, out_channels, stride=1): 140 | super().__init__() 141 | self.residual_function = nn.Sequential( 142 | mem_update(), 143 | Snn_Conv2d(in_channels, 144 | out_channels, 145 | kernel_size=3, 146 | stride=stride, 147 | padding=1, 148 | bias=False), 149 | batch_norm_2d(out_channels), 150 | mem_update(), 151 | Snn_Conv2d(out_channels, 152 | out_channels * BasicBlock_104.expansion, 153 | kernel_size=3, 154 | padding=1, 155 | bias=False), 156 | batch_norm_2d1(out_channels * BasicBlock_104.expansion), 157 | ) 158 | self.shortcut = nn.Sequential() 159 | 160 | if stride != 1 or in_channels != BasicBlock_104.expansion * out_channels: 161 | self.shortcut = nn.Sequential( 162 | nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)), 163 | Snn_Conv2d(in_channels, 164 | out_channels * BasicBlock_104.expansion, 165 | kernel_size=1, 166 | stride=1, 167 | bias=False), 168 | batch_norm_2d(out_channels * BasicBlock_104.expansion), 169 | ) 170 | 171 | def forward(self, x): 172 | return (self.residual_function(x) + self.shortcut(x)) 173 | 174 | 175 | class ResNet_104(nn.Module): 176 | # Channel: 177 | def __init__(self, block, num_block, num_classes=1000): 178 | super().__init__() 179 | k = 1 180 | self.in_channels = 64 * k 181 | self.conv1 = nn.Sequential( 182 | Snn_Conv2d(3, 64 * k, kernel_size=3, padding=1, stride=2), 183 | Snn_Conv2d(64 * k, 64 * k, kernel_size=3, padding=1, stride=1), 184 | Snn_Conv2d(64 * k, 64 * k, kernel_size=3, padding=1, stride=1), 185 | batch_norm_2d(64 * k), 186 | ) 187 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 188 | 189 | self.mem_update = mem_update() 190 | self.conv2_x = self._make_layer(block, 64 * k, num_block[0], 2) 191 | self.conv3_x = self._make_layer(block, 128 * k, num_block[1], 2) 192 | self.conv4_x = self._make_layer(block, 256 * k, num_block[2], 2) 193 | self.conv5_x = self._make_layer(block, 512 * k, num_block[3], 2) 194 | self.fc = nn.Linear(512 * block.expansion * k, num_classes) 195 | self.dropout = nn.Dropout(p=0.2) 196 | 197 | def _make_layer(self, block, out_channels, num_blocks, stride): 198 | 199 | strides = [stride] + [1] * (num_blocks - 1) 200 | layers = [] 201 | for stride in strides: 202 | layers.append(block(self.in_channels, out_channels, stride)) 203 | self.in_channels = out_channels * block.expansion 204 | 205 | return nn.Sequential(*layers) 206 | 207 | def forward(self, x): 208 | input = torch.zeros(time_window, 209 | x.size()[0], 210 | 3, 211 | x.size()[2], 212 | x.size()[3], 213 | device=device) 214 | for i in range(time_window): 215 | input[i] = x 216 | output = self.conv1(input) 217 | output = self.conv2_x(output) 218 | output = self.conv3_x(output) 219 | output = self.conv4_x(output) 220 | output = self.conv5_x(output) 221 | output = self.mem_update(output) 222 | output = F.adaptive_avg_pool3d(output, (None, 1, 1)) 223 | output = output.view(output.size()[0], output.size()[1], -1) 224 | output = output.sum(dim=0) / output.size()[0] 225 | output = self.dropout(output) 226 | output = self.fc(output) 227 | return output 228 | 229 | 230 | def resnet104(): 231 | return ResNet_104(BasicBlock_104, [3, 8, 32, 8]) 232 | 233 | 234 | class BasicBlock_18(nn.Module): 235 | expansion = 1 236 | 237 | def __init__(self, in_channels, out_channels, stride=1): 238 | super().__init__() 239 | self.residual_function = nn.Sequential( 240 | mem_update(), 241 | Snn_Conv2d(in_channels, 242 | out_channels, 243 | kernel_size=3, 244 | stride=stride, 245 | padding=1, 246 | bias=False), 247 | batch_norm_2d(out_channels), 248 | mem_update(), 249 | Snn_Conv2d(out_channels, 250 | out_channels * BasicBlock_18.expansion, 251 | kernel_size=3, 252 | padding=1, 253 | bias=False), 254 | batch_norm_2d1(out_channels * BasicBlock_18.expansion), 255 | ) 256 | self.shortcut = nn.Sequential() 257 | 258 | if stride != 1 or in_channels != BasicBlock_18.expansion * out_channels: 259 | self.shortcut = nn.Sequential( 260 | Snn_Conv2d(in_channels, 261 | out_channels * BasicBlock_18.expansion, 262 | kernel_size=1, 263 | stride=stride, 264 | bias=False), 265 | batch_norm_2d(out_channels * BasicBlock_18.expansion), 266 | ) 267 | 268 | def forward(self, x): 269 | return (self.residual_function(x) + self.shortcut(x)) 270 | 271 | 272 | class ResNet_origin_18(nn.Module): 273 | # Channel: 274 | def __init__(self, block, num_block, num_classes=1000): 275 | super().__init__() 276 | k = 1 277 | self.in_channels = 64 * k 278 | self.conv1 = nn.Sequential( 279 | Snn_Conv2d(3, 280 | 64 * k, 281 | kernel_size=7, 282 | padding=3, 283 | bias=False, 284 | stride=2), 285 | batch_norm_2d(64 * k), 286 | ) 287 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 288 | 289 | self.mem_update = mem_update() 290 | self.conv2_x = self._make_layer(block, 64 * k, num_block[0], 2) 291 | self.conv3_x = self._make_layer(block, 128 * k, num_block[1], 2) 292 | self.conv4_x = self._make_layer(block, 256 * k, num_block[2], 2) 293 | self.conv5_x = self._make_layer(block, 512 * k, num_block[3], 2) 294 | self.fc = nn.Linear(512 * block.expansion * k, num_classes) 295 | 296 | def _make_layer(self, block, out_channels, num_blocks, stride): 297 | strides = [stride] + [1] * (num_blocks - 1) 298 | layers = [] 299 | for stride in strides: 300 | layers.append(block(self.in_channels, out_channels, stride)) 301 | self.in_channels = out_channels * block.expansion 302 | 303 | return nn.Sequential(*layers) 304 | 305 | def forward(self, x): 306 | input = torch.zeros(time_window, 307 | x.size()[0], 308 | 3, 309 | x.size()[2], 310 | x.size()[3], 311 | device=device) 312 | for i in range(time_window): 313 | input[i] = x 314 | output = self.conv1(input) 315 | output = self.conv2_x(output) 316 | output = self.conv3_x(output) 317 | output = self.conv4_x(output) 318 | output = self.conv5_x(output) 319 | output = self.mem_update(output) 320 | output = F.adaptive_avg_pool3d(output, (None, 1, 1)) 321 | output = output.view(output.size()[0], output.size()[1], -1) 322 | output = output.sum(dim=0) / output.size()[0] 323 | output = self.fc(output) 324 | return output 325 | 326 | 327 | def resnet18(): 328 | return ResNet_origin_18(BasicBlock_18, [2, 2, 2, 2]) 329 | 330 | 331 | def resnet34(): 332 | return ResNet_origin_18(BasicBlock_18, [3, 4, 6, 3]) 333 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import DataLoader 6 | import models.MS_ResNet 7 | import torchvision.datasets as datasets 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-weights', 13 | type=str, 14 | default="resnet34.pth", 15 | help='the weights file you want to test') 16 | parser.add_argument('-net', type=str, required=True, help='net type') 17 | parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not') 18 | parser.add_argument('-b', 19 | type=int, 20 | default=100, 21 | help='batch size for dataloader') 22 | args = parser.parse_args() 23 | if args.net == "resnet34": 24 | net = models.MS_ResNet.resnet34() 25 | elif args.net == "resnet104": 26 | net = models.MS_ResNet.resnet104() 27 | elif args.net == "resnet18": 28 | net = models.MS_ResNet.resnet18() 29 | 30 | def get_test_dataloader(batch_size=16, num_workers=4, shuffle=False): 31 | valdir = "/data1/imagenet/val" 32 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]) 34 | ImageNet_test = datasets.ImageFolder( 35 | valdir, 36 | transforms.Compose([ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | normalize, 41 | ])) 42 | ImageNet_test_loader = DataLoader(ImageNet_test, 43 | shuffle=shuffle, 44 | num_workers=num_workers, 45 | batch_size=batch_size) 46 | return ImageNet_test_loader 47 | 48 | ImageNet_test_loader = get_test_dataloader( 49 | num_workers=4, 50 | batch_size=args.b, 51 | ) 52 | 53 | # net = torch.load(args.weights) 54 | net.load_state_dict({ 55 | k.replace('module.', ''): v 56 | for k, v in torch.load(args.weights).items() 57 | }) 58 | net.cuda() 59 | net = torch.nn.DataParallel(net) 60 | net.eval() 61 | correct_1 = 0.0 62 | correct_5 = 0.0 63 | total = 0 64 | start = time.time() 65 | with torch.no_grad(): 66 | for n_iter, (image, label) in enumerate(ImageNet_test_loader): 67 | if (n_iter % 10 == 0): 68 | print("iteration: {}\ttotal {} iterations".format( 69 | n_iter + 1, len(ImageNet_test_loader))) 70 | 71 | if args.gpu: 72 | image = image.cuda() 73 | label = label.cuda() 74 | 75 | output = net(image) 76 | _, pred = output.topk(5, 1, largest=True, sorted=True) 77 | label = label.view(label.size(0), -1).expand_as(pred) 78 | correct = pred.eq(label).float() 79 | 80 | # compute top 5 81 | correct_5 += correct[:, :5].sum() 82 | 83 | # compute top1 84 | correct_1 += correct[:, :1].sum() 85 | finish = time.time() 86 | 87 | print() 88 | print("Time consumed:", finish - start) 89 | print("Top 1 acc: ", correct_1.item() / len(ImageNet_test_loader.dataset)) 90 | print("Top 5 acc: ", correct_5.item() / len(ImageNet_test_loader.dataset)) 91 | print("Parameter numbers: {}".format( 92 | sum(p.numel() for p in net.parameters()))) 93 | -------------------------------------------------------------------------------- /train_amp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torch.cuda.amp import autocast 10 | import torch.cuda.amp 11 | from conf import settings 12 | from utils import get_network, get_training_dataloader, get_test_dataloader 13 | 14 | 15 | def train(epoch, args): 16 | running_loss = 0 17 | start = time.time() 18 | net.train() 19 | correct = 0.0 20 | num_sample = 0 21 | for batch_index, (images, labels) in enumerate(ImageNet_training_loader): 22 | if args.gpu: 23 | labels = labels.cuda(non_blocking=True) 24 | images = images.cuda(non_blocking=True) 25 | num_sample += images.size()[0] 26 | optimizer.zero_grad() 27 | with autocast(): 28 | outputs = net(images) 29 | _, preds = outputs.max(1) 30 | correct += preds.eq(labels).sum() 31 | loss = loss_function(outputs, labels) 32 | running_loss += loss.item() 33 | scaler.scale(loss).backward() 34 | scaler.step(optimizer) 35 | scaler.update() 36 | n_iter = (epoch - 1) * len(ImageNet_training_loader) + batch_index + 1 37 | if batch_index % 10 == 9: 38 | print( 39 | 'Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}' 40 | .format(running_loss / 10, 41 | optimizer.param_groups[0]['lr'], 42 | epoch=epoch, 43 | trained_samples=batch_index * args.b + len(images), 44 | total_samples=len(ImageNet_training_loader.dataset))) 45 | print('training time consumed: {:.2f}s'.format(time.time() - 46 | start)) 47 | if args.local_rank == 0: 48 | writer.add_scalar('Train/avg_loss', running_loss / 10, n_iter) 49 | writer.add_scalar('Train/avg_loss_numpic', running_loss / 10, 50 | n_iter * args.b) 51 | running_loss = 0 52 | finish = time.time() 53 | if args.local_rank == 0: 54 | writer.add_scalar('Train/acc', correct / num_sample * 100, epoch) 55 | print("Training accuracy: {:.2f} of epoch {}".format( 56 | correct / num_sample * 100, epoch)) 57 | print('epoch {} training time consumed: {:.2f}s'.format( 58 | epoch, finish - start)) 59 | 60 | 61 | @torch.no_grad() 62 | def eval_training(epoch, args): 63 | 64 | start = time.time() 65 | net.eval() 66 | 67 | test_loss = 0.0 68 | correct = 0.0 69 | real_batch = 0 70 | for (images, labels) in ImageNet_test_loader: 71 | real_batch += images.size()[0] 72 | if args.gpu: 73 | images = images.cuda() 74 | labels = labels.cuda() 75 | 76 | outputs = net(images) 77 | loss = loss_function(outputs, labels) 78 | test_loss += loss.item() 79 | _, preds = outputs.max(1) 80 | correct += preds.eq(labels).sum() 81 | 82 | finish = time.time() 83 | print('Evaluating Network.....') 84 | print( 85 | 'Test set: Average loss: {:.4f}, Accuracy: {:.4f}%, Time consumed:{:.2f}s' 86 | .format(test_loss * args.b / len(ImageNet_test_loader.dataset), 87 | correct.float() / real_batch * 100, finish - start)) 88 | 89 | if args.local_rank == 0: 90 | # add information to tensorboard 91 | writer.add_scalar( 92 | 'Test/Average loss', 93 | test_loss * args.b / len(ImageNet_test_loader.dataset), epoch) 94 | writer.add_scalar('Test/Accuracy', 95 | correct.float() / real_batch * 100, epoch) 96 | 97 | return correct.float() / len(ImageNet_test_loader.dataset) 98 | 99 | 100 | # for resnet-104 101 | class CrossEntropyLabelSmooth(nn.Module): 102 | 103 | def __init__(self, num_classes=1000, epsilon=0.1): 104 | super(CrossEntropyLabelSmooth, self).__init__() 105 | self.num_classes = num_classes 106 | self.epsilon = epsilon 107 | self.logsoftmax = nn.LogSoftmax(dim=1) 108 | 109 | def forward(self, inputs, targets): 110 | log_probs = self.logsoftmax(inputs) 111 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 112 | 1) 113 | targets = (1 - 114 | self.epsilon) * targets + self.epsilon / self.num_classes 115 | loss = (-targets * log_probs).mean(0).sum() 116 | return loss 117 | 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('-net', type=str, required=True, help='net type') 122 | parser.add_argument('-gpu', 123 | action='store_true', 124 | default=True, 125 | help='use gpu or not') 126 | parser.add_argument('-b', 127 | type=int, 128 | default=256, 129 | help='batch size for dataloader') 130 | parser.add_argument('-lr', 131 | type=float, 132 | default=0.1, 133 | help='initial learning rate') 134 | parser.add_argument('--local_rank', 135 | default=-1, 136 | type=int, 137 | help='node rank for distributed training') 138 | args = parser.parse_args() 139 | print(args.local_rank) 140 | torch.distributed.init_process_group(backend='nccl') 141 | torch.cuda.set_device(args.local_rank) 142 | 143 | SEED = 445 144 | torch.manual_seed(SEED) 145 | torch.cuda.manual_seed(SEED) 146 | np.random.seed(SEED) 147 | 148 | net = get_network(args) 149 | net.cuda() 150 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 151 | net = torch.nn.parallel.DistributedDataParallel( 152 | net, device_ids=[args.local_rank]) 153 | 154 | # to load a pretrained model 155 | # map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank} 156 | # net.load_state_dict( 157 | # torch.load("path", map_location=map_location)) 158 | 159 | num_gpus = torch.cuda.device_count() 160 | if torch.cuda.device_count() > 1: 161 | print("Let's use", torch.cuda.device_count(), "GPUs!") 162 | 163 | # data preprocessing: 164 | ImageNet_training_loader = get_training_dataloader( 165 | traindir="/data/imagenet/train", 166 | num_workers=2, 167 | batch_size=args.b // num_gpus, 168 | shuffle=False, 169 | sampler=1 # to enable sampler for DDP 170 | ) 171 | 172 | ImageNet_test_loader = get_test_dataloader(valdir="/data/imagenet/val", 173 | num_workers=2, 174 | batch_size=args.b // num_gpus, 175 | shuffle=False, 176 | sampler=1) 177 | # learning rate should go with batch size. 178 | b_lr = args.lr 179 | 180 | loss_function = CrossEntropyLabelSmooth() 181 | optimizer = optim.SGD([{ 182 | 'params': net.parameters(), 183 | 'initial_lr': b_lr 184 | }], 185 | momentum=0.9, 186 | lr=b_lr, 187 | weight_decay=1e-4) # SGD MOMENTUM 188 | train_scheduler = optim.lr_scheduler.CosineAnnealingLR( 189 | optimizer, T_max=settings.EPOCH, eta_min=0, last_epoch=0) 190 | iter_per_epoch = len(ImageNet_training_loader) 191 | LOG_INFO = "ImageNet_ACC" 192 | checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, 193 | str(args.b), str(args.lr), LOG_INFO, 194 | settings.TIME_NOW) 195 | 196 | # use tensorboard 197 | if args.local_rank == 0: 198 | if not os.path.exists(settings.LOG_DIR): 199 | os.mkdir(settings.LOG_DIR) 200 | writer = SummaryWriter( 201 | log_dir=os.path.join(settings.LOG_DIR, args.net, str(args.b), 202 | str(args.lr), LOG_INFO, settings.TIME_NOW)) 203 | 204 | # create checkpoint folder to save model 205 | if args.local_rank == 0: 206 | if not os.path.exists(checkpoint_path): 207 | os.makedirs(checkpoint_path) 208 | checkpoint_path = os.path.join(checkpoint_path, 209 | '{net}-{epoch}-{type}.pth') 210 | scaler = torch.cuda.amp.GradScaler() 211 | best_acc = 0.0 212 | 213 | for epoch in range(1, settings.EPOCH + 1): 214 | train(epoch, args) 215 | 216 | train_scheduler.step() 217 | acc = eval_training(epoch, args) 218 | 219 | if epoch > (settings.EPOCH - 220 | 5) and best_acc < acc and args.local_rank == 0: 221 | torch.save( 222 | net.state_dict(), 223 | checkpoint_path.format(net=args.net, epoch=epoch, type='best')) 224 | best_acc = acc 225 | continue 226 | elif epoch >= (settings.EPOCH - 5) and args.local_rank == 0: 227 | torch.save( 228 | net.state_dict(), 229 | checkpoint_path.format(net=args.net, 230 | epoch=epoch, 231 | type='regular')) 232 | continue 233 | elif ((not epoch % settings.SAVE_EPOCH) and args.local_rank == 0): 234 | torch.save( 235 | net.state_dict(), 236 | checkpoint_path.format(net=args.net, 237 | epoch=epoch, 238 | type='regular')) 239 | continue 240 | 241 | if args.local_rank == 0: 242 | writer.close() 243 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader 4 | import torchvision.datasets as datasets 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | 8 | def get_network(args): 9 | """ return given network 10 | """ 11 | if args.net == 'resnet18': 12 | from models.MS_ResNet import resnet18 13 | net = resnet18() 14 | elif args.net == 'resnet34': 15 | from models.MS_ResNet import resnet34 16 | net = resnet34() 17 | elif args.net == 'resnet104': 18 | from models.MS_ResNet import resnet104 19 | net = resnet104() 20 | else: 21 | print('the network name you have entered is not supported yet') 22 | sys.exit() 23 | 24 | return net 25 | 26 | 27 | def get_training_dataloader(traindir, 28 | sampler=None, 29 | batch_size=16, 30 | num_workers=2, 31 | shuffle=True): 32 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]) 34 | ImageNet_training = datasets.ImageFolder( 35 | traindir, 36 | transforms.Compose([ 37 | transforms.RandomResizedCrop(224), 38 | transforms.AutoAugment(), 39 | transforms.ToTensor(), 40 | normalize, 41 | ])) 42 | if sampler is not None: 43 | ImageNet_training_loader = DataLoader( 44 | ImageNet_training, 45 | shuffle=shuffle, 46 | num_workers=num_workers, 47 | batch_size=batch_size, 48 | pin_memory=True, 49 | sampler=DistributedSampler(ImageNet_training)) 50 | else: 51 | ImageNet_training_loader = DataLoader(ImageNet_training, 52 | shuffle=shuffle, 53 | num_workers=num_workers, 54 | batch_size=batch_size, 55 | pin_memory=True) 56 | 57 | return ImageNet_training_loader 58 | 59 | 60 | def get_test_dataloader(valdir, 61 | sampler=None, 62 | batch_size=16, 63 | num_workers=2, 64 | shuffle=False): 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | ImageNet_test = datasets.ImageFolder( 68 | valdir, 69 | transforms.Compose([ 70 | transforms.Resize(256), # 320 71 | transforms.CenterCrop(224), # 288 72 | transforms.ToTensor(), 73 | normalize, 74 | ])) 75 | if sampler is not None: 76 | ImageNet_test_loader = DataLoader( 77 | ImageNet_test, 78 | shuffle=shuffle, 79 | num_workers=num_workers, 80 | batch_size=batch_size, 81 | pin_memory=True, 82 | sampler=DistributedSampler(ImageNet_test)) 83 | else: 84 | ImageNet_test_loader = DataLoader(ImageNet_test, 85 | shuffle=shuffle, 86 | num_workers=num_workers, 87 | batch_size=batch_size) 88 | 89 | return ImageNet_test_loader 90 | --------------------------------------------------------------------------------