├── .gitignore
├── README.md
├── dataset.py
├── figs
├── RM.png
└── RM.svg
├── logs
├── log_ade20k_80k_upernet_r50_baseline_40.40_miou.json
├── log_ade20k_80k_upernet_r50_cutmix_41.24_miou.json
├── log_ade20k_80k_upernet_r50_rm_42.30_miou.json
├── log_cifar10_300epoch_pyramidnet_rm_97.60_top1_acc.log
├── log_cifar10_300epoch_pyramidnet_rm_97.62_top1_acc.log
├── log_cifar10_300epoch_pyramidnet_rm_97.72_top1_acc.log
├── log_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.json
├── log_coco_12epoch_atss_r50_fpn_rm_41.5_ap.json
└── log_imagenet_300epoch_r50_rm_79.20_top1_acc.log
├── loss
├── __init__.py
└── label_smooth_loss.py
├── main.py
├── models_cifar
├── __init__.py
├── densenet.py
├── pyramidnet.py
└── resnet.py
├── models_imagenet
├── __init__.py
└── resnet.py
├── samplers.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pyc
3 | *.so
4 | checkpoints/
5 | *.pth
6 | *.pth.tar
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code for 'RecursiveMix: Mixed Learning with History'
2 |
3 |

4 |
5 | RecursiveMix (RM), which uses the historical input-prediction-label triplet to enhance the generalization of Deep Vision Models. [Paper Link Here.](https://arxiv.org/pdf/2203.06844.pdf)
6 |
7 | ## Requirements
8 |
9 | Experiment Environment
10 | - python 3.6
11 | - pytorch 1.7.1+cu101
12 | - torchvision 0.8.2
13 | - mmcv-full 1.4.1
14 | - mmdet 2.19.1
15 | - mmsegmentation 0.20.2
16 |
17 | ## Usage
18 |
19 | ### 1. Train the model
20 | For example, to reproduce the results of RM in CIFAR-10 (97.65% Top-1 acc in averaged 3 runs, logs are provided in logs/):
21 | ```python
22 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
23 | --name 'your_experiment_log_path' \
24 | --model_file 'pyramidnet' \
25 | --model_name 'pyramidnet_200_240' \
26 | --data 'cifar10' \
27 | --data_dir '/path/to/CIFAR10' \
28 | --epoch 300 \
29 | --batch_size 64 \
30 | --lr 0.25 \
31 | --scheduler 'step' \
32 | --schedule 150 225 \
33 | --weight_decay 1e-4 \
34 | --nesterov \
35 | --num_workers 8 \
36 | --save_model \
37 | --aug 'recursive_mix' \
38 | --aug_alpha 0.5 \
39 | --aug_omega 0.1
40 | ```
41 |
42 | RM in ImageNet (79.20% Top-1 acc)
43 | ```python
44 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port 29500 main.py \
45 | --name 'your_experiment_log_path' \
46 | --model_file 'resnet' \
47 | --model_name 'resnet50' \
48 | --data 'imagenet' \
49 | --epoch 300 \
50 | --batch_size 512 \
51 | --lr 0.2 \
52 | --warmup 5 \
53 | --weight_decay 1e-4 \
54 | --aug_plus \
55 | --num_workers 32 \
56 | --save_model \
57 | --aug 'recursive_mix' \
58 | --aug_alpha 0.5 \
59 | --aug_omega 0.5
60 | ```
61 |
62 |
63 | ### 2. Test the model
64 | ```python
65 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29500 main.py \
66 | --name 'your_experiment_log_path' \
67 | --batch_size 64 \
68 | --model_file 'pyramidnet' \
69 | --model_name 'pyramidnet_200_240' \
70 | --data 'cifar10' \
71 | --data_dir '/path/to/CIFAR10' \
72 | --num_workers 8 \
73 | --evaluate \
74 | --resume 'best'
75 | ```
76 |
77 | ## Model Zoo
78 | ### Image Classification
79 | - ImageNet-1K (300 epoch)
80 |
81 | | Backbone | Size | Params (M) | Acc@1 | Log | Download |
82 | | -------------- | :---: | :--------: | :-------: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------: |
83 | | ResNet-50 | 224 | 25.56 | 76.32 | log | [Google] [GitHub] |
84 | | + Mixup | 224 | 25.56 | 77.42 | log | [Google] [GitHub] |
85 | | + CutMix | 224 | 25.56 | 78.60 | log | [Google] [GitHub] |
86 | | + RecursiveMix | 224 | 25.56 | **79.20** | [log](logs/log_imagenet_300epoch_r50_rm_79.20_top1_acc.log) | [[Google]](https://drive.google.com/file/d/19dlKcrTgfY3UqAOFNXkLmIwX41wUOOPB/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_imagenet_300epoch_r50_rm_79.20_top1_acc.pth) |
87 |
88 | ### Object Detection
89 |
90 | - COCO (1x schedule)
91 |
92 | #### ATSS
93 |
94 | | Backbone | Lr schd | Mem (GB) | Inf time (fps) | box AP | Log | Download |
95 | | -------------- | :-----: | :------: | :------------: | :------: | :----------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------: |
96 | | ResNet-50 | 1x | 3.7 | 19.7 | 39.4 | [log](https://download.openmmlab.com/mmdetection/v2.0/atss/atss_r50_fpn_1x_coco/atss_r50_fpn_1x_coco_20200209_102539.log.json) | [[Google]](https://drive.google.com/file/d/1DVVarV-os8BwEGgnBHkkCpbdSWUquYIV/view?usp=sharing) [[GitHub]](https://download.openmmlab.com/mmdetection/v2.0/atss/atss_r50_fpn_1x_coco/atss_r50_fpn_1x_coco_20200209-985f7bd0.pth) |
97 | | + CutMix | 1x | 3.7 | 19.7 | 40.1 | [log](logs/log_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.json) | [[Google]](https://drive.google.com/file/d/1T2fVCmwyMMzBdxg5QevCIRmey9vlc7JZ/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_coco_12epoch_atss_r50_fpn_cutmix_40.1_ap.pth) |
98 | | + RecursiveMix | 1x | 3.7 | 19.7 | **41.5** | [log](logs/log_coco_12epoch_atss_r50_fpn_rm_41.5_ap.json) | [[Google]](https://drive.google.com/file/d/1iFqrhkrm05LqQw_1W2q7lcl0WB_WAgNA/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_coco_12epoch_atss_r50_fpn_rm_41.5_ap.pth) |
99 |
100 | ### Semantic Segmentation
101 |
102 | - ADE20K (80k iteration)
103 |
104 | #### UPerNet
105 | | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | Log | download |
106 | | -------------- | :-------: | :-----: | :------: | :------------: | :-------: | :-------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------: |
107 | | ResNet-50 | 512x512 | 80000 | 8.1 | 23.40 | 40.40 | [log](logs/log_ade20k_80k_upernet_r50_baseline_40.40_miou.json) | [[Google]](https://drive.google.com/file/d/1ps0KJZUTM_e-0aaDrTE0pjvTTCSO3qpH/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_baseline_40.40_miou.pth) |
108 | | + CutMix | 512x512 | 80000 | 8.1 | 23.40 | 41.24 | [log](logs/log_ade20k_80k_upernet_r50_cutmix_41.24_miou.json) | [[Google]](https://drive.google.com/file/d/1GIqjeir1hXPC8z6pkNYaL-9lBFgistJh/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_cutmix_41.24_miou.pth) |
109 | | + RecursiveMix | 512x512 | 80000 | 8.1 | 23.40 | **42.30** | [log](logs/log_ade20k_80k_upernet_r50_rm_42.30_miou.json) | [[Google]](https://drive.google.com/file/d/1mysKzWVGnaEAZ61Vcdg8J_tHrU2Aotnt/view?usp=sharing) [[GitHub]](https://github.com/implus/RecursiveMix/releases/download/v0.0/checkpoint_ade20k_80k_upernet_r50_rm_42.30_miou.pth) |
110 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch.utils.data import DataLoader, DistributedSampler
4 | from torchvision import datasets, transforms
5 |
6 | from samplers import RASampler
7 | from utils import ColorJitter, Lighting
8 |
9 |
10 | def create_loader(args):
11 | loader = {
12 | 'cifar10': cifar10_loader,
13 | 'cifar100': cifar100_loader,
14 | 'imagenet': imagenet_loader,
15 | }
16 | trainset, testset = loader[args.data](args)
17 |
18 | if args.ddp:
19 | if args.repeated_aug:
20 | train_sampler = RASampler(trainset, shuffle=True)
21 | else:
22 | train_sampler = DistributedSampler(trainset, shuffle=True)
23 | test_sampler = DistributedSampler(testset, shuffle=False)
24 |
25 | train_loader = DataLoader(trainset,
26 | args.batch_size,
27 | sampler=train_sampler,
28 | num_workers=args.num_workers,
29 | pin_memory=True)
30 | test_loader = DataLoader(testset,
31 | args.batch_size,
32 | sampler=test_sampler,
33 | num_workers=args.num_workers,
34 | pin_memory=True)
35 | else:
36 | train_loader = DataLoader(trainset,
37 | args.batch_size,
38 | shuffle=True,
39 | num_workers=args.num_workers,
40 | pin_memory=True)
41 | test_loader = DataLoader(testset, args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
42 |
43 | return train_loader, test_loader
44 |
45 |
46 | def cifar10_loader(args):
47 | args.num_classes = 10
48 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
49 | transform_train = transforms.Compose([
50 | transforms.RandomCrop(32, padding=4),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | normalize,
54 | ])
55 | transform_test = transforms.Compose([
56 | transforms.ToTensor(),
57 | normalize,
58 | ])
59 |
60 | trainset = datasets.CIFAR10(root=args.data_dir, train=True, download=False, transform=transform_train)
61 | testset = datasets.CIFAR10(root=args.data_dir, train=False, download=False, transform=transform_test)
62 |
63 | return trainset, testset
64 |
65 |
66 | def cifar100_loader(args):
67 | args.num_classes = 100
68 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
69 | transform_train = transforms.Compose([
70 | transforms.RandomCrop(32, padding=4),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | normalize,
74 | ])
75 | transform_test = transforms.Compose([
76 | transforms.ToTensor(),
77 | normalize,
78 | ])
79 |
80 | trainset = datasets.CIFAR100(root=args.data_dir, train=True, download=False, transform=transform_train)
81 | testset = datasets.CIFAR100(root=args.data_dir, train=False, download=False, transform=transform_test)
82 |
83 | return trainset, testset
84 |
85 |
86 | def imagenet_loader(args):
87 | args.num_classes = 1000
88 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
89 |
90 | if args.aug_plus:
91 | args.logger.info('Using aug_plus')
92 | jittering = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4)
93 | lighting = Lighting(alphastd=0.1,
94 | eigval=[0.2175, 0.0188, 0.0045],
95 | eigvec=[[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]])
96 |
97 | transform_train = transforms.Compose([
98 | transforms.RandomResizedCrop(224),
99 | transforms.RandomHorizontalFlip(),
100 | transforms.ToTensor(),
101 | jittering,
102 | lighting,
103 | normalize,
104 | ])
105 | else:
106 | transform_train = transforms.Compose([
107 | transforms.RandomResizedCrop(224),
108 | transforms.RandomHorizontalFlip(),
109 | transforms.ToTensor(),
110 | normalize,
111 | ])
112 | transform_test = transforms.Compose([
113 | transforms.Resize(256),
114 | transforms.CenterCrop(224),
115 | transforms.ToTensor(),
116 | normalize,
117 | ])
118 | trainset = datasets.ImageFolder(root=os.path.join(args.data_dir, 'train'), transform=transform_train)
119 | testset = datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'), transform=transform_test)
120 |
121 | return trainset, testset
122 |
--------------------------------------------------------------------------------
/figs/RM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/implus/RecursiveMix-pytorch/ed381bf28304796efd46a4d44b357ae8d4d10061/figs/RM.png
--------------------------------------------------------------------------------
/figs/RM.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .label_smooth_loss import *
--------------------------------------------------------------------------------
/loss/label_smooth_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import *
4 |
5 |
6 | class LabelSmoothingLoss(nn.Module):
7 | def __init__(self, num_classes, smoothing=0.0):
8 | super(LabelSmoothingLoss, self).__init__()
9 | assert 0 <= smoothing < 1
10 | self.num_classes = num_classes
11 | self.smoothing = smoothing
12 |
13 | def forward(self, pred: torch.Tensor, target: torch.Tensor):
14 | bs = float(pred.size(0))
15 | pred = pred.log_softmax(dim=1)
16 | if len(target.shape) == 2:
17 | true_dist = target
18 | else:
19 | true_dist = smooth_one_hot(target, self.num_classes, self.smoothing)
20 | loss = (-pred * true_dist).sum() / bs
21 | return loss
22 |
23 |
24 | if __name__ == '__main__':
25 | criterion = LabelSmoothingLoss(5)
26 | pred = torch.randn(2, 5)
27 | target = torch.tensor([3, 1])
28 | loss = criterion(pred, target)
29 | print(loss)
30 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | from torch.cuda.amp import GradScaler as GradScaler
13 | from torch.cuda.amp import autocast as autocast
14 | from torch.nn.parallel import DistributedDataParallel
15 |
16 | import models_cifar
17 | import models_imagenet
18 | from dataset import create_loader
19 | from loss import *
20 | from utils import *
21 |
22 | scaler = GradScaler()
23 |
24 |
25 | def main():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--name', required=True, type=str)
28 | parser.add_argument('--data', default='cifar100', type=str, help='cifar10|cifar100|imagenet')
29 | parser.add_argument('--data_dir', type=str, default='/data/datasets/cls/cifar')
30 | parser.add_argument('--save_dir', type=str, default='./logs')
31 | parser.add_argument('--model_file', default='resnet', type=str, help='model type')
32 | parser.add_argument('--model_name', default='resnet18', type=str, help='model type in detail')
33 |
34 | parser.add_argument('--epoch', default=200, type=int)
35 | parser.add_argument('--optimizer', default='sgd', type=str, help='sgd|adamw')
36 | parser.add_argument('--scheduler', default='cos', type=str, help='step|cos')
37 | parser.add_argument('--schedule', default=[100, 150], type=int, nargs='+')
38 | parser.add_argument('--batch_size', default=128, type=int)
39 | parser.add_argument('--warmup', default=0, type=int)
40 | parser.add_argument('--lr', default=0.1, type=float)
41 | parser.add_argument('--lr_decay', default=0.1, type=float)
42 | parser.add_argument('--momentum', default=0.9, type=float)
43 | parser.add_argument('--weight_decay', default=5e-4, type=float)
44 | parser.add_argument('--nesterov', action='store_true', help='enables Nesterov momentum (default: False)')
45 | parser.add_argument('--ddp', default=True, type=str2bool, help='nn.DataParallel|DistributedDataParallel')
46 | parser.add_argument('--smoothing', default=0.0, type=float, help='Label smoothing (default: 0.0)')
47 |
48 | parser.add_argument('--save_model', action='store_true')
49 | parser.add_argument('--print_freq', default=100, type=int)
50 | parser.add_argument('--random_seed', default=27, type=int)
51 | parser.add_argument('--num_workers', default=16, type=int)
52 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
53 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: none)')
54 | parser.add_argument('--evaluate', action='store_true', help='evaluate model on validation set')
55 | parser.add_argument('--pretrained', action='store_true', help='use pretrained models')
56 | parser.add_argument('--fold', default=1, type=int, help='training fold')
57 | parser.add_argument('--strict', default=True, type=str2bool, help='args for resume training: load_state_dict')
58 |
59 | # augmentation
60 | parser.add_argument('--aug', default='none', type=str, help='mixup|cutmix')
61 | parser.add_argument('--aug_alpha', default=0.5, type=float, help='alpha of RM')
62 | parser.add_argument('--aug_omega', default=0.5, type=float, help='omega of RM')
63 | parser.add_argument('--aug_plus', action='store_true')
64 | parser.add_argument('--interpolate_mode', default='nearest', type=str, help='nearest|bilinear')
65 | parser.add_argument('--share_fc', action='store_true')
66 | parser.add_argument('--repeated_aug', action='store_true')
67 |
68 | args = parser.parse_args()
69 |
70 | # set random seed
71 | random.seed(args.random_seed)
72 | np.random.seed(args.random_seed)
73 | torch.manual_seed(args.random_seed)
74 | torch.cuda.manual_seed(args.random_seed)
75 | torch.cuda.manual_seed_all(args.random_seed)
76 | torch.backends.cudnn.deterministic = False
77 | torch.backends.cudnn.benchmark = True
78 |
79 | args.nprocs = torch.cuda.device_count()
80 | if args.ddp:
81 | dist.init_process_group(backend='nccl')
82 | torch.cuda.set_device(args.local_rank)
83 | args.batch_size = int(args.batch_size / args.nprocs)
84 | args.num_workers = int((args.num_workers + args.nprocs - 1) / args.nprocs)
85 |
86 | # creat logger
87 | creat_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
88 | args.path_log = os.path.join(args.save_dir, f'{args.data}', f'{args.name}')
89 | os.makedirs(args.path_log, exist_ok=True)
90 | logger = create_logging(os.path.join(args.path_log, '%s_fold%s.log' % (creat_time, args.fold)))
91 | args.logger = logger
92 |
93 | # creat dataloader
94 | train_loader, test_loader = create_loader(args)
95 |
96 | # print args
97 | for param in sorted(vars(args).keys()):
98 | logger.info('--{0} {1}'.format(param, vars(args)[param]))
99 |
100 | # creat model
101 | models_package = models_imagenet if args.data == 'imagenet' else models_cifar
102 | if args.pretrained:
103 | model = models_package.__dict__[args.model_file].__dict__[args.model_name](num_classes=args.num_classes,
104 | pretrained=args.pretrained)
105 | else:
106 | model = models_package.__dict__[args.model_file].__dict__[args.model_name](num_classes=args.num_classes)
107 | if args.ddp:
108 | model.cuda(args.local_rank)
109 | model = DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
110 | else:
111 | model = nn.DataParallel(model).cuda()
112 |
113 | # creat criterion
114 | criterion = LabelSmoothingLoss(args.num_classes).cuda(args.local_rank)
115 |
116 | # creat optimizer
117 | if args.optimizer == 'sgd':
118 | optimizer = optim.SGD(model.parameters(),
119 | lr=args.lr,
120 | momentum=args.momentum,
121 | weight_decay=args.weight_decay,
122 | nesterov=args.nesterov)
123 | elif args.optimizer == 'adamw':
124 | optimizer = optim.AdamW(model.parameters(),
125 | lr=args.lr,
126 | betas=(0.9, 0.999),
127 | eps=1e-8,
128 | weight_decay=args.weight_decay,
129 | amsgrad=False)
130 | else:
131 | raise NotImplementedError
132 |
133 | best_acc1 = 0.0
134 | best_acc5 = 0.0
135 | start_epoch = 1
136 | # optionally resume from a checkpoint
137 | if args.resume:
138 | if args.resume in ['best', 'latest']:
139 | args.resume = os.path.join(args.path_log, 'fold%s_%s.pth' % (args.fold, args.resume))
140 | if os.path.isfile(args.resume):
141 | logger.info("=> loading checkpoint '{}'".format(args.resume))
142 | # Map model to be loaded to specified single gpu.
143 | loc = 'cuda:{}'.format(args.local_rank) if args.ddp else None
144 | state_dict = torch.load(args.resume, map_location=loc)
145 |
146 | if 'state_dict' in state_dict:
147 | state_dict_ = state_dict['state_dict']
148 | elif 'model' in state_dict:
149 | state_dict_ = state_dict['model']
150 | else:
151 | state_dict_ = state_dict
152 | model.load_state_dict(state_dict_, strict=args.strict)
153 |
154 | start_epoch = state_dict['epoch'] + 1
155 | optimizer.load_state_dict(state_dict['optimizer'])
156 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, state_dict['epoch']))
157 | else:
158 | logger.info("=> no checkpoint found at '{}'".format(args.resume))
159 |
160 | # optionally evaluate
161 | if args.evaluate:
162 | epoch = start_epoch - 1
163 |
164 | acc1, acc5 = test(epoch, model, test_loader, logger, args)
165 | logger.info('Epoch(val) [{}]\tTest Acc@1: {:.4f}\tTest Acc@5: {:.4f}\tCopypaste: {:.4f}, {:.4f}'.format(
166 | epoch, acc1, acc5, acc1, acc5))
167 | logger.info('Exp path: %s' % args.path_log)
168 | return
169 |
170 | # start training
171 | for epoch in range(start_epoch, args.epoch + 1):
172 | if args.ddp:
173 | train_loader.sampler.set_epoch(epoch)
174 | train(epoch, model, optimizer, criterion, train_loader, logger, args)
175 | save_checkpoint(epoch, model, optimizer, args, save_name='latest')
176 |
177 | acc1, acc5 = test(epoch, model, test_loader, logger, args)
178 | if acc1 >= best_acc1:
179 | best_acc1 = acc1
180 | best_acc5 = acc5
181 | save_checkpoint(epoch, model, optimizer, args, save_name='best')
182 |
183 | logger.info('Epoch(val) [{}]\tTest Acc@1: {:.4f}\tTest Acc@5: {:.4f}\t'
184 | 'Best Acc@1: {:.4f}\tBest Acc@5: {:.4f}\tCopypaste: {:.4f}, {:.4f}'.format(
185 | epoch, acc1, acc5, best_acc1, best_acc5, best_acc1, best_acc5))
186 | logger.info('Exp path: %s' % args.path_log)
187 |
188 |
189 | def train(epoch, model, optimizer, criterion, train_loader, logger, args):
190 | model.train()
191 | losses = AverageMeter()
192 | top1 = AverageMeter()
193 | top5 = AverageMeter()
194 |
195 | old_inputs = None
196 | lr = adjust_learning_rate(optimizer, epoch, args)
197 | for idx, (inputs, targets) in enumerate(train_loader):
198 | optimizer.zero_grad()
199 | inputs, targets = inputs.cuda(), targets.cuda()
200 | targets_onehot = smooth_one_hot(targets, args.num_classes, args.smoothing)
201 |
202 | with autocast():
203 | if args.aug == 'none':
204 | out = model(inputs)
205 | loss = criterion(out, targets_onehot)
206 | elif args.aug == 'recursive_mix':
207 | if old_inputs is not None:
208 | inputs, targets_onehot, boxes, lam = recursive_mix(inputs, old_inputs, targets_onehot, old_targets,
209 | args.aug_alpha, args.interpolate_mode)
210 | else:
211 | lam = 1.0
212 |
213 | if lam < 1.0:
214 | out, out_roi = model(inputs, boxes, share_fc=args.share_fc)
215 | else:
216 | out = model(inputs, None, share_fc=args.share_fc)
217 | loss = criterion(out, targets_onehot)
218 | if lam < 1.0:
219 | loss_roi = criterion(out_roi, (old_out).softmax(dim=-1)[:inputs.size(0)])
220 | loss += loss_roi * args.aug_omega * (1.0 - lam)
221 | old_inputs = inputs.clone().detach()
222 | old_targets = targets_onehot.clone().detach()
223 | old_out = out.clone().detach()
224 | else:
225 | raise NotImplementedError
226 |
227 | scaler.scale(loss).backward()
228 | scaler.step(optimizer)
229 | scaler.update()
230 |
231 | batch_size = targets.size(0)
232 | losses.update(reduce_value(loss).item(), batch_size)
233 | acc1, acc5 = accuracy(out, targets, topk=(1, 5))
234 | top1.update(reduce_value(acc1), batch_size)
235 | top5.update(reduce_value(acc5), batch_size)
236 |
237 | if idx % args.print_freq == 0:
238 | logger.info("Epoch [{0}/{1}][{2}/{3}]\t"
239 | "lr {4:.6f}\t"
240 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
241 | "Acc@1 {top1.val:.4f} ({top1.avg:.4f})\t"
242 | "Acc@5 {top5.val:.4f} ({top5.avg:.4f})".format(
243 | epoch,
244 | args.epoch,
245 | idx,
246 | len(train_loader),
247 | lr,
248 | loss=losses,
249 | top1=top1,
250 | top5=top5,
251 | ))
252 | sys.stdout.flush()
253 | return top1.avg, top5.avg
254 |
255 |
256 | @torch.no_grad()
257 | def test(epoch, model, test_loader, logger, args):
258 | model.eval()
259 | top1 = AverageMeter()
260 | top5 = AverageMeter()
261 |
262 | for idx, (inputs, targets) in enumerate(test_loader):
263 | batch_size = targets.size(0)
264 | inputs, targets = inputs.cuda(), targets.cuda()
265 | out = model(inputs)
266 |
267 | acc1, acc5 = accuracy(out, targets, topk=(1, 5))
268 | top1.update(reduce_value(acc1), batch_size)
269 | top5.update(reduce_value(acc5), batch_size)
270 |
271 | if idx % args.print_freq == 0:
272 | logger.info("Epoch(val) [{0}/{1}][{2}/{3}]\t"
273 | "Acc@1 {top1.val:.4f} ({top1.avg:.4f})\t"
274 | "Acc@5 {top5.val:.4f} ({top5.avg:.4f})".format(
275 | epoch,
276 | args.epoch,
277 | idx,
278 | len(test_loader),
279 | top1=top1,
280 | top5=top5,
281 | ))
282 | return top1.avg, top5.avg
283 |
284 |
285 | if __name__ == '__main__':
286 | main()
287 |
--------------------------------------------------------------------------------
/models_cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from .densenet import *
2 | from .pyramidnet import *
3 | from .resnet import *
4 |
--------------------------------------------------------------------------------
/models_cifar/densenet.py:
--------------------------------------------------------------------------------
1 | """dense net in pytorch
2 |
3 |
4 |
5 | [1] Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger.
6 |
7 | Densely Connected Convolutional Networks
8 | https://arxiv.org/abs/1608.06993v5
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision.ops.roi_align import roi_align
14 |
15 |
16 | #"""Bottleneck layers. Although each layer only produces k
17 | #output feature-maps, it typically has many more inputs. It
18 | #has been noted in [37, 11] that a 1×1 convolution can be in-
19 | #troduced as bottleneck layer before each 3×3 convolution
20 | #to reduce the number of input feature-maps, and thus to
21 | #improve computational efficiency."""
22 | class Bottleneck(nn.Module):
23 | def __init__(self, in_channels, growth_rate):
24 | super().__init__()
25 | #"""In our experiments, we let each 1×1 convolution
26 | #produce 4k feature-maps."""
27 | inner_channel = 4 * growth_rate
28 |
29 | #"""We find this design especially effective for DenseNet and
30 | #we refer to our network with such a bottleneck layer, i.e.,
31 | #to the BN-ReLU-Conv(1×1)-BN-ReLU-Conv(3×3) version of H ` ,
32 | #as DenseNet-B."""
33 | self.bottle_neck = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
34 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False),
35 | nn.BatchNorm2d(inner_channel), nn.ReLU(inplace=True),
36 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False))
37 |
38 | def forward(self, x):
39 | return torch.cat([x, self.bottle_neck(x)], 1)
40 |
41 |
42 | #"""We refer to layers between blocks as transition
43 | #layers, which do convolution and pooling."""
44 | class Transition(nn.Module):
45 | def __init__(self, in_channels, out_channels):
46 | super().__init__()
47 | #"""The transition layers used in our experiments
48 | #consist of a batch normalization layer and an 1×1
49 | #convolutional layer followed by a 2×2 average pooling
50 | #layer""".
51 | self.down_sample = nn.Sequential(nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels,
52 | out_channels,
53 | 1,
54 | bias=False), nn.AvgPool2d(2, stride=2))
55 |
56 | def forward(self, x):
57 | return self.down_sample(x)
58 |
59 |
60 | #DesneNet-BC
61 | #B stands for bottleneck layer(BN-RELU-CONV(1x1)-BN-RELU-CONV(3x3))
62 | #C stands for compression factor(0<=theta<=1)
63 | class DenseNet(nn.Module):
64 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=100):
65 | super().__init__()
66 | self.growth_rate = growth_rate
67 |
68 | #"""Before entering the first dense block, a convolution
69 | #with 16 (or twice the growth rate for DenseNet-BC)
70 | #output channels is performed on the input images."""
71 | inner_channels = 2 * growth_rate
72 |
73 | #For convolutional layers with kernel size 3×3, each
74 | #side of the inputs is zero-padded by one pixel to keep
75 | #the feature-map size fixed.
76 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False)
77 |
78 | self.features = nn.Sequential()
79 |
80 | for index in range(len(nblocks) - 1):
81 | self.features.add_module("dense_block_layer_{}".format(index),
82 | self._make_dense_layers(block, inner_channels, nblocks[index]))
83 | inner_channels += growth_rate * nblocks[index]
84 |
85 | #"""If a dense block contains m feature-maps, we let the
86 | #following transition layer generate θm output feature-
87 | #maps, where 0 < θ ≤ 1 is referred to as the compression
88 | #fac-tor.
89 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value
90 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels))
91 | inner_channels = out_channels
92 |
93 | self.features.add_module("dense_block{}".format(len(nblocks) - 1),
94 | self._make_dense_layers(block, inner_channels, nblocks[len(nblocks) - 1]))
95 | inner_channels += growth_rate * nblocks[len(nblocks) - 1]
96 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels))
97 | self.features.add_module('relu', nn.ReLU(inplace=True))
98 |
99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
100 |
101 | self.linear = nn.Linear(inner_channels, num_classes)
102 | self.linear_roi = nn.Linear(inner_channels, num_classes)
103 |
104 | def forward(self, x, boxes=None, share_fc=False):
105 | bs = x.shape[0]
106 | sz = x.shape[-1]
107 | output = self.conv1(x)
108 | output = self.features(output)
109 | feat_map = output
110 | output = self.avgpool(output)
111 | output = output.view(output.size()[0], -1)
112 | output = self.linear(output)
113 | if boxes is not None:
114 | index = torch.arange(bs).view(-1, 1).to(x.device)
115 | boxes = torch.cat([index, boxes], 1)
116 | spatial_scale = feat_map.shape[-1] / sz
117 | roi_feat = roi_align(feat_map,
118 | boxes,
119 | output_size=(1, 1),
120 | spatial_scale=spatial_scale,
121 | sampling_ratio=-1,
122 | aligned=True).squeeze()
123 | if share_fc:
124 | out_roi = self.linear(roi_feat)
125 | else:
126 | out_roi = self.linear_roi(roi_feat)
127 | return output, out_roi
128 | return output
129 |
130 | def _make_dense_layers(self, block, in_channels, nblocks):
131 | dense_block = nn.Sequential()
132 | for index in range(nblocks):
133 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate))
134 | in_channels += self.growth_rate
135 | return dense_block
136 |
137 |
138 | def densenet121(**kwargs):
139 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32, **kwargs)
140 |
141 |
142 | def densenet169(**kwargs):
143 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32, **kwargs)
144 |
145 |
146 | def densenet201(**kwargs):
147 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32, **kwargs)
148 |
149 |
150 | def densenet161(**kwargs):
151 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48, **kwargs)
152 |
--------------------------------------------------------------------------------
/models_cifar/pyramidnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision.ops.roi_align import roi_align
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | outchannel_ratio = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.bn1 = nn.BatchNorm2d(inplanes)
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn3 = nn.BatchNorm2d(planes)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x: torch.Tensor):
28 |
29 | out = self.bn1(x)
30 | out = self.conv1(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 | out = self.conv2(out)
34 | out = self.bn3(out)
35 |
36 | if self.downsample is not None:
37 | shortcut = self.downsample(x)
38 | featuremap_size = shortcut.size()[2:4]
39 | else:
40 | shortcut = x
41 | featuremap_size = out.size()[2:4]
42 |
43 | batch_size = out.size()[0]
44 | residual_channel = out.size()[1]
45 | shortcut_channel = shortcut.size()[1]
46 |
47 | if residual_channel != shortcut_channel:
48 | padding = torch.autograd.Variable(
49 | x.new_zeros(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]))
50 | out += torch.cat((shortcut, padding), 1)
51 | else:
52 | out += shortcut
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | outchannel_ratio = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.bn1 = nn.BatchNorm2d(inplanes)
63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False)
66 | self.bn3 = nn.BatchNorm2d((planes))
67 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False)
68 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x: torch.Tensor):
74 |
75 | out = self.bn1(x)
76 | out = self.conv1(out)
77 |
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 | out = self.conv2(out)
81 |
82 | out = self.bn3(out)
83 | out = self.relu(out)
84 | out = self.conv3(out)
85 |
86 | out = self.bn4(out)
87 |
88 | if self.downsample is not None:
89 | shortcut = self.downsample(x)
90 | featuremap_size = shortcut.size()[2:4]
91 | else:
92 | shortcut = x
93 | featuremap_size = out.size()[2:4]
94 |
95 | batch_size = out.size()[0]
96 | residual_channel = out.size()[1]
97 | shortcut_channel = shortcut.size()[1]
98 |
99 | if residual_channel != shortcut_channel:
100 | padding = torch.autograd.Variable(
101 | x.new_zeros(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]))
102 | out += torch.cat((shortcut, padding), 1)
103 | else:
104 | out += shortcut
105 |
106 | return out
107 |
108 |
109 | class PyramidNet(nn.Module):
110 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False):
111 | super(PyramidNet, self).__init__()
112 | self.dataset = dataset
113 | if self.dataset.startswith('cifar'):
114 | self.inplanes = 16
115 | if bottleneck == True:
116 | n = int((depth - 2) / 9)
117 | block = Bottleneck
118 | else:
119 | n = int((depth - 2) / 6)
120 | block = BasicBlock
121 |
122 | self.addrate = alpha / (3 * n * 1.0)
123 |
124 | self.input_featuremap_dim = self.inplanes
125 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False)
126 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
127 |
128 | self.featuremap_dim = self.input_featuremap_dim
129 | self.layer1 = self.pyramidal_make_layer(block, n)
130 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2)
131 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2)
132 |
133 | self.final_featuremap_dim = self.input_featuremap_dim
134 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
135 | self.relu_final = nn.ReLU(inplace=True)
136 | self.avgpool = nn.AvgPool2d(8)
137 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
138 | self.fc_roi = nn.Linear(self.final_featuremap_dim, num_classes)
139 |
140 | elif self.dataset == 'imagenet':
141 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
142 | layers = {
143 | 18: [2, 2, 2, 2],
144 | 34: [3, 4, 6, 3],
145 | 50: [3, 4, 6, 3],
146 | 101: [3, 4, 23, 3],
147 | 152: [3, 8, 36, 3],
148 | 200: [3, 24, 36, 3]
149 | }
150 |
151 | if layers.get(depth) is None:
152 | if bottleneck == True:
153 | blocks[depth] = Bottleneck
154 | temp_cfg = int((depth - 2) / 12)
155 | else:
156 | blocks[depth] = BasicBlock
157 | temp_cfg = int((depth - 2) / 8)
158 |
159 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg]
160 | print('=> the layer configuration for each stage is set to', layers[depth])
161 |
162 | self.inplanes = 64
163 | self.addrate = alpha / (sum(layers[depth]) * 1.0)
164 |
165 | self.input_featuremap_dim = self.inplanes
166 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False)
167 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
168 | self.relu = nn.ReLU(inplace=True)
169 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
170 |
171 | self.featuremap_dim = self.input_featuremap_dim
172 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0])
173 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2)
174 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2)
175 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2)
176 |
177 | self.final_featuremap_dim = self.input_featuremap_dim
178 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim)
179 | self.relu_final = nn.ReLU(inplace=True)
180 | self.avgpool = nn.AvgPool2d(7)
181 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
182 |
183 | for m in self.modules():
184 | if isinstance(m, nn.Conv2d):
185 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
186 | m.weight.data.normal_(0, math.sqrt(2. / n))
187 | elif isinstance(m, nn.BatchNorm2d):
188 | m.weight.data.fill_(1)
189 | m.bias.data.zero_()
190 |
191 | def pyramidal_make_layer(self, block, block_depth, stride=1):
192 | downsample = None
193 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio:
194 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True)
195 |
196 | layers = []
197 | self.featuremap_dim = self.featuremap_dim + self.addrate
198 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample))
199 | for i in range(1, block_depth):
200 | temp_featuremap_dim = self.featuremap_dim + self.addrate
201 | layers.append(
202 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1))
203 | self.featuremap_dim = temp_featuremap_dim
204 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def forward(self, x, boxes=None, share_fc=False):
209 | if self.dataset == 'cifar10' or self.dataset == 'cifar100':
210 | bs = x.shape[0]
211 | sz = x.shape[-1]
212 | x = self.conv1(x)
213 | x = self.bn1(x)
214 |
215 | x = self.layer1(x)
216 | x = self.layer2(x)
217 | x = self.layer3(x)
218 |
219 | x = self.bn_final(x)
220 | x = self.relu_final(x)
221 | feat_map = x
222 | x = self.avgpool(x)
223 | x = x.view(x.size(0), -1)
224 | x = self.fc(x)
225 | if boxes is not None:
226 | index = torch.arange(bs).view(-1, 1).to(x.device)
227 | boxes = torch.cat([index, boxes], 1)
228 | spatial_scale = feat_map.shape[-1] / sz
229 | roi_feat = roi_align(feat_map,
230 | boxes,
231 | output_size=(1, 1),
232 | spatial_scale=spatial_scale,
233 | sampling_ratio=-1,
234 | aligned=True).squeeze()
235 | if share_fc:
236 | out_roi = self.fc(roi_feat)
237 | else:
238 | out_roi = self.fc_roi(roi_feat)
239 |
240 | elif self.dataset == 'imagenet':
241 | x = self.conv1(x)
242 | x = self.bn1(x)
243 | x = self.relu(x)
244 | x = self.maxpool(x)
245 |
246 | x = self.layer1(x)
247 | x = self.layer2(x)
248 | x = self.layer3(x)
249 | x = self.layer4(x)
250 |
251 | x = self.bn_final(x)
252 | x = self.relu_final(x)
253 | x = self.avgpool(x)
254 | x = x.view(x.size(0), -1)
255 | x = self.fc(x)
256 |
257 | if boxes is not None:
258 | return x, out_roi
259 | return x
260 |
261 |
262 | def pyramidnet_200_240(num_classes=100):
263 | net = PyramidNet(dataset='cifar100', depth=200, alpha=240, num_classes=num_classes, bottleneck=True)
264 | return net
265 |
266 |
267 | def pyramidnet_164_270(num_classes=100):
268 | net = PyramidNet(dataset='cifar100', depth=164, alpha=270, num_classes=num_classes, bottleneck=True)
269 | return net
270 |
271 |
272 | if __name__ == '__main__':
273 | net = pyramidnet_200_240(100)
274 | img = torch.rand(10, 3, 32, 32)
275 | boxes = torch.Tensor([0, 0, 0, 0]).float()
276 | y = net(img, None)
277 | print(len(y))
--------------------------------------------------------------------------------
/models_cifar/resnet.py:
--------------------------------------------------------------------------------
1 | """resnet in pytorch
2 |
3 |
4 |
5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
6 |
7 | Deep Residual Learning for Image Recognition
8 | https://arxiv.org/abs/1512.03385v1
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision.ops.roi_align import roi_align
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | """Basic Block for resnet 18 and resnet 34
18 |
19 | """
20 |
21 | #BasicBlock and BottleNeck block
22 | #have different output size
23 | #we use class attribute expansion
24 | #to distinct
25 | expansion = 1
26 |
27 | def __init__(self, in_channels, out_channels, stride=1):
28 | super().__init__()
29 |
30 | #residual function
31 | self.residual_function = nn.Sequential(
32 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
33 | nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
34 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
35 | nn.BatchNorm2d(out_channels * BasicBlock.expansion))
36 |
37 | #shortcut
38 | self.shortcut = nn.Sequential()
39 |
40 | #the shortcut output dimension is not the same with residual function
41 | #use 1*1 convolution to match the dimension
42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
43 | self.shortcut = nn.Sequential(
44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion))
46 |
47 | def forward(self, x):
48 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
49 |
50 |
51 | class BottleNeck(nn.Module):
52 | """Residual block for resnet over 50 layers
53 |
54 | """
55 | expansion = 4
56 |
57 | def __init__(self, in_channels, out_channels, stride=1):
58 | super().__init__()
59 | self.residual_function = nn.Sequential(
60 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
61 | nn.BatchNorm2d(out_channels),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
64 | nn.BatchNorm2d(out_channels),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
67 | nn.BatchNorm2d(out_channels * BottleNeck.expansion),
68 | )
69 |
70 | self.shortcut = nn.Sequential()
71 |
72 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
73 | self.shortcut = nn.Sequential(
74 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
75 | nn.BatchNorm2d(out_channels * BottleNeck.expansion))
76 |
77 | def forward(self, x):
78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
79 |
80 |
81 | class ResNet(nn.Module):
82 | def __init__(self, block, num_block, num_classes=100):
83 | super().__init__()
84 |
85 | self.in_channels = 64
86 |
87 | self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64),
88 | nn.ReLU(inplace=True))
89 | #we use a different inputsize than the original paper
90 | #so conv2_x's stride is 1
91 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
92 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
93 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
94 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
95 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
96 | self.fc = nn.Linear(512 * block.expansion, num_classes)
97 | self.fc_roi = nn.Linear(512 * block.expansion, num_classes)
98 |
99 | def _make_layer(self, block, out_channels, num_blocks, stride):
100 | """make resnet layers(by layer i didnt mean this 'layer' was the
101 | same as a neuron netowork layer, ex. conv layer), one layer may
102 | contain more than one residual block
103 |
104 | Args:
105 | block: block type, basic block or bottle neck block
106 | out_channels: output depth channel number of this layer
107 | num_blocks: how many blocks per layer
108 | stride: the stride of the first block of this layer
109 |
110 | Return:
111 | return a resnet layer
112 | """
113 |
114 | # we have num_block blocks per layer, the first block
115 | # could be 1 or 2, other blocks would always be 1
116 | strides = [stride] + [1] * (num_blocks - 1)
117 | layers = []
118 | for stride in strides:
119 | layers.append(block(self.in_channels, out_channels, stride))
120 | self.in_channels = out_channels * block.expansion
121 |
122 | return nn.Sequential(*layers)
123 |
124 | def forward(self, x, boxes=None, share_fc=False):
125 | bs = x.shape[0]
126 | sz = x.shape[-1]
127 | output = self.conv1(x)
128 | output = self.conv2_x(output)
129 | output = self.conv3_x(output)
130 | output = self.conv4_x(output)
131 | output = self.conv5_x(output)
132 | feat_map = output
133 | output = self.avg_pool(output)
134 | output = output.view(output.size(0), -1)
135 | output = self.fc(output)
136 |
137 | if boxes is not None:
138 | index = torch.arange(bs).view(-1, 1).to(x.device)
139 | boxes = torch.cat([index, boxes], 1)
140 | spatial_scale = feat_map.shape[-1] / sz
141 | roi_feat = roi_align(feat_map,
142 | boxes,
143 | output_size=(1, 1),
144 | spatial_scale=spatial_scale,
145 | sampling_ratio=-1,
146 | aligned=True).squeeze()
147 | if share_fc:
148 | out_roi = self.fc(roi_feat)
149 | else:
150 | out_roi = self.fc_roi(roi_feat)
151 | return output, out_roi
152 | return output
153 |
154 |
155 | def resnet18(**kwargs):
156 | """ return a ResNet 18 object
157 | """
158 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
159 |
160 |
161 | def resnet34(**kwargs):
162 | """ return a ResNet 34 object
163 | """
164 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
165 |
166 |
167 | def resnet50(**kwargs):
168 | """ return a ResNet 50 object
169 | """
170 | return ResNet(BottleNeck, [3, 4, 6, 3], **kwargs)
171 |
172 |
173 | def resnet101(**kwargs):
174 | """ return a ResNet 101 object
175 | """
176 | return ResNet(BottleNeck, [3, 4, 23, 3], **kwargs)
177 |
178 |
179 | def resnet152(**kwargs):
180 | """ return a ResNet 152 object
181 | """
182 | return ResNet(BottleNeck, [3, 8, 36, 3], **kwargs)
183 |
--------------------------------------------------------------------------------
/models_imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 |
--------------------------------------------------------------------------------
/models_imagenet/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.hub import load_state_dict_from_url
4 | from torchvision.ops.roi_align import roi_align
5 |
6 | __all__ = [
7 | 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2'
9 | ]
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes,
27 | out_planes,
28 | kernel_size=3,
29 | stride=stride,
30 | padding=dilation,
31 | groups=groups,
32 | bias=False,
33 | dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self,
45 | inplanes,
46 | planes,
47 | stride=1,
48 | downsample=None,
49 | groups=1,
50 | base_width=64,
51 | dilation=1,
52 | norm_layer=None):
53 | super(BasicBlock, self).__init__()
54 | if norm_layer is None:
55 | norm_layer = nn.BatchNorm2d
56 | if groups != 1 or base_width != 64:
57 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
58 | if dilation > 1:
59 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
60 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
61 | self.conv1 = conv3x3(inplanes, planes, stride)
62 | self.bn1 = norm_layer(planes)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.conv2 = conv3x3(planes, planes)
65 | self.bn2 = norm_layer(planes)
66 | self.downsample = downsample
67 | self.stride = stride
68 |
69 | def forward(self, x):
70 | identity = x
71 |
72 | out = self.conv1(x)
73 | out = self.bn1(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv2(out)
77 | out = self.bn2(out)
78 |
79 | if self.downsample is not None:
80 | identity = self.downsample(x)
81 |
82 | out += identity
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class Bottleneck(nn.Module):
89 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
90 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
91 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
92 | # This variant is also known as ResNet V1.5 and improves accuracy according to
93 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
94 |
95 | expansion = 4
96 |
97 | def __init__(self,
98 | inplanes,
99 | planes,
100 | stride=1,
101 | downsample=None,
102 | groups=1,
103 | base_width=64,
104 | dilation=1,
105 | norm_layer=None):
106 | super(Bottleneck, self).__init__()
107 | if norm_layer is None:
108 | norm_layer = nn.BatchNorm2d
109 | width = int(planes * (base_width / 64.)) * groups
110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
111 | self.conv1 = conv1x1(inplanes, width)
112 | self.bn1 = norm_layer(width)
113 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
114 | self.bn2 = norm_layer(width)
115 | self.conv3 = conv1x1(width, planes * self.expansion)
116 | self.bn3 = norm_layer(planes * self.expansion)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.downsample = downsample
119 | self.stride = stride
120 |
121 | def forward(self, x):
122 | identity = x
123 |
124 | out = self.conv1(x)
125 | out = self.bn1(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv2(out)
129 | out = self.bn2(out)
130 | out = self.relu(out)
131 |
132 | out = self.conv3(out)
133 | out = self.bn3(out)
134 |
135 | if self.downsample is not None:
136 | identity = self.downsample(x)
137 |
138 | out += identity
139 | out = self.relu(out)
140 |
141 | return out
142 |
143 |
144 | class ResNet(nn.Module):
145 | def __init__(self,
146 | block,
147 | layers,
148 | num_classes=1000,
149 | zero_init_residual=False,
150 | groups=1,
151 | width_per_group=64,
152 | replace_stride_with_dilation=None,
153 | norm_layer=None):
154 | super(ResNet, self).__init__()
155 | if norm_layer is None:
156 | norm_layer = nn.BatchNorm2d
157 | self._norm_layer = norm_layer
158 |
159 | self.inplanes = 64
160 | self.dilation = 1
161 | if replace_stride_with_dilation is None:
162 | # each element in the tuple indicates if we should replace
163 | # the 2x2 stride with a dilated convolution instead
164 | replace_stride_with_dilation = [False, False, False]
165 | if len(replace_stride_with_dilation) != 3:
166 | raise ValueError("replace_stride_with_dilation should be None "
167 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
168 | self.groups = groups
169 | self.base_width = width_per_group
170 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
171 | self.bn1 = norm_layer(self.inplanes)
172 | self.relu = nn.ReLU(inplace=True)
173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
174 | self.layer1 = self._make_layer(block, 64, layers[0])
175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
176 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
177 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
178 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
179 | self.fc = nn.Linear(512 * block.expansion, num_classes)
180 | self.fc_roi = nn.Linear(512 * block.expansion, num_classes)
181 |
182 | for m in self.modules():
183 | if isinstance(m, nn.Conv2d):
184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
185 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
186 | nn.init.constant_(m.weight, 1)
187 | nn.init.constant_(m.bias, 0)
188 |
189 | # Zero-initialize the last BN in each residual branch,
190 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
191 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
192 | if zero_init_residual:
193 | for m in self.modules():
194 | if isinstance(m, Bottleneck):
195 | nn.init.constant_(m.bn3.weight, 0)
196 | elif isinstance(m, BasicBlock):
197 | nn.init.constant_(m.bn2.weight, 0)
198 |
199 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
200 | norm_layer = self._norm_layer
201 | downsample = None
202 | previous_dilation = self.dilation
203 | if dilate:
204 | self.dilation *= stride
205 | stride = 1
206 | if stride != 1 or self.inplanes != planes * block.expansion:
207 | downsample = nn.Sequential(
208 | conv1x1(self.inplanes, planes * block.expansion, stride),
209 | norm_layer(planes * block.expansion),
210 | )
211 |
212 | layers = []
213 | layers.append(
214 | block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation,
215 | norm_layer))
216 | self.inplanes = planes * block.expansion
217 | for _ in range(1, blocks):
218 | layers.append(
219 | block(self.inplanes,
220 | planes,
221 | groups=self.groups,
222 | base_width=self.base_width,
223 | dilation=self.dilation,
224 | norm_layer=norm_layer))
225 |
226 | return nn.Sequential(*layers)
227 |
228 | def forward(self, x, boxes=None, share_fc=False):
229 | # See note [TorchScript super()]
230 | bs = x.shape[0]
231 | sz = x.shape[-1]
232 | x = self.conv1(x)
233 | x = self.bn1(x)
234 | x = self.relu(x)
235 | x = self.maxpool(x)
236 |
237 | x = self.layer1(x)
238 | x = self.layer2(x)
239 | x = self.layer3(x)
240 | x = self.layer4(x)
241 | feat_map = x
242 |
243 | x = self.avgpool(x)
244 | x = torch.flatten(x, 1)
245 | x = self.fc(x)
246 | if boxes is not None:
247 | index = torch.arange(bs).view(-1, 1).to(x.device)
248 | boxes = torch.cat([index, boxes], 1)
249 | spatial_scale = feat_map.shape[-1] / sz
250 | roi_feat = roi_align(feat_map,
251 | boxes,
252 | output_size=(1, 1),
253 | spatial_scale=spatial_scale,
254 | sampling_ratio=-1,
255 | aligned=True).squeeze()
256 | if share_fc:
257 | out_roi = self.fc(roi_feat)
258 | else:
259 | out_roi = self.fc_roi(roi_feat)
260 | return x, out_roi
261 |
262 | return x
263 |
264 |
265 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
266 | model = ResNet(block, layers, **kwargs)
267 | if pretrained:
268 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
269 | model.load_state_dict(state_dict)
270 | return model
271 |
272 |
273 | def resnet18(pretrained=False, progress=True, **kwargs):
274 | r"""ResNet-18 model from
275 | `"Deep Residual Learning for Image Recognition" `_
276 |
277 | Args:
278 | pretrained (bool): If True, returns a model pre-trained on ImageNet
279 | progress (bool): If True, displays a progress bar of the download to stderr
280 | """
281 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
282 |
283 |
284 | def resnet34(pretrained=False, progress=True, **kwargs):
285 | r"""ResNet-34 model from
286 | `"Deep Residual Learning for Image Recognition" `_
287 |
288 | Args:
289 | pretrained (bool): If True, returns a model pre-trained on ImageNet
290 | progress (bool): If True, displays a progress bar of the download to stderr
291 | """
292 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
293 |
294 |
295 | def resnet50(pretrained=False, progress=True, **kwargs):
296 | r"""ResNet-50 model from
297 | `"Deep Residual Learning for Image Recognition" `_
298 |
299 | Args:
300 | pretrained (bool): If True, returns a model pre-trained on ImageNet
301 | progress (bool): If True, displays a progress bar of the download to stderr
302 | """
303 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
304 |
305 |
306 | def resnet101(pretrained=False, progress=True, **kwargs):
307 | r"""ResNet-101 model from
308 | `"Deep Residual Learning for Image Recognition" `_
309 |
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | progress (bool): If True, displays a progress bar of the download to stderr
313 | """
314 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
315 |
316 |
317 | def resnet152(pretrained=False, progress=True, **kwargs):
318 | r"""ResNet-152 model from
319 | `"Deep Residual Learning for Image Recognition" `_
320 |
321 | Args:
322 | pretrained (bool): If True, returns a model pre-trained on ImageNet
323 | progress (bool): If True, displays a progress bar of the download to stderr
324 | """
325 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
326 |
327 |
328 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
329 | r"""ResNeXt-50 32x4d model from
330 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
331 |
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | kwargs['groups'] = 32
337 | kwargs['width_per_group'] = 4
338 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
339 |
340 |
341 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
342 | r"""ResNeXt-101 32x8d model from
343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
344 |
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | kwargs['groups'] = 32
350 | kwargs['width_per_group'] = 8
351 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
352 |
353 |
354 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
355 | r"""Wide ResNet-50-2 model from
356 | `"Wide Residual Networks" `_
357 |
358 | The model is the same as ResNet except for the bottleneck number of channels
359 | which is twice larger in every block. The number of channels in outer 1x1
360 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
361 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
362 |
363 | Args:
364 | pretrained (bool): If True, returns a model pre-trained on ImageNet
365 | progress (bool): If True, displays a progress bar of the download to stderr
366 | """
367 | kwargs['width_per_group'] = 64 * 2
368 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
369 |
370 |
371 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
372 | r"""Wide ResNet-101-2 model from
373 | `"Wide Residual Networks" `_
374 |
375 | The model is the same as ResNet except for the bottleneck number of channels
376 | which is twice larger in every block. The number of channels in outer 1x1
377 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
378 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
379 |
380 | Args:
381 | pretrained (bool): If True, returns a model pre-trained on ImageNet
382 | progress (bool): If True, displays a progress bar of the download to stderr
383 | """
384 | kwargs['width_per_group'] = 64 * 2
385 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
386 |
--------------------------------------------------------------------------------
/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import math
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 |
9 | class RASampler(torch.utils.data.Sampler):
10 | """Sampler that restricts data loading to a subset of the dataset for distributed,
11 | with repeated augmentation.
12 | It ensures that different each augmented version of a sample will be visible to a
13 | different process (GPU)
14 | Heavily based on torch.utils.data.DistributedSampler
15 | """
16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
17 | if num_replicas is None:
18 | if not dist.is_available():
19 | raise RuntimeError("Requires distributed package to be available")
20 | num_replicas = dist.get_world_size()
21 | if rank is None:
22 | if not dist.is_available():
23 | raise RuntimeError("Requires distributed package to be available")
24 | rank = dist.get_rank()
25 | self.dataset = dataset
26 | self.num_replicas = num_replicas
27 | self.rank = rank
28 | self.epoch = 0
29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
30 | self.total_size = self.num_samples * self.num_replicas
31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
33 | self.shuffle = shuffle
34 |
35 | def __iter__(self):
36 | # deterministically shuffle based on epoch
37 | g = torch.Generator()
38 | g.manual_seed(self.epoch)
39 | if self.shuffle:
40 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
41 | else:
42 | indices = list(range(len(self.dataset)))
43 |
44 | # add extra samples to make it evenly divisible
45 | indices = [ele for ele in indices for i in range(3)]
46 | indices += indices[:(self.total_size - len(indices))]
47 | assert len(indices) == self.total_size
48 |
49 | # subsample
50 | indices = indices[self.rank:self.total_size:self.num_replicas]
51 | assert len(indices) == self.num_samples
52 |
53 | return iter(indices[:self.num_selected_samples])
54 |
55 | def __len__(self):
56 | return self.num_selected_samples
57 |
58 | def set_epoch(self, epoch):
59 | self.epoch = epoch
60 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 |
6 | import numpy as np
7 | import torch
8 | import torch.distributed as dist
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 |
13 | def str2bool(v):
14 | if isinstance(v, bool):
15 | return v
16 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
17 | return True
18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
19 | return False
20 | else:
21 | raise argparse.ArgumentTypeError('Boolean value expected.')
22 |
23 |
24 | def print_peak_memory(prefix, device):
25 | if device == 0:
26 | print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
27 |
28 |
29 | def get_flops(model, img_size=224, backend='ptflops'):
30 | if backend == 'thop':
31 | from thop import clever_format, profile
32 | bs = 2
33 | img = torch.randn(bs, 3, img_size, img_size)
34 | flops, params = profile(model, inputs=(img, ))
35 | flops = flops / bs
36 | flops, params = clever_format([flops, params], "%.3f")
37 | else:
38 | from ptflops import get_model_complexity_info
39 | flops, params = get_model_complexity_info(model, (3, img_size, img_size),
40 | as_strings=True,
41 | print_per_layer_stat=True,
42 | verbose=True)
43 |
44 | print('{:<30} {:<8}'.format('Computational complexity: ', flops))
45 | print('{:<30} {:<8}'.format('Number of parameters: ', params))
46 |
47 |
48 | def load_checkpoint():
49 | pass
50 |
51 |
52 | def save_checkpoint(epoch, model, optimizer, args, save_name='latest'):
53 | if args.save_model and (not args.ddp or (args.ddp and args.local_rank == 0)):
54 | state_dict = {
55 | 'epoch': epoch,
56 | 'state_dict': model.state_dict(),
57 | 'optimizer': optimizer.state_dict(),
58 | }
59 | torch.save(state_dict, os.path.join(args.path_log, 'fold%s_%s.pth' % (args.fold, save_name)))
60 |
61 |
62 | @torch.no_grad()
63 | def smooth_one_hot(target: torch.Tensor, num_classes: int, smoothing=0.0):
64 | """
65 | if smoothing == 0, it's one-hot method
66 | if 0 < smoothing < 1, it's smooth method
67 | """
68 | assert 0 <= smoothing < 1
69 | confidence = 1.0 - smoothing
70 | true_dist = target.new_zeros(size=(len(target), num_classes)).float()
71 | true_dist.fill_(smoothing / (num_classes - 1))
72 | true_dist.scatter_(1, target.data.unsqueeze(1), confidence)
73 | return true_dist
74 |
75 |
76 | def reduce_value(value, average=True):
77 | if dist.is_available() and dist.is_initialized():
78 | world_size = dist.get_world_size()
79 | if world_size < 2: # single gpu
80 | return value
81 |
82 | with torch.no_grad():
83 | dist.all_reduce(value) # sum
84 | if average:
85 | value /= world_size # mean
86 | return value
87 |
88 |
89 | def create_logging(log_file=None, log_level=logging.INFO, file_mode='a'):
90 | """Initialize and get a logger.
91 | If the logger has not been initialized, this method will initialize the
92 | logger by adding one or two handlers, otherwise the initialized logger will
93 | be directly returned. During initialization, a StreamHandler will always be
94 | added. If `log_file` is specified and the process rank is 0, a FileHandler
95 | will also be added.
96 |
97 | Args:
98 | log_file (str | None): The log filename. If specified, a FileHandler
99 | will be added to the logger.
100 | log_level (int): The logger level. Note that only the process of
101 | rank 0 is affected, and other processes will set the level to
102 | "Error" thus be silent most of the time.
103 | file_mode (str): The file mode used in opening log file.
104 | Defaults to 'w'.
105 |
106 | Returns:
107 | logging.Logger: The expected logger.
108 | """
109 | logger = logging.getLogger()
110 |
111 | handlers = []
112 | stream_handler = logging.StreamHandler()
113 | handlers.append(stream_handler)
114 |
115 | if dist.is_available() and dist.is_initialized():
116 | rank = dist.get_rank()
117 | else:
118 | rank = 0
119 |
120 | # only rank 0 will add a FileHandler
121 | if rank == 0 and log_file is not None:
122 | # Here, the default behaviour of the official logger is 'a'. Thus, we
123 | # provide an interface to change the file mode to the default
124 | # behaviour.
125 | file_handler = logging.FileHandler(log_file, file_mode)
126 | handlers.append(file_handler)
127 |
128 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
129 | for handler in handlers:
130 | handler.setFormatter(formatter)
131 | handler.setLevel(log_level)
132 | logger.addHandler(handler)
133 |
134 | if rank == 0:
135 | logger.setLevel(log_level)
136 | else:
137 | logger.setLevel(logging.ERROR)
138 |
139 | return logger
140 |
141 |
142 | class AverageMeter(object):
143 | """Computes and stores the average and current value"""
144 | def __init__(self):
145 | self.reset()
146 |
147 | def reset(self):
148 | self.val = 0
149 | self.avg = 0
150 | self.sum = 0
151 | self.count = 0
152 |
153 | def update(self, val, n=1):
154 | self.val = val
155 | self.sum += val * n
156 | self.count += n
157 | self.avg = self.sum / self.count
158 |
159 |
160 | def accuracy(output, target, topk=(1, )):
161 | """Computes the accuracy over the k top predictions for the specified values of k"""
162 | with torch.no_grad():
163 | maxk = max(topk)
164 | batch_size = target.size(0)
165 |
166 | _, pred = output.topk(maxk, 1, True, True)
167 | pred = pred.t()
168 | correct = pred.eq(target.view(1, -1).expand_as(pred))
169 |
170 | res = []
171 | for k in topk:
172 | correct_k = correct[:k].float().sum()
173 | res.append(correct_k.mul_(1.0 / batch_size))
174 | return res
175 |
176 |
177 | def adjust_learning_rate(optimizer, epoch, args):
178 | # epoch >= 1
179 | assert args.scheduler in ['step', 'cos']
180 | if epoch <= args.warmup:
181 | lr = args.lr * (epoch / (args.warmup + 1))
182 | elif args.scheduler == 'step':
183 | exp = 0
184 | for mile_stone in args.schedule:
185 | if epoch > mile_stone:
186 | exp += 1
187 | lr = args.lr * (args.lr_decay**exp)
188 | elif args.scheduler == 'cos':
189 | decay_rate = 0.5 * (1 + np.cos((epoch - 1) * np.pi / args.epoch))
190 | lr = args.lr * decay_rate
191 | else:
192 | raise NotImplementedError
193 |
194 | for param_group in optimizer.param_groups:
195 | param_group['lr'] = lr
196 | return lr
197 |
198 |
199 | def lr_scheduler(optimizer, scheduler, schedule, lr_decay, total_epoch):
200 | optimizer.zero_grad()
201 | optimizer.step()
202 | if scheduler == 'step':
203 | return optim.lr_scheduler.MultiStepLR(optimizer, schedule, gamma=lr_decay)
204 | elif scheduler == 'cos':
205 | return optim.lr_scheduler.CosineAnnealingLR(optimizer, total_epoch)
206 | else:
207 | raise NotImplementedError('{} learning rate is not implemented.')
208 |
209 |
210 | def mixed_criterion(criterion, pred, y_a, y_b, lam):
211 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
212 |
213 |
214 | class Compose(object):
215 | """Composes several transforms together.
216 | Args:
217 | transforms (list of ``Transform`` objects): list of transforms to compose.
218 | Example:
219 | >>> transforms.Compose([
220 | >>> transforms.CenterCrop(10),
221 | >>> transforms.ToTensor(),
222 | >>> ])
223 | """
224 | def __init__(self, transforms):
225 | self.transforms = transforms
226 |
227 | def __call__(self, img: torch.Tensor):
228 | for t in self.transforms:
229 | img = t(img)
230 | return img
231 |
232 | def __repr__(self):
233 | format_string = self.__class__.__name__ + '('
234 | for t in self.transforms:
235 | format_string += '\n'
236 | format_string += ' {0}'.format(t)
237 | format_string += '\n)'
238 | return format_string
239 |
240 |
241 | class Lighting(object):
242 | """Lighting noise(AlexNet - style PCA - based noise)"""
243 | def __init__(self, alphastd, eigval, eigvec):
244 | self.alphastd = alphastd
245 | self.eigval = torch.Tensor(eigval)
246 | self.eigvec = torch.Tensor(eigvec)
247 |
248 | def __call__(self, img: torch.Tensor):
249 | if self.alphastd == 0:
250 | return img
251 |
252 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
253 | rgb = self.eigvec.type_as(img).clone().mul(alpha.view(1, 3).expand(3, 3)).mul(
254 | self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
255 |
256 | return img.add(rgb.view(3, 1, 1).expand_as(img))
257 |
258 |
259 | class Grayscale(object):
260 | def __call__(self, img: torch.Tensor):
261 | gs = img.clone()
262 | gs[0].mul_(0.2989).add_(gs[1], alpha=0.587).add_(gs[2], alpha=0.114)
263 | gs[1].copy_(gs[0])
264 | gs[2].copy_(gs[0])
265 | return gs
266 |
267 |
268 | class Saturation(object):
269 | def __init__(self, var):
270 | self.var = var
271 |
272 | def __call__(self, img: torch.Tensor):
273 | gs = Grayscale()(img)
274 | alpha = random.uniform(-self.var, self.var)
275 | return img.lerp(gs, alpha)
276 |
277 |
278 | class Brightness(object):
279 | def __init__(self, var):
280 | self.var = var
281 |
282 | def __call__(self, img: torch.Tensor):
283 | gs = img.new().resize_as_(img).zero_()
284 | alpha = random.uniform(-self.var, self.var)
285 | return img.lerp(gs, alpha)
286 |
287 |
288 | class Contrast(object):
289 | def __init__(self, var):
290 | self.var = var
291 |
292 | def __call__(self, img: torch.Tensor):
293 | gs = Grayscale()(img)
294 | gs.fill_(gs.mean())
295 | alpha = random.uniform(-self.var, self.var)
296 | return img.lerp(gs, alpha)
297 |
298 |
299 | class ColorJitter(object):
300 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
301 | self.brightness = brightness
302 | self.contrast = contrast
303 | self.saturation = saturation
304 |
305 | def __call__(self, img: torch.Tensor):
306 | self.transforms = []
307 | if self.brightness != 0:
308 | self.transforms.append(Brightness(self.brightness))
309 | if self.contrast != 0:
310 | self.transforms.append(Contrast(self.contrast))
311 | if self.saturation != 0:
312 | self.transforms.append(Saturation(self.saturation))
313 |
314 | random.shuffle(self.transforms)
315 | transform = Compose(self.transforms)
316 | return transform(img)
317 |
318 |
319 | def mixup(x, y, alpha=0.4):
320 | index = torch.randperm(x.size(0)).to(x.device)
321 | lam = np.random.beta(alpha, alpha)
322 |
323 | x = lam * x + (1 - lam) * x[index]
324 | y = lam * y + (1 - lam) * y[index]
325 | return x, y
326 |
327 |
328 | def cutmix(x, y, alpha=1.0):
329 | def rand_bbox(size, alpha):
330 | H = size[2]
331 | W = size[3]
332 |
333 | cut_rat = np.sqrt(1. - np.random.beta(alpha, alpha))
334 | cut_w = np.int(W * cut_rat)
335 | cut_h = np.int(H * cut_rat)
336 |
337 | # uniform
338 | cx = np.random.randint(W)
339 | cy = np.random.randint(H)
340 |
341 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
342 | bby1 = np.clip(cy - cut_h // 2, 0, H)
343 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
344 | bby2 = np.clip(cy + cut_h // 2, 0, H)
345 | return bbx1, bby1, bbx2, bby2
346 |
347 | index = torch.randperm(x.size(0)).to(x.device)
348 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), alpha)
349 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
350 |
351 | x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
352 | y = lam * y + (1 - lam) * y[index]
353 | return x, y
354 |
355 |
356 | def recursive_mix(x, old_x, y, old_y, alpha, interpolate_mode):
357 | def rand_bbox(size, alpha):
358 | H = size[2]
359 | W = size[3]
360 |
361 | cut_rat = np.sqrt(random.uniform(0.0, alpha))
362 | cut_w = np.int(W * cut_rat)
363 | cut_h = np.int(H * cut_rat)
364 |
365 | # uniform
366 | cx = np.random.randint(W)
367 | cy = np.random.randint(H)
368 |
369 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
370 | bby1 = np.clip(cy - cut_h // 2, 0, H)
371 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
372 | bby2 = np.clip(cy + cut_h // 2, 0, H)
373 | return bbx1, bby1, bbx2, bby2
374 |
375 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), alpha)
376 |
377 | size = (bby2 - bby1, bbx2 - bbx1)
378 | bs = x.size(0)
379 | if size != (0, 0):
380 | align_corners = None if interpolate_mode == 'nearest' else True
381 | x[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(old_x[:bs],
382 | size=size,
383 | mode=interpolate_mode,
384 | align_corners=align_corners)
385 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
386 |
387 | y = lam * y + (1 - lam) * old_y[:bs]
388 | boxes = torch.Tensor([bbx1, bby1, bbx2, bby2]).float().to(x.device)
389 | boxes = boxes[None].expand(bs, 4)
390 | return x, y, boxes, lam
391 |
--------------------------------------------------------------------------------