├── .idea ├── .gitignore ├── PyTorch_ImageNet_experiments.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── webServers.xml ├── README.md ├── clonal_resnet18_from_scratch.log ├── clonal_resnet34_from_scratch.log ├── clonalnet_main.py ├── distill_loss ├── KD.py ├── __init__.py └── fpLoss.py ├── main.py ├── models ├── __init__.py ├── mobilenet.py └── resnet.py └── requirements.txt /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/PyTorch_ImageNet_experiments.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 29 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 32 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 9 | 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FocusNet 2 | The implementation of our Pattern Recognition 2022 paper: "FocusNet: Classifying better by focusing on confusing classes" 3 | 4 | Paper: https://www.sciencedirect.com/science/article/abs/pii/S003132032200190X?via%3Dihub 5 | ## Note: 6 | - This repository mainly relies on "[ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet)". Therefore, it is helpful for you to refer to its document. 7 | - The first version of our architecture was named ClonalNet, and after the second revision we changed its name to FocusNet. Therefore, **the following clonalnet is just focusnet**. 8 | # ImageNet training in PyTorch 9 | 10 | This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. 11 | 12 | ## Requirements 13 | 14 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 15 | - `pip install -r requirements.txt` 16 | - Note: the `requirements.txt` in this repository is not the same as the official requirements. If something goes wrong, please use the official requirements. 17 | - Download the ImageNet dataset from http://www.image-net.org/ 18 | - Then, move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) 19 | ## Training 20 | 21 | To train our network, run `clonalnet_main.py` with the desired model architecture and the path to the ImageNet dataset: 22 | 23 | ```bash 24 | python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc 25 | resnet34 26 | mobilenet_v2 27 | ``` 28 | 29 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. 30 | 31 | ## Validation 32 | 33 | To evaluate our network, run `clonalnet_main.py` with the desired model architecture and the path to the ImageNet dataset: 34 | 35 | ```bash 36 | python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc -e --resume clonalnet_resnet18_model_best.pth.tar 37 | resnet34 clonalnet_resnet34_model_best.pth.tar 38 | mobilenet_v2 clonalnet_mobilenet_v2_model_best.pth.tar 39 | 40 | ``` 41 | 42 | ## Logs 43 | The `clonal_resnet18_from_scratch.log` and the `clonal_resnet34_from_scratch.log` are the training logs of the clonalnet_resnet18 and the clonalnet_resnet34. 44 | 45 | ## Baseline 46 | To validate the baseline results, please run: 47 | ```bash 48 | # resnet18 / resnet34 49 | python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a resnet18 --seed 10 -e --pretrained --gpu 0 50 | resnet34 51 | # mobilenet_V2 52 | python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a mobilenet_v2 --seed 10 -e --pretrained --gpu 0 --resume models/_pytorch_pretrained_checkpoints/baseline_mobilenet_v2_model_best.pth.tar 53 | 54 | ``` 55 | ## Results on ILSVRC2012 56 | |Models|Acc@1|Acc@5|Checkpoint| 57 | |------|-----|-----|-----| 58 | |ResNet18|69.760|89.082|[PyTorch Pre-trained](https://pytorch.org/vision/stable/models.html)| 59 | |ClonalNet (r18)|70.422|89.562|[Baidu](https://pan.baidu.com/s/17GAra665g3Y9Uf9l_XIffg), code:1234; [Google Driver](https://drive.google.com/file/d/1VuYREp2tWDyamjzphMeb0pGMIlVTN4Se/view?usp=sharing)| 60 | |ResNet34|73.310|91.420|[PyTorch Pre-trained](https://pytorch.org/vision/stable/models.html)| 61 | |ClonalNet (r34)|74.366|91.884|[Baidu](https://pan.baidu.com/s/1E-MocRLYlFUxc93_E-Ndtw), code:1234; [Google Driver](https://drive.google.com/file/d/1NfnyQMP0dy3eYNuaIfs56fFj8nG4_L9f/view?usp=sharing)| 62 | |MobileNet_v2|65.558|86.744|[Baidu](https://pan.baidu.com/s/11f5wxVbuDtKQ2WguIPDtbw), code:1234; [Google Driver](https://drive.google.com/file/d/1EecCV14dXD9yzFNfgbcTBw_CDPLZQi6i/view?usp=sharing)| 63 | |ClonalNet (MobileNet_v2)|66.300|87.232|[Baidu](https://pan.baidu.com/s/16aAsj3-RKIoL-k4Bydt14w); [Google Driver](https://drive.google.com/file/d/1nDfBea0GSQ4Fj8cdleRhwocJw8oO2T60/view?usp=sharing)| 64 | 65 | you can also download more checkpoints at here: [Baidu](https://pan.baidu.com/s/1BPcyHRWokKcfpGTAiuVoug), code: 1234; [Google Driver](https://drive.google.com/drive/folders/18KBAvXccSPZDAZOjVKwLqZ9ZKGqL4RMf?usp=sharing). 66 | 67 | ## Reference 68 | If you find our work is helpful to you, please cite it: 69 | ```bash 70 | @article{zhang2022focusnet, 71 | title={FocusNet: Classifying better by focusing on confusing classes}, 72 | author={Zhang, Xue and Sheng, Zehua and Shen, Hui-Liang}, 73 | journal={Pattern Recognition}, 74 | pages={108709}, 75 | year={2022}, 76 | publisher={Elsevier} 77 | } 78 | ``` -------------------------------------------------------------------------------- /clonalnet_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | from enum import Enum 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import models 21 | from distill_loss import fpLoss 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 28 | parser.add_argument('--paradigm', default='clonalnet', type=str) 29 | parser.add_argument('--data', metavar='DIR', 30 | help='path to dataset') 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenet_v2', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: resnet18)') 36 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 37 | help='number of data loading workers (default: 4)') 38 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 39 | help='number of total epochs to run') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 48 | metavar='LR', help='initial learning rate', dest='lr') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)', 53 | dest='weight_decay') 54 | parser.add_argument('-p', '--print-freq', default=10, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 59 | help='evaluate model on validation set') 60 | parser.add_argument('-ebc', '--evaluate_baseline_and_clonalnet_before_training_clonalnet', action='store_true', 61 | help='evaluate baseline model and clonalnet on validation set') 62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 63 | help='use pre-trained model') 64 | parser.add_argument('--world-size', default=-1, type=int, 65 | help='number of nodes for distributed training') 66 | parser.add_argument('--rank', default=-1, type=int, 67 | help='node rank for distributed training') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='nccl', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--seed', default=None, type=int, 73 | help='seed for initializing training. ') 74 | parser.add_argument('--gpu', default=None, type=int, 75 | help='GPU id to use.') 76 | parser.add_argument('--multiprocessing-distributed', action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | 82 | best_acc1 = 0 83 | 84 | 85 | def main(): 86 | args = parser.parse_args( 87 | # [training 88 | #'--data', '/UsrFile/yjc/xzq/ssddata/zx/ILSVRC2012', 89 | #'-a', 'resnet18','--seed', '42', 90 | #'--gpu', '1', '-ebc', 91 | 92 | ## validation 93 | ## '-e', '--resume', 'clonalnet_resnet18_model_best.pth.tar'] 94 | ) 95 | 96 | for _argsk, _argsv in args._get_kwargs(): 97 | print('--{} {}'.format(_argsk, _argsv)) 98 | 99 | if args.seed is not None: 100 | random.seed(args.seed) 101 | torch.manual_seed(args.seed) 102 | cudnn.deterministic = True 103 | warnings.warn('You have chosen to seed training. ' 104 | 'This will turn on the CUDNN deterministic setting, ' 105 | 'which can slow down your training considerably! ' 106 | 'You may see unexpected behavior when restarting ' 107 | 'from checkpoints.') 108 | 109 | if args.gpu is not None: 110 | warnings.warn('You have chosen a specific GPU. This will completely ' 111 | 'disable data parallelism.') 112 | 113 | if args.dist_url == "env://" and args.world_size == -1: 114 | args.world_size = int(os.environ["WORLD_SIZE"]) 115 | 116 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 117 | 118 | ngpus_per_node = torch.cuda.device_count() 119 | if args.multiprocessing_distributed: 120 | # Since we have ngpus_per_node processes per node, the total world_size 121 | # needs to be adjusted accordingly 122 | args.world_size = ngpus_per_node * args.world_size 123 | # Use torch.multiprocessing.spawn to launch distributed processes: the 124 | # main_worker process function 125 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 126 | else: 127 | # Simply call main_worker function 128 | main_worker(args.gpu, ngpus_per_node, args) 129 | 130 | 131 | def main_worker(gpu, ngpus_per_node, args): 132 | global best_acc1 133 | args.gpu = gpu 134 | 135 | if args.gpu is not None: 136 | print("Use GPU: {} for training".format(args.gpu)) 137 | 138 | if args.distributed: 139 | if args.dist_url == "env://" and args.rank == -1: 140 | args.rank = int(os.environ["RANK"]) 141 | if args.multiprocessing_distributed: 142 | # For multiprocessing distributed training, rank needs to be the 143 | # global rank among all the processes 144 | args.rank = args.rank * ngpus_per_node + gpu 145 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 146 | world_size=args.world_size, rank=args.rank) 147 | # create model 148 | if args.pretrained: 149 | print("=> using pre-trained model '{}'".format(args.arch)) 150 | model = models.__dict__[args.arch](pretrained=True) 151 | else: 152 | print("=> Baseline: using pre-trained model '{}'".format(args.arch)) 153 | base_model = models.__dict__[args.arch](pretrained=True) 154 | base_model.eval() 155 | 156 | print("=> ClonalNet: creating model '{}'".format(args.arch)) 157 | model = models.__dict__[args.arch](pretrained=False) 158 | 159 | if not torch.cuda.is_available(): 160 | print('using CPU, this will be slow') 161 | elif args.distributed: 162 | # For multiprocessing distributed, DistributedDataParallel constructor 163 | # should always set the single device scope, otherwise, 164 | # DistributedDataParallel will use all available devices. 165 | if args.gpu is not None: 166 | torch.cuda.set_device(args.gpu) 167 | model.cuda(args.gpu) 168 | # When using a single GPU per process and per 169 | # DistributedDataParallel, we need to divide the batch size 170 | # ourselves based on the total number of GPUs we have 171 | args.batch_size = int(args.batch_size / ngpus_per_node) 172 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 173 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 174 | else: 175 | model.cuda() 176 | # DistributedDataParallel will divide and allocate batch_size to all 177 | # available GPUs if device_ids are not set 178 | model = torch.nn.parallel.DistributedDataParallel(model) 179 | elif args.gpu is not None: 180 | torch.cuda.set_device(args.gpu) 181 | base_model = base_model.cuda(args.gpu) 182 | model = model.cuda(args.gpu) 183 | else: 184 | # DataParallel will divide and allocate batch_size to all available GPUs 185 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 186 | model.features = torch.nn.DataParallel(model.features) 187 | model.cuda() 188 | else: 189 | model = torch.nn.DataParallel(model).cuda() 190 | 191 | # define loss function (criterion) and optimizer 192 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu) 193 | criterion = fpLoss().cuda(args.gpu) 194 | 195 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 196 | momentum=args.momentum, 197 | weight_decay=args.weight_decay) 198 | 199 | # optionally resume from a checkpoint 200 | if args.resume: 201 | if os.path.isfile(args.resume): 202 | print("=> loading checkpoint '{}'".format(args.resume)) 203 | if args.gpu is None: 204 | checkpoint = torch.load(args.resume) 205 | else: 206 | # Map model to be loaded to specified single gpu. 207 | loc = 'cuda:{}'.format(args.gpu) 208 | checkpoint = torch.load(args.resume, map_location=loc) 209 | args.start_epoch = checkpoint['epoch'] 210 | best_acc1 = checkpoint['best_acc1'] 211 | if args.gpu is not None: 212 | # best_acc1 may be from a checkpoint from a different GPU 213 | best_acc1 = best_acc1.to(args.gpu) 214 | model.load_state_dict(checkpoint['state_dict']) 215 | optimizer.load_state_dict(checkpoint['optimizer']) 216 | print("=> loaded checkpoint '{}' (epoch {})" 217 | .format(args.resume, checkpoint['epoch'])) 218 | else: 219 | print("=> no checkpoint found at '{}'".format(args.resume)) 220 | 221 | cudnn.benchmark = True 222 | 223 | # Data loading code 224 | traindir = os.path.join(args.data, 'train') 225 | valdir = os.path.join(args.data, 'valid') 226 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 227 | std=[0.229, 0.224, 0.225]) 228 | 229 | train_dataset = datasets.ImageFolder( 230 | traindir, 231 | transforms.Compose([ 232 | transforms.RandomResizedCrop(224), 233 | transforms.RandomHorizontalFlip(), 234 | transforms.ToTensor(), 235 | normalize, 236 | ])) 237 | 238 | if args.distributed: 239 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 240 | else: 241 | train_sampler = None 242 | 243 | train_loader = torch.utils.data.DataLoader( 244 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 245 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 246 | 247 | val_loader = torch.utils.data.DataLoader( 248 | datasets.ImageFolder(valdir, transforms.Compose([ 249 | transforms.Resize(256), 250 | transforms.CenterCrop(224), 251 | transforms.ToTensor(), 252 | normalize, 253 | ])), 254 | batch_size=args.batch_size, shuffle=False, 255 | num_workers=args.workers, pin_memory=True) 256 | 257 | if args.evaluate: 258 | validate(val_loader, model, criterion_ce, args) 259 | return 260 | 261 | if args.evaluate_baseline_and_clonalnet_before_training_clonalnet: 262 | print("=> Baseline: evaluating pre-trained model '{}'".format(args.arch)) 263 | validate(val_loader, base_model, criterion_ce, args) 264 | print('-'*100) 265 | print("=> ClonalNet: evaluating random model '{}'".format(args.arch)) 266 | validate(val_loader, model, criterion_ce, args) 267 | 268 | for epoch in range(args.start_epoch, args.epochs): 269 | if args.distributed: 270 | train_sampler.set_epoch(epoch) 271 | adjust_learning_rate(optimizer, epoch, args) 272 | 273 | # train for one epoch 274 | train(train_loader, base_model, model, criterion, optimizer, epoch, args) 275 | 276 | # evaluate on validation set 277 | acc1 = validate(val_loader, model, criterion_ce, args) 278 | 279 | # remember best acc@1 and save checkpoint 280 | is_best = acc1 > best_acc1 281 | best_acc1 = max(acc1, best_acc1) 282 | 283 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 284 | and args.rank % ngpus_per_node == 0): 285 | save_checkpoint({ 286 | 'epoch': epoch + 1, 287 | 'arch': args.arch, 288 | 'state_dict': model.state_dict(), 289 | 'best_acc1': best_acc1, 290 | 'optimizer': optimizer.state_dict(), 291 | }, is_best, 292 | '{}_{}_checkpoint.pth.tar'.format(args.paradigm, args.arch)) 293 | 294 | 295 | def train(train_loader, base_model, model, criterion, optimizer, epoch, args): 296 | batch_time = AverageMeter('Time', ':6.3f') 297 | data_time = AverageMeter('Data', ':6.3f') 298 | losses = AverageMeter('Loss', ':.4e') 299 | top1 = AverageMeter('Acc@1', ':6.2f') 300 | top5 = AverageMeter('Acc@5', ':6.2f') 301 | progress = ProgressMeter( 302 | len(train_loader), 303 | [batch_time, data_time, losses, top1, top5], 304 | prefix="Epoch: [{}]".format(epoch)) 305 | 306 | # switch to train mode 307 | base_model.eval() 308 | model.train() 309 | 310 | end = time.time() 311 | for i, (images, target) in enumerate(train_loader): 312 | # measure data loading time 313 | data_time.update(time.time() - end) 314 | 315 | if args.gpu is not None: 316 | images = images.cuda(args.gpu, non_blocking=True) 317 | if torch.cuda.is_available(): 318 | target = target.cuda(args.gpu, non_blocking=True) 319 | 320 | # compute output 321 | with torch.no_grad(): 322 | bi_output = base_model(images) 323 | output = model(images) 324 | loss = criterion(target, output, bi_output) 325 | 326 | # measure accuracy and record loss 327 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 328 | losses.update(loss.item(), images.size(0)) 329 | top1.update(acc1[0], images.size(0)) 330 | top5.update(acc5[0], images.size(0)) 331 | 332 | # compute gradient and do SGD step 333 | optimizer.zero_grad() 334 | loss.backward() 335 | optimizer.step() 336 | 337 | # measure elapsed time 338 | batch_time.update(time.time() - end) 339 | end = time.time() 340 | 341 | if i % args.print_freq == 0: 342 | progress.display(i) 343 | 344 | 345 | def validate(val_loader, model, criterion, args): 346 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 347 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 348 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 349 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 350 | progress = ProgressMeter( 351 | len(val_loader), 352 | [batch_time, losses, top1, top5], 353 | prefix='Test: ') 354 | 355 | # switch to evaluate mode 356 | model.eval() 357 | 358 | with torch.no_grad(): 359 | end = time.time() 360 | for i, (images, target) in enumerate(val_loader): 361 | if args.gpu is not None: 362 | images = images.cuda(args.gpu, non_blocking=True) 363 | if torch.cuda.is_available(): 364 | target = target.cuda(args.gpu, non_blocking=True) 365 | 366 | # compute output 367 | output = model(images) 368 | loss = criterion(output, target) 369 | 370 | # measure accuracy and record loss 371 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 372 | losses.update(loss.item(), images.size(0)) 373 | top1.update(acc1[0], images.size(0)) 374 | top5.update(acc5[0], images.size(0)) 375 | 376 | # measure elapsed time 377 | batch_time.update(time.time() - end) 378 | end = time.time() 379 | 380 | if i % args.print_freq == 0: 381 | progress.display(i) 382 | 383 | progress.display_summary() 384 | 385 | return top1.avg 386 | 387 | 388 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 389 | torch.save(state, filename) 390 | if is_best: 391 | shutil.copyfile(filename, filename.split('checkpoint.pth.tar')[0]+'model_best.pth.tar') 392 | 393 | 394 | class Summary(Enum): 395 | NONE = 0 396 | AVERAGE = 1 397 | SUM = 2 398 | COUNT = 3 399 | 400 | 401 | class AverageMeter(object): 402 | """Computes and stores the average and current value""" 403 | 404 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 405 | self.name = name 406 | self.fmt = fmt 407 | self.summary_type = summary_type 408 | self.reset() 409 | 410 | def reset(self): 411 | self.val = 0 412 | self.avg = 0 413 | self.sum = 0 414 | self.count = 0 415 | 416 | def update(self, val, n=1): 417 | self.val = val 418 | self.sum += val * n 419 | self.count += n 420 | self.avg = self.sum / self.count 421 | 422 | def __str__(self): 423 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 424 | return fmtstr.format(**self.__dict__) 425 | 426 | def summary(self): 427 | fmtstr = '' 428 | if self.summary_type is Summary.NONE: 429 | fmtstr = '' 430 | elif self.summary_type is Summary.AVERAGE: 431 | fmtstr = '{name} {avg:.3f}' 432 | elif self.summary_type is Summary.SUM: 433 | fmtstr = '{name} {sum:.3f}' 434 | elif self.summary_type is Summary.COUNT: 435 | fmtstr = '{name} {count:.3f}' 436 | else: 437 | raise ValueError('invalid summary type %r' % self.summary_type) 438 | 439 | return fmtstr.format(**self.__dict__) 440 | 441 | 442 | class ProgressMeter(object): 443 | def __init__(self, num_batches, meters, prefix=""): 444 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 445 | self.meters = meters 446 | self.prefix = prefix 447 | 448 | def display(self, batch): 449 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 450 | entries += [str(meter) for meter in self.meters] 451 | print('\t'.join(entries)) 452 | 453 | def display_summary(self): 454 | entries = [" *"] 455 | entries += [meter.summary() for meter in self.meters] 456 | print(' '.join(entries)) 457 | 458 | def _get_batch_fmtstr(self, num_batches): 459 | num_digits = len(str(num_batches // 1)) 460 | fmt = '{:' + str(num_digits) + 'd}' 461 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 462 | 463 | 464 | def adjust_learning_rate(optimizer, epoch, args): 465 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 466 | lr = args.lr * (0.1 ** (epoch // 30)) 467 | for param_group in optimizer.param_groups: 468 | param_group['lr'] = lr 469 | 470 | 471 | def accuracy(output, target, topk=(1,)): 472 | """Computes the accuracy over the k top predictions for the specified values of k""" 473 | with torch.no_grad(): 474 | maxk = max(topk) 475 | batch_size = target.size(0) 476 | 477 | _, pred = output.topk(maxk, 1, True, True) 478 | pred = pred.t() 479 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 480 | 481 | res = [] 482 | for k in topk: 483 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 484 | res.append(correct_k.mul_(100.0 / batch_size)) 485 | return res 486 | 487 | 488 | if __name__ == '__main__': 489 | main() -------------------------------------------------------------------------------- /distill_loss/KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DistillKL(nn.Module): 7 | def __init__(self, args): 8 | super(DistillKL, self).__init__() 9 | self.T = args.temperature 10 | 11 | def forward(self, y_s, y_t): 12 | p_s = F.log_softmax(y_s/self.T, dim=1) 13 | p_t = F.softmax(y_t/self.T, dim=1) 14 | loss = F.kl_div(p_s, p_t.detach(), reduction='sum') * (self.T**2) / y_s.shape[0] 15 | return loss 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /distill_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .KD import * 2 | from .fpLoss import * 3 | -------------------------------------------------------------------------------- /distill_loss/fpLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | eps = 1e-10 5 | 6 | class fpLoss(nn.Module): 7 | def __init__(self, ): 8 | super(fpLoss, self).__init__() 9 | 10 | def cross_entropy(self, logits, onehot_labels, ls=False): 11 | if ls: 12 | onehot_labels = 0.9 * onehot_labels + 0.1 / logits.size(-1) 13 | onehot_labels = onehot_labels.double() 14 | return (-1.0 * torch.mean(torch.sum(onehot_labels * F.log_softmax(logits, -1), -1), 0)) 15 | 16 | 17 | def neg_entropy(self, logits): 18 | probs = F.softmax(logits, -1) 19 | return torch.mean(torch.sum(probs * F.log_softmax(logits, -1), -1), 0) 20 | 21 | def forward(self, targets, outputs, bi_outputs,): 22 | # Loss_cls 23 | difference = F.softmax(outputs, -1) - F.softmax(bi_outputs, -1) # 与FRSKD比较时,使用了detach() 24 | onehot_labels = F.one_hot(targets, outputs.size(-1)) 25 | loss_cls = self.cross_entropy(outputs + difference, onehot_labels, True) 26 | # tiny-imagenet 上, alpha=1, beta=1, ls=True, best test acc: 0.5870 27 | # tiny-imagenet 上, alpha=1, beta=1, ls=False, best test acc: 0.5840 28 | # 所以ls不是主要原因 29 | 30 | # R_attention 31 | 32 | # multi_warm_lb = bi_outputs.detach() > 0.0 33 | '''因为推导发现,使用multi-warm label的交叉熵梯度为 hat{y(x)} - m(x), 34 | 其中hat{y(x)}是clonalnet预测的概率分布,m(x)表示的是multi-warm label,其中非零值为1/len(m(x)!=0) 35 | 对比正常交叉熵的损失值是 hat{y(x)} - y(x) 其中y(x)为one-hot label, 36 | 所以正确位置的梯度为负值,不正确位置的梯度为正值,也就实现了正确位置预测变大,不正确位置预测变小,也就使得预测的概率更加接近于one-hot label 37 | 但是发现, 38 | 使用multi-warm label的交叉熵会使得 hat{y(x)} 大于 m(x) 的梯度为正,小于 m(x) 的梯度为负值,这意味着, 39 | 预测的概率会趋向于m(x)的分布,所以,应该使得m(x)中的非零值尽量少一些,这样只关注几个很混淆的类就可以了,这可以使非零值更大一些 40 | 如果m(x)的非零值太小,就损害了自信的预测了 41 | 所以将multi-warm label做了调整 42 | ''' 43 | # multi_warm_lb = bi_outputs > 0.0 44 | multi_warm_lb = F.softmax(bi_outputs/2, -1) > 1.0/bi_outputs.size(-1) 45 | multi_warm_lb = torch.clamp(multi_warm_lb.double() + onehot_labels, 0, 1) 46 | multi_warm_lb = multi_warm_lb/torch.sum(multi_warm_lb, -1, True) 47 | R_attention = self.cross_entropy(outputs, multi_warm_lb.detach(), False)# 与FRSKD比较时,使用了detach() 48 | 49 | # R_entropy 50 | R_negtropy = self.neg_entropy(outputs) 51 | 52 | fp_loss = loss_cls + R_attention + R_negtropy 53 | 54 | # test for CE + neg_entropy 55 | # loss_cls = self.cross_entropy(outputs, onehot_labels) 56 | # fp_loss = loss_cls + R_negtropy # 已经试验证明 CE + negtive_entropy的CUB200精度(59.10%)低于loss_cls + negtive_entropy的精度(60.72%) 57 | return fp_loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | from enum import Enum 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | import models 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--paradigm', default='baseline', type=str) 28 | parser.add_argument('--data', metavar='DIR', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenet_v2', 31 | choices=model_names, 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet18)') 35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=256, type=int, 42 | metavar='N', 43 | help='mini-batch size (default: 256), this is the total ' 44 | 'batch size of all GPUs on the current node when ' 45 | 'using Data Parallel or Distributed Data Parallel') 46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 47 | metavar='LR', help='initial learning rate', dest='lr') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)', 52 | dest='weight_decay') 53 | parser.add_argument('-p', '--print-freq', default=10, type=int, 54 | metavar='N', help='print frequency (default: 10)') 55 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 56 | help='path to latest checkpoint (default: none)') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 60 | help='use pre-trained model') 61 | parser.add_argument('--world-size', default=-1, type=int, 62 | help='number of nodes for distributed training') 63 | parser.add_argument('--rank', default=-1, type=int, 64 | help='node rank for distributed training') 65 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 66 | help='url used to set up distributed training') 67 | parser.add_argument('--dist-backend', default='nccl', type=str, 68 | help='distributed backend') 69 | parser.add_argument('--seed', default=None, type=int, 70 | help='seed for initializing training. ') 71 | parser.add_argument('--gpu', default=None, type=int, 72 | help='GPU id to use.') 73 | parser.add_argument('--multiprocessing-distributed', action='store_true', 74 | help='Use multi-processing distributed training to launch ' 75 | 'N processes per node, which has N GPUs. This is the ' 76 | 'fastest way to use PyTorch for either single node or ' 77 | 'multi node data parallel training') 78 | 79 | best_acc1 = 0 80 | 81 | 82 | def main(): 83 | args = parser.parse_args( 84 | # ['--paradigm', 'baseline', 85 | # '--data', '/UsrFile/yjc/xzq/ssddata/zx/ILSVRC2012', 86 | # '-a', 'resnet18', '--seed', '10', '-e', --pretrained, 87 | # '--resume', './baseline_mobilenet_v2_model_best.pth.tar', 88 | # '--gpu', '1'] 89 | ) 90 | 91 | for _argsk, _argsv in args._get_kwargs(): 92 | print('--{} {}'.format(_argsk, _argsv)) 93 | 94 | if args.seed is not None: 95 | random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | cudnn.deterministic = True 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | ngpus_per_node = torch.cuda.device_count() 114 | if args.multiprocessing_distributed: 115 | # Since we have ngpus_per_node processes per node, the total world_size 116 | # needs to be adjusted accordingly 117 | args.world_size = ngpus_per_node * args.world_size 118 | # Use torch.multiprocessing.spawn to launch distributed processes: the 119 | # main_worker process function 120 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 121 | else: 122 | # Simply call main_worker function 123 | main_worker(args.gpu, ngpus_per_node, args) 124 | 125 | 126 | def main_worker(gpu, ngpus_per_node, args): 127 | global best_acc1 128 | args.gpu = gpu 129 | 130 | if args.gpu is not None: 131 | print("Use GPU: {} for training".format(args.gpu)) 132 | 133 | if args.distributed: 134 | if args.dist_url == "env://" and args.rank == -1: 135 | args.rank = int(os.environ["RANK"]) 136 | if args.multiprocessing_distributed: 137 | # For multiprocessing distributed training, rank needs to be the 138 | # global rank among all the processes 139 | args.rank = args.rank * ngpus_per_node + gpu 140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 141 | world_size=args.world_size, rank=args.rank) 142 | # create model 143 | if args.pretrained: 144 | print("=> using pre-trained model '{}'".format(args.arch)) 145 | model = models.__dict__[args.arch](pretrained=True) 146 | else: 147 | print("=> creating model '{}'".format(args.arch)) 148 | model = models.__dict__[args.arch]() 149 | 150 | if not torch.cuda.is_available(): 151 | print('using CPU, this will be slow') 152 | elif args.distributed: 153 | # For multiprocessing distributed, DistributedDataParallel constructor 154 | # should always set the single device scope, otherwise, 155 | # DistributedDataParallel will use all available devices. 156 | if args.gpu is not None: 157 | torch.cuda.set_device(args.gpu) 158 | model.cuda(args.gpu) 159 | # When using a single GPU per process and per 160 | # DistributedDataParallel, we need to divide the batch size 161 | # ourselves based on the total number of GPUs we have 162 | args.batch_size = int(args.batch_size / ngpus_per_node) 163 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 164 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 165 | else: 166 | model.cuda() 167 | # DistributedDataParallel will divide and allocate batch_size to all 168 | # available GPUs if device_ids are not set 169 | model = torch.nn.parallel.DistributedDataParallel(model) 170 | elif args.gpu is not None: 171 | torch.cuda.set_device(args.gpu) 172 | model = model.cuda(args.gpu) 173 | else: 174 | # DataParallel will divide and allocate batch_size to all available GPUs 175 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 176 | model.features = torch.nn.DataParallel(model.features) 177 | model.cuda() 178 | else: 179 | model = torch.nn.DataParallel(model).cuda() 180 | 181 | # define loss function (criterion) and optimizer 182 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 183 | 184 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 185 | momentum=args.momentum, 186 | weight_decay=args.weight_decay) 187 | 188 | # optionally resume from a checkpoint 189 | if args.resume: 190 | if os.path.isfile(args.resume): 191 | print("=> loading checkpoint '{}'".format(args.resume)) 192 | if args.gpu is None: 193 | checkpoint = torch.load(args.resume) 194 | else: 195 | # Map model to be loaded to specified single gpu. 196 | loc = 'cuda:{}'.format(args.gpu) 197 | checkpoint = torch.load(args.resume, map_location=loc) 198 | args.start_epoch = checkpoint['epoch'] 199 | best_acc1 = checkpoint['best_acc1'] 200 | if args.gpu is not None: 201 | # best_acc1 may be from a checkpoint from a different GPU 202 | best_acc1 = best_acc1.to(args.gpu) 203 | model.load_state_dict(checkpoint['state_dict']) 204 | optimizer.load_state_dict(checkpoint['optimizer']) 205 | print("=> loaded checkpoint '{}' (epoch {})" 206 | .format(args.resume, checkpoint['epoch'])) 207 | else: 208 | print("=> no checkpoint found at '{}'".format(args.resume)) 209 | 210 | cudnn.benchmark = True 211 | 212 | # Data loading code 213 | traindir = os.path.join(args.data, 'train') 214 | valdir = os.path.join(args.data, 'valid') 215 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 216 | std=[0.229, 0.224, 0.225]) 217 | 218 | train_dataset = datasets.ImageFolder( 219 | traindir, 220 | transforms.Compose([ 221 | transforms.RandomResizedCrop(224), 222 | transforms.RandomHorizontalFlip(), 223 | transforms.ToTensor(), 224 | normalize, 225 | ])) 226 | 227 | if args.distributed: 228 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 229 | else: 230 | train_sampler = None 231 | 232 | train_loader = torch.utils.data.DataLoader( 233 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 234 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 235 | 236 | val_loader = torch.utils.data.DataLoader( 237 | datasets.ImageFolder(valdir, transforms.Compose([ 238 | transforms.Resize(256), 239 | transforms.CenterCrop(224), 240 | transforms.ToTensor(), 241 | normalize, 242 | ])), 243 | batch_size=args.batch_size, shuffle=False, 244 | num_workers=args.workers, pin_memory=True) 245 | 246 | if args.evaluate: 247 | validate(val_loader, model, criterion, args) 248 | return 249 | 250 | for epoch in range(args.start_epoch, args.epochs): 251 | if args.distributed: 252 | train_sampler.set_epoch(epoch) 253 | adjust_learning_rate(optimizer, epoch, args) 254 | 255 | # train for one epoch 256 | train(train_loader, model, criterion, optimizer, epoch, args) 257 | 258 | # evaluate on validation set 259 | acc1 = validate(val_loader, model, criterion, args) 260 | 261 | # remember best acc@1 and save checkpoint 262 | is_best = acc1 > best_acc1 263 | best_acc1 = max(acc1, best_acc1) 264 | 265 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 266 | and args.rank % ngpus_per_node == 0): 267 | save_checkpoint({ 268 | 'epoch': epoch + 1, 269 | 'arch': args.arch, 270 | 'state_dict': model.state_dict(), 271 | 'best_acc1': best_acc1, 272 | 'optimizer': optimizer.state_dict(), 273 | }, is_best, 274 | '{}_{}_checkpoint.pth.tar'.format(args.paradigm, args.arch)) 275 | 276 | 277 | def train(train_loader, model, criterion, optimizer, epoch, args): 278 | batch_time = AverageMeter('Time', ':6.3f') 279 | data_time = AverageMeter('Data', ':6.3f') 280 | losses = AverageMeter('Loss', ':.4e') 281 | top1 = AverageMeter('Acc@1', ':6.2f') 282 | top5 = AverageMeter('Acc@5', ':6.2f') 283 | progress = ProgressMeter( 284 | len(train_loader), 285 | [batch_time, data_time, losses, top1, top5], 286 | prefix="Epoch: [{}]".format(epoch)) 287 | 288 | # switch to train mode 289 | model.train() 290 | 291 | end = time.time() 292 | for i, (images, target) in enumerate(train_loader): 293 | # measure data loading time 294 | data_time.update(time.time() - end) 295 | 296 | if args.gpu is not None: 297 | images = images.cuda(args.gpu, non_blocking=True) 298 | if torch.cuda.is_available(): 299 | target = target.cuda(args.gpu, non_blocking=True) 300 | 301 | # compute output 302 | output = model(images) 303 | loss = criterion(output, target) 304 | 305 | # measure accuracy and record loss 306 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 307 | losses.update(loss.item(), images.size(0)) 308 | top1.update(acc1[0], images.size(0)) 309 | top5.update(acc5[0], images.size(0)) 310 | 311 | # compute gradient and do SGD step 312 | optimizer.zero_grad() 313 | loss.backward() 314 | optimizer.step() 315 | 316 | # measure elapsed time 317 | batch_time.update(time.time() - end) 318 | end = time.time() 319 | 320 | if i % args.print_freq == 0: 321 | progress.display(i) 322 | 323 | 324 | def validate(val_loader, model, criterion, args): 325 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 326 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 327 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 328 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 329 | progress = ProgressMeter( 330 | len(val_loader), 331 | [batch_time, losses, top1, top5], 332 | prefix='Test: ') 333 | 334 | # switch to evaluate mode 335 | model.eval() 336 | 337 | with torch.no_grad(): 338 | end = time.time() 339 | for i, (images, target) in enumerate(val_loader): 340 | if args.gpu is not None: 341 | images = images.cuda(args.gpu, non_blocking=True) 342 | if torch.cuda.is_available(): 343 | target = target.cuda(args.gpu, non_blocking=True) 344 | 345 | # compute output 346 | output = model(images) 347 | loss = criterion(output, target) 348 | 349 | # measure accuracy and record loss 350 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 351 | losses.update(loss.item(), images.size(0)) 352 | top1.update(acc1[0], images.size(0)) 353 | top5.update(acc5[0], images.size(0)) 354 | 355 | # measure elapsed time 356 | batch_time.update(time.time() - end) 357 | end = time.time() 358 | 359 | if i % args.print_freq == 0: 360 | progress.display(i) 361 | 362 | progress.display_summary() 363 | 364 | return top1.avg 365 | 366 | 367 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 368 | torch.save(state, filename) 369 | if is_best: 370 | shutil.copyfile(filename, filename.split('checkpoint.pth.tar')[0]+'model_best.pth.tar') 371 | 372 | 373 | class Summary(Enum): 374 | NONE = 0 375 | AVERAGE = 1 376 | SUM = 2 377 | COUNT = 3 378 | 379 | 380 | class AverageMeter(object): 381 | """Computes and stores the average and current value""" 382 | 383 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 384 | self.name = name 385 | self.fmt = fmt 386 | self.summary_type = summary_type 387 | self.reset() 388 | 389 | def reset(self): 390 | self.val = 0 391 | self.avg = 0 392 | self.sum = 0 393 | self.count = 0 394 | 395 | def update(self, val, n=1): 396 | self.val = val 397 | self.sum += val * n 398 | self.count += n 399 | self.avg = self.sum / self.count 400 | 401 | def __str__(self): 402 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 403 | return fmtstr.format(**self.__dict__) 404 | 405 | def summary(self): 406 | fmtstr = '' 407 | if self.summary_type is Summary.NONE: 408 | fmtstr = '' 409 | elif self.summary_type is Summary.AVERAGE: 410 | fmtstr = '{name} {avg:.3f}' 411 | elif self.summary_type is Summary.SUM: 412 | fmtstr = '{name} {sum:.3f}' 413 | elif self.summary_type is Summary.COUNT: 414 | fmtstr = '{name} {count:.3f}' 415 | else: 416 | raise ValueError('invalid summary type %r' % self.summary_type) 417 | 418 | return fmtstr.format(**self.__dict__) 419 | 420 | 421 | class ProgressMeter(object): 422 | def __init__(self, num_batches, meters, prefix=""): 423 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 424 | self.meters = meters 425 | self.prefix = prefix 426 | 427 | def display(self, batch): 428 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 429 | entries += [str(meter) for meter in self.meters] 430 | print('\t'.join(entries)) 431 | 432 | def display_summary(self): 433 | entries = [" *"] 434 | entries += [meter.summary() for meter in self.meters] 435 | print(' '.join(entries)) 436 | 437 | def _get_batch_fmtstr(self, num_batches): 438 | num_digits = len(str(num_batches // 1)) 439 | fmt = '{:' + str(num_digits) + 'd}' 440 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 441 | 442 | 443 | def adjust_learning_rate(optimizer, epoch, args): 444 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 445 | lr = args.lr * (0.1 ** (epoch // 30)) 446 | for param_group in optimizer.param_groups: 447 | param_group['lr'] = lr 448 | 449 | 450 | def accuracy(output, target, topk=(1,)): 451 | """Computes the accuracy over the k top predictions for the specified values of k""" 452 | with torch.no_grad(): 453 | maxk = max(topk) 454 | batch_size = target.size(0) 455 | 456 | _, pred = output.topk(maxk, 1, True, True) 457 | pred = pred.t() 458 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 459 | 460 | res = [] 461 | for k in topk: 462 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 463 | res.append(correct_k.mul_(100.0 / batch_size)) 464 | return res 465 | 466 | 467 | if __name__ == '__main__': 468 | main() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenet import * 2 | from .resnet import * -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch import nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | 7 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 8 | 9 | 10 | model_urls = { 11 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 12 | } 13 | 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | :param v: 22 | :param divisor: 23 | :param min_value: 24 | :return: 25 | """ 26 | if min_value is None: 27 | min_value = divisor 28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 29 | # Make sure that round down does not go down by more than 10%. 30 | if new_v < 0.9 * v: 31 | new_v += divisor 32 | return new_v 33 | 34 | 35 | class ConvBNReLU(nn.Sequential): 36 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 37 | padding = (kernel_size - 1) // 2 38 | super(ConvBNReLU, self).__init__( 39 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 40 | nn.BatchNorm2d(out_planes), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expand_ratio): 47 | super(InvertedResidual, self).__init__() 48 | self.stride = stride 49 | assert stride in [1, 2] 50 | 51 | hidden_dim = int(round(inp * expand_ratio)) 52 | self.use_res_connect = self.stride == 1 and inp == oup 53 | 54 | layers = [] 55 | if expand_ratio != 1: 56 | # pw 57 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 58 | layers.extend([ 59 | # dw 60 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 61 | # pw-linear 62 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 63 | nn.BatchNorm2d(oup), 64 | ]) 65 | self.conv = nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | if self.use_res_connect: 69 | return x + self.conv(x) 70 | else: 71 | return self.conv(x) 72 | 73 | 74 | class MobileNetV2(nn.Module): 75 | def __init__(self, 76 | num_classes=1000, 77 | width_mult=1.0, 78 | inverted_residual_setting=None, 79 | round_nearest=8, 80 | block=None): 81 | """ 82 | MobileNet V2 main class 83 | 84 | Args: 85 | num_classes (int): Number of classes 86 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 87 | inverted_residual_setting: Network structure 88 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 89 | Set to 1 to turn off rounding 90 | block: Module specifying inverted residual building block for mobilenet 91 | 92 | """ 93 | super(MobileNetV2, self).__init__() 94 | 95 | if block is None: 96 | block = InvertedResidual 97 | input_channel = 32 98 | last_channel = 1280 99 | 100 | if inverted_residual_setting is None: 101 | inverted_residual_setting = [ 102 | # t, c, n, s 103 | [1, 16, 1, 1], 104 | [6, 24, 2, 2], 105 | [6, 32, 3, 2], 106 | [6, 64, 4, 2], 107 | [6, 96, 3, 1], 108 | [6, 160, 3, 2], 109 | [6, 320, 1, 1], 110 | ] 111 | 112 | # only check the first element, assuming user knows t,c,n,s are required 113 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 114 | raise ValueError("inverted_residual_setting should be non-empty " 115 | "or a 4-element list, got {}".format(inverted_residual_setting)) 116 | 117 | # building first layer 118 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 119 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 120 | features = [ConvBNReLU(3, input_channel, stride=2)] 121 | # building inverted residual blocks 122 | for t, c, n, s in inverted_residual_setting: 123 | output_channel = _make_divisible(c * width_mult, round_nearest) 124 | for i in range(n): 125 | stride = s if i == 0 else 1 126 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 127 | input_channel = output_channel 128 | # building last several layers 129 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 130 | # make it nn.Sequential 131 | self.features = nn.Sequential(*features) 132 | 133 | # building classifier 134 | self.classifier = nn.Sequential( 135 | nn.Dropout(0.2), 136 | nn.Linear(self.last_channel, num_classes), 137 | ) 138 | 139 | # weight initialization 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 143 | if m.bias is not None: 144 | nn.init.zeros_(m.bias) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | nn.init.ones_(m.weight) 147 | nn.init.zeros_(m.bias) 148 | elif isinstance(m, nn.Linear): 149 | nn.init.normal_(m.weight, 0, 0.01) 150 | nn.init.zeros_(m.bias) 151 | 152 | def _forward_impl(self, x): 153 | # This exists since TorchScript doesn't support inheritance, so the superclass method 154 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 155 | x = self.features(x) 156 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 157 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 158 | x = self.classifier(x) 159 | return x 160 | 161 | def forward(self, x): 162 | return self._forward_impl(x) 163 | 164 | 165 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 166 | """ 167 | Constructs a MobileNetV2 architecture from 168 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | progress (bool): If True, displays a progress bar of the download to stderr 173 | """ 174 | model = MobileNetV2(**kwargs) 175 | if pretrained: 176 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], model_dir='./models/_pytorch_pretrained_checkpoints/', 177 | progress=progress) 178 | model.load_state_dict(state_dict) 179 | return model 180 | 181 | 182 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 79 | # This variant is also known as ResNet V1.5 and improves accuracy according to 80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 81 | 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def _forward_impl(self, x): 202 | # See note [TorchScript super()] 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | x = self.maxpool(x) 207 | 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | 213 | x = self.avgpool(x) 214 | x = torch.flatten(x, 1) 215 | x = self.fc(x) 216 | 217 | return x 218 | 219 | def forward(self, x): 220 | return self._forward_impl(x) 221 | 222 | 223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 224 | model = ResNet(block, layers, **kwargs) 225 | if pretrained: 226 | state_dict = load_state_dict_from_url(model_urls[arch], model_dir='./models/_pytorch_pretrained_checkpoints/', 227 | progress=progress) 228 | model.load_state_dict(state_dict) 229 | return model 230 | 231 | 232 | def resnet18(pretrained=False, progress=True, **kwargs): 233 | r"""ResNet-18 model from 234 | `"Deep Residual Learning for Image Recognition" `_ 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet34(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-34 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | 256 | def resnet50(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-50 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet101(pretrained=False, progress=True, **kwargs): 269 | r"""ResNet-101 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet152(pretrained=False, progress=True, **kwargs): 281 | r"""ResNet-152 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 293 | r"""ResNeXt-50 32x4d model from 294 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | progress (bool): If True, displays a progress bar of the download to stderr 299 | """ 300 | kwargs['groups'] = 32 301 | kwargs['width_per_group'] = 4 302 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 303 | pretrained, progress, **kwargs) 304 | 305 | 306 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 307 | r"""ResNeXt-101 32x8d model from 308 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | kwargs['groups'] = 32 315 | kwargs['width_per_group'] = 8 316 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 317 | pretrained, progress, **kwargs) 318 | 319 | 320 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 321 | r"""Wide ResNet-50-2 model from 322 | `"Wide Residual Networks" `_ 323 | 324 | The model is the same as ResNet except for the bottleneck number of channels 325 | which is twice larger in every block. The number of channels in outer 1x1 326 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 327 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 328 | 329 | Args: 330 | pretrained (bool): If True, returns a model pre-trained on ImageNet 331 | progress (bool): If True, displays a progress bar of the download to stderr 332 | """ 333 | kwargs['width_per_group'] = 64 * 2 334 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 335 | pretrained, progress, **kwargs) 336 | 337 | 338 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 339 | r"""Wide ResNet-101-2 model from 340 | `"Wide Residual Networks" `_ 341 | 342 | The model is the same as ResNet except for the bottleneck number of channels 343 | which is twice larger in every block. The number of channels in outer 1x1 344 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 345 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 346 | 347 | Args: 348 | pretrained (bool): If True, returns a model pre-trained on ImageNet 349 | progress (bool): If True, displays a progress bar of the download to stderr 350 | """ 351 | kwargs['width_per_group'] = 64 * 2 352 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 353 | pretrained, progress, **kwargs) 354 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | albumentations==1.1.0 3 | cachetools==5.0.0 4 | certifi==2021.5.30 5 | chardet==4.0.0 6 | charset-normalizer==2.0.12 7 | click==8.0.4 8 | cycler==0.10.0 9 | Cython==0.29.28 10 | docker-pycreds==0.4.0 11 | fonttools==4.30.0 12 | gitdb==4.0.9 13 | GitPython==3.1.27 14 | google-auth==2.6.2 15 | google-auth-oauthlib==0.4.6 16 | grpcio==1.44.0 17 | idna==2.10 18 | imageio==2.16.1 19 | importlib-metadata==4.11.3 20 | joblib==1.1.0 21 | kiwisolver==1.3.1 22 | Markdown==3.3.6 23 | matplotlib==3.5.1 24 | networkx==2.6.3 25 | numpy==1.21.5 26 | oauthlib==3.2.0 27 | opencv-python==4.5.5.64 28 | opencv-python-headless==4.5.5.64 29 | packaging==21.3 30 | pandas==1.3.5 31 | pathtools==0.1.2 32 | Pillow==9.0.1 33 | promise==2.3 34 | protobuf==3.19.4 35 | psutil==5.9.0 36 | pyasn1==0.4.8 37 | pyasn1-modules==0.2.8 38 | pycocotools==2.0.4 39 | pyparsing==2.4.7 40 | python-dateutil==2.8.2 41 | python-dotenv==0.19.2 42 | pytz==2021.3 43 | PyWavelets==1.3.0 44 | PyYAML==6.0 45 | qudida==0.0.4 46 | requests==2.27.1 47 | requests-oauthlib==1.3.1 48 | roboflow==0.2.2 49 | rsa==4.8 50 | scikit-image==0.19.2 51 | scikit-learn==1.0.2 52 | scipy==1.7.3 53 | seaborn==0.11.2 54 | sentry-sdk==1.5.8 55 | setproctitle==1.2.2 56 | shortuuid==1.0.8 57 | six==1.16.0 58 | smmap==5.0.0 59 | tensorboard==2.8.0 60 | tensorboard-data-server==0.6.1 61 | tensorboard-plugin-wit==1.8.1 62 | termcolor==1.1.0 63 | thop==0.0.31.post2005241907 64 | threadpoolctl==3.1.0 65 | tifffile==2021.11.2 66 | torch @ file:///UsrFile/ybn/zx/pytorch_wheels/torch1.9/torch-1.9.0%2Bcu102-cp37-cp37m-linux_x86_64.whl 67 | torchvision @ file:///UsrFile/ybn/zx/pytorch_wheels/torch1.9/torchvision-0.10.0%2Bcu102-cp37-cp37m-linux_x86_64.whl 68 | tqdm==4.63.0 69 | typing_extensions==4.1.1 70 | urllib3==1.26.6 71 | wandb==0.12.11 72 | Werkzeug==2.0.3 73 | wget==3.2 74 | yaspin==2.1.0 75 | zipp==3.7.0 76 | --------------------------------------------------------------------------------