├── .idea
├── .gitignore
├── PyTorch_ImageNet_experiments.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── webServers.xml
├── README.md
├── clonal_resnet18_from_scratch.log
├── clonal_resnet34_from_scratch.log
├── clonalnet_main.py
├── distill_loss
├── KD.py
├── __init__.py
└── fpLoss.py
├── main.py
├── models
├── __init__.py
├── mobilenet.py
└── resnet.py
└── requirements.txt
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/PyTorch_ImageNet_experiments.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
9 |
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FocusNet
2 | The implementation of our Pattern Recognition 2022 paper: "FocusNet: Classifying better by focusing on confusing classes"
3 |
4 | Paper: https://www.sciencedirect.com/science/article/abs/pii/S003132032200190X?via%3Dihub
5 | ## Note:
6 | - This repository mainly relies on "[ImageNet training in PyTorch](https://github.com/pytorch/examples/tree/master/imagenet)". Therefore, it is helpful for you to refer to its document.
7 | - The first version of our architecture was named ClonalNet, and after the second revision we changed its name to FocusNet. Therefore, **the following clonalnet is just focusnet**.
8 | # ImageNet training in PyTorch
9 |
10 | This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset.
11 |
12 | ## Requirements
13 |
14 | - Install PyTorch ([pytorch.org](http://pytorch.org))
15 | - `pip install -r requirements.txt`
16 | - Note: the `requirements.txt` in this repository is not the same as the official requirements. If something goes wrong, please use the official requirements.
17 | - Download the ImageNet dataset from http://www.image-net.org/
18 | - Then, move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)
19 | ## Training
20 |
21 | To train our network, run `clonalnet_main.py` with the desired model architecture and the path to the ImageNet dataset:
22 |
23 | ```bash
24 | python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc
25 | resnet34
26 | mobilenet_v2
27 | ```
28 |
29 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs.
30 |
31 | ## Validation
32 |
33 | To evaluate our network, run `clonalnet_main.py` with the desired model architecture and the path to the ImageNet dataset:
34 |
35 | ```bash
36 | python clonalnet_main.py --data /path/to/ILSVRC2012 -a resnet18 --seed 42 --gpu 0 -ebc -e --resume clonalnet_resnet18_model_best.pth.tar
37 | resnet34 clonalnet_resnet34_model_best.pth.tar
38 | mobilenet_v2 clonalnet_mobilenet_v2_model_best.pth.tar
39 |
40 | ```
41 |
42 | ## Logs
43 | The `clonal_resnet18_from_scratch.log` and the `clonal_resnet34_from_scratch.log` are the training logs of the clonalnet_resnet18 and the clonalnet_resnet34.
44 |
45 | ## Baseline
46 | To validate the baseline results, please run:
47 | ```bash
48 | # resnet18 / resnet34
49 | python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a resnet18 --seed 10 -e --pretrained --gpu 0
50 | resnet34
51 | # mobilenet_V2
52 | python main.py --paradigm baseline --data /path/to/ILSVRC2012 -a mobilenet_v2 --seed 10 -e --pretrained --gpu 0 --resume models/_pytorch_pretrained_checkpoints/baseline_mobilenet_v2_model_best.pth.tar
53 |
54 | ```
55 | ## Results on ILSVRC2012
56 | |Models|Acc@1|Acc@5|Checkpoint|
57 | |------|-----|-----|-----|
58 | |ResNet18|69.760|89.082|[PyTorch Pre-trained](https://pytorch.org/vision/stable/models.html)|
59 | |ClonalNet (r18)|70.422|89.562|[Baidu](https://pan.baidu.com/s/17GAra665g3Y9Uf9l_XIffg), code:1234; [Google Driver](https://drive.google.com/file/d/1VuYREp2tWDyamjzphMeb0pGMIlVTN4Se/view?usp=sharing)|
60 | |ResNet34|73.310|91.420|[PyTorch Pre-trained](https://pytorch.org/vision/stable/models.html)|
61 | |ClonalNet (r34)|74.366|91.884|[Baidu](https://pan.baidu.com/s/1E-MocRLYlFUxc93_E-Ndtw), code:1234; [Google Driver](https://drive.google.com/file/d/1NfnyQMP0dy3eYNuaIfs56fFj8nG4_L9f/view?usp=sharing)|
62 | |MobileNet_v2|65.558|86.744|[Baidu](https://pan.baidu.com/s/11f5wxVbuDtKQ2WguIPDtbw), code:1234; [Google Driver](https://drive.google.com/file/d/1EecCV14dXD9yzFNfgbcTBw_CDPLZQi6i/view?usp=sharing)|
63 | |ClonalNet (MobileNet_v2)|66.300|87.232|[Baidu](https://pan.baidu.com/s/16aAsj3-RKIoL-k4Bydt14w); [Google Driver](https://drive.google.com/file/d/1nDfBea0GSQ4Fj8cdleRhwocJw8oO2T60/view?usp=sharing)|
64 |
65 | you can also download more checkpoints at here: [Baidu](https://pan.baidu.com/s/1BPcyHRWokKcfpGTAiuVoug), code: 1234; [Google Driver](https://drive.google.com/drive/folders/18KBAvXccSPZDAZOjVKwLqZ9ZKGqL4RMf?usp=sharing).
66 |
67 | ## Reference
68 | If you find our work is helpful to you, please cite it:
69 | ```bash
70 | @article{zhang2022focusnet,
71 | title={FocusNet: Classifying better by focusing on confusing classes},
72 | author={Zhang, Xue and Sheng, Zehua and Shen, Hui-Liang},
73 | journal={Pattern Recognition},
74 | pages={108709},
75 | year={2022},
76 | publisher={Elsevier}
77 | }
78 | ```
--------------------------------------------------------------------------------
/clonalnet_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | from enum import Enum
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.optim
15 | import torch.multiprocessing as mp
16 | import torch.utils.data
17 | import torch.utils.data.distributed
18 | import torchvision.transforms as transforms
19 | import torchvision.datasets as datasets
20 | import models
21 | from distill_loss import fpLoss
22 |
23 | model_names = sorted(name for name in models.__dict__
24 | if name.islower() and not name.startswith("__")
25 | and callable(models.__dict__[name]))
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
28 | parser.add_argument('--paradigm', default='clonalnet', type=str)
29 | parser.add_argument('--data', metavar='DIR',
30 | help='path to dataset')
31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenet_v2',
32 | choices=model_names,
33 | help='model architecture: ' +
34 | ' | '.join(model_names) +
35 | ' (default: resnet18)')
36 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
37 | help='number of data loading workers (default: 4)')
38 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
39 | help='number of total epochs to run')
40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
41 | help='manual epoch number (useful on restarts)')
42 | parser.add_argument('-b', '--batch-size', default=256, type=int,
43 | metavar='N',
44 | help='mini-batch size (default: 256), this is the total '
45 | 'batch size of all GPUs on the current node when '
46 | 'using Data Parallel or Distributed Data Parallel')
47 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
48 | metavar='LR', help='initial learning rate', dest='lr')
49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
50 | help='momentum')
51 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
52 | metavar='W', help='weight decay (default: 1e-4)',
53 | dest='weight_decay')
54 | parser.add_argument('-p', '--print-freq', default=10, type=int,
55 | metavar='N', help='print frequency (default: 10)')
56 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
57 | help='path to latest checkpoint (default: none)')
58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
59 | help='evaluate model on validation set')
60 | parser.add_argument('-ebc', '--evaluate_baseline_and_clonalnet_before_training_clonalnet', action='store_true',
61 | help='evaluate baseline model and clonalnet on validation set')
62 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
63 | help='use pre-trained model')
64 | parser.add_argument('--world-size', default=-1, type=int,
65 | help='number of nodes for distributed training')
66 | parser.add_argument('--rank', default=-1, type=int,
67 | help='node rank for distributed training')
68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
69 | help='url used to set up distributed training')
70 | parser.add_argument('--dist-backend', default='nccl', type=str,
71 | help='distributed backend')
72 | parser.add_argument('--seed', default=None, type=int,
73 | help='seed for initializing training. ')
74 | parser.add_argument('--gpu', default=None, type=int,
75 | help='GPU id to use.')
76 | parser.add_argument('--multiprocessing-distributed', action='store_true',
77 | help='Use multi-processing distributed training to launch '
78 | 'N processes per node, which has N GPUs. This is the '
79 | 'fastest way to use PyTorch for either single node or '
80 | 'multi node data parallel training')
81 |
82 | best_acc1 = 0
83 |
84 |
85 | def main():
86 | args = parser.parse_args(
87 | # [training
88 | #'--data', '/UsrFile/yjc/xzq/ssddata/zx/ILSVRC2012',
89 | #'-a', 'resnet18','--seed', '42',
90 | #'--gpu', '1', '-ebc',
91 |
92 | ## validation
93 | ## '-e', '--resume', 'clonalnet_resnet18_model_best.pth.tar']
94 | )
95 |
96 | for _argsk, _argsv in args._get_kwargs():
97 | print('--{} {}'.format(_argsk, _argsv))
98 |
99 | if args.seed is not None:
100 | random.seed(args.seed)
101 | torch.manual_seed(args.seed)
102 | cudnn.deterministic = True
103 | warnings.warn('You have chosen to seed training. '
104 | 'This will turn on the CUDNN deterministic setting, '
105 | 'which can slow down your training considerably! '
106 | 'You may see unexpected behavior when restarting '
107 | 'from checkpoints.')
108 |
109 | if args.gpu is not None:
110 | warnings.warn('You have chosen a specific GPU. This will completely '
111 | 'disable data parallelism.')
112 |
113 | if args.dist_url == "env://" and args.world_size == -1:
114 | args.world_size = int(os.environ["WORLD_SIZE"])
115 |
116 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
117 |
118 | ngpus_per_node = torch.cuda.device_count()
119 | if args.multiprocessing_distributed:
120 | # Since we have ngpus_per_node processes per node, the total world_size
121 | # needs to be adjusted accordingly
122 | args.world_size = ngpus_per_node * args.world_size
123 | # Use torch.multiprocessing.spawn to launch distributed processes: the
124 | # main_worker process function
125 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
126 | else:
127 | # Simply call main_worker function
128 | main_worker(args.gpu, ngpus_per_node, args)
129 |
130 |
131 | def main_worker(gpu, ngpus_per_node, args):
132 | global best_acc1
133 | args.gpu = gpu
134 |
135 | if args.gpu is not None:
136 | print("Use GPU: {} for training".format(args.gpu))
137 |
138 | if args.distributed:
139 | if args.dist_url == "env://" and args.rank == -1:
140 | args.rank = int(os.environ["RANK"])
141 | if args.multiprocessing_distributed:
142 | # For multiprocessing distributed training, rank needs to be the
143 | # global rank among all the processes
144 | args.rank = args.rank * ngpus_per_node + gpu
145 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
146 | world_size=args.world_size, rank=args.rank)
147 | # create model
148 | if args.pretrained:
149 | print("=> using pre-trained model '{}'".format(args.arch))
150 | model = models.__dict__[args.arch](pretrained=True)
151 | else:
152 | print("=> Baseline: using pre-trained model '{}'".format(args.arch))
153 | base_model = models.__dict__[args.arch](pretrained=True)
154 | base_model.eval()
155 |
156 | print("=> ClonalNet: creating model '{}'".format(args.arch))
157 | model = models.__dict__[args.arch](pretrained=False)
158 |
159 | if not torch.cuda.is_available():
160 | print('using CPU, this will be slow')
161 | elif args.distributed:
162 | # For multiprocessing distributed, DistributedDataParallel constructor
163 | # should always set the single device scope, otherwise,
164 | # DistributedDataParallel will use all available devices.
165 | if args.gpu is not None:
166 | torch.cuda.set_device(args.gpu)
167 | model.cuda(args.gpu)
168 | # When using a single GPU per process and per
169 | # DistributedDataParallel, we need to divide the batch size
170 | # ourselves based on the total number of GPUs we have
171 | args.batch_size = int(args.batch_size / ngpus_per_node)
172 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
173 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
174 | else:
175 | model.cuda()
176 | # DistributedDataParallel will divide and allocate batch_size to all
177 | # available GPUs if device_ids are not set
178 | model = torch.nn.parallel.DistributedDataParallel(model)
179 | elif args.gpu is not None:
180 | torch.cuda.set_device(args.gpu)
181 | base_model = base_model.cuda(args.gpu)
182 | model = model.cuda(args.gpu)
183 | else:
184 | # DataParallel will divide and allocate batch_size to all available GPUs
185 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
186 | model.features = torch.nn.DataParallel(model.features)
187 | model.cuda()
188 | else:
189 | model = torch.nn.DataParallel(model).cuda()
190 |
191 | # define loss function (criterion) and optimizer
192 | criterion_ce = nn.CrossEntropyLoss().cuda(args.gpu)
193 | criterion = fpLoss().cuda(args.gpu)
194 |
195 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
196 | momentum=args.momentum,
197 | weight_decay=args.weight_decay)
198 |
199 | # optionally resume from a checkpoint
200 | if args.resume:
201 | if os.path.isfile(args.resume):
202 | print("=> loading checkpoint '{}'".format(args.resume))
203 | if args.gpu is None:
204 | checkpoint = torch.load(args.resume)
205 | else:
206 | # Map model to be loaded to specified single gpu.
207 | loc = 'cuda:{}'.format(args.gpu)
208 | checkpoint = torch.load(args.resume, map_location=loc)
209 | args.start_epoch = checkpoint['epoch']
210 | best_acc1 = checkpoint['best_acc1']
211 | if args.gpu is not None:
212 | # best_acc1 may be from a checkpoint from a different GPU
213 | best_acc1 = best_acc1.to(args.gpu)
214 | model.load_state_dict(checkpoint['state_dict'])
215 | optimizer.load_state_dict(checkpoint['optimizer'])
216 | print("=> loaded checkpoint '{}' (epoch {})"
217 | .format(args.resume, checkpoint['epoch']))
218 | else:
219 | print("=> no checkpoint found at '{}'".format(args.resume))
220 |
221 | cudnn.benchmark = True
222 |
223 | # Data loading code
224 | traindir = os.path.join(args.data, 'train')
225 | valdir = os.path.join(args.data, 'valid')
226 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
227 | std=[0.229, 0.224, 0.225])
228 |
229 | train_dataset = datasets.ImageFolder(
230 | traindir,
231 | transforms.Compose([
232 | transforms.RandomResizedCrop(224),
233 | transforms.RandomHorizontalFlip(),
234 | transforms.ToTensor(),
235 | normalize,
236 | ]))
237 |
238 | if args.distributed:
239 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
240 | else:
241 | train_sampler = None
242 |
243 | train_loader = torch.utils.data.DataLoader(
244 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
245 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
246 |
247 | val_loader = torch.utils.data.DataLoader(
248 | datasets.ImageFolder(valdir, transforms.Compose([
249 | transforms.Resize(256),
250 | transforms.CenterCrop(224),
251 | transforms.ToTensor(),
252 | normalize,
253 | ])),
254 | batch_size=args.batch_size, shuffle=False,
255 | num_workers=args.workers, pin_memory=True)
256 |
257 | if args.evaluate:
258 | validate(val_loader, model, criterion_ce, args)
259 | return
260 |
261 | if args.evaluate_baseline_and_clonalnet_before_training_clonalnet:
262 | print("=> Baseline: evaluating pre-trained model '{}'".format(args.arch))
263 | validate(val_loader, base_model, criterion_ce, args)
264 | print('-'*100)
265 | print("=> ClonalNet: evaluating random model '{}'".format(args.arch))
266 | validate(val_loader, model, criterion_ce, args)
267 |
268 | for epoch in range(args.start_epoch, args.epochs):
269 | if args.distributed:
270 | train_sampler.set_epoch(epoch)
271 | adjust_learning_rate(optimizer, epoch, args)
272 |
273 | # train for one epoch
274 | train(train_loader, base_model, model, criterion, optimizer, epoch, args)
275 |
276 | # evaluate on validation set
277 | acc1 = validate(val_loader, model, criterion_ce, args)
278 |
279 | # remember best acc@1 and save checkpoint
280 | is_best = acc1 > best_acc1
281 | best_acc1 = max(acc1, best_acc1)
282 |
283 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
284 | and args.rank % ngpus_per_node == 0):
285 | save_checkpoint({
286 | 'epoch': epoch + 1,
287 | 'arch': args.arch,
288 | 'state_dict': model.state_dict(),
289 | 'best_acc1': best_acc1,
290 | 'optimizer': optimizer.state_dict(),
291 | }, is_best,
292 | '{}_{}_checkpoint.pth.tar'.format(args.paradigm, args.arch))
293 |
294 |
295 | def train(train_loader, base_model, model, criterion, optimizer, epoch, args):
296 | batch_time = AverageMeter('Time', ':6.3f')
297 | data_time = AverageMeter('Data', ':6.3f')
298 | losses = AverageMeter('Loss', ':.4e')
299 | top1 = AverageMeter('Acc@1', ':6.2f')
300 | top5 = AverageMeter('Acc@5', ':6.2f')
301 | progress = ProgressMeter(
302 | len(train_loader),
303 | [batch_time, data_time, losses, top1, top5],
304 | prefix="Epoch: [{}]".format(epoch))
305 |
306 | # switch to train mode
307 | base_model.eval()
308 | model.train()
309 |
310 | end = time.time()
311 | for i, (images, target) in enumerate(train_loader):
312 | # measure data loading time
313 | data_time.update(time.time() - end)
314 |
315 | if args.gpu is not None:
316 | images = images.cuda(args.gpu, non_blocking=True)
317 | if torch.cuda.is_available():
318 | target = target.cuda(args.gpu, non_blocking=True)
319 |
320 | # compute output
321 | with torch.no_grad():
322 | bi_output = base_model(images)
323 | output = model(images)
324 | loss = criterion(target, output, bi_output)
325 |
326 | # measure accuracy and record loss
327 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
328 | losses.update(loss.item(), images.size(0))
329 | top1.update(acc1[0], images.size(0))
330 | top5.update(acc5[0], images.size(0))
331 |
332 | # compute gradient and do SGD step
333 | optimizer.zero_grad()
334 | loss.backward()
335 | optimizer.step()
336 |
337 | # measure elapsed time
338 | batch_time.update(time.time() - end)
339 | end = time.time()
340 |
341 | if i % args.print_freq == 0:
342 | progress.display(i)
343 |
344 |
345 | def validate(val_loader, model, criterion, args):
346 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
347 | losses = AverageMeter('Loss', ':.4e', Summary.NONE)
348 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
349 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
350 | progress = ProgressMeter(
351 | len(val_loader),
352 | [batch_time, losses, top1, top5],
353 | prefix='Test: ')
354 |
355 | # switch to evaluate mode
356 | model.eval()
357 |
358 | with torch.no_grad():
359 | end = time.time()
360 | for i, (images, target) in enumerate(val_loader):
361 | if args.gpu is not None:
362 | images = images.cuda(args.gpu, non_blocking=True)
363 | if torch.cuda.is_available():
364 | target = target.cuda(args.gpu, non_blocking=True)
365 |
366 | # compute output
367 | output = model(images)
368 | loss = criterion(output, target)
369 |
370 | # measure accuracy and record loss
371 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
372 | losses.update(loss.item(), images.size(0))
373 | top1.update(acc1[0], images.size(0))
374 | top5.update(acc5[0], images.size(0))
375 |
376 | # measure elapsed time
377 | batch_time.update(time.time() - end)
378 | end = time.time()
379 |
380 | if i % args.print_freq == 0:
381 | progress.display(i)
382 |
383 | progress.display_summary()
384 |
385 | return top1.avg
386 |
387 |
388 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
389 | torch.save(state, filename)
390 | if is_best:
391 | shutil.copyfile(filename, filename.split('checkpoint.pth.tar')[0]+'model_best.pth.tar')
392 |
393 |
394 | class Summary(Enum):
395 | NONE = 0
396 | AVERAGE = 1
397 | SUM = 2
398 | COUNT = 3
399 |
400 |
401 | class AverageMeter(object):
402 | """Computes and stores the average and current value"""
403 |
404 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
405 | self.name = name
406 | self.fmt = fmt
407 | self.summary_type = summary_type
408 | self.reset()
409 |
410 | def reset(self):
411 | self.val = 0
412 | self.avg = 0
413 | self.sum = 0
414 | self.count = 0
415 |
416 | def update(self, val, n=1):
417 | self.val = val
418 | self.sum += val * n
419 | self.count += n
420 | self.avg = self.sum / self.count
421 |
422 | def __str__(self):
423 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
424 | return fmtstr.format(**self.__dict__)
425 |
426 | def summary(self):
427 | fmtstr = ''
428 | if self.summary_type is Summary.NONE:
429 | fmtstr = ''
430 | elif self.summary_type is Summary.AVERAGE:
431 | fmtstr = '{name} {avg:.3f}'
432 | elif self.summary_type is Summary.SUM:
433 | fmtstr = '{name} {sum:.3f}'
434 | elif self.summary_type is Summary.COUNT:
435 | fmtstr = '{name} {count:.3f}'
436 | else:
437 | raise ValueError('invalid summary type %r' % self.summary_type)
438 |
439 | return fmtstr.format(**self.__dict__)
440 |
441 |
442 | class ProgressMeter(object):
443 | def __init__(self, num_batches, meters, prefix=""):
444 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
445 | self.meters = meters
446 | self.prefix = prefix
447 |
448 | def display(self, batch):
449 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
450 | entries += [str(meter) for meter in self.meters]
451 | print('\t'.join(entries))
452 |
453 | def display_summary(self):
454 | entries = [" *"]
455 | entries += [meter.summary() for meter in self.meters]
456 | print(' '.join(entries))
457 |
458 | def _get_batch_fmtstr(self, num_batches):
459 | num_digits = len(str(num_batches // 1))
460 | fmt = '{:' + str(num_digits) + 'd}'
461 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
462 |
463 |
464 | def adjust_learning_rate(optimizer, epoch, args):
465 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
466 | lr = args.lr * (0.1 ** (epoch // 30))
467 | for param_group in optimizer.param_groups:
468 | param_group['lr'] = lr
469 |
470 |
471 | def accuracy(output, target, topk=(1,)):
472 | """Computes the accuracy over the k top predictions for the specified values of k"""
473 | with torch.no_grad():
474 | maxk = max(topk)
475 | batch_size = target.size(0)
476 |
477 | _, pred = output.topk(maxk, 1, True, True)
478 | pred = pred.t()
479 | correct = pred.eq(target.view(1, -1).expand_as(pred))
480 |
481 | res = []
482 | for k in topk:
483 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
484 | res.append(correct_k.mul_(100.0 / batch_size))
485 | return res
486 |
487 |
488 | if __name__ == '__main__':
489 | main()
--------------------------------------------------------------------------------
/distill_loss/KD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DistillKL(nn.Module):
7 | def __init__(self, args):
8 | super(DistillKL, self).__init__()
9 | self.T = args.temperature
10 |
11 | def forward(self, y_s, y_t):
12 | p_s = F.log_softmax(y_s/self.T, dim=1)
13 | p_t = F.softmax(y_t/self.T, dim=1)
14 | loss = F.kl_div(p_s, p_t.detach(), reduction='sum') * (self.T**2) / y_s.shape[0]
15 | return loss
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/distill_loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .KD import *
2 | from .fpLoss import *
3 |
--------------------------------------------------------------------------------
/distill_loss/fpLoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | eps = 1e-10
5 |
6 | class fpLoss(nn.Module):
7 | def __init__(self, ):
8 | super(fpLoss, self).__init__()
9 |
10 | def cross_entropy(self, logits, onehot_labels, ls=False):
11 | if ls:
12 | onehot_labels = 0.9 * onehot_labels + 0.1 / logits.size(-1)
13 | onehot_labels = onehot_labels.double()
14 | return (-1.0 * torch.mean(torch.sum(onehot_labels * F.log_softmax(logits, -1), -1), 0))
15 |
16 |
17 | def neg_entropy(self, logits):
18 | probs = F.softmax(logits, -1)
19 | return torch.mean(torch.sum(probs * F.log_softmax(logits, -1), -1), 0)
20 |
21 | def forward(self, targets, outputs, bi_outputs,):
22 | # Loss_cls
23 | difference = F.softmax(outputs, -1) - F.softmax(bi_outputs, -1) # 与FRSKD比较时,使用了detach()
24 | onehot_labels = F.one_hot(targets, outputs.size(-1))
25 | loss_cls = self.cross_entropy(outputs + difference, onehot_labels, True)
26 | # tiny-imagenet 上, alpha=1, beta=1, ls=True, best test acc: 0.5870
27 | # tiny-imagenet 上, alpha=1, beta=1, ls=False, best test acc: 0.5840
28 | # 所以ls不是主要原因
29 |
30 | # R_attention
31 |
32 | # multi_warm_lb = bi_outputs.detach() > 0.0
33 | '''因为推导发现,使用multi-warm label的交叉熵梯度为 hat{y(x)} - m(x),
34 | 其中hat{y(x)}是clonalnet预测的概率分布,m(x)表示的是multi-warm label,其中非零值为1/len(m(x)!=0)
35 | 对比正常交叉熵的损失值是 hat{y(x)} - y(x) 其中y(x)为one-hot label,
36 | 所以正确位置的梯度为负值,不正确位置的梯度为正值,也就实现了正确位置预测变大,不正确位置预测变小,也就使得预测的概率更加接近于one-hot label
37 | 但是发现,
38 | 使用multi-warm label的交叉熵会使得 hat{y(x)} 大于 m(x) 的梯度为正,小于 m(x) 的梯度为负值,这意味着,
39 | 预测的概率会趋向于m(x)的分布,所以,应该使得m(x)中的非零值尽量少一些,这样只关注几个很混淆的类就可以了,这可以使非零值更大一些
40 | 如果m(x)的非零值太小,就损害了自信的预测了
41 | 所以将multi-warm label做了调整
42 | '''
43 | # multi_warm_lb = bi_outputs > 0.0
44 | multi_warm_lb = F.softmax(bi_outputs/2, -1) > 1.0/bi_outputs.size(-1)
45 | multi_warm_lb = torch.clamp(multi_warm_lb.double() + onehot_labels, 0, 1)
46 | multi_warm_lb = multi_warm_lb/torch.sum(multi_warm_lb, -1, True)
47 | R_attention = self.cross_entropy(outputs, multi_warm_lb.detach(), False)# 与FRSKD比较时,使用了detach()
48 |
49 | # R_entropy
50 | R_negtropy = self.neg_entropy(outputs)
51 |
52 | fp_loss = loss_cls + R_attention + R_negtropy
53 |
54 | # test for CE + neg_entropy
55 | # loss_cls = self.cross_entropy(outputs, onehot_labels)
56 | # fp_loss = loss_cls + R_negtropy # 已经试验证明 CE + negtive_entropy的CUB200精度(59.10%)低于loss_cls + negtive_entropy的精度(60.72%)
57 | return fp_loss
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | from enum import Enum
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.optim
15 | import torch.multiprocessing as mp
16 | import torch.utils.data
17 | import torch.utils.data.distributed
18 | import torchvision.transforms as transforms
19 | import torchvision.datasets as datasets
20 | import models
21 |
22 | model_names = sorted(name for name in models.__dict__
23 | if name.islower() and not name.startswith("__")
24 | and callable(models.__dict__[name]))
25 |
26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
27 | parser.add_argument('--paradigm', default='baseline', type=str)
28 | parser.add_argument('--data', metavar='DIR',
29 | help='path to dataset')
30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='mobilenet_v2',
31 | choices=model_names,
32 | help='model architecture: ' +
33 | ' | '.join(model_names) +
34 | ' (default: resnet18)')
35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
36 | help='number of data loading workers (default: 4)')
37 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
38 | help='number of total epochs to run')
39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
40 | help='manual epoch number (useful on restarts)')
41 | parser.add_argument('-b', '--batch-size', default=256, type=int,
42 | metavar='N',
43 | help='mini-batch size (default: 256), this is the total '
44 | 'batch size of all GPUs on the current node when '
45 | 'using Data Parallel or Distributed Data Parallel')
46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
47 | metavar='LR', help='initial learning rate', dest='lr')
48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
49 | help='momentum')
50 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
51 | metavar='W', help='weight decay (default: 1e-4)',
52 | dest='weight_decay')
53 | parser.add_argument('-p', '--print-freq', default=10, type=int,
54 | metavar='N', help='print frequency (default: 10)')
55 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
56 | help='path to latest checkpoint (default: none)')
57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
58 | help='evaluate model on validation set')
59 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
60 | help='use pre-trained model')
61 | parser.add_argument('--world-size', default=-1, type=int,
62 | help='number of nodes for distributed training')
63 | parser.add_argument('--rank', default=-1, type=int,
64 | help='node rank for distributed training')
65 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
66 | help='url used to set up distributed training')
67 | parser.add_argument('--dist-backend', default='nccl', type=str,
68 | help='distributed backend')
69 | parser.add_argument('--seed', default=None, type=int,
70 | help='seed for initializing training. ')
71 | parser.add_argument('--gpu', default=None, type=int,
72 | help='GPU id to use.')
73 | parser.add_argument('--multiprocessing-distributed', action='store_true',
74 | help='Use multi-processing distributed training to launch '
75 | 'N processes per node, which has N GPUs. This is the '
76 | 'fastest way to use PyTorch for either single node or '
77 | 'multi node data parallel training')
78 |
79 | best_acc1 = 0
80 |
81 |
82 | def main():
83 | args = parser.parse_args(
84 | # ['--paradigm', 'baseline',
85 | # '--data', '/UsrFile/yjc/xzq/ssddata/zx/ILSVRC2012',
86 | # '-a', 'resnet18', '--seed', '10', '-e', --pretrained,
87 | # '--resume', './baseline_mobilenet_v2_model_best.pth.tar',
88 | # '--gpu', '1']
89 | )
90 |
91 | for _argsk, _argsv in args._get_kwargs():
92 | print('--{} {}'.format(_argsk, _argsv))
93 |
94 | if args.seed is not None:
95 | random.seed(args.seed)
96 | torch.manual_seed(args.seed)
97 | cudnn.deterministic = True
98 | warnings.warn('You have chosen to seed training. '
99 | 'This will turn on the CUDNN deterministic setting, '
100 | 'which can slow down your training considerably! '
101 | 'You may see unexpected behavior when restarting '
102 | 'from checkpoints.')
103 |
104 | if args.gpu is not None:
105 | warnings.warn('You have chosen a specific GPU. This will completely '
106 | 'disable data parallelism.')
107 |
108 | if args.dist_url == "env://" and args.world_size == -1:
109 | args.world_size = int(os.environ["WORLD_SIZE"])
110 |
111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
112 |
113 | ngpus_per_node = torch.cuda.device_count()
114 | if args.multiprocessing_distributed:
115 | # Since we have ngpus_per_node processes per node, the total world_size
116 | # needs to be adjusted accordingly
117 | args.world_size = ngpus_per_node * args.world_size
118 | # Use torch.multiprocessing.spawn to launch distributed processes: the
119 | # main_worker process function
120 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
121 | else:
122 | # Simply call main_worker function
123 | main_worker(args.gpu, ngpus_per_node, args)
124 |
125 |
126 | def main_worker(gpu, ngpus_per_node, args):
127 | global best_acc1
128 | args.gpu = gpu
129 |
130 | if args.gpu is not None:
131 | print("Use GPU: {} for training".format(args.gpu))
132 |
133 | if args.distributed:
134 | if args.dist_url == "env://" and args.rank == -1:
135 | args.rank = int(os.environ["RANK"])
136 | if args.multiprocessing_distributed:
137 | # For multiprocessing distributed training, rank needs to be the
138 | # global rank among all the processes
139 | args.rank = args.rank * ngpus_per_node + gpu
140 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
141 | world_size=args.world_size, rank=args.rank)
142 | # create model
143 | if args.pretrained:
144 | print("=> using pre-trained model '{}'".format(args.arch))
145 | model = models.__dict__[args.arch](pretrained=True)
146 | else:
147 | print("=> creating model '{}'".format(args.arch))
148 | model = models.__dict__[args.arch]()
149 |
150 | if not torch.cuda.is_available():
151 | print('using CPU, this will be slow')
152 | elif args.distributed:
153 | # For multiprocessing distributed, DistributedDataParallel constructor
154 | # should always set the single device scope, otherwise,
155 | # DistributedDataParallel will use all available devices.
156 | if args.gpu is not None:
157 | torch.cuda.set_device(args.gpu)
158 | model.cuda(args.gpu)
159 | # When using a single GPU per process and per
160 | # DistributedDataParallel, we need to divide the batch size
161 | # ourselves based on the total number of GPUs we have
162 | args.batch_size = int(args.batch_size / ngpus_per_node)
163 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
164 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
165 | else:
166 | model.cuda()
167 | # DistributedDataParallel will divide and allocate batch_size to all
168 | # available GPUs if device_ids are not set
169 | model = torch.nn.parallel.DistributedDataParallel(model)
170 | elif args.gpu is not None:
171 | torch.cuda.set_device(args.gpu)
172 | model = model.cuda(args.gpu)
173 | else:
174 | # DataParallel will divide and allocate batch_size to all available GPUs
175 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
176 | model.features = torch.nn.DataParallel(model.features)
177 | model.cuda()
178 | else:
179 | model = torch.nn.DataParallel(model).cuda()
180 |
181 | # define loss function (criterion) and optimizer
182 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
183 |
184 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
185 | momentum=args.momentum,
186 | weight_decay=args.weight_decay)
187 |
188 | # optionally resume from a checkpoint
189 | if args.resume:
190 | if os.path.isfile(args.resume):
191 | print("=> loading checkpoint '{}'".format(args.resume))
192 | if args.gpu is None:
193 | checkpoint = torch.load(args.resume)
194 | else:
195 | # Map model to be loaded to specified single gpu.
196 | loc = 'cuda:{}'.format(args.gpu)
197 | checkpoint = torch.load(args.resume, map_location=loc)
198 | args.start_epoch = checkpoint['epoch']
199 | best_acc1 = checkpoint['best_acc1']
200 | if args.gpu is not None:
201 | # best_acc1 may be from a checkpoint from a different GPU
202 | best_acc1 = best_acc1.to(args.gpu)
203 | model.load_state_dict(checkpoint['state_dict'])
204 | optimizer.load_state_dict(checkpoint['optimizer'])
205 | print("=> loaded checkpoint '{}' (epoch {})"
206 | .format(args.resume, checkpoint['epoch']))
207 | else:
208 | print("=> no checkpoint found at '{}'".format(args.resume))
209 |
210 | cudnn.benchmark = True
211 |
212 | # Data loading code
213 | traindir = os.path.join(args.data, 'train')
214 | valdir = os.path.join(args.data, 'valid')
215 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
216 | std=[0.229, 0.224, 0.225])
217 |
218 | train_dataset = datasets.ImageFolder(
219 | traindir,
220 | transforms.Compose([
221 | transforms.RandomResizedCrop(224),
222 | transforms.RandomHorizontalFlip(),
223 | transforms.ToTensor(),
224 | normalize,
225 | ]))
226 |
227 | if args.distributed:
228 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
229 | else:
230 | train_sampler = None
231 |
232 | train_loader = torch.utils.data.DataLoader(
233 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
234 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
235 |
236 | val_loader = torch.utils.data.DataLoader(
237 | datasets.ImageFolder(valdir, transforms.Compose([
238 | transforms.Resize(256),
239 | transforms.CenterCrop(224),
240 | transforms.ToTensor(),
241 | normalize,
242 | ])),
243 | batch_size=args.batch_size, shuffle=False,
244 | num_workers=args.workers, pin_memory=True)
245 |
246 | if args.evaluate:
247 | validate(val_loader, model, criterion, args)
248 | return
249 |
250 | for epoch in range(args.start_epoch, args.epochs):
251 | if args.distributed:
252 | train_sampler.set_epoch(epoch)
253 | adjust_learning_rate(optimizer, epoch, args)
254 |
255 | # train for one epoch
256 | train(train_loader, model, criterion, optimizer, epoch, args)
257 |
258 | # evaluate on validation set
259 | acc1 = validate(val_loader, model, criterion, args)
260 |
261 | # remember best acc@1 and save checkpoint
262 | is_best = acc1 > best_acc1
263 | best_acc1 = max(acc1, best_acc1)
264 |
265 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
266 | and args.rank % ngpus_per_node == 0):
267 | save_checkpoint({
268 | 'epoch': epoch + 1,
269 | 'arch': args.arch,
270 | 'state_dict': model.state_dict(),
271 | 'best_acc1': best_acc1,
272 | 'optimizer': optimizer.state_dict(),
273 | }, is_best,
274 | '{}_{}_checkpoint.pth.tar'.format(args.paradigm, args.arch))
275 |
276 |
277 | def train(train_loader, model, criterion, optimizer, epoch, args):
278 | batch_time = AverageMeter('Time', ':6.3f')
279 | data_time = AverageMeter('Data', ':6.3f')
280 | losses = AverageMeter('Loss', ':.4e')
281 | top1 = AverageMeter('Acc@1', ':6.2f')
282 | top5 = AverageMeter('Acc@5', ':6.2f')
283 | progress = ProgressMeter(
284 | len(train_loader),
285 | [batch_time, data_time, losses, top1, top5],
286 | prefix="Epoch: [{}]".format(epoch))
287 |
288 | # switch to train mode
289 | model.train()
290 |
291 | end = time.time()
292 | for i, (images, target) in enumerate(train_loader):
293 | # measure data loading time
294 | data_time.update(time.time() - end)
295 |
296 | if args.gpu is not None:
297 | images = images.cuda(args.gpu, non_blocking=True)
298 | if torch.cuda.is_available():
299 | target = target.cuda(args.gpu, non_blocking=True)
300 |
301 | # compute output
302 | output = model(images)
303 | loss = criterion(output, target)
304 |
305 | # measure accuracy and record loss
306 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
307 | losses.update(loss.item(), images.size(0))
308 | top1.update(acc1[0], images.size(0))
309 | top5.update(acc5[0], images.size(0))
310 |
311 | # compute gradient and do SGD step
312 | optimizer.zero_grad()
313 | loss.backward()
314 | optimizer.step()
315 |
316 | # measure elapsed time
317 | batch_time.update(time.time() - end)
318 | end = time.time()
319 |
320 | if i % args.print_freq == 0:
321 | progress.display(i)
322 |
323 |
324 | def validate(val_loader, model, criterion, args):
325 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
326 | losses = AverageMeter('Loss', ':.4e', Summary.NONE)
327 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
328 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
329 | progress = ProgressMeter(
330 | len(val_loader),
331 | [batch_time, losses, top1, top5],
332 | prefix='Test: ')
333 |
334 | # switch to evaluate mode
335 | model.eval()
336 |
337 | with torch.no_grad():
338 | end = time.time()
339 | for i, (images, target) in enumerate(val_loader):
340 | if args.gpu is not None:
341 | images = images.cuda(args.gpu, non_blocking=True)
342 | if torch.cuda.is_available():
343 | target = target.cuda(args.gpu, non_blocking=True)
344 |
345 | # compute output
346 | output = model(images)
347 | loss = criterion(output, target)
348 |
349 | # measure accuracy and record loss
350 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
351 | losses.update(loss.item(), images.size(0))
352 | top1.update(acc1[0], images.size(0))
353 | top5.update(acc5[0], images.size(0))
354 |
355 | # measure elapsed time
356 | batch_time.update(time.time() - end)
357 | end = time.time()
358 |
359 | if i % args.print_freq == 0:
360 | progress.display(i)
361 |
362 | progress.display_summary()
363 |
364 | return top1.avg
365 |
366 |
367 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
368 | torch.save(state, filename)
369 | if is_best:
370 | shutil.copyfile(filename, filename.split('checkpoint.pth.tar')[0]+'model_best.pth.tar')
371 |
372 |
373 | class Summary(Enum):
374 | NONE = 0
375 | AVERAGE = 1
376 | SUM = 2
377 | COUNT = 3
378 |
379 |
380 | class AverageMeter(object):
381 | """Computes and stores the average and current value"""
382 |
383 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
384 | self.name = name
385 | self.fmt = fmt
386 | self.summary_type = summary_type
387 | self.reset()
388 |
389 | def reset(self):
390 | self.val = 0
391 | self.avg = 0
392 | self.sum = 0
393 | self.count = 0
394 |
395 | def update(self, val, n=1):
396 | self.val = val
397 | self.sum += val * n
398 | self.count += n
399 | self.avg = self.sum / self.count
400 |
401 | def __str__(self):
402 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
403 | return fmtstr.format(**self.__dict__)
404 |
405 | def summary(self):
406 | fmtstr = ''
407 | if self.summary_type is Summary.NONE:
408 | fmtstr = ''
409 | elif self.summary_type is Summary.AVERAGE:
410 | fmtstr = '{name} {avg:.3f}'
411 | elif self.summary_type is Summary.SUM:
412 | fmtstr = '{name} {sum:.3f}'
413 | elif self.summary_type is Summary.COUNT:
414 | fmtstr = '{name} {count:.3f}'
415 | else:
416 | raise ValueError('invalid summary type %r' % self.summary_type)
417 |
418 | return fmtstr.format(**self.__dict__)
419 |
420 |
421 | class ProgressMeter(object):
422 | def __init__(self, num_batches, meters, prefix=""):
423 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
424 | self.meters = meters
425 | self.prefix = prefix
426 |
427 | def display(self, batch):
428 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
429 | entries += [str(meter) for meter in self.meters]
430 | print('\t'.join(entries))
431 |
432 | def display_summary(self):
433 | entries = [" *"]
434 | entries += [meter.summary() for meter in self.meters]
435 | print(' '.join(entries))
436 |
437 | def _get_batch_fmtstr(self, num_batches):
438 | num_digits = len(str(num_batches // 1))
439 | fmt = '{:' + str(num_digits) + 'd}'
440 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
441 |
442 |
443 | def adjust_learning_rate(optimizer, epoch, args):
444 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
445 | lr = args.lr * (0.1 ** (epoch // 30))
446 | for param_group in optimizer.param_groups:
447 | param_group['lr'] = lr
448 |
449 |
450 | def accuracy(output, target, topk=(1,)):
451 | """Computes the accuracy over the k top predictions for the specified values of k"""
452 | with torch.no_grad():
453 | maxk = max(topk)
454 | batch_size = target.size(0)
455 |
456 | _, pred = output.topk(maxk, 1, True, True)
457 | pred = pred.t()
458 | correct = pred.eq(target.view(1, -1).expand_as(pred))
459 |
460 | res = []
461 | for k in topk:
462 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
463 | res.append(correct_k.mul_(100.0 / batch_size))
464 | return res
465 |
466 |
467 | if __name__ == '__main__':
468 | main()
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mobilenet import *
2 | from .resnet import *
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch import nn
4 | from torchvision.models.utils import load_state_dict_from_url
5 |
6 |
7 | __all__ = ['MobileNetV2', 'mobilenet_v2']
8 |
9 |
10 | model_urls = {
11 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
12 | }
13 |
14 |
15 | def _make_divisible(v, divisor, min_value=None):
16 | """
17 | This function is taken from the original tf repo.
18 | It ensures that all layers have a channel number that is divisible by 8
19 | It can be seen here:
20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
21 | :param v:
22 | :param divisor:
23 | :param min_value:
24 | :return:
25 | """
26 | if min_value is None:
27 | min_value = divisor
28 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
29 | # Make sure that round down does not go down by more than 10%.
30 | if new_v < 0.9 * v:
31 | new_v += divisor
32 | return new_v
33 |
34 |
35 | class ConvBNReLU(nn.Sequential):
36 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
37 | padding = (kernel_size - 1) // 2
38 | super(ConvBNReLU, self).__init__(
39 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
40 | nn.BatchNorm2d(out_planes),
41 | nn.ReLU6(inplace=True)
42 | )
43 |
44 |
45 | class InvertedResidual(nn.Module):
46 | def __init__(self, inp, oup, stride, expand_ratio):
47 | super(InvertedResidual, self).__init__()
48 | self.stride = stride
49 | assert stride in [1, 2]
50 |
51 | hidden_dim = int(round(inp * expand_ratio))
52 | self.use_res_connect = self.stride == 1 and inp == oup
53 |
54 | layers = []
55 | if expand_ratio != 1:
56 | # pw
57 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
58 | layers.extend([
59 | # dw
60 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
61 | # pw-linear
62 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
63 | nn.BatchNorm2d(oup),
64 | ])
65 | self.conv = nn.Sequential(*layers)
66 |
67 | def forward(self, x):
68 | if self.use_res_connect:
69 | return x + self.conv(x)
70 | else:
71 | return self.conv(x)
72 |
73 |
74 | class MobileNetV2(nn.Module):
75 | def __init__(self,
76 | num_classes=1000,
77 | width_mult=1.0,
78 | inverted_residual_setting=None,
79 | round_nearest=8,
80 | block=None):
81 | """
82 | MobileNet V2 main class
83 |
84 | Args:
85 | num_classes (int): Number of classes
86 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
87 | inverted_residual_setting: Network structure
88 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
89 | Set to 1 to turn off rounding
90 | block: Module specifying inverted residual building block for mobilenet
91 |
92 | """
93 | super(MobileNetV2, self).__init__()
94 |
95 | if block is None:
96 | block = InvertedResidual
97 | input_channel = 32
98 | last_channel = 1280
99 |
100 | if inverted_residual_setting is None:
101 | inverted_residual_setting = [
102 | # t, c, n, s
103 | [1, 16, 1, 1],
104 | [6, 24, 2, 2],
105 | [6, 32, 3, 2],
106 | [6, 64, 4, 2],
107 | [6, 96, 3, 1],
108 | [6, 160, 3, 2],
109 | [6, 320, 1, 1],
110 | ]
111 |
112 | # only check the first element, assuming user knows t,c,n,s are required
113 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
114 | raise ValueError("inverted_residual_setting should be non-empty "
115 | "or a 4-element list, got {}".format(inverted_residual_setting))
116 |
117 | # building first layer
118 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
119 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
120 | features = [ConvBNReLU(3, input_channel, stride=2)]
121 | # building inverted residual blocks
122 | for t, c, n, s in inverted_residual_setting:
123 | output_channel = _make_divisible(c * width_mult, round_nearest)
124 | for i in range(n):
125 | stride = s if i == 0 else 1
126 | features.append(block(input_channel, output_channel, stride, expand_ratio=t))
127 | input_channel = output_channel
128 | # building last several layers
129 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
130 | # make it nn.Sequential
131 | self.features = nn.Sequential(*features)
132 |
133 | # building classifier
134 | self.classifier = nn.Sequential(
135 | nn.Dropout(0.2),
136 | nn.Linear(self.last_channel, num_classes),
137 | )
138 |
139 | # weight initialization
140 | for m in self.modules():
141 | if isinstance(m, nn.Conv2d):
142 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
143 | if m.bias is not None:
144 | nn.init.zeros_(m.bias)
145 | elif isinstance(m, nn.BatchNorm2d):
146 | nn.init.ones_(m.weight)
147 | nn.init.zeros_(m.bias)
148 | elif isinstance(m, nn.Linear):
149 | nn.init.normal_(m.weight, 0, 0.01)
150 | nn.init.zeros_(m.bias)
151 |
152 | def _forward_impl(self, x):
153 | # This exists since TorchScript doesn't support inheritance, so the superclass method
154 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
155 | x = self.features(x)
156 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
157 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
158 | x = self.classifier(x)
159 | return x
160 |
161 | def forward(self, x):
162 | return self._forward_impl(x)
163 |
164 |
165 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
166 | """
167 | Constructs a MobileNetV2 architecture from
168 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
169 |
170 | Args:
171 | pretrained (bool): If True, returns a model pre-trained on ImageNet
172 | progress (bool): If True, displays a progress bar of the download to stderr
173 | """
174 | model = MobileNetV2(**kwargs)
175 | if pretrained:
176 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], model_dir='./models/_pytorch_pretrained_checkpoints/',
177 | progress=progress)
178 | model.load_state_dict(state_dict)
179 | return model
180 |
181 |
182 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.utils import load_state_dict_from_url
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=dilation, groups=groups, bias=False, dilation=dilation)
28 |
29 |
30 | def conv1x1(in_planes, out_planes, stride=1):
31 | """1x1 convolution"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39 | base_width=64, dilation=1, norm_layer=None):
40 | super(BasicBlock, self).__init__()
41 | if norm_layer is None:
42 | norm_layer = nn.BatchNorm2d
43 | if groups != 1 or base_width != 64:
44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45 | if dilation > 1:
46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = norm_layer(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes)
52 | self.bn2 = norm_layer(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | identity = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | identity = self.downsample(x)
68 |
69 | out += identity
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
77 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
78 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
79 | # This variant is also known as ResNet V1.5 and improves accuracy according to
80 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81 |
82 | expansion = 4
83 |
84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85 | base_width=64, dilation=1, norm_layer=None):
86 | super(Bottleneck, self).__init__()
87 | if norm_layer is None:
88 | norm_layer = nn.BatchNorm2d
89 | width = int(planes * (base_width / 64.)) * groups
90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
91 | self.conv1 = conv1x1(inplanes, width)
92 | self.bn1 = norm_layer(width)
93 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
94 | self.bn2 = norm_layer(width)
95 | self.conv3 = conv1x1(width, planes * self.expansion)
96 | self.bn3 = norm_layer(planes * self.expansion)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.downsample = downsample
99 | self.stride = stride
100 |
101 | def forward(self, x):
102 | identity = x
103 |
104 | out = self.conv1(x)
105 | out = self.bn1(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv2(out)
109 | out = self.bn2(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv3(out)
113 | out = self.bn3(out)
114 |
115 | if self.downsample is not None:
116 | identity = self.downsample(x)
117 |
118 | out += identity
119 | out = self.relu(out)
120 |
121 | return out
122 |
123 |
124 | class ResNet(nn.Module):
125 |
126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
128 | norm_layer=None):
129 | super(ResNet, self).__init__()
130 | if norm_layer is None:
131 | norm_layer = nn.BatchNorm2d
132 | self._norm_layer = norm_layer
133 |
134 | self.inplanes = 64
135 | self.dilation = 1
136 | if replace_stride_with_dilation is None:
137 | # each element in the tuple indicates if we should replace
138 | # the 2x2 stride with a dilated convolution instead
139 | replace_stride_with_dilation = [False, False, False]
140 | if len(replace_stride_with_dilation) != 3:
141 | raise ValueError("replace_stride_with_dilation should be None "
142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143 | self.groups = groups
144 | self.base_width = width_per_group
145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
146 | bias=False)
147 | self.bn1 = norm_layer(self.inplanes)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150 | self.layer1 = self._make_layer(block, 64, layers[0])
151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
152 | dilate=replace_stride_with_dilation[0])
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
154 | dilate=replace_stride_with_dilation[1])
155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
156 | dilate=replace_stride_with_dilation[2])
157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
158 | self.fc = nn.Linear(512 * block.expansion, num_classes)
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164 | nn.init.constant_(m.weight, 1)
165 | nn.init.constant_(m.bias, 0)
166 |
167 | # Zero-initialize the last BN in each residual branch,
168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
170 | if zero_init_residual:
171 | for m in self.modules():
172 | if isinstance(m, Bottleneck):
173 | nn.init.constant_(m.bn3.weight, 0)
174 | elif isinstance(m, BasicBlock):
175 | nn.init.constant_(m.bn2.weight, 0)
176 |
177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
178 | norm_layer = self._norm_layer
179 | downsample = None
180 | previous_dilation = self.dilation
181 | if dilate:
182 | self.dilation *= stride
183 | stride = 1
184 | if stride != 1 or self.inplanes != planes * block.expansion:
185 | downsample = nn.Sequential(
186 | conv1x1(self.inplanes, planes * block.expansion, stride),
187 | norm_layer(planes * block.expansion),
188 | )
189 |
190 | layers = []
191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
192 | self.base_width, previous_dilation, norm_layer))
193 | self.inplanes = planes * block.expansion
194 | for _ in range(1, blocks):
195 | layers.append(block(self.inplanes, planes, groups=self.groups,
196 | base_width=self.base_width, dilation=self.dilation,
197 | norm_layer=norm_layer))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def _forward_impl(self, x):
202 | # See note [TorchScript super()]
203 | x = self.conv1(x)
204 | x = self.bn1(x)
205 | x = self.relu(x)
206 | x = self.maxpool(x)
207 |
208 | x = self.layer1(x)
209 | x = self.layer2(x)
210 | x = self.layer3(x)
211 | x = self.layer4(x)
212 |
213 | x = self.avgpool(x)
214 | x = torch.flatten(x, 1)
215 | x = self.fc(x)
216 |
217 | return x
218 |
219 | def forward(self, x):
220 | return self._forward_impl(x)
221 |
222 |
223 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
224 | model = ResNet(block, layers, **kwargs)
225 | if pretrained:
226 | state_dict = load_state_dict_from_url(model_urls[arch], model_dir='./models/_pytorch_pretrained_checkpoints/',
227 | progress=progress)
228 | model.load_state_dict(state_dict)
229 | return model
230 |
231 |
232 | def resnet18(pretrained=False, progress=True, **kwargs):
233 | r"""ResNet-18 model from
234 | `"Deep Residual Learning for Image Recognition" `_
235 |
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | progress (bool): If True, displays a progress bar of the download to stderr
239 | """
240 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
241 | **kwargs)
242 |
243 |
244 | def resnet34(pretrained=False, progress=True, **kwargs):
245 | r"""ResNet-34 model from
246 | `"Deep Residual Learning for Image Recognition" `_
247 |
248 | Args:
249 | pretrained (bool): If True, returns a model pre-trained on ImageNet
250 | progress (bool): If True, displays a progress bar of the download to stderr
251 | """
252 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
253 | **kwargs)
254 |
255 |
256 | def resnet50(pretrained=False, progress=True, **kwargs):
257 | r"""ResNet-50 model from
258 | `"Deep Residual Learning for Image Recognition" `_
259 |
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | progress (bool): If True, displays a progress bar of the download to stderr
263 | """
264 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
265 | **kwargs)
266 |
267 |
268 | def resnet101(pretrained=False, progress=True, **kwargs):
269 | r"""ResNet-101 model from
270 | `"Deep Residual Learning for Image Recognition" `_
271 |
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
277 | **kwargs)
278 |
279 |
280 | def resnet152(pretrained=False, progress=True, **kwargs):
281 | r"""ResNet-152 model from
282 | `"Deep Residual Learning for Image Recognition" `_
283 |
284 | Args:
285 | pretrained (bool): If True, returns a model pre-trained on ImageNet
286 | progress (bool): If True, displays a progress bar of the download to stderr
287 | """
288 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
289 | **kwargs)
290 |
291 |
292 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
293 | r"""ResNeXt-50 32x4d model from
294 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
295 |
296 | Args:
297 | pretrained (bool): If True, returns a model pre-trained on ImageNet
298 | progress (bool): If True, displays a progress bar of the download to stderr
299 | """
300 | kwargs['groups'] = 32
301 | kwargs['width_per_group'] = 4
302 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
303 | pretrained, progress, **kwargs)
304 |
305 |
306 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
307 | r"""ResNeXt-101 32x8d model from
308 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
309 |
310 | Args:
311 | pretrained (bool): If True, returns a model pre-trained on ImageNet
312 | progress (bool): If True, displays a progress bar of the download to stderr
313 | """
314 | kwargs['groups'] = 32
315 | kwargs['width_per_group'] = 8
316 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
317 | pretrained, progress, **kwargs)
318 |
319 |
320 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
321 | r"""Wide ResNet-50-2 model from
322 | `"Wide Residual Networks" `_
323 |
324 | The model is the same as ResNet except for the bottleneck number of channels
325 | which is twice larger in every block. The number of channels in outer 1x1
326 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
327 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
328 |
329 | Args:
330 | pretrained (bool): If True, returns a model pre-trained on ImageNet
331 | progress (bool): If True, displays a progress bar of the download to stderr
332 | """
333 | kwargs['width_per_group'] = 64 * 2
334 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
335 | pretrained, progress, **kwargs)
336 |
337 |
338 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
339 | r"""Wide ResNet-101-2 model from
340 | `"Wide Residual Networks" `_
341 |
342 | The model is the same as ResNet except for the bottleneck number of channels
343 | which is twice larger in every block. The number of channels in outer 1x1
344 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
345 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
346 |
347 | Args:
348 | pretrained (bool): If True, returns a model pre-trained on ImageNet
349 | progress (bool): If True, displays a progress bar of the download to stderr
350 | """
351 | kwargs['width_per_group'] = 64 * 2
352 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
353 | pretrained, progress, **kwargs)
354 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | albumentations==1.1.0
3 | cachetools==5.0.0
4 | certifi==2021.5.30
5 | chardet==4.0.0
6 | charset-normalizer==2.0.12
7 | click==8.0.4
8 | cycler==0.10.0
9 | Cython==0.29.28
10 | docker-pycreds==0.4.0
11 | fonttools==4.30.0
12 | gitdb==4.0.9
13 | GitPython==3.1.27
14 | google-auth==2.6.2
15 | google-auth-oauthlib==0.4.6
16 | grpcio==1.44.0
17 | idna==2.10
18 | imageio==2.16.1
19 | importlib-metadata==4.11.3
20 | joblib==1.1.0
21 | kiwisolver==1.3.1
22 | Markdown==3.3.6
23 | matplotlib==3.5.1
24 | networkx==2.6.3
25 | numpy==1.21.5
26 | oauthlib==3.2.0
27 | opencv-python==4.5.5.64
28 | opencv-python-headless==4.5.5.64
29 | packaging==21.3
30 | pandas==1.3.5
31 | pathtools==0.1.2
32 | Pillow==9.0.1
33 | promise==2.3
34 | protobuf==3.19.4
35 | psutil==5.9.0
36 | pyasn1==0.4.8
37 | pyasn1-modules==0.2.8
38 | pycocotools==2.0.4
39 | pyparsing==2.4.7
40 | python-dateutil==2.8.2
41 | python-dotenv==0.19.2
42 | pytz==2021.3
43 | PyWavelets==1.3.0
44 | PyYAML==6.0
45 | qudida==0.0.4
46 | requests==2.27.1
47 | requests-oauthlib==1.3.1
48 | roboflow==0.2.2
49 | rsa==4.8
50 | scikit-image==0.19.2
51 | scikit-learn==1.0.2
52 | scipy==1.7.3
53 | seaborn==0.11.2
54 | sentry-sdk==1.5.8
55 | setproctitle==1.2.2
56 | shortuuid==1.0.8
57 | six==1.16.0
58 | smmap==5.0.0
59 | tensorboard==2.8.0
60 | tensorboard-data-server==0.6.1
61 | tensorboard-plugin-wit==1.8.1
62 | termcolor==1.1.0
63 | thop==0.0.31.post2005241907
64 | threadpoolctl==3.1.0
65 | tifffile==2021.11.2
66 | torch @ file:///UsrFile/ybn/zx/pytorch_wheels/torch1.9/torch-1.9.0%2Bcu102-cp37-cp37m-linux_x86_64.whl
67 | torchvision @ file:///UsrFile/ybn/zx/pytorch_wheels/torch1.9/torchvision-0.10.0%2Bcu102-cp37-cp37m-linux_x86_64.whl
68 | tqdm==4.63.0
69 | typing_extensions==4.1.1
70 | urllib3==1.26.6
71 | wandb==0.12.11
72 | Werkzeug==2.0.3
73 | wget==3.2
74 | yaspin==2.1.0
75 | zipp==3.7.0
76 |
--------------------------------------------------------------------------------