├── .gitignore ├── README.md ├── dataset.py ├── figs ├── RM.png └── RM.svg ├── logs ├── log_ade20k_80k_upernet_r50_baseline_40.40_miou.json ├── log_ade20k_80k_upernet_r50_cutmix_41.24_miou.json ├── log_ade20k_80k_upernet_r50_rm_42.30_miou.json ├── log_cifar10_300epoch_pyramidnet_rm_97.60_top1_acc.log ├── log_cifar10_300epoch_pyramidnet_rm_97.62_top1_acc.log ├── log_cifar10_300epoch_pyramidnet_rm_97.72_top1_acc.log ├── log_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.json ├── log_coco_12epoch_atss_r50_fpn_rm_41.5_ap.json └── log_imagenet_300epoch_r50_rm_79.20_top1_acc.log ├── loss ├── __init__.py └── label_smooth_loss.py ├── main.py ├── models_cifar ├── __init__.py ├── densenet.py ├── pyramidnet.py └── resnet.py ├── models_imagenet ├── __init__.py └── resnet.py ├── samplers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.so 4 | checkpoints/ 5 | *.pth 6 | *.pth.tar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for 'RecursiveMix: Mixed Learning with History' 2 | 3 |

4 | 5 | RecursiveMix (RM), which uses the historical input-prediction-label triplet to enhance the generalization of Deep Vision Models. [Paper Link Here.](https://arxiv.org/pdf/2203.06844.pdf) 6 | 7 | ## Requirements 8 | 9 | Experiment Environment 10 | - python 3.6 11 | - pytorch 1.7.1+cu101 12 | - torchvision 0.8.2 13 | - mmcv-full 1.4.1 14 | - mmdet 2.19.1 15 | - mmsegmentation 0.20.2 16 | 17 | ## Usage 18 | 19 | ### 1. Train the model 20 | For example, to reproduce the results of RM in CIFAR-10 (97.65% Top-1 acc in averaged 3 runs, logs are provided in logs/): 21 | ```python 22 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \ 23 | --name 'your_experiment_log_path' \ 24 | --model_file 'pyramidnet' \ 25 | --model_name 'pyramidnet_200_240' \ 26 | --data 'cifar10' \ 27 | --data_dir '/path/to/CIFAR10' \ 28 | --epoch 300 \ 29 | --batch_size 64 \ 30 | --lr 0.25 \ 31 | --scheduler 'step' \ 32 | --schedule 150 225 \ 33 | --weight_decay 1e-4 \ 34 | --nesterov \ 35 | --num_workers 8 \ 36 | --save_model \ 37 | --aug 'recursive_mix' \ 38 | --aug_alpha 0.5 \ 39 | --aug_omega 0.1 40 | ``` 41 | 42 | RM in ImageNet (79.20% Top-1 acc) 43 | ```python 44 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port 29500 main.py \ 45 | --name 'your_experiment_log_path' \ 46 | --model_file 'resnet' \ 47 | --model_name 'resnet50' \ 48 | --data 'imagenet' \ 49 | --epoch 300 \ 50 | --batch_size 512 \ 51 | --lr 0.2 \ 52 | --warmup 5 \ 53 | --weight_decay 1e-4 \ 54 | --aug_plus \ 55 | --num_workers 32 \ 56 | --save_model \ 57 | --aug 'recursive_mix' \ 58 | --aug_alpha 0.5 \ 59 | --aug_omega 0.5 60 | ``` 61 | 62 | 63 | ### 2. Test the model 64 | ```python 65 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \ 66 | --name 'your_experiment_log_path' \ 67 | --batch_size 64 \ 68 | --model_file 'pyramidnet' \ 69 | --model_name 'pyramidnet_200_240' \ 70 | --data 'cifar10' \ 71 | --data_dir '/path/to/CIFAR10' \ 72 | --num_workers 8 \ 73 | --evaluate \ 74 | --resume 'best' 75 | ``` 76 | 77 | ## Model Zoo 78 | ### Image Classification 79 | - ImageNet-1K (300 epoch) 80 | 81 | | Backbone | Size | Params (M) | Acc@1 | Log | Download | 82 | | -------------- | :---: | :--------: | :-------: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------: | 83 | | ResNet-50 | 224 | 25.56 | 76.32 | log | [Google] [GitHub] | 84 | | + Mixup | 224 | 25.56 | 77.42 | log | [Google] [GitHub] | 85 | | + CutMix | 224 | 25.56 | 78.60 | log | [Google] [GitHub] | 86 | | + RecursiveMix | 224 | 25.56 | **79.20** | [log](logs/log_imagenet_300epoch_r50_rm_79.20_top1_acc.log) | [[Google]](https://drive.google.com/file/d/19dlKcrTgfY3UqAOFNXkLmIwX41wUOOPB/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_imagenet_300epoch_r50_rm_79.20_top1_acc.pth) | 87 | 88 | ### Object Detection 89 | 90 | - COCO (1x schedule) 91 | 92 | #### ATSS 93 | 94 | | Backbone | Lr schd | Mem (GB) | Inf time (fps) | box AP | Log | Download | 95 | | -------------- | :-----: | :------: | :------------: | :------: | :----------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------: | 96 | | ResNet-50 | 1x | 3.7 | 19.7 | 39.4 | [log](https://download.openmmlab.com/mmdetection/v2.0/atss/atss_r50_fpn_1x_coco/atss_r50_fpn_1x_coco_20200209_102539.log.json) | [[Google]](https://drive.google.com/file/d/1DVVarV-os8BwEGgnBHkkCpbdSWUquYIV/view?usp=sharing) [[GitHub]](https://download.openmmlab.com/mmdetection/v2.0/atss/atss_r50_fpn_1x_coco/atss_r50_fpn_1x_coco_20200209-985f7bd0.pth) | 97 | | + CutMix | 1x | 3.7 | 19.7 | 40.1 | [log](logs/log_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.json) | [[Google]](https://drive.google.com/file/d/1T2fVCmwyMMzBdxg5QevCIRmey9vlc7JZ/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.pth) | 98 | | + RecursiveMix | 1x | 3.7 | 19.7 | **41.5** | [log](logs/log_coco_12epoch_atss_r50_fpn_rm_41.5_ap.json) | [[Google]](https://drive.google.com/file/d/1iFqrhkrm05LqQw_1W2q7lcl0WB_WAgNA/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_coco_12epoch_atss_r50_fpn_rm_41.5_ap.pth) | 99 | 100 | ### Semantic Segmentation 101 | 102 | - ADE20K (80k iteration) 103 | 104 | #### UPerNet 105 | | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | Log | download | 106 | | -------------- | :-------: | :-----: | :------: | :------------: | :-------: | :-------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------: | 107 | | ResNet-50 | 512x512 | 80000 | 8.1 | 23.40 | 40.40 | [log](logs/log_ade20k_80k_upernet_r50_baseline_40.40_miou.json) | [[Google]](https://drive.google.com/file/d/1ps0KJZUTM_e-0aaDrTE0pjvTTCSO3qpH/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_baseline_40.40_miou.pth) | 108 | | + CutMix | 512x512 | 80000 | 8.1 | 23.40 | 41.24 | [log](logs/log_ade20k_80k_upernet_r50_cutmix_41.24_miou.json) | [[Google]](https://drive.google.com/file/d/1GIqjeir1hXPC8z6pkNYaL-9lBFgistJh/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_cutmix_41.24_miou.pth) | 109 | | + RecursiveMix | 512x512 | 80000 | 8.1 | 23.40 | **42.30** | [log](logs/log_ade20k_80k_upernet_r50_rm_42.30_miou.json) | [[Google]](https://drive.google.com/file/d/1mysKzWVGnaEAZ61Vcdg8J_tHrU2Aotnt/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_rm_42.30_miou.pth) | 110 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data import DataLoader, DistributedSampler 4 | from torchvision import datasets, transforms 5 | 6 | from samplers import RASampler 7 | from utils import ColorJitter, Lighting 8 | 9 | 10 | def create_loader(args): 11 | loader = { 12 | 'cifar10': cifar10_loader, 13 | 'cifar100': cifar100_loader, 14 | 'imagenet': imagenet_loader, 15 | } 16 | trainset, testset = loader[args.data](args) 17 | 18 | if args.ddp: 19 | if args.repeated_aug: 20 | train_sampler = RASampler(trainset, shuffle=True) 21 | else: 22 | train_sampler = DistributedSampler(trainset, shuffle=True) 23 | test_sampler = DistributedSampler(testset, shuffle=False) 24 | 25 | train_loader = DataLoader(trainset, 26 | args.batch_size, 27 | sampler=train_sampler, 28 | num_workers=args.num_workers, 29 | pin_memory=True) 30 | test_loader = DataLoader(testset, 31 | args.batch_size, 32 | sampler=test_sampler, 33 | num_workers=args.num_workers, 34 | pin_memory=True) 35 | else: 36 | train_loader = DataLoader(trainset, 37 | args.batch_size, 38 | shuffle=True, 39 | num_workers=args.num_workers, 40 | pin_memory=True) 41 | test_loader = DataLoader(testset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 42 | 43 | return train_loader, test_loader 44 | 45 | 46 | def cifar10_loader(args): 47 | args.num_classes = 10 48 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 49 | transform_train = transforms.Compose([ 50 | transforms.RandomCrop(32, padding=4), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | normalize, 54 | ]) 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | normalize, 58 | ]) 59 | 60 | trainset = datasets.CIFAR10(root=args.data_dir, train=True, download=False, transform=transform_train) 61 | testset = datasets.CIFAR10(root=args.data_dir, train=False, download=False, transform=transform_test) 62 | 63 | return trainset, testset 64 | 65 | 66 | def cifar100_loader(args): 67 | args.num_classes = 100 68 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 69 | transform_train = transforms.Compose([ 70 | transforms.RandomCrop(32, padding=4), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | normalize, 74 | ]) 75 | transform_test = transforms.Compose([ 76 | transforms.ToTensor(), 77 | normalize, 78 | ]) 79 | 80 | trainset = datasets.CIFAR100(root=args.data_dir, train=True, download=False, transform=transform_train) 81 | testset = datasets.CIFAR100(root=args.data_dir, train=False, download=False, transform=transform_test) 82 | 83 | return trainset, testset 84 | 85 | 86 | def imagenet_loader(args): 87 | args.num_classes = 1000 88 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 89 | 90 | if args.aug_plus: 91 | args.logger.info('Using aug_plus') 92 | jittering = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4) 93 | lighting = Lighting(alphastd=0.1, 94 | eigval=[0.2175, 0.0188, 0.0045], 95 | eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]) 96 | 97 | transform_train = transforms.Compose([ 98 | transforms.RandomResizedCrop(224), 99 | transforms.RandomHorizontalFlip(), 100 | transforms.ToTensor(), 101 | jittering, 102 | lighting, 103 | normalize, 104 | ]) 105 | else: 106 | transform_train = transforms.Compose([ 107 | transforms.RandomResizedCrop(224), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.ToTensor(), 110 | normalize, 111 | ]) 112 | transform_test = transforms.Compose([ 113 | transforms.Resize(256), 114 | transforms.CenterCrop(224), 115 | transforms.ToTensor(), 116 | normalize, 117 | ]) 118 | trainset = datasets.ImageFolder(root=os.path.join(args.data_dir, 'train'), transform=transform_train) 119 | testset = datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'), transform=transform_test) 120 | 121 | return trainset, testset 122 | -------------------------------------------------------------------------------- /figs/RM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/implus/RecursiveMix-pytorch/ed381bf28304796efd46a4d44b357ae8d4d10061/figs/RM.png -------------------------------------------------------------------------------- /figs/RM.svg: -------------------------------------------------------------------------------- 1 | Iteration 𝑡-1Iteration 𝑡+1ModelModelModel𝝀𝒕𝟏𝝀𝒕+𝟏𝝀𝒕𝝀𝒕+𝟏𝟏𝝀𝒕Historical LabelHistorical Input𝝀𝒕+𝟏…………Iteration 𝑡SuperviseResizeFillIteration 𝑡+2Model𝟏𝝀𝒕+𝟐𝝀𝒕+𝟐𝝀𝒕+𝟐One-hot Label𝑦𝑡1Cat One-hot Label𝑦𝑡Dog One-hot Label𝑦𝑡+1Bird One-hot Label𝑦𝑡+2Horse ……Historical PredictionKL divergenceKL divergenceKL divergenceGAPGAPRoIAlignRoIAlignGAPRoIAlignGAPℋ′ℋ′ℋ′ -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .label_smooth_loss import * -------------------------------------------------------------------------------- /loss/label_smooth_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import * 4 | 5 | 6 | class LabelSmoothingLoss(nn.Module): 7 | def __init__(self, num_classes, smoothing=0.0): 8 | super(LabelSmoothingLoss, self).__init__() 9 | assert 0 <= smoothing < 1 10 | self.num_classes = num_classes 11 | self.smoothing = smoothing 12 | 13 | def forward(self, pred: torch.Tensor, target: torch.Tensor): 14 | bs = float(pred.size(0)) 15 | pred = pred.log_softmax(dim=1) 16 | if len(target.shape) == 2: 17 | true_dist = target 18 | else: 19 | true_dist = smooth_one_hot(target, self.num_classes, self.smoothing) 20 | loss = (-pred * true_dist).sum() / bs 21 | return loss 22 | 23 | 24 | if __name__ == '__main__': 25 | criterion = LabelSmoothingLoss(5) 26 | pred = torch.randn(2, 5) 27 | target = torch.tensor([3, 1]) 28 | loss = criterion(pred, target) 29 | print(loss) 30 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.cuda.amp import GradScaler as GradScaler 13 | from torch.cuda.amp import autocast as autocast 14 | from torch.nn.parallel import DistributedDataParallel 15 | 16 | import models_cifar 17 | import models_imagenet 18 | from dataset import create_loader 19 | from loss import * 20 | from utils import * 21 | 22 | scaler = GradScaler() 23 | 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--name', required=True, type=str) 28 | parser.add_argument('--data', default='cifar100', type=str, help='cifar10|cifar100|imagenet') 29 | parser.add_argument('--data_dir', type=str, default='/data/datasets/cls/cifar') 30 | parser.add_argument('--save_dir', type=str, default='./logs') 31 | parser.add_argument('--model_file', default='resnet', type=str, help='model type') 32 | parser.add_argument('--model_name', default='resnet18', type=str, help='model type in detail') 33 | 34 | parser.add_argument('--epoch', default=200, type=int) 35 | parser.add_argument('--optimizer', default='sgd', type=str, help='sgd|adamw') 36 | parser.add_argument('--scheduler', default='cos', type=str, help='step|cos') 37 | parser.add_argument('--schedule', default=[100, 150], type=int, nargs='+') 38 | parser.add_argument('--batch_size', default=128, type=int) 39 | parser.add_argument('--warmup', default=0, type=int) 40 | parser.add_argument('--lr', default=0.1, type=float) 41 | parser.add_argument('--lr_decay', default=0.1, type=float) 42 | parser.add_argument('--momentum', default=0.9, type=float) 43 | parser.add_argument('--weight_decay', default=5e-4, type=float) 44 | parser.add_argument('--nesterov', action='store_true', help='enables Nesterov momentum (default: False)') 45 | parser.add_argument('--ddp', default=True, type=str2bool, help='nn.DataParallel|DistributedDataParallel') 46 | parser.add_argument('--smoothing', default=0.0, type=float, help='Label smoothing (default: 0.0)') 47 | 48 | parser.add_argument('--save_model', action='store_true') 49 | parser.add_argument('--print_freq', default=100, type=int) 50 | parser.add_argument('--random_seed', default=27, type=int) 51 | parser.add_argument('--num_workers', default=16, type=int) 52 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 53 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)') 54 | parser.add_argument('--evaluate', action='store_true', help='evaluate model on validation set') 55 | parser.add_argument('--pretrained', action='store_true', help='use pretrained models') 56 | parser.add_argument('--fold', default=1, type=int, help='training fold') 57 | parser.add_argument('--strict', default=True, type=str2bool, help='args for resume training: load_state_dict') 58 | 59 | # augmentation 60 | parser.add_argument('--aug', default='none', type=str, help='mixup|cutmix') 61 | parser.add_argument('--aug_alpha', default=0.5, type=float, help='alpha of RM') 62 | parser.add_argument('--aug_omega', default=0.5, type=float, help='omega of RM') 63 | parser.add_argument('--aug_plus', action='store_true') 64 | parser.add_argument('--interpolate_mode', default='nearest', type=str, help='nearest|bilinear') 65 | parser.add_argument('--share_fc', action='store_true') 66 | parser.add_argument('--repeated_aug', action='store_true') 67 | 68 | args = parser.parse_args() 69 | 70 | # set random seed 71 | random.seed(args.random_seed) 72 | np.random.seed(args.random_seed) 73 | torch.manual_seed(args.random_seed) 74 | torch.cuda.manual_seed(args.random_seed) 75 | torch.cuda.manual_seed_all(args.random_seed) 76 | torch.backends.cudnn.deterministic = False 77 | torch.backends.cudnn.benchmark = True 78 | 79 | args.nprocs = torch.cuda.device_count() 80 | if args.ddp: 81 | dist.init_process_group(backend='nccl') 82 | torch.cuda.set_device(args.local_rank) 83 | args.batch_size = int(args.batch_size / args.nprocs) 84 | args.num_workers = int((args.num_workers + args.nprocs - 1) / args.nprocs) 85 | 86 | # creat logger 87 | creat_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 88 | args.path_log = os.path.join(args.save_dir, f'{args.data}', f'{args.name}') 89 | os.makedirs(args.path_log, exist_ok=True) 90 | logger = create_logging(os.path.join(args.path_log, '%s_fold%s.log' % (creat_time, args.fold))) 91 | args.logger = logger 92 | 93 | # creat dataloader 94 | train_loader, test_loader = create_loader(args) 95 | 96 | # print args 97 | for param in sorted(vars(args).keys()): 98 | logger.info('--{0} {1}'.format(param, vars(args)[param])) 99 | 100 | # creat model 101 | models_package = models_imagenet if args.data == 'imagenet' else models_cifar 102 | if args.pretrained: 103 | model = models_package.__dict__[args.model_file].__dict__[args.model_name](num_classes=args.num_classes, 104 | pretrained=args.pretrained) 105 | else: 106 | model = models_package.__dict__[args.model_file].__dict__[args.model_name](num_classes=args.num_classes) 107 | if args.ddp: 108 | model.cuda(args.local_rank) 109 | model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) 110 | else: 111 | model = nn.DataParallel(model).cuda() 112 | 113 | # creat criterion 114 | criterion = LabelSmoothingLoss(args.num_classes).cuda(args.local_rank) 115 | 116 | # creat optimizer 117 | if args.optimizer == 'sgd': 118 | optimizer = optim.SGD(model.parameters(), 119 | lr=args.lr, 120 | momentum=args.momentum, 121 | weight_decay=args.weight_decay, 122 | nesterov=args.nesterov) 123 | elif args.optimizer == 'adamw': 124 | optimizer = optim.AdamW(model.parameters(), 125 | lr=args.lr, 126 | betas=(0.9, 0.999), 127 | eps=1e-8, 128 | weight_decay=args.weight_decay, 129 | amsgrad=False) 130 | else: 131 | raise NotImplementedError 132 | 133 | best_acc1 = 0.0 134 | best_acc5 = 0.0 135 | start_epoch = 1 136 | # optionally resume from a checkpoint 137 | if args.resume: 138 | if args.resume in ['best', 'latest']: 139 | args.resume = os.path.join(args.path_log, 'fold%s_%s.pth' % (args.fold, args.resume)) 140 | if os.path.isfile(args.resume): 141 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 142 | # Map model to be loaded to specified single gpu. 143 | loc = 'cuda:{}'.format(args.local_rank) if args.ddp else None 144 | state_dict = torch.load(args.resume, map_location=loc) 145 | 146 | if 'state_dict' in state_dict: 147 | state_dict_ = state_dict['state_dict'] 148 | elif 'model' in state_dict: 149 | state_dict_ = state_dict['model'] 150 | else: 151 | state_dict_ = state_dict 152 | model.load_state_dict(state_dict_, strict=args.strict) 153 | 154 | start_epoch = state_dict['epoch'] + 1 155 | optimizer.load_state_dict(state_dict['optimizer']) 156 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, state_dict['epoch'])) 157 | else: 158 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 159 | 160 | # optionally evaluate 161 | if args.evaluate: 162 | epoch = start_epoch - 1 163 | 164 | acc1, acc5 = test(epoch, model, test_loader, logger, args) 165 | logger.info('Epoch(val) [{}]\tTest Acc@1: {:.4f}\tTest Acc@5: {:.4f}\tCopypaste: {:.4f}, {:.4f}'.format( 166 | epoch, acc1, acc5, acc1, acc5)) 167 | logger.info('Exp path: %s' % args.path_log) 168 | return 169 | 170 | # start training 171 | for epoch in range(start_epoch, args.epoch + 1): 172 | if args.ddp: 173 | train_loader.sampler.set_epoch(epoch) 174 | train(epoch, model, optimizer, criterion, train_loader, logger, args) 175 | save_checkpoint(epoch, model, optimizer, args, save_name='latest') 176 | 177 | acc1, acc5 = test(epoch, model, test_loader, logger, args) 178 | if acc1 >= best_acc1: 179 | best_acc1 = acc1 180 | best_acc5 = acc5 181 | save_checkpoint(epoch, model, optimizer, args, save_name='best') 182 | 183 | logger.info('Epoch(val) [{}]\tTest Acc@1: {:.4f}\tTest Acc@5: {:.4f}\t' 184 | 'Best Acc@1: {:.4f}\tBest Acc@5: {:.4f}\tCopypaste: {:.4f}, {:.4f}'.format( 185 | epoch, acc1, acc5, best_acc1, best_acc5, best_acc1, best_acc5)) 186 | logger.info('Exp path: %s' % args.path_log) 187 | 188 | 189 | def train(epoch, model, optimizer, criterion, train_loader, logger, args): 190 | model.train() 191 | losses = AverageMeter() 192 | top1 = AverageMeter() 193 | top5 = AverageMeter() 194 | 195 | old_inputs = None 196 | lr = adjust_learning_rate(optimizer, epoch, args) 197 | for idx, (inputs, targets) in enumerate(train_loader): 198 | optimizer.zero_grad() 199 | inputs, targets = inputs.cuda(), targets.cuda() 200 | targets_onehot = smooth_one_hot(targets, args.num_classes, args.smoothing) 201 | 202 | with autocast(): 203 | if args.aug == 'none': 204 | out = model(inputs) 205 | loss = criterion(out, targets_onehot) 206 | elif args.aug == 'recursive_mix': 207 | if old_inputs is not None: 208 | inputs, targets_onehot, boxes, lam = recursive_mix(inputs, old_inputs, targets_onehot, old_targets, 209 | args.aug_alpha, args.interpolate_mode) 210 | else: 211 | lam = 1.0 212 | 213 | if lam < 1.0: 214 | out, out_roi = model(inputs, boxes, share_fc=args.share_fc) 215 | else: 216 | out = model(inputs, None, share_fc=args.share_fc) 217 | loss = criterion(out, targets_onehot) 218 | if lam < 1.0: 219 | loss_roi = criterion(out_roi, (old_out).softmax(dim=-1)[:inputs.size(0)]) 220 | loss += loss_roi * args.aug_omega * (1.0 - lam) 221 | old_inputs = inputs.clone().detach() 222 | old_targets = targets_onehot.clone().detach() 223 | old_out = out.clone().detach() 224 | else: 225 | raise NotImplementedError 226 | 227 | scaler.scale(loss).backward() 228 | scaler.step(optimizer) 229 | scaler.update() 230 | 231 | batch_size = targets.size(0) 232 | losses.update(reduce_value(loss).item(), batch_size) 233 | acc1, acc5 = accuracy(out, targets, topk=(1, 5)) 234 | top1.update(reduce_value(acc1), batch_size) 235 | top5.update(reduce_value(acc5), batch_size) 236 | 237 | if idx % args.print_freq == 0: 238 | logger.info("Epoch [{0}/{1}][{2}/{3}]\t" 239 | "lr {4:.6f}\t" 240 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 241 | "Acc@1 {top1.val:.4f} ({top1.avg:.4f})\t" 242 | "Acc@5 {top5.val:.4f} ({top5.avg:.4f})".format( 243 | epoch, 244 | args.epoch, 245 | idx, 246 | len(train_loader), 247 | lr, 248 | loss=losses, 249 | top1=top1, 250 | top5=top5, 251 | )) 252 | sys.stdout.flush() 253 | return top1.avg, top5.avg 254 | 255 | 256 | @torch.no_grad() 257 | def test(epoch, model, test_loader, logger, args): 258 | model.eval() 259 | top1 = AverageMeter() 260 | top5 = AverageMeter() 261 | 262 | for idx, (inputs, targets) in enumerate(test_loader): 263 | batch_size = targets.size(0) 264 | inputs, targets = inputs.cuda(), targets.cuda() 265 | out = model(inputs) 266 | 267 | acc1, acc5 = accuracy(out, targets, topk=(1, 5)) 268 | top1.update(reduce_value(acc1), batch_size) 269 | top5.update(reduce_value(acc5), batch_size) 270 | 271 | if idx % args.print_freq == 0: 272 | logger.info("Epoch(val) [{0}/{1}][{2}/{3}]\t" 273 | "Acc@1 {top1.val:.4f} ({top1.avg:.4f})\t" 274 | "Acc@5 {top5.val:.4f} ({top5.avg:.4f})".format( 275 | epoch, 276 | args.epoch, 277 | idx, 278 | len(test_loader), 279 | top1=top1, 280 | top5=top5, 281 | )) 282 | return top1.avg, top5.avg 283 | 284 | 285 | if __name__ == '__main__': 286 | main() 287 | -------------------------------------------------------------------------------- /models_cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .pyramidnet import * 3 | from .resnet import * 4 | -------------------------------------------------------------------------------- /models_cifar/densenet.py: -------------------------------------------------------------------------------- 1 | """dense net in pytorch 2 | 3 | 4 | 5 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger. 6 | 7 | Densely Connected Convolutional Networks 8 | https://arxiv.org/abs/1608.06993v5 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision.ops.roi_align import roi_align 14 | 15 | 16 | #"""Bottleneck layers. Although each layer only produces k 17 | #output feature-maps, it typically has many more inputs. It 18 | #has been noted in [37, 11] that a 1×1 convolution can be in- 19 | #troduced as bottleneck layer before each 3×3 convolution 20 | #to reduce the number of input feature-maps, and thus to 21 | #improve computational efficiency.""" 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_channels, growth_rate): 24 | super().__init__() 25 | #"""In our experiments, we let each 1×1 convolution 26 | #produce 4k feature-maps.""" 27 | inner_channel = 4 * growth_rate 28 | 29 | #"""We find this design especially effective for DenseNet and 30 | #we refer to our network with such a bottleneck layer, i.e., 31 | #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` , 32 | #as DenseNet-B.""" 33 | self.bottle_neck = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), 34 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 35 | nn.BatchNorm2d(inner_channel), nn.ReLU(inplace=True), 36 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False)) 37 | 38 | def forward(self, x): 39 | return torch.cat([x, self.bottle_neck(x)], 1) 40 | 41 | 42 | #"""We refer to layers between blocks as transition 43 | #layers, which do convolution and pooling.""" 44 | class Transition(nn.Module): 45 | def __init__(self, in_channels, out_channels): 46 | super().__init__() 47 | #"""The transition layers used in our experiments 48 | #consist of a batch normalization layer and an 1×1 49 | #convolutional layer followed by a 2×2 average pooling 50 | #layer""". 51 | self.down_sample = nn.Sequential(nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels, 52 | out_channels, 53 | 1, 54 | bias=False), nn.AvgPool2d(2, stride=2)) 55 | 56 | def forward(self, x): 57 | return self.down_sample(x) 58 | 59 | 60 | #DesneNet-BC 61 | #B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3)) 62 | #C stands for compression factor(0<=theta<=1) 63 | class DenseNet(nn.Module): 64 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=100): 65 | super().__init__() 66 | self.growth_rate = growth_rate 67 | 68 | #"""Before entering the first dense block, a convolution 69 | #with 16 (or twice the growth rate for DenseNet-BC) 70 | #output channels is performed on the input images.""" 71 | inner_channels = 2 * growth_rate 72 | 73 | #For convolutional layers with kernel size 3×3, each 74 | #side of the inputs is zero-padded by one pixel to keep 75 | #the feature-map size fixed. 76 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) 77 | 78 | self.features = nn.Sequential() 79 | 80 | for index in range(len(nblocks) - 1): 81 | self.features.add_module("dense_block_layer_{}".format(index), 82 | self._make_dense_layers(block, inner_channels, nblocks[index])) 83 | inner_channels += growth_rate * nblocks[index] 84 | 85 | #"""If a dense block contains m feature-maps, we let the 86 | #following transition layer generate θm output feature- 87 | #maps, where 0 < θ ≤ 1 is referred to as the compression 88 | #fac-tor. 89 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 90 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 91 | inner_channels = out_channels 92 | 93 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), 94 | self._make_dense_layers(block, inner_channels, nblocks[len(nblocks) - 1])) 95 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 96 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 97 | self.features.add_module('relu', nn.ReLU(inplace=True)) 98 | 99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 100 | 101 | self.linear = nn.Linear(inner_channels, num_classes) 102 | self.linear_roi = nn.Linear(inner_channels, num_classes) 103 | 104 | def forward(self, x, boxes=None, share_fc=False): 105 | bs = x.shape[0] 106 | sz = x.shape[-1] 107 | output = self.conv1(x) 108 | output = self.features(output) 109 | feat_map = output 110 | output = self.avgpool(output) 111 | output = output.view(output.size()[0], -1) 112 | output = self.linear(output) 113 | if boxes is not None: 114 | index = torch.arange(bs).view(-1, 1).to(x.device) 115 | boxes = torch.cat([index, boxes], 1) 116 | spatial_scale = feat_map.shape[-1] / sz 117 | roi_feat = roi_align(feat_map, 118 | boxes, 119 | output_size=(1, 1), 120 | spatial_scale=spatial_scale, 121 | sampling_ratio=-1, 122 | aligned=True).squeeze() 123 | if share_fc: 124 | out_roi = self.linear(roi_feat) 125 | else: 126 | out_roi = self.linear_roi(roi_feat) 127 | return output, out_roi 128 | return output 129 | 130 | def _make_dense_layers(self, block, in_channels, nblocks): 131 | dense_block = nn.Sequential() 132 | for index in range(nblocks): 133 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 134 | in_channels += self.growth_rate 135 | return dense_block 136 | 137 | 138 | def densenet121(**kwargs): 139 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, **kwargs) 140 | 141 | 142 | def densenet169(**kwargs): 143 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32, **kwargs) 144 | 145 | 146 | def densenet201(**kwargs): 147 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32, **kwargs) 148 | 149 | 150 | def densenet161(**kwargs): 151 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48, **kwargs) 152 | -------------------------------------------------------------------------------- /models_cifar/pyramidnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision.ops.roi_align import roi_align 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | outchannel_ratio = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn3 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x: torch.Tensor): 28 | 29 | out = self.bn1(x) 30 | out = self.conv1(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | out = self.conv2(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | shortcut = self.downsample(x) 38 | featuremap_size = shortcut.size()[2:4] 39 | else: 40 | shortcut = x 41 | featuremap_size = out.size()[2:4] 42 | 43 | batch_size = out.size()[0] 44 | residual_channel = out.size()[1] 45 | shortcut_channel = shortcut.size()[1] 46 | 47 | if residual_channel != shortcut_channel: 48 | padding = torch.autograd.Variable( 49 | x.new_zeros(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1])) 50 | out += torch.cat((shortcut, padding), 1) 51 | else: 52 | out += shortcut 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | outchannel_ratio = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.bn1 = nn.BatchNorm2d(inplanes) 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d((planes)) 67 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 68 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x: torch.Tensor): 74 | 75 | out = self.bn1(x) 76 | out = self.conv1(out) 77 | 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | 82 | out = self.bn3(out) 83 | out = self.relu(out) 84 | out = self.conv3(out) 85 | 86 | out = self.bn4(out) 87 | 88 | if self.downsample is not None: 89 | shortcut = self.downsample(x) 90 | featuremap_size = shortcut.size()[2:4] 91 | else: 92 | shortcut = x 93 | featuremap_size = out.size()[2:4] 94 | 95 | batch_size = out.size()[0] 96 | residual_channel = out.size()[1] 97 | shortcut_channel = shortcut.size()[1] 98 | 99 | if residual_channel != shortcut_channel: 100 | padding = torch.autograd.Variable( 101 | x.new_zeros(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1])) 102 | out += torch.cat((shortcut, padding), 1) 103 | else: 104 | out += shortcut 105 | 106 | return out 107 | 108 | 109 | class PyramidNet(nn.Module): 110 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False): 111 | super(PyramidNet, self).__init__() 112 | self.dataset = dataset 113 | if self.dataset.startswith('cifar'): 114 | self.inplanes = 16 115 | if bottleneck == True: 116 | n = int((depth - 2) / 9) 117 | block = Bottleneck 118 | else: 119 | n = int((depth - 2) / 6) 120 | block = BasicBlock 121 | 122 | self.addrate = alpha / (3 * n * 1.0) 123 | 124 | self.input_featuremap_dim = self.inplanes 125 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 126 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 127 | 128 | self.featuremap_dim = self.input_featuremap_dim 129 | self.layer1 = self.pyramidal_make_layer(block, n) 130 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 131 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 132 | 133 | self.final_featuremap_dim = self.input_featuremap_dim 134 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 135 | self.relu_final = nn.ReLU(inplace=True) 136 | self.avgpool = nn.AvgPool2d(8) 137 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 138 | self.fc_roi = nn.Linear(self.final_featuremap_dim, num_classes) 139 | 140 | elif self.dataset == 'imagenet': 141 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 142 | layers = { 143 | 18: [2, 2, 2, 2], 144 | 34: [3, 4, 6, 3], 145 | 50: [3, 4, 6, 3], 146 | 101: [3, 4, 23, 3], 147 | 152: [3, 8, 36, 3], 148 | 200: [3, 24, 36, 3] 149 | } 150 | 151 | if layers.get(depth) is None: 152 | if bottleneck == True: 153 | blocks[depth] = Bottleneck 154 | temp_cfg = int((depth - 2) / 12) 155 | else: 156 | blocks[depth] = BasicBlock 157 | temp_cfg = int((depth - 2) / 8) 158 | 159 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 160 | print('=> the layer configuration for each stage is set to', layers[depth]) 161 | 162 | self.inplanes = 64 163 | self.addrate = alpha / (sum(layers[depth]) * 1.0) 164 | 165 | self.input_featuremap_dim = self.inplanes 166 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 167 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 168 | self.relu = nn.ReLU(inplace=True) 169 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 170 | 171 | self.featuremap_dim = self.input_featuremap_dim 172 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 173 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 174 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 175 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 176 | 177 | self.final_featuremap_dim = self.input_featuremap_dim 178 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 179 | self.relu_final = nn.ReLU(inplace=True) 180 | self.avgpool = nn.AvgPool2d(7) 181 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 186 | m.weight.data.normal_(0, math.sqrt(2. / n)) 187 | elif isinstance(m, nn.BatchNorm2d): 188 | m.weight.data.fill_(1) 189 | m.bias.data.zero_() 190 | 191 | def pyramidal_make_layer(self, block, block_depth, stride=1): 192 | downsample = None 193 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 194 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 195 | 196 | layers = [] 197 | self.featuremap_dim = self.featuremap_dim + self.addrate 198 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) 199 | for i in range(1, block_depth): 200 | temp_featuremap_dim = self.featuremap_dim + self.addrate 201 | layers.append( 202 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) 203 | self.featuremap_dim = temp_featuremap_dim 204 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def forward(self, x, boxes=None, share_fc=False): 209 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 210 | bs = x.shape[0] 211 | sz = x.shape[-1] 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | 215 | x = self.layer1(x) 216 | x = self.layer2(x) 217 | x = self.layer3(x) 218 | 219 | x = self.bn_final(x) 220 | x = self.relu_final(x) 221 | feat_map = x 222 | x = self.avgpool(x) 223 | x = x.view(x.size(0), -1) 224 | x = self.fc(x) 225 | if boxes is not None: 226 | index = torch.arange(bs).view(-1, 1).to(x.device) 227 | boxes = torch.cat([index, boxes], 1) 228 | spatial_scale = feat_map.shape[-1] / sz 229 | roi_feat = roi_align(feat_map, 230 | boxes, 231 | output_size=(1, 1), 232 | spatial_scale=spatial_scale, 233 | sampling_ratio=-1, 234 | aligned=True).squeeze() 235 | if share_fc: 236 | out_roi = self.fc(roi_feat) 237 | else: 238 | out_roi = self.fc_roi(roi_feat) 239 | 240 | elif self.dataset == 'imagenet': 241 | x = self.conv1(x) 242 | x = self.bn1(x) 243 | x = self.relu(x) 244 | x = self.maxpool(x) 245 | 246 | x = self.layer1(x) 247 | x = self.layer2(x) 248 | x = self.layer3(x) 249 | x = self.layer4(x) 250 | 251 | x = self.bn_final(x) 252 | x = self.relu_final(x) 253 | x = self.avgpool(x) 254 | x = x.view(x.size(0), -1) 255 | x = self.fc(x) 256 | 257 | if boxes is not None: 258 | return x, out_roi 259 | return x 260 | 261 | 262 | def pyramidnet_200_240(num_classes=100): 263 | net = PyramidNet(dataset='cifar100', depth=200, alpha=240, num_classes=num_classes, bottleneck=True) 264 | return net 265 | 266 | 267 | def pyramidnet_164_270(num_classes=100): 268 | net = PyramidNet(dataset='cifar100', depth=164, alpha=270, num_classes=num_classes, bottleneck=True) 269 | return net 270 | 271 | 272 | if __name__ == '__main__': 273 | net = pyramidnet_200_240(100) 274 | img = torch.rand(10, 3, 32, 32) 275 | boxes = torch.Tensor([0, 0, 0, 0]).float() 276 | y = net(img, None) 277 | print(len(y)) -------------------------------------------------------------------------------- /models_cifar/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision.ops.roi_align import roi_align 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | """Basic Block for resnet 18 and resnet 34 18 | 19 | """ 20 | 21 | #BasicBlock and BottleNeck block 22 | #have different output size 23 | #we use class attribute expansion 24 | #to distinct 25 | expansion = 1 26 | 27 | def __init__(self, in_channels, out_channels, stride=1): 28 | super().__init__() 29 | 30 | #residual function 31 | self.residual_function = nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 33 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), 34 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 35 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)) 46 | 47 | def forward(self, x): 48 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 49 | 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | 57 | def __init__(self, in_channels, out_channels, stride=1): 58 | super().__init__() 59 | self.residual_function = nn.Sequential( 60 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 61 | nn.BatchNorm2d(out_channels), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 64 | nn.BatchNorm2d(out_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 67 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 68 | ) 69 | 70 | self.shortcut = nn.Sequential() 71 | 72 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 75 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | 81 | class ResNet(nn.Module): 82 | def __init__(self, block, num_block, num_classes=100): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), 88 | nn.ReLU(inplace=True)) 89 | #we use a different inputsize than the original paper 90 | #so conv2_x's stride is 1 91 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 92 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 93 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 94 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 95 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 96 | self.fc = nn.Linear(512 * block.expansion, num_classes) 97 | self.fc_roi = nn.Linear(512 * block.expansion, num_classes) 98 | 99 | def _make_layer(self, block, out_channels, num_blocks, stride): 100 | """make resnet layers(by layer i didnt mean this 'layer' was the 101 | same as a neuron netowork layer, ex. conv layer), one layer may 102 | contain more than one residual block 103 | 104 | Args: 105 | block: block type, basic block or bottle neck block 106 | out_channels: output depth channel number of this layer 107 | num_blocks: how many blocks per layer 108 | stride: the stride of the first block of this layer 109 | 110 | Return: 111 | return a resnet layer 112 | """ 113 | 114 | # we have num_block blocks per layer, the first block 115 | # could be 1 or 2, other blocks would always be 1 116 | strides = [stride] + [1] * (num_blocks - 1) 117 | layers = [] 118 | for stride in strides: 119 | layers.append(block(self.in_channels, out_channels, stride)) 120 | self.in_channels = out_channels * block.expansion 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x, boxes=None, share_fc=False): 125 | bs = x.shape[0] 126 | sz = x.shape[-1] 127 | output = self.conv1(x) 128 | output = self.conv2_x(output) 129 | output = self.conv3_x(output) 130 | output = self.conv4_x(output) 131 | output = self.conv5_x(output) 132 | feat_map = output 133 | output = self.avg_pool(output) 134 | output = output.view(output.size(0), -1) 135 | output = self.fc(output) 136 | 137 | if boxes is not None: 138 | index = torch.arange(bs).view(-1, 1).to(x.device) 139 | boxes = torch.cat([index, boxes], 1) 140 | spatial_scale = feat_map.shape[-1] / sz 141 | roi_feat = roi_align(feat_map, 142 | boxes, 143 | output_size=(1, 1), 144 | spatial_scale=spatial_scale, 145 | sampling_ratio=-1, 146 | aligned=True).squeeze() 147 | if share_fc: 148 | out_roi = self.fc(roi_feat) 149 | else: 150 | out_roi = self.fc_roi(roi_feat) 151 | return output, out_roi 152 | return output 153 | 154 | 155 | def resnet18(**kwargs): 156 | """ return a ResNet 18 object 157 | """ 158 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 159 | 160 | 161 | def resnet34(**kwargs): 162 | """ return a ResNet 34 object 163 | """ 164 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 165 | 166 | 167 | def resnet50(**kwargs): 168 | """ return a ResNet 50 object 169 | """ 170 | return ResNet(BottleNeck, [3, 4, 6, 3], **kwargs) 171 | 172 | 173 | def resnet101(**kwargs): 174 | """ return a ResNet 101 object 175 | """ 176 | return ResNet(BottleNeck, [3, 4, 23, 3], **kwargs) 177 | 178 | 179 | def resnet152(**kwargs): 180 | """ return a ResNet 152 object 181 | """ 182 | return ResNet(BottleNeck, [3, 8, 36, 3], **kwargs) 183 | -------------------------------------------------------------------------------- /models_imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models_imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | from torchvision.ops.roi_align import roi_align 5 | 6 | __all__ = [ 7 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2' 9 | ] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, 27 | out_planes, 28 | kernel_size=3, 29 | stride=stride, 30 | padding=dilation, 31 | groups=groups, 32 | bias=False, 33 | dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, 45 | inplanes, 46 | planes, 47 | stride=1, 48 | downsample=None, 49 | groups=1, 50 | base_width=64, 51 | dilation=1, 52 | norm_layer=None): 53 | super(BasicBlock, self).__init__() 54 | if norm_layer is None: 55 | norm_layer = nn.BatchNorm2d 56 | if groups != 1 or base_width != 64: 57 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 58 | if dilation > 1: 59 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 60 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = norm_layer(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = norm_layer(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | identity = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 92 | # This variant is also known as ResNet V1.5 and improves accuracy according to 93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 94 | 95 | expansion = 4 96 | 97 | def __init__(self, 98 | inplanes, 99 | planes, 100 | stride=1, 101 | downsample=None, 102 | groups=1, 103 | base_width=64, 104 | dilation=1, 105 | norm_layer=None): 106 | super(Bottleneck, self).__init__() 107 | if norm_layer is None: 108 | norm_layer = nn.BatchNorm2d 109 | width = int(planes * (base_width / 64.)) * groups 110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 111 | self.conv1 = conv1x1(inplanes, width) 112 | self.bn1 = norm_layer(width) 113 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 114 | self.bn2 = norm_layer(width) 115 | self.conv3 = conv1x1(width, planes * self.expansion) 116 | self.bn3 = norm_layer(planes * self.expansion) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x): 122 | identity = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet(nn.Module): 145 | def __init__(self, 146 | block, 147 | layers, 148 | num_classes=1000, 149 | zero_init_residual=False, 150 | groups=1, 151 | width_per_group=64, 152 | replace_stride_with_dilation=None, 153 | norm_layer=None): 154 | super(ResNet, self).__init__() 155 | if norm_layer is None: 156 | norm_layer = nn.BatchNorm2d 157 | self._norm_layer = norm_layer 158 | 159 | self.inplanes = 64 160 | self.dilation = 1 161 | if replace_stride_with_dilation is None: 162 | # each element in the tuple indicates if we should replace 163 | # the 2x2 stride with a dilated convolution instead 164 | replace_stride_with_dilation = [False, False, False] 165 | if len(replace_stride_with_dilation) != 3: 166 | raise ValueError("replace_stride_with_dilation should be None " 167 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 168 | self.groups = groups 169 | self.base_width = width_per_group 170 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 171 | self.bn1 = norm_layer(self.inplanes) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 174 | self.layer1 = self._make_layer(block, 64, layers[0]) 175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 176 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 177 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 178 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 179 | self.fc = nn.Linear(512 * block.expansion, num_classes) 180 | self.fc_roi = nn.Linear(512 * block.expansion, num_classes) 181 | 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 185 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 186 | nn.init.constant_(m.weight, 1) 187 | nn.init.constant_(m.bias, 0) 188 | 189 | # Zero-initialize the last BN in each residual branch, 190 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 191 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 192 | if zero_init_residual: 193 | for m in self.modules(): 194 | if isinstance(m, Bottleneck): 195 | nn.init.constant_(m.bn3.weight, 0) 196 | elif isinstance(m, BasicBlock): 197 | nn.init.constant_(m.bn2.weight, 0) 198 | 199 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 200 | norm_layer = self._norm_layer 201 | downsample = None 202 | previous_dilation = self.dilation 203 | if dilate: 204 | self.dilation *= stride 205 | stride = 1 206 | if stride != 1 or self.inplanes != planes * block.expansion: 207 | downsample = nn.Sequential( 208 | conv1x1(self.inplanes, planes * block.expansion, stride), 209 | norm_layer(planes * block.expansion), 210 | ) 211 | 212 | layers = [] 213 | layers.append( 214 | block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, 215 | norm_layer)) 216 | self.inplanes = planes * block.expansion 217 | for _ in range(1, blocks): 218 | layers.append( 219 | block(self.inplanes, 220 | planes, 221 | groups=self.groups, 222 | base_width=self.base_width, 223 | dilation=self.dilation, 224 | norm_layer=norm_layer)) 225 | 226 | return nn.Sequential(*layers) 227 | 228 | def forward(self, x, boxes=None, share_fc=False): 229 | # See note [TorchScript super()] 230 | bs = x.shape[0] 231 | sz = x.shape[-1] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | feat_map = x 242 | 243 | x = self.avgpool(x) 244 | x = torch.flatten(x, 1) 245 | x = self.fc(x) 246 | if boxes is not None: 247 | index = torch.arange(bs).view(-1, 1).to(x.device) 248 | boxes = torch.cat([index, boxes], 1) 249 | spatial_scale = feat_map.shape[-1] / sz 250 | roi_feat = roi_align(feat_map, 251 | boxes, 252 | output_size=(1, 1), 253 | spatial_scale=spatial_scale, 254 | sampling_ratio=-1, 255 | aligned=True).squeeze() 256 | if share_fc: 257 | out_roi = self.fc(roi_feat) 258 | else: 259 | out_roi = self.fc_roi(roi_feat) 260 | return x, out_roi 261 | 262 | return x 263 | 264 | 265 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 266 | model = ResNet(block, layers, **kwargs) 267 | if pretrained: 268 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 269 | model.load_state_dict(state_dict) 270 | return model 271 | 272 | 273 | def resnet18(pretrained=False, progress=True, **kwargs): 274 | r"""ResNet-18 model from 275 | `"Deep Residual Learning for Image Recognition" `_ 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | progress (bool): If True, displays a progress bar of the download to stderr 280 | """ 281 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 282 | 283 | 284 | def resnet34(pretrained=False, progress=True, **kwargs): 285 | r"""ResNet-34 model from 286 | `"Deep Residual Learning for Image Recognition" `_ 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | progress (bool): If True, displays a progress bar of the download to stderr 291 | """ 292 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 293 | 294 | 295 | def resnet50(pretrained=False, progress=True, **kwargs): 296 | r"""ResNet-50 model from 297 | `"Deep Residual Learning for Image Recognition" `_ 298 | 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 304 | 305 | 306 | def resnet101(pretrained=False, progress=True, **kwargs): 307 | r"""ResNet-101 model from 308 | `"Deep Residual Learning for Image Recognition" `_ 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 315 | 316 | 317 | def resnet152(pretrained=False, progress=True, **kwargs): 318 | r"""ResNet-152 model from 319 | `"Deep Residual Learning for Image Recognition" `_ 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) 326 | 327 | 328 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 329 | r"""ResNeXt-50 32x4d model from 330 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['groups'] = 32 337 | kwargs['width_per_group'] = 4 338 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 339 | 340 | 341 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 342 | r"""ResNeXt-101 32x8d model from 343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | kwargs['groups'] = 32 350 | kwargs['width_per_group'] = 8 351 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 352 | 353 | 354 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 355 | r"""Wide ResNet-50-2 model from 356 | `"Wide Residual Networks" `_ 357 | 358 | The model is the same as ResNet except for the bottleneck number of channels 359 | which is twice larger in every block. The number of channels in outer 1x1 360 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 361 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 362 | 363 | Args: 364 | pretrained (bool): If True, returns a model pre-trained on ImageNet 365 | progress (bool): If True, displays a progress bar of the download to stderr 366 | """ 367 | kwargs['width_per_group'] = 64 * 2 368 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 369 | 370 | 371 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 372 | r"""Wide ResNet-101-2 model from 373 | `"Wide Residual Networks" `_ 374 | 375 | The model is the same as ResNet except for the bottleneck number of channels 376 | which is twice larger in every block. The number of channels in outer 1x1 377 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 378 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 379 | 380 | Args: 381 | pretrained (bool): If True, returns a model pre-trained on ImageNet 382 | progress (bool): If True, displays a progress bar of the download to stderr 383 | """ 384 | kwargs['width_per_group'] = 64 * 2 385 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 386 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import math 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | class RASampler(torch.utils.data.Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset for distributed, 11 | with repeated augmentation. 12 | It ensures that different each augmented version of a sample will be visible to a 13 | different process (GPU) 14 | Heavily based on torch.utils.data.DistributedSampler 15 | """ 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | 13 | def str2bool(v): 14 | if isinstance(v, bool): 15 | return v 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 17 | return True 18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 19 | return False 20 | else: 21 | raise argparse.ArgumentTypeError('Boolean value expected.') 22 | 23 | 24 | def print_peak_memory(prefix, device): 25 | if device == 0: 26 | print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ") 27 | 28 | 29 | def get_flops(model, img_size=224, backend='ptflops'): 30 | if backend == 'thop': 31 | from thop import clever_format, profile 32 | bs = 2 33 | img = torch.randn(bs, 3, img_size, img_size) 34 | flops, params = profile(model, inputs=(img, )) 35 | flops = flops / bs 36 | flops, params = clever_format([flops, params], "%.3f") 37 | else: 38 | from ptflops import get_model_complexity_info 39 | flops, params = get_model_complexity_info(model, (3, img_size, img_size), 40 | as_strings=True, 41 | print_per_layer_stat=True, 42 | verbose=True) 43 | 44 | print('{:<30} {:<8}'.format('Computational complexity: ', flops)) 45 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 46 | 47 | 48 | def load_checkpoint(): 49 | pass 50 | 51 | 52 | def save_checkpoint(epoch, model, optimizer, args, save_name='latest'): 53 | if args.save_model and (not args.ddp or (args.ddp and args.local_rank == 0)): 54 | state_dict = { 55 | 'epoch': epoch, 56 | 'state_dict': model.state_dict(), 57 | 'optimizer': optimizer.state_dict(), 58 | } 59 | torch.save(state_dict, os.path.join(args.path_log, 'fold%s_%s.pth' % (args.fold, save_name))) 60 | 61 | 62 | @torch.no_grad() 63 | def smooth_one_hot(target: torch.Tensor, num_classes: int, smoothing=0.0): 64 | """ 65 | if smoothing == 0, it's one-hot method 66 | if 0 < smoothing < 1, it's smooth method 67 | """ 68 | assert 0 <= smoothing < 1 69 | confidence = 1.0 - smoothing 70 | true_dist = target.new_zeros(size=(len(target), num_classes)).float() 71 | true_dist.fill_(smoothing / (num_classes - 1)) 72 | true_dist.scatter_(1, target.data.unsqueeze(1), confidence) 73 | return true_dist 74 | 75 | 76 | def reduce_value(value, average=True): 77 | if dist.is_available() and dist.is_initialized(): 78 | world_size = dist.get_world_size() 79 | if world_size < 2: # single gpu 80 | return value 81 | 82 | with torch.no_grad(): 83 | dist.all_reduce(value) # sum 84 | if average: 85 | value /= world_size # mean 86 | return value 87 | 88 | 89 | def create_logging(log_file=None, log_level=logging.INFO, file_mode='a'): 90 | """Initialize and get a logger. 91 | If the logger has not been initialized, this method will initialize the 92 | logger by adding one or two handlers, otherwise the initialized logger will 93 | be directly returned. During initialization, a StreamHandler will always be 94 | added. If `log_file` is specified and the process rank is 0, a FileHandler 95 | will also be added. 96 | 97 | Args: 98 | log_file (str | None): The log filename. If specified, a FileHandler 99 | will be added to the logger. 100 | log_level (int): The logger level. Note that only the process of 101 | rank 0 is affected, and other processes will set the level to 102 | "Error" thus be silent most of the time. 103 | file_mode (str): The file mode used in opening log file. 104 | Defaults to 'w'. 105 | 106 | Returns: 107 | logging.Logger: The expected logger. 108 | """ 109 | logger = logging.getLogger() 110 | 111 | handlers = [] 112 | stream_handler = logging.StreamHandler() 113 | handlers.append(stream_handler) 114 | 115 | if dist.is_available() and dist.is_initialized(): 116 | rank = dist.get_rank() 117 | else: 118 | rank = 0 119 | 120 | # only rank 0 will add a FileHandler 121 | if rank == 0 and log_file is not None: 122 | # Here, the default behaviour of the official logger is 'a'. Thus, we 123 | # provide an interface to change the file mode to the default 124 | # behaviour. 125 | file_handler = logging.FileHandler(log_file, file_mode) 126 | handlers.append(file_handler) 127 | 128 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 129 | for handler in handlers: 130 | handler.setFormatter(formatter) 131 | handler.setLevel(log_level) 132 | logger.addHandler(handler) 133 | 134 | if rank == 0: 135 | logger.setLevel(log_level) 136 | else: 137 | logger.setLevel(logging.ERROR) 138 | 139 | return logger 140 | 141 | 142 | class AverageMeter(object): 143 | """Computes and stores the average and current value""" 144 | def __init__(self): 145 | self.reset() 146 | 147 | def reset(self): 148 | self.val = 0 149 | self.avg = 0 150 | self.sum = 0 151 | self.count = 0 152 | 153 | def update(self, val, n=1): 154 | self.val = val 155 | self.sum += val * n 156 | self.count += n 157 | self.avg = self.sum / self.count 158 | 159 | 160 | def accuracy(output, target, topk=(1, )): 161 | """Computes the accuracy over the k top predictions for the specified values of k""" 162 | with torch.no_grad(): 163 | maxk = max(topk) 164 | batch_size = target.size(0) 165 | 166 | _, pred = output.topk(maxk, 1, True, True) 167 | pred = pred.t() 168 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 169 | 170 | res = [] 171 | for k in topk: 172 | correct_k = correct[:k].float().sum() 173 | res.append(correct_k.mul_(1.0 / batch_size)) 174 | return res 175 | 176 | 177 | def adjust_learning_rate(optimizer, epoch, args): 178 | # epoch >= 1 179 | assert args.scheduler in ['step', 'cos'] 180 | if epoch <= args.warmup: 181 | lr = args.lr * (epoch / (args.warmup + 1)) 182 | elif args.scheduler == 'step': 183 | exp = 0 184 | for mile_stone in args.schedule: 185 | if epoch > mile_stone: 186 | exp += 1 187 | lr = args.lr * (args.lr_decay**exp) 188 | elif args.scheduler == 'cos': 189 | decay_rate = 0.5 * (1 + np.cos((epoch - 1) * np.pi / args.epoch)) 190 | lr = args.lr * decay_rate 191 | else: 192 | raise NotImplementedError 193 | 194 | for param_group in optimizer.param_groups: 195 | param_group['lr'] = lr 196 | return lr 197 | 198 | 199 | def lr_scheduler(optimizer, scheduler, schedule, lr_decay, total_epoch): 200 | optimizer.zero_grad() 201 | optimizer.step() 202 | if scheduler == 'step': 203 | return optim.lr_scheduler.MultiStepLR(optimizer, schedule, gamma=lr_decay) 204 | elif scheduler == 'cos': 205 | return optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch) 206 | else: 207 | raise NotImplementedError('{} learning rate is not implemented.') 208 | 209 | 210 | def mixed_criterion(criterion, pred, y_a, y_b, lam): 211 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 212 | 213 | 214 | class Compose(object): 215 | """Composes several transforms together. 216 | Args: 217 | transforms (list of ``Transform`` objects): list of transforms to compose. 218 | Example: 219 | >>> transforms.Compose([ 220 | >>> transforms.CenterCrop(10), 221 | >>> transforms.ToTensor(), 222 | >>> ]) 223 | """ 224 | def __init__(self, transforms): 225 | self.transforms = transforms 226 | 227 | def __call__(self, img: torch.Tensor): 228 | for t in self.transforms: 229 | img = t(img) 230 | return img 231 | 232 | def __repr__(self): 233 | format_string = self.__class__.__name__ + '(' 234 | for t in self.transforms: 235 | format_string += '\n' 236 | format_string += ' {0}'.format(t) 237 | format_string += '\n)' 238 | return format_string 239 | 240 | 241 | class Lighting(object): 242 | """Lighting noise(AlexNet - style PCA - based noise)""" 243 | def __init__(self, alphastd, eigval, eigvec): 244 | self.alphastd = alphastd 245 | self.eigval = torch.Tensor(eigval) 246 | self.eigvec = torch.Tensor(eigvec) 247 | 248 | def __call__(self, img: torch.Tensor): 249 | if self.alphastd == 0: 250 | return img 251 | 252 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 253 | rgb = self.eigvec.type_as(img).clone().mul(alpha.view(1, 3).expand(3, 3)).mul( 254 | self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze() 255 | 256 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 257 | 258 | 259 | class Grayscale(object): 260 | def __call__(self, img: torch.Tensor): 261 | gs = img.clone() 262 | gs[0].mul_(0.2989).add_(gs[1], alpha=0.587).add_(gs[2], alpha=0.114) 263 | gs[1].copy_(gs[0]) 264 | gs[2].copy_(gs[0]) 265 | return gs 266 | 267 | 268 | class Saturation(object): 269 | def __init__(self, var): 270 | self.var = var 271 | 272 | def __call__(self, img: torch.Tensor): 273 | gs = Grayscale()(img) 274 | alpha = random.uniform(-self.var, self.var) 275 | return img.lerp(gs, alpha) 276 | 277 | 278 | class Brightness(object): 279 | def __init__(self, var): 280 | self.var = var 281 | 282 | def __call__(self, img: torch.Tensor): 283 | gs = img.new().resize_as_(img).zero_() 284 | alpha = random.uniform(-self.var, self.var) 285 | return img.lerp(gs, alpha) 286 | 287 | 288 | class Contrast(object): 289 | def __init__(self, var): 290 | self.var = var 291 | 292 | def __call__(self, img: torch.Tensor): 293 | gs = Grayscale()(img) 294 | gs.fill_(gs.mean()) 295 | alpha = random.uniform(-self.var, self.var) 296 | return img.lerp(gs, alpha) 297 | 298 | 299 | class ColorJitter(object): 300 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 301 | self.brightness = brightness 302 | self.contrast = contrast 303 | self.saturation = saturation 304 | 305 | def __call__(self, img: torch.Tensor): 306 | self.transforms = [] 307 | if self.brightness != 0: 308 | self.transforms.append(Brightness(self.brightness)) 309 | if self.contrast != 0: 310 | self.transforms.append(Contrast(self.contrast)) 311 | if self.saturation != 0: 312 | self.transforms.append(Saturation(self.saturation)) 313 | 314 | random.shuffle(self.transforms) 315 | transform = Compose(self.transforms) 316 | return transform(img) 317 | 318 | 319 | def mixup(x, y, alpha=0.4): 320 | index = torch.randperm(x.size(0)).to(x.device) 321 | lam = np.random.beta(alpha, alpha) 322 | 323 | x = lam * x + (1 - lam) * x[index] 324 | y = lam * y + (1 - lam) * y[index] 325 | return x, y 326 | 327 | 328 | def cutmix(x, y, alpha=1.0): 329 | def rand_bbox(size, alpha): 330 | H = size[2] 331 | W = size[3] 332 | 333 | cut_rat = np.sqrt(1. - np.random.beta(alpha, alpha)) 334 | cut_w = np.int(W * cut_rat) 335 | cut_h = np.int(H * cut_rat) 336 | 337 | # uniform 338 | cx = np.random.randint(W) 339 | cy = np.random.randint(H) 340 | 341 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 342 | bby1 = np.clip(cy - cut_h // 2, 0, H) 343 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 344 | bby2 = np.clip(cy + cut_h // 2, 0, H) 345 | return bbx1, bby1, bbx2, bby2 346 | 347 | index = torch.randperm(x.size(0)).to(x.device) 348 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), alpha) 349 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 350 | 351 | x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2] 352 | y = lam * y + (1 - lam) * y[index] 353 | return x, y 354 | 355 | 356 | def recursive_mix(x, old_x, y, old_y, alpha, interpolate_mode): 357 | def rand_bbox(size, alpha): 358 | H = size[2] 359 | W = size[3] 360 | 361 | cut_rat = np.sqrt(random.uniform(0.0, alpha)) 362 | cut_w = np.int(W * cut_rat) 363 | cut_h = np.int(H * cut_rat) 364 | 365 | # uniform 366 | cx = np.random.randint(W) 367 | cy = np.random.randint(H) 368 | 369 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 370 | bby1 = np.clip(cy - cut_h // 2, 0, H) 371 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 372 | bby2 = np.clip(cy + cut_h // 2, 0, H) 373 | return bbx1, bby1, bbx2, bby2 374 | 375 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), alpha) 376 | 377 | size = (bby2 - bby1, bbx2 - bbx1) 378 | bs = x.size(0) 379 | if size != (0, 0): 380 | align_corners = None if interpolate_mode == 'nearest' else True 381 | x[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(old_x[:bs], 382 | size=size, 383 | mode=interpolate_mode, 384 | align_corners=align_corners) 385 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 386 | 387 | y = lam * y + (1 - lam) * old_y[:bs] 388 | boxes = torch.Tensor([bbx1, bby1, bbx2, bby2]).float().to(x.device) 389 | boxes = boxes[None].expand(bs, 4) 390 | return x, y, boxes, lam 391 | --------------------------------------------------------------------------------