├── IN100.txt ├── LICENSE ├── README.md ├── generate_IN100.py ├── main_IN100.py ├── networks ├── __pycache__ │ ├── resnet.cpython-37.pyc │ └── resnet_big.cpython-37.pyc └── resnet.py └── util.py /IN100.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chun-Hsiao (Daniel) Yeh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageNet-100 (IN100) PyTorch Implementation 2 | 3 | PyTorch Implementation: Training ResNets on ImageNet-100 data 4 | 5 | ## Prepare Datasets (ImageNet) 6 | ImageNet-1K data could be accessed with [ILSVRC 2012](http://www.image-net.org/challenges/LSVRC/2012/). If ImageNet-1K data is available already, jump to the Quick Start section below to generate ImageNet-100. 7 | 8 | ``` 9 | root 10 | ├── data 11 | │ ├── imagenet 12 | │ │ ├── train 13 | │ │ ├── val 14 | │ ├── imagenet-100 (would be generated, no need to specify) 15 | │ │ ├── train 16 | │ │ ├── val 17 | 18 | ``` 19 | 20 | 21 | ## Quick Start 22 | 23 | Generate ImageNet-100 dataset based on [selected class file](https://arxiv.org/pdf/1906.05849.pdf) randomly sampled from ImageNet-1K dataset. Simply run the generate_IN100.py could generate folder of ImageNet-100. 24 | 25 | For example, run the following command to generate ImageNet-100 from ImageNet-1K data. 26 | 27 | arguments: 28 | - `--source_folder`: specify the ImageNet-1K data folder (e.g., `/root/data/imagenet/train`) 29 | - `--target_folder`: specify the ImageNet-100 data folder (e.g., `/root/data/imagenet-100/train`) 30 | - `--target_class`: specify the ImageNet-100 txt file with list of classes [default: 'IN100.txt'] 31 | 32 | ``` 33 | python generate_IN100.py \ 34 | --source_folder /path/to/ImageNet-1K data 35 | --target_folder /path/to/ImageNet-100 data 36 | ``` 37 | 38 | Note: Replace `train` with `val` to generate ImageNet-100 val data as well 39 | 40 | ## Training ResNets on ImageNet-100 41 | 42 | The implementation of training and validation code can be used in main_IN100.py, and run it for the usage. 43 | 44 | ``` 45 | python main_IN100.py --model resnet18 \ 46 | --data_folder /path/to/ImageNet-100 main folder \ 47 | --batch_size 256 \ 48 | --epochs 200 \ 49 | --learning_rate 0.2 \ 50 | --cosine \ 51 | ``` 52 | Note: Please set up the augment: `--data_folder` as main path (e.g., `/root/data/imagenet-100`) to ImageNet-100. 53 | 54 | Cosine annealing schedule is applied as defult in the implementation (remove `--cosine` to switch to step learning schedule). 55 | 56 | ## Results 57 | Experiments on ImageNet-100: 58 | | Arch | Batch Size | Epoch | Loss | kNN Accuracy(%) | 59 | |:----:|:---:|:---:|:---:|:---:| 60 | | ResNet18 | 256 | 200 | Cross Entropy | - | 61 | | ResNet50 | 256 | 200 | Cross Entropy | - | 62 | 63 | ## Citation 64 | 65 | If you use this toolbox in your work, please cite this project. 66 | 67 | ```bibteX 68 | @misc{imagenet100pytorch, 69 | title={{IN100pytorch}: PyTorch Implementation: Training ResNets on ImageNet-100}, 70 | author={Chun-Hsiao Yeh, Yubei Chen}, 71 | howpublished={\url{https://github.com/danielchyeh/ImageNet-100-Pytorch}}, 72 | year={2022} 73 | } 74 | ``` 75 | 76 | ## Acknowledgements 77 | 78 | Part of this code is based on [HobbitLong/SupContrast](https://github.com/HobbitLong/SupContrast). -------------------------------------------------------------------------------- /generate_IN100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | 5 | 6 | def parse_option(): 7 | parser = argparse.ArgumentParser('argument for generating ImageNet-100') 8 | 9 | parser.add_argument('--source_folder', type=str, 10 | default='', help='folder of ImageNet-1K dataset') 11 | parser.add_argument('--target_folder', type=str, 12 | default='', help='folder of ImageNet-100 dataset') 13 | parser.add_argument('--target_class', type=str, 14 | default='IN100.txt', help='class file of ImageNet-100') 15 | 16 | opt = parser.parse_args() 17 | 18 | return opt 19 | 20 | f = [] 21 | def generate_data(source_folder, target_folder, target_class): 22 | 23 | txt_data = open(target_class, "r") 24 | for ids, txt in enumerate(txt_data): 25 | s = str(txt.split('\n')[0]) 26 | f.append(s) 27 | 28 | for ids, dirs in enumerate(os.listdir(source_folder)): 29 | for tg_class in f: 30 | if dirs == tg_class: 31 | print('{} is transferred'.format(dirs)) 32 | shutil.copytree(os.path.join(source_folder,dirs), os.path.join(target_folder,dirs)) 33 | 34 | 35 | opt = parse_option() 36 | generate_data(opt.source_folder, opt.target_folder, opt.target_class) 37 | 38 | -------------------------------------------------------------------------------- /main_IN100.py: -------------------------------------------------------------------------------- 1 | """Implementation of ResNet train and eval on ImageNet-100""" 2 | """The implementation code is partially based on SupContrast GitHub repo""" 3 | 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import argparse 9 | import time 10 | import math 11 | 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | from torchvision import transforms, datasets 15 | 16 | from util import AverageMeter 17 | from util import adjust_learning_rate, warmup_learning_rate, accuracy 18 | from util import set_optimizer, save_model 19 | from networks.resnet import ResNet_Model 20 | 21 | try: 22 | import apex 23 | from apex import amp, optimizers 24 | except ImportError: 25 | pass 26 | 27 | 28 | def parse_option(): 29 | parser = argparse.ArgumentParser('argument for training ResNets on ImageNet-100') 30 | 31 | parser.add_argument('--data_folder', type=str, 32 | default='', help='dataset') 33 | parser.add_argument('--print_freq', type=int, default=10, 34 | help='print frequency') 35 | parser.add_argument('--save_freq', type=int, default=50, 36 | help='save frequency') 37 | parser.add_argument('--batch_size', type=int, default=256, 38 | help='batch_size') 39 | parser.add_argument('--num_workers', type=int, default=4, 40 | help='num of workers to use') 41 | parser.add_argument('--epochs', type=int, default=200, 42 | help='number of training epochs') 43 | 44 | # optimization 45 | parser.add_argument('--learning_rate', type=float, default=0.2, 46 | help='learning rate') 47 | parser.add_argument('--lr_decay_epochs', type=str, default='120,160', 48 | help='where to decay lr, can be a list') 49 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 50 | help='decay rate for learning rate') 51 | parser.add_argument('--weight_decay', type=float, default=1e-4, 52 | help='weight decay') 53 | parser.add_argument('--momentum', type=float, default=0.9, 54 | help='momentum') 55 | 56 | # model dataset 57 | parser.add_argument('--model', type=str, default='resnet50') 58 | parser.add_argument('--dataset', type=str, default='imagenet100', 59 | choices=['imagenet100', 'imagenet'], help='dataset') 60 | 61 | # other setting 62 | parser.add_argument('--cosine', action='store_true', 63 | help='using cosine annealing') 64 | parser.add_argument('--warm', action='store_true', 65 | help='warm-up for large batch training') 66 | 67 | opt = parser.parse_args() 68 | 69 | # set the path according to the environment 70 | opt.traindir = os.path.join(opt.data_folder, 'train') 71 | opt.valdir = os.path.join(opt.data_folder, 'val') 72 | 73 | 74 | opt.model_path = './save/{}_models'.format(opt.dataset) 75 | 76 | iterations = opt.lr_decay_epochs.split(',') 77 | opt.lr_decay_epochs = list([]) 78 | for it in iterations: 79 | opt.lr_decay_epochs.append(int(it)) 80 | 81 | opt.model_name = 'CE_{}_{}_lr_{}_decay_{}_bsz_{}'.\ 82 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 83 | opt.batch_size) 84 | 85 | if opt.cosine: 86 | opt.model_name = '{}_cosine'.format(opt.model_name) 87 | 88 | # warm-up for large-batch training, 89 | if opt.batch_size > 256: 90 | opt.warm = True 91 | if opt.warm: 92 | opt.model_name = '{}_warm'.format(opt.model_name) 93 | opt.warmup_from = 0.01 94 | opt.warm_epochs = 10 95 | if opt.cosine: 96 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 97 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 98 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 99 | else: 100 | opt.warmup_to = opt.learning_rate 101 | 102 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 103 | if not os.path.isdir(opt.save_folder): 104 | os.makedirs(opt.save_folder) 105 | 106 | opt.n_cls = 100 107 | 108 | 109 | return opt 110 | 111 | 112 | def set_loader(opt): 113 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 114 | std=[0.229, 0.224, 0.225]) 115 | train_dataset = datasets.ImageFolder( 116 | opt.traindir, 117 | transform=transforms.Compose([ 118 | transforms.RandomResizedCrop(224), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.ToTensor(), 121 | normalize, 122 | ])) 123 | val_dataset = datasets.ImageFolder( 124 | opt.valdir, 125 | transform=transforms.Compose([ 126 | transforms.Resize(256), 127 | transforms.CenterCrop(224), 128 | transforms.ToTensor(), 129 | normalize, 130 | ])) 131 | 132 | train_loader = torch.utils.data.DataLoader( 133 | train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) 134 | val_loader = torch.utils.data.DataLoader( 135 | val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True) 136 | 137 | return train_loader, val_loader 138 | 139 | 140 | def set_model(opt): 141 | model = ResNet_Model(name=opt.model, num_classes=opt.n_cls) 142 | criterion = torch.nn.CrossEntropyLoss() 143 | 144 | if torch.cuda.is_available(): 145 | if torch.cuda.device_count() > 1: 146 | model = torch.nn.DataParallel(model) 147 | model = model.cuda() 148 | criterion = criterion.cuda() 149 | cudnn.benchmark = True 150 | 151 | return model, criterion 152 | 153 | 154 | def train(train_loader, model, criterion, optimizer, epoch, opt): 155 | """one epoch training""" 156 | model.train() 157 | 158 | batch_time = AverageMeter() 159 | data_time = AverageMeter() 160 | losses = AverageMeter() 161 | top1 = AverageMeter() 162 | 163 | end = time.time() 164 | for idx, (images, labels) in enumerate(train_loader): 165 | data_time.update(time.time() - end) 166 | 167 | images = images.cuda(non_blocking=True) 168 | labels = labels.cuda(non_blocking=True) 169 | bsz = labels.shape[0] 170 | 171 | # warm-up learning rate 172 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 173 | 174 | # compute loss 175 | output = model(images) 176 | loss = criterion(output, labels) 177 | 178 | # update metric 179 | losses.update(loss.item(), bsz) 180 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 181 | top1.update(acc1[0], bsz) 182 | 183 | # SGD 184 | optimizer.zero_grad() 185 | loss.backward() 186 | optimizer.step() 187 | 188 | # measure elapsed time 189 | batch_time.update(time.time() - end) 190 | end = time.time() 191 | 192 | # print info 193 | if (idx + 1) % opt.print_freq == 0: 194 | print('Train: [{0}][{1}/{2}]\t' 195 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 196 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 197 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 198 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 199 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 200 | data_time=data_time, loss=losses, top1=top1)) 201 | sys.stdout.flush() 202 | 203 | return losses.avg, top1.avg 204 | 205 | 206 | def validate(val_loader, model, criterion, opt): 207 | """validation""" 208 | model.eval() 209 | 210 | batch_time = AverageMeter() 211 | losses = AverageMeter() 212 | top1 = AverageMeter() 213 | 214 | with torch.no_grad(): 215 | end = time.time() 216 | for idx, (images, labels) in enumerate(val_loader): 217 | images = images.float().cuda() 218 | labels = labels.cuda() 219 | bsz = labels.shape[0] 220 | 221 | # forward 222 | output = model(images) 223 | loss = criterion(output, labels) 224 | 225 | # update metric 226 | losses.update(loss.item(), bsz) 227 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 228 | top1.update(acc1[0], bsz) 229 | 230 | # measure elapsed time 231 | batch_time.update(time.time() - end) 232 | end = time.time() 233 | 234 | if idx % opt.print_freq == 0: 235 | print('Test: [{0}/{1}]\t' 236 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 237 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 238 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 239 | idx, len(val_loader), batch_time=batch_time, 240 | loss=losses, top1=top1)) 241 | 242 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) 243 | return losses.avg, top1.avg 244 | 245 | 246 | def main(): 247 | best_acc = 0 248 | opt = parse_option() 249 | 250 | # build data loader 251 | train_loader, val_loader = set_loader(opt) 252 | # build model and criterion 253 | model, criterion = set_model(opt) 254 | # build optimizer 255 | optimizer = set_optimizer(opt, model) 256 | 257 | # training routine 258 | for epoch in range(1, opt.epochs + 1): 259 | adjust_learning_rate(opt, optimizer, epoch) 260 | 261 | # train for one epoch 262 | time1 = time.time() 263 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) 264 | time2 = time.time() 265 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 266 | 267 | # evaluation 268 | loss, val_acc = validate(val_loader, model, criterion, opt) 269 | 270 | if val_acc > best_acc: 271 | best_acc = val_acc 272 | 273 | if epoch % opt.save_freq == 0: 274 | save_file = os.path.join( 275 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 276 | save_model(model, optimizer, opt, epoch, save_file) 277 | 278 | # save the last model 279 | save_file = os.path.join( 280 | opt.save_folder, 'last.pth') 281 | save_model(model, optimizer, opt, opt.epochs, save_file) 282 | 283 | print('best accuracy: {:.2f}'.format(best_acc)) 284 | 285 | 286 | if __name__ == '__main__': 287 | main() 288 | -------------------------------------------------------------------------------- /networks/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielchyeh/ImageNet-100-Pytorch/5ae5a42f74a23e8107aa060067fc355f16007d91/networks/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet_big.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielchyeh/ImageNet-100-Pytorch/5ae5a42f74a23e8107aa060067fc355f16007d91/networks/__pycache__/resnet_big.cpython-37.pyc -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | Adapted from: https://github.com/bearpaw/pytorch-classification 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 81 | bias=False) 82 | self.bn1 = nn.BatchNorm2d(64) 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves 98 | # like an identity. This improves the model by 0.2~0.3% according to: 99 | # https://arxiv.org/abs/1706.02677 100 | if zero_init_residual: 101 | for m in self.modules(): 102 | if isinstance(m, Bottleneck): 103 | nn.init.constant_(m.bn3.weight, 0) 104 | elif isinstance(m, BasicBlock): 105 | nn.init.constant_(m.bn2.weight, 0) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | layers = [] 110 | for i in range(num_blocks): 111 | stride = strides[i] 112 | layers.append(block(self.in_planes, planes, stride)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, layer=100): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = self.avgpool(out) 123 | out = torch.flatten(out, 1) 124 | return out 125 | 126 | 127 | def resnet18(**kwargs): 128 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 129 | 130 | 131 | def resnet34(**kwargs): 132 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 133 | 134 | 135 | def resnet50(**kwargs): 136 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | 138 | 139 | def resnet101(**kwargs): 140 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 141 | 142 | 143 | model_dict = { 144 | 'resnet18': [resnet18, 512], 145 | 'resnet34': [resnet34, 512], 146 | 'resnet50': [resnet50, 2048], 147 | 'resnet101': [resnet101, 2048], 148 | } 149 | 150 | 151 | class LinearBatchNorm(nn.Module): 152 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 153 | def __init__(self, dim, affine=True): 154 | super(LinearBatchNorm, self).__init__() 155 | self.dim = dim 156 | self.bn = nn.BatchNorm2d(dim, affine=affine) 157 | 158 | def forward(self, x): 159 | x = x.view(-1, self.dim, 1, 1) 160 | x = self.bn(x) 161 | x = x.view(-1, self.dim) 162 | return x 163 | 164 | 165 | 166 | 167 | class ResNet_Model(nn.Module): 168 | """encoder + classifier""" 169 | def __init__(self, name='resnet50', num_classes=10): 170 | super(ResNet_Model, self).__init__() 171 | model_fun, dim_in = model_dict[name] 172 | self.encoder = model_fun() 173 | self.fc = nn.Linear(dim_in, num_classes) 174 | 175 | def forward(self, x): 176 | return self.fc(self.encoder(x)) 177 | 178 | 179 | class LinearClassifier(nn.Module): 180 | """Linear classifier""" 181 | def __init__(self, name='resnet50', num_classes=10): 182 | super(LinearClassifier, self).__init__() 183 | _, feat_dim = model_dict[name] 184 | self.fc = nn.Linear(feat_dim, num_classes) 185 | 186 | def forward(self, features): 187 | return self.fc(features) 188 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def accuracy(output, target, topk=(1,)): 28 | """Computes the accuracy over the k top predictions for the specified values of k""" 29 | with torch.no_grad(): 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 40 | res.append(correct_k.mul_(100.0 / batch_size)) 41 | return res 42 | 43 | 44 | def adjust_learning_rate(args, optimizer, epoch): 45 | lr = args.learning_rate 46 | if args.cosine: 47 | eta_min = lr * (args.lr_decay_rate ** 3) 48 | lr = eta_min + (lr - eta_min) * ( 49 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 50 | else: 51 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 52 | if steps > 0: 53 | lr = lr * (args.lr_decay_rate ** steps) 54 | 55 | for param_group in optimizer.param_groups: 56 | param_group['lr'] = lr 57 | 58 | 59 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 60 | if args.warm and epoch <= args.warm_epochs: 61 | p = (batch_id + (epoch - 1) * total_batches) / \ 62 | (args.warm_epochs * total_batches) 63 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 64 | 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | 69 | def set_optimizer(opt, model): 70 | optimizer = optim.SGD(model.parameters(), 71 | lr=opt.learning_rate, 72 | momentum=opt.momentum, 73 | weight_decay=opt.weight_decay) 74 | return optimizer 75 | 76 | 77 | def save_model(model, optimizer, opt, epoch, save_file): 78 | print('==> Saving...') 79 | state = { 80 | 'opt': opt, 81 | 'model': model.state_dict(), 82 | 'optimizer': optimizer.state_dict(), 83 | 'epoch': epoch, 84 | } 85 | torch.save(state, save_file) 86 | del state 87 | --------------------------------------------------------------------------------