├── README.md ├── main.py ├── model.py └── utils ├── bridge.py ├── bridge_split.py ├── config.py ├── former.py ├── mobile.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Simple-implementation-of-Mobile-Former 2 | 3 | At present, only the model but no trained. There may be some bug in the code, and some details may be different from the original paper, if you are interested in this, welcome to discuss. 4 | 5 | Add: CutUp,MixUp,RandomErasing,SyncBatchNorm for DDP train 6 | 7 | There are tow way for qkv aline in new code,A: Split token dim into heads(N); B: Broadcast x while product(Y) 8 | 9 | Add: Make model by config(mf52, mf294, mf508) in config.py, the number of parameters almost same with paper 10 | 11 | Train:python main.py --name mf294 --data path/to/ImageNet --dist-url 'tcp://127.0.0.1:12345' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 --batch-size 256 12 | 13 | Inference: 14 | 15 | paper:https://arxiv.org/pdf/2108.05895.pdf 16 | 17 | https://github.com/xiaolai-sqlai/mobilenetv3 18 | 19 | https://github.com/lucidrains/vit-pytorch 20 | 21 | https://github.com/Islanna/DynamicReLU 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.multiprocessing as mp 16 | import torch.utils.data 17 | import torch.utils.data.distributed 18 | import torchvision.transforms as transforms 19 | import torchvision.datasets as datasets 20 | 21 | from utils.utils import cutmix, cutmix_criterion 22 | from utils.config import config 23 | from model import MobileFormer 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('--name', default='mf294', type=str, 27 | help='model name') 28 | parser.add_argument('--data', metavar='DIR', 29 | help='path to dataset') 30 | parser.add_argument('--num_cls', default=1000, type=int, 31 | help='number of classes') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | metavar='LR', help='initial learning rate', dest='lr') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)', 49 | dest='weight_decay') 50 | parser.add_argument('-p', '--print-freq', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 57 | help='use pre-trained model') 58 | parser.add_argument('--world-size', default=-1, type=int, 59 | help='number of nodes for distributed training') 60 | parser.add_argument('--rank', default=-1, type=int, 61 | help='node rank for distributed training') 62 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:55554', type=str, 63 | help='url used to set up distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', type=str, 65 | help='distributed backend') 66 | parser.add_argument('--seed', default=None, type=int, 67 | help='seed for initializing training. ') 68 | parser.add_argument('--gpu', default=None, type=int, 69 | help='GPU id to use.') 70 | parser.add_argument('--multiprocessing-distributed', action='store_true', 71 | help='Use multi-processing distributed training to launch ' 72 | 'N processes per node, which has N GPUs. This is the ' 73 | 'fastest way to use PyTorch for either single node or ' 74 | 'multi node data parallel training') 75 | parser.add_argument('--cutmix', action='store_true', 76 | help='Use cutmix data augument') 77 | parser.add_argument('--cutmix-prob', default=0.5, type=float, 78 | help='cutmix probility') 79 | parser.add_argument('--beta', default=1.0, type=float) 80 | 81 | best_acc1 = 0 82 | 83 | 84 | def main(): 85 | args = parser.parse_args() 86 | 87 | if args.seed is not None: 88 | random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | cudnn.deterministic = True 91 | warnings.warn('You have chosen to seed training. ' 92 | 'This will turn on the CUDNN deterministic setting, ' 93 | 'which can slow down your training considerably! ' 94 | 'You may see unexpected behavior when restarting ' 95 | 'from checkpoints.') 96 | 97 | if args.gpu is not None: 98 | warnings.warn('You have chosen a specific GPU. This will completely ' 99 | 'disable data parallelism.') 100 | 101 | if args.dist_url == "env://" and args.world_size == -1: 102 | args.world_size = int(os.environ["WORLD_SIZE"]) 103 | 104 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 105 | 106 | ngpus_per_node = torch.cuda.device_count() 107 | print('n_per_node:', ngpus_per_node) 108 | if args.multiprocessing_distributed: 109 | # Since we have ngpus_per_node processes per node, the total world_size 110 | # needs to be adjusted accordingly 111 | args.world_size = ngpus_per_node * args.world_size 112 | print('world_size:', args.world_size) 113 | # Use torch.multiprocessing.spawn to launch distributed processes: the 114 | # main_worker process function 115 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 116 | else: 117 | # Simply call main_worker function 118 | main_worker(args.gpu, ngpus_per_node, args) 119 | 120 | 121 | def main_worker(gpu, ngpus_per_node, args): 122 | global best_acc1 123 | print('gpu', gpu) 124 | args.gpu = gpu 125 | 126 | if args.gpu is not None: 127 | print("Use GPU: {} for training".format(args.gpu)) 128 | 129 | if args.distributed: 130 | if args.dist_url == "env://" and args.rank == -1: 131 | args.rank = int(os.environ["RANK"]) 132 | if args.multiprocessing_distributed: 133 | # For multiprocessing distributed training, rank needs to be the 134 | # global rank among all the processes 135 | args.rank = args.rank * ngpus_per_node + gpu 136 | print('rank:', args.rank) 137 | print('init ...', args.dist_url) 138 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 139 | world_size=args.world_size, rank=args.rank) 140 | 141 | # create model 142 | assert args.name in ['mf52', 'mf294', 'mf508'] 143 | print('create model {}'.format(args.name)) 144 | cfg = config[args.name] 145 | model = MobileFormer(cfg) 146 | 147 | if not torch.cuda.is_available(): 148 | print('using CPU, this will be slow') 149 | elif args.distributed: 150 | print('ddp mode') 151 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 152 | # For multiprocessing distributed, DistributedDataParallel constructor 153 | # should always set the single device scope, otherwise, 154 | # DistributedDataParallel will use all available devices. 155 | if args.gpu is not None: 156 | torch.cuda.set_device(args.gpu) 157 | model.cuda(args.gpu) 158 | # When using a single GPU per process and per 159 | # DistributedDataParallel, we need to divide the batch size 160 | # ourselves based on the total number of GPUs we have 161 | args.batch_size = int(args.batch_size / ngpus_per_node) 162 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 163 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 164 | else: 165 | model.cuda() 166 | # DistributedDataParallel will divide and allocate batch_size to all 167 | # available GPUs if device_ids are not set 168 | model = torch.nn.parallel.DistributedDataParallel(model) 169 | elif args.gpu is not None: 170 | torch.cuda.set_device(args.gpu) 171 | model = model.cuda(args.gpu) 172 | else: 173 | # DataParallel will divide and allocate batch_size to all available GPUs 174 | # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 175 | # model.features = torch.nn.DataParallel(model.features) 176 | # model.cuda() 177 | # else: 178 | model = torch.nn.DataParallel(model).cuda() 179 | 180 | # define loss function (criterion) and optimizer 181 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 182 | 183 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 184 | momentum=args.momentum, 185 | weight_decay=args.weight_decay) 186 | 187 | # optionally resume from a checkpoint 188 | if args.resume: 189 | if os.path.isfile(args.resume): 190 | print("=> loading checkpoint '{}'".format(args.resume)) 191 | if args.gpu is None: 192 | checkpoint = torch.load(args.resume) 193 | else: 194 | # Map model to be loaded to specified single gpu. 195 | loc = 'cuda:{}'.format(args.gpu) 196 | checkpoint = torch.load(args.resume, map_location=loc) 197 | args.start_epoch = checkpoint['epoch'] 198 | best_acc1 = checkpoint['best_acc1'] 199 | if args.gpu is not None: 200 | # best_acc1 may be from a checkpoint from a different GPU 201 | best_acc1 = best_acc1.to(args.gpu) 202 | model.load_state_dict(checkpoint['state_dict']) 203 | optimizer.load_state_dict(checkpoint['optimizer']) 204 | print("=> loaded checkpoint '{}' (epoch {})" 205 | .format(args.resume, checkpoint['epoch'])) 206 | else: 207 | print("=> no checkpoint found at '{}'".format(args.resume)) 208 | 209 | cudnn.benchmark = True 210 | 211 | # Data loading code 212 | traindir = os.path.join(args.data, 'train') 213 | valdir = os.path.join(args.data, 'valid') 214 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 215 | std=[0.229, 0.224, 0.225]) 216 | 217 | train_dataset = datasets.ImageFolder( 218 | traindir, 219 | transforms.Compose([ 220 | transforms.RandomResizedCrop(224), 221 | transforms.RandomHorizontalFlip(), 222 | transforms.ToTensor(), 223 | normalize, 224 | ])) 225 | 226 | if args.distributed: 227 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 228 | else: 229 | train_sampler = None 230 | 231 | train_loader = torch.utils.data.DataLoader( 232 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 233 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 234 | 235 | val_loader = torch.utils.data.DataLoader( 236 | datasets.ImageFolder(valdir, transforms.Compose([ 237 | transforms.Resize(256), 238 | transforms.CenterCrop(224), 239 | transforms.ToTensor(), 240 | normalize, 241 | ])), 242 | batch_size=args.batch_size, shuffle=False, 243 | num_workers=args.workers, pin_memory=True) 244 | 245 | if args.evaluate: 246 | validate(val_loader, model, criterion, args) 247 | return 248 | 249 | for epoch in range(args.start_epoch, args.epochs): 250 | if args.distributed: 251 | train_sampler.set_epoch(epoch) 252 | adjust_learning_rate(optimizer, epoch, args) 253 | 254 | # train for one epoch 255 | train(train_loader, model, criterion, optimizer, epoch, args) 256 | 257 | # evaluate on validation set 258 | acc1 = validate(val_loader, model, criterion, args) 259 | 260 | # remember best acc@1 and save checkpoint 261 | is_best = acc1 > best_acc1 262 | best_acc1 = max(acc1, best_acc1) 263 | 264 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 265 | and args.rank % ngpus_per_node == 0): 266 | save_checkpoint({ 267 | 'epoch': epoch + 1, 268 | # 'arch': args.arch, 269 | 'state_dict': model.state_dict(), 270 | 'best_acc1': best_acc1, 271 | 'optimizer': optimizer.state_dict(), 272 | }, is_best) 273 | 274 | 275 | def train(train_loader, model, criterion, optimizer, epoch, args): 276 | batch_time = AverageMeter('Time', ':6.3f') 277 | data_time = AverageMeter('Data', ':6.3f') 278 | losses = AverageMeter('Loss', ':.4e') 279 | top1 = AverageMeter('Acc@1', ':6.2f') 280 | top5 = AverageMeter('Acc@5', ':6.2f') 281 | progress = ProgressMeter( 282 | len(train_loader), 283 | [batch_time, data_time, losses, top1, top5], 284 | prefix="Epoch: [{}]".format(epoch)) 285 | 286 | # switch to train mode 287 | model.train() 288 | 289 | end = time.time() 290 | for i, (images, target) in enumerate(train_loader): 291 | # measure data loading time 292 | data_time.update(time.time() - end) 293 | 294 | if args.gpu is not None: 295 | images = images.cuda(args.gpu, non_blocking=True) 296 | if torch.cuda.is_available(): 297 | target = target.cuda(args.gpu, non_blocking=True) 298 | 299 | if args.cutmix and np.random.rand(1) < args.cutmix_prob: 300 | images, target_a, target_b, lam = cutmix(images, target, args.beta) 301 | output = model(images) 302 | loss = cutmix_criterion(criterion, output, target_a, target_b, lam) 303 | else: 304 | # compute output 305 | output = model(images) 306 | loss = criterion(output, target) 307 | 308 | # measure accuracy and record loss 309 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 310 | losses.update(loss.item(), images.size(0)) 311 | top1.update(acc1[0], images.size(0)) 312 | top5.update(acc5[0], images.size(0)) 313 | 314 | # compute gradient and do SGD step 315 | optimizer.zero_grad() 316 | loss.backward() 317 | optimizer.step() 318 | 319 | # measure elapsed time 320 | batch_time.update(time.time() - end) 321 | end = time.time() 322 | 323 | if i % args.print_freq == 0: 324 | progress.display(i) 325 | 326 | 327 | def validate(val_loader, model, criterion, args): 328 | batch_time = AverageMeter('Time', ':6.3f') 329 | losses = AverageMeter('Loss', ':.4e') 330 | top1 = AverageMeter('Acc@1', ':6.2f') 331 | top5 = AverageMeter('Acc@5', ':6.2f') 332 | progress = ProgressMeter( 333 | len(val_loader), 334 | [batch_time, losses, top1, top5], 335 | prefix='Test: ') 336 | 337 | # switch to evaluate mode 338 | model.eval() 339 | 340 | with torch.no_grad(): 341 | end = time.time() 342 | for i, (images, target) in enumerate(val_loader): 343 | if args.gpu is not None: 344 | images = images.cuda(args.gpu, non_blocking=True) 345 | if torch.cuda.is_available(): 346 | target = target.cuda(args.gpu, non_blocking=True) 347 | 348 | # compute output 349 | output = model(images) 350 | loss = criterion(output, target) 351 | 352 | # measure accuracy and record loss 353 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 354 | losses.update(loss.item(), images.size(0)) 355 | top1.update(acc1[0], images.size(0)) 356 | top5.update(acc5[0], images.size(0)) 357 | 358 | # measure elapsed time 359 | batch_time.update(time.time() - end) 360 | end = time.time() 361 | 362 | if i % args.print_freq == 0: 363 | progress.display(i) 364 | 365 | # TODO: this should also be done with the ProgressMeter 366 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 367 | .format(top1=top1, top5=top5)) 368 | 369 | return top1.avg 370 | 371 | 372 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 373 | torch.save(state, filename) 374 | if is_best: 375 | shutil.copyfile(filename, 'model_best.pth.tar') 376 | 377 | 378 | class AverageMeter(object): 379 | """Computes and stores the average and current value""" 380 | 381 | def __init__(self, name, fmt=':f'): 382 | self.name = name 383 | self.fmt = fmt 384 | self.reset() 385 | 386 | def reset(self): 387 | self.val = 0 388 | self.avg = 0 389 | self.sum = 0 390 | self.count = 0 391 | 392 | def update(self, val, n=1): 393 | self.val = val 394 | self.sum += val * n 395 | self.count += n 396 | self.avg = self.sum / self.count 397 | 398 | def __str__(self): 399 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 400 | return fmtstr.format(**self.__dict__) 401 | 402 | 403 | class ProgressMeter(object): 404 | def __init__(self, num_batches, meters, prefix=""): 405 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 406 | self.meters = meters 407 | self.prefix = prefix 408 | 409 | def display(self, batch): 410 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 411 | entries += [str(meter) for meter in self.meters] 412 | print('\t'.join(entries)) 413 | 414 | def _get_batch_fmtstr(self, num_batches): 415 | num_digits = len(str(num_batches // 1)) 416 | fmt = '{:' + str(num_digits) + 'd}' 417 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 418 | 419 | 420 | def adjust_learning_rate(optimizer, epoch, args): 421 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 422 | lr = args.lr * (0.1 ** (epoch // 30)) 423 | print('lr', lr) 424 | for param_group in optimizer.param_groups: 425 | param_group['lr'] = lr 426 | 427 | 428 | def accuracy(output, target, topk=(1,)): 429 | """Computes the accuracy over the k top predictions for the specified values of k""" 430 | with torch.no_grad(): 431 | maxk = max(topk) 432 | batch_size = target.size(0) 433 | 434 | _, pred = output.topk(maxk, 1, True, True) 435 | pred = pred.t() 436 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 437 | 438 | res = [] 439 | for k in topk: 440 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 441 | res.append(correct_k.mul_(100.0 / batch_size)) 442 | return res 443 | 444 | 445 | if __name__ == '__main__': 446 | main() 447 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.nn import init 6 | from utils.mobile import Mobile, hswish, MobileDown 7 | from utils.former import Former 8 | from utils.bridge import Mobile2Former, Former2Mobile 9 | from utils.config import config_294, config_508, config_52 10 | 11 | class BaseBlock(nn.Module): 12 | def __init__(self, inp, exp, out, se, stride, heads, dim): 13 | super(BaseBlock, self).__init__() 14 | if stride == 2: 15 | self.mobile = MobileDown(3, inp, exp, out, se, stride, dim) 16 | else: 17 | self.mobile = Mobile(3, inp, exp, out, se, stride, dim) 18 | self.mobile2former = Mobile2Former(dim=dim, heads=heads, channel=inp) 19 | self.former = Former(dim=dim) 20 | self.former2mobile = Former2Mobile(dim=dim, heads=heads, channel=out) 21 | 22 | def forward(self, inputs): 23 | x, z = inputs 24 | z_hid = self.mobile2former(x, z) 25 | z_out = self.former(z_hid) 26 | x_hid = self.mobile(x, z_out) 27 | x_out = self.former2mobile(x_hid, z_out) 28 | return [x_out, z_out] 29 | 30 | 31 | class MobileFormer(nn.Module): 32 | def __init__(self, cfg): 33 | super(MobileFormer, self).__init__() 34 | self.token = nn.Parameter(nn.Parameter(torch.randn(1, cfg['token'], cfg['embed']))) 35 | # stem 3 224 224 -> 16 112 112 36 | self.stem = nn.Sequential( 37 | nn.Conv2d(3, cfg['stem'], kernel_size=3, stride=2, padding=1, bias=False), 38 | nn.BatchNorm2d(cfg['stem']), 39 | hswish(), 40 | ) 41 | # bneck 42 | self.bneck = nn.Sequential( 43 | nn.Conv2d(cfg['stem'], cfg['bneck']['e'], 3, stride=cfg['bneck']['s'], padding=1, groups=cfg['stem']), 44 | hswish(), 45 | nn.Conv2d(cfg['bneck']['e'], cfg['bneck']['o'], kernel_size=1, stride=1), 46 | nn.BatchNorm2d(cfg['bneck']['o']) 47 | ) 48 | 49 | # body 50 | self.block = nn.ModuleList() 51 | for kwargs in cfg['body']: 52 | self.block.append(BaseBlock(**kwargs, dim=cfg['embed'])) 53 | inp = cfg['body'][-1]['out'] 54 | exp = cfg['body'][-1]['exp'] 55 | self.conv = nn.Conv2d(inp, exp, kernel_size=1, stride=1, padding=0, bias=False) 56 | self.bn = nn.BatchNorm2d(exp) 57 | self.avg = nn.AvgPool2d((7, 7)) 58 | self.head = nn.Sequential( 59 | nn.Linear(exp + cfg['embed'], cfg['fc1']), 60 | hswish(), 61 | nn.Linear(cfg['fc1'], cfg['fc2']) 62 | ) 63 | self.init_params() 64 | 65 | def init_params(self): 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | init.kaiming_normal_(m.weight, mode='fan_out') 69 | if m.bias is not None: 70 | init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | init.constant_(m.weight, 1) 73 | init.constant_(m.bias, 0) 74 | elif isinstance(m, nn.Linear): 75 | init.normal_(m.weight, std=0.001) 76 | if m.bias is not None: 77 | init.constant_(m.bias, 0) 78 | 79 | def forward(self, x): 80 | b, _, _, _ = x.shape 81 | z = self.token.repeat(b, 1, 1) 82 | x = self.bneck(self.stem(x)) 83 | for m in self.block: 84 | x, z = m([x, z]) 85 | # x, z = self.block([x, z]) 86 | x = self.avg(self.bn(self.conv(x))).view(b, -1) 87 | z = z[:, 0, :].view(b, -1) 88 | out = torch.cat((x, z), -1) 89 | return self.head(out) 90 | # return x, z 91 | 92 | 93 | if __name__ == "__main__": 94 | model = MobileFormer(config_52) 95 | inputs = torch.randn((3, 3, 224, 224)) 96 | print(inputs.shape) 97 | # for i in range(100): 98 | # t = time.time() 99 | # output = model(inputs) 100 | # print(time.time() - t) 101 | print("Total number of parameters in networks is {} M".format(sum(x.numel() for x in model.parameters()) / 1e6)) 102 | output = model(inputs) 103 | print(output.shape) 104 | -------------------------------------------------------------------------------- /utils/bridge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | # inputs: x(b c h w) z(b m d) 7 | # output: z(b m d) 8 | class Mobile2Former(nn.Module): 9 | def __init__(self, dim, heads, channel, dropout=0.): 10 | super(Mobile2Former, self).__init__() 11 | inner_dim = heads * channel 12 | self.heads = heads 13 | self.to_q = nn.Linear(dim, inner_dim) 14 | self.attend = nn.Softmax(dim=-1) 15 | self.scale = channel ** -0.5 16 | self.to_out = nn.Sequential( 17 | nn.Linear(inner_dim, dim), 18 | nn.Dropout(dropout) 19 | ) 20 | 21 | def forward(self, x, z): 22 | b, m, d = z.shape 23 | b, c, h, w = x.shape 24 | x = x.reshape(b, c, h*w).transpose(1,2).unsqueeze(1) 25 | q = self.to_q(z).view(b, self.heads, m, c) 26 | dots = q @ x.transpose(2, 3) * self.scale 27 | attn = self.attend(dots) 28 | out = attn @ x 29 | out = rearrange(out, 'b h m c -> b m (h c)') 30 | return z + self.to_out(out) 31 | 32 | 33 | # inputs: x(b c h w) z(b m d) 34 | # output: x(b c h w) 35 | class Former2Mobile(nn.Module): 36 | def __init__(self, dim, heads, channel, dropout=0.): 37 | super(Former2Mobile, self).__init__() 38 | inner_dim = heads * channel 39 | self.heads = heads 40 | self.to_k = nn.Linear(dim, inner_dim) 41 | self.to_v = nn.Linear(dim, inner_dim) 42 | self.attend = nn.Softmax(dim=-1) 43 | self.scale = channel ** -0.5 44 | 45 | self.to_out = nn.Sequential( 46 | nn.Linear(inner_dim, channel), 47 | nn.Dropout(dropout) 48 | ) 49 | 50 | def forward(self, x, z): 51 | b, m, d = z.shape 52 | b, c, h, w = x.shape 53 | q = x.reshape(b, c, h*w).transpose(1,2).unsqueeze(1) 54 | k = self.to_k(z).view(b, self.heads, m, c) 55 | v = self.to_v(z).view(b, self.heads, m, c) 56 | dots = q @ k.transpose(2, 3) * self.scale 57 | attn = self.attend(dots) 58 | out = attn @ v 59 | out = rearrange(out, 'b h l c -> b l (h c)') 60 | out = self.to_out(out) 61 | out = out.view(b, c, h, w) 62 | return x + out 63 | -------------------------------------------------------------------------------- /utils/bridge_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | # inputs: x(b c h w) z(b m d) 7 | # output: z(b m d) 8 | class Mobile2Former(nn.Module): 9 | def __init__(self, dim, heads, c, dropout=0.): 10 | super(Mobile2Former, self).__init__() 11 | inner_dim = c 12 | dim_head = c // heads 13 | self.heads = heads 14 | self.to_q = nn.Linear(dim, inner_dim) 15 | self.attend = nn.Softmax(dim=-1) 16 | self.scale = dim_head ** -0.5 17 | self.to_out = nn.Sequential( 18 | nn.Linear(inner_dim, dim), 19 | nn.Dropout(dropout) 20 | ) 21 | # self.to_out = nn.Identity() 22 | 23 | def forward(self, x, z): 24 | b, m, d = z.shape 25 | b, c, h, w = x.shape 26 | # b l c -> b l h*c -> b h l c 27 | x = x.contiguous().view(b, h * w, c) 28 | x = rearrange(x, 'b (h i) c -> b h i c', h=self.heads) 29 | k, v = x, x 30 | # b m d -> b m c 31 | q = self.to_q(z) 32 | q = rearrange(q, 'b (h j) c -> b h j c', h=self.heads) 33 | dots = einsum('b h j c, b h i c -> b h j i', q, k) * self.scale 34 | # b h j i 35 | attn = self.attend(dots) 36 | out = einsum('b h j i, b h i c -> b h j c', attn, v) 37 | out = rearrange(out, 'b h j c -> b (h j) c') 38 | return z + self.to_out(out) 39 | 40 | 41 | # inputs: x(b c h w) z(b m d) 42 | # output: x(b c h w) 43 | class Former2Mobile(nn.Module): 44 | def __init__(self, dim, heads, c, dropout=0.): 45 | super(Former2Mobile, self).__init__() 46 | inner_dim = c 47 | dim_head = c // heads 48 | self.heads = heads 49 | self.to_k = nn.Linear(dim, inner_dim) 50 | self.to_v = nn.Linear(dim, inner_dim) 51 | self.attend = nn.Softmax(dim=-1) 52 | self.scale = dim_head ** -0.5 53 | 54 | # self.to_out = nn.Sequential( 55 | # nn.Linear(inner_dim, c), 56 | # nn.Dropout(dropout) 57 | # ) 58 | self.to_out = nn.Identity() 59 | 60 | def forward(self, x, z): 61 | b, m, d = z.shape 62 | b, c, h, w = x.shape 63 | x_ = x.contiguous().view(b, h * w, c) 64 | x_ = rearrange(x_, 'b (h i) c -> b h i c', h=self.heads) 65 | q = x_ 66 | # b m c 67 | k = self.to_k(z) 68 | v = self.to_v(z) 69 | k = rearrange(k, 'b (h j) c -> b h j c', h=self.heads) 70 | v = rearrange(v, 'b (h j) c -> b h j c', h=self.heads) 71 | 72 | # b h l m 73 | dots = einsum('b h i c, b h j c -> b h i j', q, k) * self.scale 74 | # b h l m 75 | attn = self.attend(dots) 76 | out = einsum('b h i j, b h j c -> b h i c', attn, v) 77 | out = rearrange(out, 'b h i c -> b (h i) c') 78 | out = self.to_out(out) 79 | out = out.view(b, c, h, w) 80 | return x + out 81 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | config_52 = { 2 | 'name': 'mf52', 3 | 'token': 3, # num tokens 4 | 'embed': 128, # embed dim 5 | 'stem': 8, 6 | 'bneck': {'e': 24, 'o': 12, 's': 2}, # exp out stride 7 | 'body': [ 8 | # stage2 9 | {'inp': 12, 'exp': 36, 'out': 12, 'se': None, 'stride': 1, 'heads': 2}, 10 | # stage3 11 | {'inp': 12, 'exp': 72, 'out': 24, 'se': None, 'stride': 2, 'heads': 2}, 12 | {'inp': 24, 'exp': 72, 'out': 24, 'se': None, 'stride': 1, 'heads': 2}, 13 | # stage4 14 | {'inp': 24, 'exp': 144, 'out': 48, 'se': None, 'stride': 2, 'heads': 2}, 15 | {'inp': 48, 'exp': 192, 'out': 48, 'se': None, 'stride': 1, 'heads': 2}, 16 | {'inp': 48, 'exp': 288, 'out': 64, 'se': None, 'stride': 1, 'heads': 2}, 17 | # stage5 18 | {'inp': 64, 'exp': 384, 'out': 96, 'se': None, 'stride': 2, 'heads': 2}, 19 | {'inp': 96, 'exp': 576, 'out': 96, 'se': None, 'stride': 1, 'heads': 2}, 20 | ], 21 | 'fc1': 1024, # hid_layer 22 | 'fc2': 1000 # num_clasess 23 | , 24 | } 25 | 26 | config_294 = { 27 | 'name': 'mf294', 28 | 'token': 6, # tokens 29 | 'embed': 192, # embed_dim 30 | 'stem': 16, 31 | # stage1 32 | 'bneck': {'e': 32, 'o': 16, 's': 1}, # exp out stride 33 | 'body': [ 34 | # stage2 35 | {'inp': 16, 'exp': 96, 'out': 24, 'se': None, 'stride': 2, 'heads': 2}, 36 | {'inp': 24, 'exp': 96, 'out': 24, 'se': None, 'stride': 1, 'heads': 2}, 37 | # stage3 38 | {'inp': 24, 'exp': 144, 'out': 48, 'se': None, 'stride': 2, 'heads': 2}, 39 | {'inp': 48, 'exp': 192, 'out': 48, 'se': None, 'stride': 1, 'heads': 2}, 40 | # stage4 41 | {'inp': 48, 'exp': 288, 'out': 96, 'se': None, 'stride': 2, 'heads': 2}, 42 | {'inp': 96, 'exp': 384, 'out': 96, 'se': None, 'stride': 1, 'heads': 2}, 43 | {'inp': 96, 'exp': 576, 'out': 128, 'se': None, 'stride': 1, 'heads': 2}, 44 | {'inp': 128, 'exp': 768, 'out': 128, 'se': None, 'stride': 1, 'heads': 2}, 45 | # stage5 46 | {'inp': 128, 'exp': 768, 'out': 192, 'se': None, 'stride': 2, 'heads': 2}, 47 | {'inp': 192, 'exp': 1152, 'out': 192, 'se': None, 'stride': 1, 'heads': 2}, 48 | {'inp': 192, 'exp': 1152, 'out': 192, 'se': None, 'stride': 1, 'heads': 2}, 49 | ], 50 | 'fc1': 1920, # hid_layer 51 | 'fc2': 1000 # num_clasess 52 | , 53 | } 54 | 55 | config_508 = { 56 | 'name': 'mf508', 57 | 'token': 6, # tokens and embed_dim 58 | 'embed': 192, 59 | 'stem': 24, 60 | 'bneck': {'e': 48, 'o': 24, 's': 1}, 61 | 'body': [ 62 | {'inp': 24, 'exp': 144, 'out': 40, 'se': None, 'stride': 2, 'heads': 2}, 63 | {'inp': 40, 'exp': 120, 'out': 40, 'se': None, 'stride': 1, 'heads': 2}, 64 | 65 | {'inp': 40, 'exp': 240, 'out': 72, 'se': None, 'stride': 2, 'heads': 2}, 66 | {'inp': 72, 'exp': 216, 'out': 72, 'se': None, 'stride': 1, 'heads': 2}, 67 | 68 | {'inp': 72, 'exp': 432, 'out': 128, 'se': None, 'stride': 2, 'heads': 2}, 69 | {'inp': 128, 'exp': 512, 'out': 128, 'se': None, 'stride': 1, 'heads': 2}, 70 | {'inp': 128, 'exp': 768, 'out': 176, 'se': None, 'stride': 1, 'heads': 2}, 71 | {'inp': 176, 'exp': 1056, 'out': 176, 'se': None, 'stride': 1, 'heads': 2}, 72 | 73 | {'inp': 176, 'exp': 1056, 'out': 240, 'se': None, 'stride': 2, 'heads': 2}, 74 | {'inp': 240, 'exp': 1440, 'out': 240, 'se': None, 'stride': 1, 'heads': 2}, 75 | {'inp': 240, 'exp': 1440, 'out': 240, 'se': None, 'stride': 1, 'heads': 2}, 76 | ], 77 | 'fc1': 1920, # hid_layer 78 | 'fc2': 1000 # num_clasess 79 | , 80 | } 81 | 82 | config = { 83 | 'mf52': config_52, 84 | 'mf294': config_294, 85 | 'mf508': config_508 86 | } 87 | -------------------------------------------------------------------------------- /utils/former.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | 6 | 7 | def pair(t): 8 | return t if isinstance(t, tuple) else (t, t) 9 | 10 | 11 | class PreNorm(nn.Module): 12 | def __init__(self, dim, fn): 13 | super(PreNorm, self).__init__() 14 | self.norm = nn.LayerNorm(dim) 15 | self.fn = fn 16 | 17 | def forward(self, x, **kwargs): 18 | return self.fn(self.norm(x), **kwargs) 19 | 20 | 21 | class FeedForward(nn.Module): 22 | def __init__(self, dim, hidden_dim, dropout=0.): 23 | super(FeedForward, self).__init__() 24 | self.net = nn.Sequential( 25 | nn.Linear(dim, hidden_dim), 26 | nn.GELU(), 27 | nn.Dropout(dropout), 28 | nn.Linear(hidden_dim, dim), 29 | nn.Dropout(dropout) 30 | ) 31 | 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 38 | super(Attention, self).__init__() 39 | inner_dim = heads * dim_head # head数量和每个head的维度 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.attend = nn.Softmax(dim=-1) 46 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 47 | 48 | self.to_out = nn.Sequential( 49 | nn.Linear(inner_dim, dim), 50 | nn.Dropout(dropout) 51 | ) if project_out else nn.Identity() 52 | 53 | def forward(self, x): # 2,65,1024 batch,patch+cls_token,dim (每个patch相当于一个token) 54 | b, n, _, h = *x.shape, self.heads 55 | # 输入x每个token的维度为1024,在注意力中token被映射16个64维的特征(head*dim_head), 56 | # 最后再把所有head的特征合并为一个(16*1024)的特征,作为每个token的输出 57 | qkv = self.to_qkv(x).chunk(3, dim=-1) # 2,65,1024 -> 2,65,1024*3 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), 59 | qkv) # 2,65,(16*64) -> 2,16,65,64 ,16个head,每个head维度64 60 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # b,16,65,64 @ b,16,64*65 -> b,16,65,65 : q@k.T 61 | attn = self.attend(dots) # 注意力 2,16,65,65 16个head,注意力map尺寸65*65,对应token(patch)[i,j]之间的注意力 62 | # 每个token经过每个head的attention后的输出 63 | out = einsum('b h i j, b h j d -> b h i d', attn, v) # atten@v 2,16,65,65 @ 2,16,65,64 -> 2,16,65,64 64 | out = rearrange(out, 'b h n d -> b n (h d)') # 合并所有head的输出(16*64) -> 1024 得到每个token当前的特征 65 | return self.to_out(out) 66 | 67 | 68 | # inputs: n L C 69 | # output: n L C 70 | class Former(nn.Module): 71 | def __init__(self, dim, depth=1, heads=2, dim_head=32, dropout=0.3): 72 | super(Former, self).__init__() 73 | mlp_dim = dim * 2 74 | self.layers = nn.ModuleList([]) 75 | # dim_head = dim // heads 76 | for _ in range(depth): 77 | self.layers.append(nn.ModuleList([ 78 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 79 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 80 | ])) 81 | 82 | def forward(self, x): 83 | for attn, ff in self.layers: 84 | x = attn(x) + x 85 | x = ff(x) + x 86 | return x 87 | -------------------------------------------------------------------------------- /utils/mobile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.utils import MyDyRelu 5 | from torch.nn import init 6 | 7 | 8 | class hswish(nn.Module): 9 | def forward(self, x): 10 | out = x * F.relu6(x + 3, inplace=True) / 6 11 | return out 12 | 13 | 14 | class hsigmoid(nn.Module): 15 | def forward(self, x): 16 | out = F.relu6(x + 3, inplace=True) / 6 17 | return out 18 | 19 | 20 | class SeModule(nn.Module): 21 | def __init__(self, inp, reduction=4): 22 | super(SeModule, self).__init__() 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 24 | self.se = nn.Sequential( 25 | nn.Linear(inp, inp // reduction, bias=False), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(inp // reduction, inp, bias=False), 28 | hsigmoid() 29 | ) 30 | 31 | def forward(self, x): 32 | se = self.avg_pool(x) 33 | b, c, _, _ = se.size() 34 | se = se.view(b, c) 35 | se = self.se(se).view(b, c, 1, 1) 36 | return x * se.expand_as(x) 37 | 38 | 39 | class Mobile(nn.Module): 40 | def __init__(self, ks, inp, hid, out, se, stride, dim, reduction=4, k=2): 41 | super(Mobile, self).__init__() 42 | self.hid = hid 43 | self.k = k 44 | self.fc1 = nn.Linear(dim, dim // reduction) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.fc2 = nn.Linear(dim // reduction, 2 * k * hid) 47 | self.sigmoid = nn.Sigmoid() 48 | 49 | self.register_buffer('lambdas', torch.Tensor([1.] * k + [0.5] * k).float()) 50 | self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * k - 1)).float()) 51 | self.stride = stride 52 | # self.se = DyReLUB(channels=out, k=1) if dyrelu else se 53 | self.se = se 54 | 55 | self.conv1 = nn.Conv2d(inp, hid, kernel_size=1, stride=1, padding=0, bias=False) 56 | self.bn1 = nn.BatchNorm2d(hid) 57 | self.act1 = MyDyRelu(2) 58 | 59 | self.conv2 = nn.Conv2d(hid, hid, kernel_size=ks, stride=stride, 60 | padding=ks // 2, groups=hid, bias=False) 61 | self.bn2 = nn.BatchNorm2d(hid) 62 | self.act2 = MyDyRelu(2) 63 | 64 | self.conv3 = nn.Conv2d(hid, out, kernel_size=1, stride=1, padding=0, bias=False) 65 | self.bn3 = nn.BatchNorm2d(out) 66 | 67 | self.shortcut = nn.Identity() 68 | if stride == 1 and inp != out: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(inp, out, kernel_size=1, stride=1, padding=0, bias=False), 71 | nn.BatchNorm2d(out), 72 | ) 73 | 74 | def get_relu_coefs(self, z): 75 | theta = z[:, 0, :] 76 | # b d -> b d//4 77 | theta = self.fc1(theta) 78 | theta = self.relu(theta) 79 | # b d//4 -> b 2*k 80 | theta = self.fc2(theta) 81 | theta = 2 * self.sigmoid(theta) - 1 82 | # b 2*k 83 | return theta 84 | 85 | def forward(self, x, z): 86 | theta = self.get_relu_coefs(z) 87 | # b 2*k*c -> b c 2*k 2*k 2*k 88 | relu_coefs = theta.view(-1, self.hid, 2 * self.k) * self.lambdas + self.init_v 89 | 90 | out = self.bn1(self.conv1(x)) 91 | out_ = [out, relu_coefs] 92 | out = self.act1(out_) 93 | 94 | out = self.bn2(self.conv2(out)) 95 | out_ = [out, relu_coefs] 96 | out = self.act2(out_) 97 | 98 | out = self.bn3(self.conv3(out)) 99 | if self.se is not None: 100 | out = self.se(out) 101 | out = out + self.shortcut(x) if self.stride == 1 else out 102 | return out 103 | 104 | 105 | class MobileDown(nn.Module): 106 | def __init__(self, ks, inp, hid, out, se, stride, dim, reduction=4, k=2): 107 | super(MobileDown, self).__init__() 108 | self.dim = dim 109 | self.hid, self.out = hid, out 110 | self.k = k 111 | self.fc1 = nn.Linear(dim, dim // reduction) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.fc2 = nn.Linear(dim // reduction, 2 * k * hid) 114 | self.sigmoid = nn.Sigmoid() 115 | self.register_buffer('lambdas', torch.Tensor([1.] * k + [0.5] * k).float()) 116 | self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * k - 1)).float()) 117 | self.stride = stride 118 | # self.se = DyReLUB(channels=out, k=1) if dyrelu else se 119 | self.se = se 120 | 121 | self.dw_conv1 = nn.Conv2d(inp, hid, kernel_size=ks, stride=stride, 122 | padding=ks // 2, groups=inp, bias=False) 123 | self.dw_bn1 = nn.BatchNorm2d(hid) 124 | self.dw_act1 = MyDyRelu(2) 125 | 126 | self.pw_conv1 = nn.Conv2d(hid, inp, kernel_size=1, stride=1, padding=0, bias=False) 127 | self.pw_bn1 = nn.BatchNorm2d(inp) 128 | self.pw_act1 = nn.ReLU() 129 | 130 | self.dw_conv2 = nn.Conv2d(inp, hid, kernel_size=ks, stride=1, 131 | padding=ks // 2, groups=inp, bias=False) 132 | self.dw_bn2 = nn.BatchNorm2d(hid) 133 | self.dw_act2 = MyDyRelu(2) 134 | 135 | self.pw_conv2 = nn.Conv2d(hid, out, kernel_size=1, stride=1, padding=0, bias=False) 136 | self.pw_bn2 = nn.BatchNorm2d(out) 137 | 138 | self.shortcut = nn.Identity() 139 | if stride == 1 and inp != out: 140 | self.shortcut = nn.Sequential( 141 | nn.Conv2d(inp, out, kernel_size=1, stride=1, padding=0, bias=False), 142 | nn.BatchNorm2d(out), 143 | ) 144 | 145 | def get_relu_coefs(self, z): 146 | theta = z[:, 0, :] 147 | # b d -> b d//4 148 | theta = self.fc1(theta) 149 | theta = self.relu(theta) 150 | # b d//4 -> b 2*k 151 | theta = self.fc2(theta) 152 | theta = 2 * self.sigmoid(theta) - 1 153 | # b 2*k 154 | return theta 155 | 156 | def forward(self, x, z): 157 | theta = self.get_relu_coefs(z) 158 | # b 2*k*c -> b c 2*k 2*k 2*k 159 | relu_coefs = theta.view(-1, self.hid, 2 * self.k) * self.lambdas + self.init_v 160 | 161 | out = self.dw_bn1(self.dw_conv1(x)) 162 | out_ = [out, relu_coefs] 163 | out = self.dw_act1(out_) 164 | out = self.pw_act1(self.pw_bn1(self.pw_conv1(out))) 165 | 166 | out = self.dw_bn2(self.dw_conv2(out)) 167 | out_ = [out, relu_coefs] 168 | out = self.dw_act2(out_) 169 | out = self.pw_bn2(self.pw_conv2(out)) 170 | 171 | if self.se is not None: 172 | out = self.se(out) 173 | out = out + self.shortcut(x) if self.stride == 1 else out 174 | return out 175 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | 8 | class MyDyRelu(nn.Module): 9 | def __init__(self, k): 10 | super(MyDyRelu, self).__init__() 11 | self.k = k 12 | 13 | def forward(self, inputs): 14 | x, relu_coefs = inputs 15 | # BxCxHxW -> HxWxBxCx1 16 | x_perm = x.permute(2, 3, 0, 1).unsqueeze(-1) 17 | # h w b c 1 -> h w b c k 18 | output = x_perm * relu_coefs[:, :, :self.k] + relu_coefs[:, :, self.k:] 19 | # HxWxBxCxk -> BxCxHxW 20 | result = torch.max(output, dim=-1)[0].permute(2, 3, 0, 1) 21 | return result 22 | 23 | 24 | def mixup_data(x, y, alpha, use_cuda=True): 25 | if alpha > 0: 26 | lam = np.random.beta(alpha, alpha) 27 | else: 28 | lam = 1 29 | b = x.size()[0] 30 | if use_cuda: 31 | index = torch.randperm(b).cuda() 32 | else: 33 | index = torch.randperm(b) 34 | 35 | mixed_x = lam * x + (1 - lam) * x[index, :] 36 | y_a, y_b = y, y[index] 37 | return mixed_x, y_a, y_b, lam 38 | 39 | 40 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 41 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 42 | 43 | 44 | def cutmix(input, target, beta): 45 | lam = np.random.beta(beta, beta) 46 | b = input.size()[0] 47 | rand_index = torch.randperm(b).cuda() 48 | target_a = target 49 | target_b = target[rand_index] 50 | bx1, by1, bx2, by2 = rand_box(input.size(), lam) 51 | input[:, :, bx1:bx2, by1:by2] = input[rand_index, :, bx1:bx2, by1:by2] 52 | lam = 1 - ((bx2 - bx1) * (by2 - by1) / (input.size()[-1] * input.size()[-2])) 53 | return input, target_a, target_b, lam 54 | 55 | 56 | def cutmix_criterion(criterion, output, target_a, target_b, lam): 57 | return lam * criterion(output, target_a) + (1. - lam) * criterion(output, target_b) 58 | 59 | 60 | def rand_box(size, lam): 61 | _, _, h, w = size 62 | cut_rat = np.sqrt(1. - lam) 63 | cut_w = np.int(w * cut_rat) 64 | cut_h = np.int(h * cut_rat) 65 | # 在图片上随机取一点作为cut的中心点 66 | cx = np.random.randint(w) 67 | cy = np.random.randint(h) 68 | bx1 = np.clip(cx - cut_w // 2, 0, w) 69 | by1 = np.clip(cy - cut_h // 2, 0, h) 70 | bx2 = np.clip(cx + cut_w // 2, 0, w) 71 | by2 = np.clip(cy + cut_h // 2, 0, h) 72 | return bx1, by1, bx2, by2 73 | 74 | 75 | ''' 76 | for batch_idx, (inputs, targets) in enumerate(trainloader): 77 | if use_cuda: 78 | inputs, targets = inputs.cuda(), targets.cuda() 79 | 80 | inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, 81 | args.alpha, use_cuda) 82 | inputs, targets_a, targets_b = map(Variable, (inputs, 83 | targets_a, targets_b)) 84 | outputs = net(inputs) 85 | loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam) 86 | train_loss += loss.data[0] 87 | _, predicted = torch.max(outputs.data, 1) 88 | total += targets.size(0) 89 | correct += (lam * predicted.eq(targets_a.data).cpu().sum().float() 90 | + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float()) 91 | ''' 92 | 93 | 94 | class RandomErasing(object): 95 | ''' 96 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 97 | ------------------------------------------------------------------------------------- 98 | probability: The probability that the operation will be performed. 99 | sl: min erasing area 100 | sh: max erasing area 101 | r1: min aspect ratio 102 | mean: erasing value 103 | ------------------------------------------------------------------------------------- 104 | ''' 105 | 106 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 107 | self.probability = probability 108 | self.mean = mean 109 | self.sl = sl 110 | self.sh = sh 111 | self.r1 = r1 112 | 113 | def __call__(self, img): 114 | 115 | if random.uniform(0, 1) > self.probability: 116 | return img 117 | for attempt in range(100): 118 | # 计算图片面积 119 | # c h w 120 | area = img.size()[1] * img.size()[2] 121 | # 比率范围 122 | target_area = random.uniform(self.sl, self.sh) * area 123 | # 宽高比 124 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 125 | h = int(round(math.sqrt(target_area * aspect_ratio))) 126 | w = int(round(math.sqrt(target_area / aspect_ratio))) 127 | if w < img.size()[2] and h < img.size()[1]: 128 | x1 = random.randint(0, img.size()[1] - h) 129 | y1 = random.randint(0, img.size()[2] - w) 130 | if img.size()[0] == 3: 131 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 132 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 133 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 134 | else: 135 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 136 | return img 137 | return img 138 | --------------------------------------------------------------------------------