├── .gitignore ├── LICENSE ├── README.md ├── SGDR.py ├── main.py └── model ├── __init__.py ├── sa_layer.py └── sa_resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | .vscode 127 | *.ipynb 128 | data/ 129 | 130 | *.pth.tar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ho Young Jhoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of [Stand-Alone Self-Attention in Vision Models](https://arxiv.org/pdf/1906.05909.pdf) 2 | 3 | **This is NOT an official implementation**. Please let me know whether this implementation contains any misreadings of the original paper. 4 | 5 | ## Prerequisites 6 | * Python +3.6 7 | * pytorch +1.1.0 8 | * scipy 9 | * Pillow 10 | * torchvision 11 | 12 | ## Benchmark (WIP) 13 | 14 | Trained with ImageNet. (WIP: CIFAR-10, CIFAR-100) 15 | 16 | Backbone network and parameters are based on the official torchvision ResNet and trainer example. 17 | 18 | Trained up to 90 epochs / batch 64 on a single NVIDIA 1080Ti GPU, with SGD optimizer with a learning rate of 0.1 which is linearly warmed up for 10 epochs followed by cosine decay. (according to the SASA paper) 19 | 20 | -------------------------------------------------------------------------------- /SGDR.py: -------------------------------------------------------------------------------- 1 | """ modified https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py """ 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | 5 | 6 | class LinearWarmupScheduler(_LRScheduler): 7 | """ Linearly warm-up(increasing) starting from zero learning rate in optimizer. 8 | Args: 9 | optimizer (Optimizer): Wrapped optimizer. 10 | total_epoch: target learning rate is reached at total_epoch, gradually 11 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 12 | """ 13 | 14 | def __init__(self, optimizer, total_epoch, after_scheduler=None, last_epoch=-1): 15 | self.total_epoch = total_epoch 16 | self.after_scheduler = after_scheduler 17 | self.finished = False 18 | super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch >= self.total_epoch: 22 | if self.after_scheduler: 23 | if not self.finished: 24 | self.after_scheduler.base_lrs = [base_lr for base_lr in self.base_lrs] 25 | self.finished = True 26 | return self.after_scheduler.get_lr() 27 | return [base_lr for base_lr in self.base_lrs] 28 | 29 | return [base_lr * ((self.last_epoch + 1) / self.total_epoch) for base_lr in self.base_lrs] 30 | 31 | def step(self, epoch=None, metrics=None): 32 | if self.finished and self.after_scheduler: 33 | if epoch is None: 34 | self.after_scheduler.step(None) 35 | else: 36 | self.after_scheduler.step(epoch - self.total_epoch) 37 | else: 38 | return super(LinearWarmupScheduler, self).step(epoch) 39 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ Original Code: https://github.com/pytorch/examples/blob/master/imagenet/main.py """ 2 | import argparse 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.optim.lr_scheduler as lr_sched 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | 23 | import model.sa_resnet as sa_resnet 24 | from SGDR import LinearWarmupScheduler 25 | 26 | model_names = list(sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name]))) 29 | 30 | all_model_names = sa_resnet.model_names + model_names 31 | 32 | dataset_dict = { 33 | 'cifar10': (10, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]), 34 | 'cifar100': (100, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]), 35 | 'imagenet': (1000, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 36 | } 37 | 38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with SASA') 39 | parser.add_argument('--data_path', default='data', help='path to dataset directory') 40 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 41 | choices=list(dataset_dict.keys()), help='dataset to train/val') 42 | parser.add_argument('-a', '--arch', metavar='ARCH', default='cstem_sa_resnet50', 43 | choices=all_model_names, 44 | help='model architecture: ' + 45 | ' | '.join(all_model_names) + 46 | ' (default: cstem_sa_resnet50)') 47 | parser.add_argument('--width', default=224, type=int, 48 | help='height/width of input image') 49 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 50 | help='number of data loading workers (default: 4)') 51 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 52 | help='number of total epochs to run') 53 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 54 | help='manual epoch number (useful on restarts)') 55 | parser.add_argument('-b', '--batch-size', default=64, type=int, 56 | metavar='N', 57 | help='mini-batch size (default: 256), this is the total ' 58 | 'batch size of all GPUs on the current node when ' 59 | 'using Data Parallel or Distributed Data Parallel') 60 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 61 | metavar='LR', help='initial learning rate', dest='lr') 62 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 63 | help='momentum') 64 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 65 | metavar='W', help='weight decay (default: 1e-4)', 66 | dest='weight_decay') 67 | parser.add_argument('-p', '--print-freq', default=10, type=int, 68 | metavar='N', help='print frequency (default: 10)') 69 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 70 | help='path to latest checkpoint (default: none)') 71 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 72 | help='evaluate model on validation set') 73 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 74 | help='use pre-trained model') 75 | parser.add_argument('--world-size', default=-1, type=int, 76 | help='number of nodes for distributed training') 77 | parser.add_argument('--rank', default=-1, type=int, 78 | help='node rank for distributed training') 79 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 80 | help='url used to set up distributed training') 81 | parser.add_argument('--dist-backend', default='nccl', type=str, 82 | help='distributed backend') 83 | parser.add_argument('--seed', default=None, type=int, 84 | help='seed for initializing training. ') 85 | parser.add_argument('--gpu', default=None, type=int, 86 | help='GPU id to use.') 87 | parser.add_argument('--cpu', action='store_true', 88 | help='Use CPU only') 89 | parser.add_argument('--multiprocessing-distributed', action='store_true', 90 | help='Use multi-processing distributed training to launch ' 91 | 'N processes per node, which has N GPUs. This is the ' 92 | 'fastest way to use PyTorch for either single node or ' 93 | 'multi node data parallel training') 94 | 95 | best_acc1 = 0 96 | 97 | 98 | def main(): 99 | args = parser.parse_args() 100 | 101 | if args.seed is not None: 102 | random.seed(args.seed) 103 | torch.manual_seed(args.seed) 104 | if not args.cpu: 105 | cudnn.deterministic = True 106 | warnings.warn('You have chosen to seed training. ' 107 | 'This will turn on the CUDNN deterministic setting, ' 108 | 'which can slow down your training considerably! ' 109 | 'You may see unexpected behavior when restarting ' 110 | 'from checkpoints.') 111 | 112 | if args.gpu is not None: 113 | warnings.warn('You have chosen a specific GPU. This will completely ' 114 | 'disable data parallelism.') 115 | 116 | if args.dist_url == "env://" and args.world_size == -1: 117 | args.world_size = int(os.environ["WORLD_SIZE"]) 118 | 119 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 120 | 121 | ngpus_per_node = torch.cuda.device_count() 122 | if args.multiprocessing_distributed: 123 | # Since we have ngpus_per_node processes per node, the total world_size 124 | # needs to be adjusted accordingly 125 | args.world_size = ngpus_per_node * args.world_size 126 | # Use torch.multiprocessing.spawn to launch distributed processes: the 127 | # main_worker process function 128 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 129 | else: 130 | # Simply call main_worker function 131 | main_worker(args.gpu, ngpus_per_node, args) 132 | 133 | 134 | def main_worker(gpu, ngpus_per_node, args): 135 | global best_acc1 136 | args.gpu = gpu 137 | 138 | if args.gpu is not None: 139 | print("Use GPU: {} for training".format(args.gpu)) 140 | 141 | if args.distributed: 142 | if args.dist_url == "env://" and args.rank == -1: 143 | args.rank = int(os.environ["RANK"]) 144 | if args.multiprocessing_distributed: 145 | # For multiprocessing distributed training, rank needs to be the 146 | # global rank among all the processes 147 | args.rank = args.rank * ngpus_per_node + gpu 148 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 149 | world_size=args.world_size, rank=args.rank) 150 | 151 | # create model 152 | num_classes, dataset_mean, dataset_std = dataset_dict[args.dataset] 153 | 154 | if args.arch not in model_names: 155 | print("=> creating model '{}'".format(args.arch)) 156 | model = sa_resnet.get_model(args, num_classes=num_classes) 157 | else: 158 | if args.pretrained: 159 | if args.dataset != 'imagenet': 160 | raise Exception('cannot download non-imagenet pretrained model') 161 | print("=> using pre-trained model '{}'".format(args.arch)) 162 | model = models.__dict__[args.arch](pretrained=True) 163 | else: 164 | print("=> creating model '{}'".format(args.arch)) 165 | model = models.__dict__[args.arch](num_classes=num_classes) 166 | 167 | if args.distributed: 168 | # For multiprocessing distributed, DistributedDataParallel constructor 169 | # should always set the single device scope, otherwise, 170 | # DistributedDataParallel will use all available devices. 171 | if args.gpu is not None: 172 | torch.cuda.set_device(args.gpu) 173 | model.cuda(args.gpu) 174 | # When using a single GPU per process and per 175 | # DistributedDataParallel, we need to divide the batch size 176 | # ourselves based on the total number of GPUs we have 177 | args.batch_size = int(args.batch_size / ngpus_per_node) 178 | args.workers = int(args.workers / ngpus_per_node) 179 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 180 | else: 181 | model.cuda() 182 | # DistributedDataParallel will divide and allocate batch_size to all 183 | # available GPUs if device_ids are not set 184 | model = torch.nn.parallel.DistributedDataParallel(model) 185 | elif args.gpu is not None: 186 | torch.cuda.set_device(args.gpu) 187 | model = model.cuda(args.gpu) 188 | else: 189 | # DataParallel will divide and allocate batch_size to all available GPUs 190 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 191 | model.features = torch.nn.DataParallel(model.features) 192 | else: 193 | model = torch.nn.DataParallel(model) 194 | 195 | if not args.cpu: 196 | model = model.cuda() 197 | 198 | print("model param #: {}".format(sum(p.numel() for p in model.parameters()))) 199 | 200 | # define loss function (criterion) and optimizer 201 | criterion = nn.CrossEntropyLoss() 202 | if not args.cpu: 203 | criterion = criterion.cuda(args.gpu) 204 | 205 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 206 | momentum=args.momentum, 207 | weight_decay=args.weight_decay) 208 | 209 | # loss scheduler 210 | scheduler = LinearWarmupScheduler(optimizer, 10, lr_sched.CosineAnnealingLR(optimizer, args.epochs)) 211 | 212 | # optionally resume from a checkpoint 213 | if args.resume: 214 | if os.path.isfile(args.resume): 215 | print("=> loading checkpoint '{}'".format(args.resume)) 216 | checkpoint = torch.load(args.resume) 217 | args.start_epoch = checkpoint['epoch'] 218 | best_acc1 = checkpoint['best_acc1'] 219 | if args.gpu is not None: 220 | # best_acc1 may be from a checkpoint from a different GPU 221 | best_acc1 = best_acc1.to(args.gpu) 222 | model.load_state_dict(checkpoint['state_dict']) 223 | optimizer.load_state_dict(checkpoint['optimizer']) 224 | print("=> loaded checkpoint '{}' (epoch {})" 225 | .format(args.resume, checkpoint['epoch'])) 226 | scheduler = LinearWarmupScheduler(optimizer, total_epoch=10, 227 | after_scheduler=lr_sched.CosineAnnealingLR(optimizer, args.epochs), 228 | last_epoch=checkpoint['epoch']) 229 | else: 230 | print("=> no checkpoint found at '{}'".format(args.resume)) 231 | 232 | if not args.cpu: 233 | cudnn.benchmark = True 234 | 235 | # Data loading code 236 | traindir = os.path.join(args.data_path, 'train') 237 | valdir = os.path.join(args.data_path, 'val') 238 | if not os.path.exists(args.data_path): 239 | os.mkdir(args.data_path) 240 | if not os.path.exists(traindir): 241 | os.mkdir(traindir) 242 | if not os.path.exists(valdir): 243 | os.mkdir(valdir) 244 | 245 | normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std) 246 | train_transform = transforms.Compose([ 247 | transforms.RandomResizedCrop(args.width), 248 | transforms.RandomHorizontalFlip(), 249 | transforms.ToTensor(), 250 | normalize, 251 | ]) 252 | val_transform = transforms.Compose([ 253 | transforms.Resize(256), 254 | transforms.CenterCrop(args.width), 255 | transforms.ToTensor(), 256 | normalize, 257 | ]) 258 | 259 | if args.dataset == 'cifar10': 260 | train_dataset = datasets.CIFAR10(traindir, train=True, 261 | download=True, transform=train_transform) 262 | elif args.dataset == 'cifar100': 263 | train_dataset = datasets.CIFAR100(traindir, train=True, 264 | download=True, transform=train_transform) 265 | else: 266 | train_dataset = datasets.ImageNet(traindir, split='train', 267 | download=True, transform=train_transform) 268 | 269 | if args.distributed: 270 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 271 | else: 272 | train_sampler = None 273 | 274 | train_loader = torch.utils.data.DataLoader( 275 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 276 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 277 | 278 | 279 | if args.dataset == 'cifar10': 280 | val_dataset = datasets.CIFAR10(valdir, train=False, 281 | download=True, transform=val_transform) 282 | elif args.dataset == 'cifar100': 283 | val_dataset = datasets.CIFAR100(valdir, train=False, 284 | download=True, transform=val_transform) 285 | else: 286 | val_dataset = datasets.ImageNet(valdir, split='val', 287 | download=True, transform=val_transform) 288 | 289 | val_loader = torch.utils.data.DataLoader( 290 | val_dataset, 291 | batch_size=args.batch_size, shuffle=False, 292 | num_workers=args.workers, pin_memory=True) 293 | 294 | if args.evaluate: 295 | validate(val_loader, model, criterion, args) 296 | return 297 | 298 | for epoch in range(args.start_epoch, args.epochs): 299 | scheduler.step() 300 | 301 | if args.distributed: 302 | train_sampler.set_epoch(epoch) 303 | 304 | # train for one epoch 305 | train(train_loader, model, criterion, optimizer, epoch, args) 306 | 307 | # evaluate on validation set 308 | acc1 = validate(val_loader, model, criterion, args) 309 | 310 | # remember best acc@1 and save checkpoint 311 | is_best = acc1 > best_acc1 312 | best_acc1 = max(acc1, best_acc1) 313 | 314 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 315 | and args.rank % ngpus_per_node == 0): 316 | save_checkpoint({ 317 | 'epoch': epoch + 1, 318 | 'arch': args.arch, 319 | 'state_dict': model.state_dict(), 320 | 'best_acc1': best_acc1, 321 | 'optimizer' : optimizer.state_dict(), 322 | }, is_best) 323 | 324 | 325 | def train(train_loader, model, criterion, optimizer, epoch, args): 326 | batch_time = AverageMeter('Time', ':6.3f') 327 | data_time = AverageMeter('Data', ':6.3f') 328 | losses = AverageMeter('Loss', ':.4e') 329 | top1 = AverageMeter('Acc@1', ':6.2f') 330 | top5 = AverageMeter('Acc@5', ':6.2f') 331 | progress = ProgressMeter( 332 | len(train_loader), 333 | [batch_time, data_time, losses, top1, top5], 334 | prefix="Epoch: [{}]".format(epoch)) 335 | 336 | # switch to train mode 337 | model.train() 338 | 339 | end = time.time() 340 | for i, (images, target) in enumerate(train_loader): 341 | # measure data loading time 342 | data_time.update(time.time() - end) 343 | 344 | if not args.cpu: 345 | if args.gpu is not None: 346 | images = images.cuda(args.gpu, non_blocking=True) 347 | target = target.cuda(args.gpu, non_blocking=True) 348 | 349 | # compute output 350 | output = model(images) 351 | loss = criterion(output, target) 352 | 353 | # measure accuracy and record loss 354 | [acc1, acc5] = accuracy(output, target, topk=(1, 5)) 355 | losses.update(loss.item(), images.size(0)) 356 | top1.update(acc1[0], images.size(0)) 357 | top5.update(acc5[0], images.size(0)) 358 | 359 | # compute gradient and do SGD step 360 | optimizer.zero_grad() 361 | loss.backward() 362 | optimizer.step() 363 | 364 | # measure elapsed time 365 | batch_time.update(time.time() - end) 366 | end = time.time() 367 | 368 | if i % args.print_freq == 0: 369 | progress.display(i) 370 | 371 | 372 | def validate(val_loader, model, criterion, args): 373 | batch_time = AverageMeter('Time', ':6.3f') 374 | losses = AverageMeter('Loss', ':.4e') 375 | top1 = AverageMeter('Acc@1', ':6.2f') 376 | top5 = AverageMeter('Acc@5', ':6.2f') 377 | progress = ProgressMeter( 378 | len(val_loader), 379 | [batch_time, losses, top1, top5], 380 | prefix='Test: ') 381 | 382 | # switch to evaluate mode 383 | model.eval() 384 | 385 | with torch.no_grad(): 386 | end = time.time() 387 | for i, (images, target) in enumerate(val_loader): 388 | if not args.cpu: 389 | if args.gpu is not None: 390 | images = images.cuda(args.gpu, non_blocking=True) 391 | target = target.cuda(args.gpu, non_blocking=True) 392 | 393 | # compute output 394 | output = model(images) 395 | loss = criterion(output, target) 396 | 397 | # measure accuracy and record loss 398 | [acc1, acc5] = accuracy(output, target, topk=(1, 5)) 399 | losses.update(loss.item(), images.size(0)) 400 | top1.update(acc1[0], images.size(0)) 401 | top5.update(acc5[0], images.size(0)) 402 | 403 | # measure elapsed time 404 | batch_time.update(time.time() - end) 405 | end = time.time() 406 | 407 | if i % args.print_freq == 0: 408 | progress.display(i) 409 | 410 | # TODO: this should also be done with the ProgressMeter 411 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 412 | .format(top1=top1, top5=top5)) 413 | 414 | return top1.avg 415 | 416 | 417 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 418 | torch.save(state, filename) 419 | if is_best: 420 | shutil.copyfile(filename, 'model_best.pth.tar') 421 | 422 | 423 | class AverageMeter(object): 424 | """Computes and stores the average and current value""" 425 | def __init__(self, name, fmt=':f'): 426 | self.name = name 427 | self.fmt = fmt 428 | self.reset() 429 | 430 | def reset(self): 431 | self.val = 0 432 | self.avg = 0 433 | self.sum = 0 434 | self.count = 0 435 | 436 | def update(self, val, n=1): 437 | self.val = val 438 | self.sum += val * n 439 | self.count += n 440 | self.avg = self.sum / self.count 441 | 442 | def __str__(self): 443 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 444 | return fmtstr.format(**self.__dict__) 445 | 446 | 447 | class ProgressMeter(object): 448 | def __init__(self, num_batches, meters, prefix=""): 449 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 450 | self.meters = meters 451 | self.prefix = prefix 452 | 453 | def display(self, batch): 454 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 455 | entries += [str(meter) for meter in self.meters] 456 | print('\t'.join(entries)) 457 | 458 | def _get_batch_fmtstr(self, num_batches): 459 | num_digits = len(str(num_batches // 1)) 460 | fmt = '{:' + str(num_digits) + 'd}' 461 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 462 | 463 | def accuracy(output, target, topk=(1,)): 464 | """Computes the accuracy over the k top predictions for the specified values of k""" 465 | with torch.no_grad(): 466 | maxk = max(topk) 467 | batch_size = target.size(0) 468 | 469 | _, pred = output.topk(maxk, 1, True, True) 470 | pred = pred.t() 471 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 472 | 473 | res = [] 474 | for k in topk: 475 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 476 | res.append(correct_k.mul_(100.0 / batch_size)) 477 | return res 478 | 479 | 480 | if __name__ == '__main__': 481 | main() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MerHS/SASA-pytorch/7d113852dce2e25d4de23caf87ad7d33758c322e/model/__init__.py -------------------------------------------------------------------------------- /model/sa_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch.nn.modules.utils import _pair 9 | from torchvision.models.resnet import conv1x1 10 | 11 | class SelfAttentionConv2d(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, 13 | stride=1, padding=0, groups=1, bias=True): 14 | super(SelfAttentionConv2d, self).__init__() 15 | if in_channels % groups != 0: 16 | raise ValueError('in_channels must be divisible by groups') 17 | if out_channels % groups != 0: 18 | raise ValueError('out_channels must be divisible by groups') 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.kernel_size = _pair(kernel_size) 22 | self.stride = _pair(stride) 23 | self.padding = _pair(padding) 24 | self.groups = groups # multi-head count 25 | 26 | if bias: 27 | self.bias = nn.Parameter(torch.Tensor(1, out_channels, 1, 1)) 28 | else: 29 | self.register_parameter('bias', None) 30 | 31 | # relative position offsets are shared between multi-heads 32 | self.rel_size = (out_channels // groups) // 2 33 | self.relative_x = nn.Parameter(torch.Tensor(self.rel_size, self.kernel_size[1])) 34 | self.relative_y = nn.Parameter(torch.Tensor((out_channels // groups) - self.rel_size, self.kernel_size[0])) 35 | 36 | self.weight_query = nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) 37 | self.weight_key = nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) 38 | self.weight_value = nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) 39 | 40 | self.softmax = nn.Softmax(dim=3) 41 | 42 | self.reset_parameters() 43 | 44 | def reset_parameters(self): 45 | init.kaiming_normal_(self.weight_query.weight, mode='fan_out', nonlinearity='relu') 46 | init.kaiming_normal_(self.weight_key.weight, mode='fan_out', nonlinearity='relu') 47 | init.kaiming_normal_(self.weight_value.weight, mode='fan_out', nonlinearity='relu') 48 | 49 | if self.bias is not None: 50 | bound = 1 / math.sqrt(self.out_channels) 51 | init.uniform_(self.bias, -bound, bound) 52 | 53 | init.normal_(self.relative_x, 0, 1) 54 | init.normal_(self.relative_y, 0, 1) 55 | 56 | def forward(self, x): 57 | b, c, h, w = x.size() 58 | kh, kw = self.kernel_size 59 | ph, pw = h + self.padding[0] * 2, w + self.padding[1] * 2 60 | 61 | fh = (ph - kh) // self.stride[0] + 1 62 | fw = (pw - kw) // self.stride[1] + 1 63 | 64 | px, py = self.padding 65 | x = F.pad(x, (py, py, px, px)) 66 | 67 | vq = self.weight_query(x) 68 | vk = self.weight_key(x) 69 | vv = self.weight_value(x) # b, fc, ph, pw 70 | 71 | # b, fc, fh, fw 72 | win_q = vq[:, :, (kh-1)//2:ph-(kh//2):self.stride[0], (kw-1)//2:pw-(kw//2):self.stride[1]] 73 | 74 | win_q_b = win_q.view(b, self.groups, -1, fh, fw) # b, g, fc/g, fh, fw 75 | 76 | win_q_x, win_q_y = win_q_b.split(self.rel_size, dim=2) # (b, g, x, fh, fw), (b, g, y, fh, fw) 77 | win_q_x = torch.einsum('bgxhw,xk->bhwk', (win_q_x, self.relative_x)) # b, fh, fw, kw 78 | win_q_y = torch.einsum('bgyhw,yk->bhwk', (win_q_y, self.relative_y)) # b, fh, fw, kh 79 | 80 | win_k = vk.unfold(2, kh, self.stride[0]).unfold(3, kw, self.stride[1]) # b, fc, fh, fw, kh, kw 81 | 82 | vx = (win_q.unsqueeze(4).unsqueeze(4) * win_k).sum(dim=1) # b, fh, fw, kh, kw 83 | vx = vx + win_q_x.unsqueeze(3) + win_q_y.unsqueeze(4) # add rel_x, rel_y 84 | vx = self.softmax(vx.view(b, fh, fw, -1)).view(b, 1, fh, fw, kh, kw) 85 | 86 | win_v = vv.unfold(2, kh, self.stride[0]).unfold(3, kw, self.stride[1]) 87 | fin_v = torch.einsum('bchwkl->bchw', (vx * win_v, )) # (b, fc, fh, fw, kh, kw) -> (b, fc, fh, fw) 88 | 89 | if self.bias is not None: 90 | fin_v += self.bias 91 | 92 | return fin_v 93 | 94 | class SAMixtureConv2d(nn.Module): 95 | """ spatially-aware SA / multiple value transformation for stem layer """ 96 | def __init__(self, in_height, in_width, in_channels, out_channels, kernel_size, 97 | stride=1, padding=0, groups=1, mix=4, bias=True): 98 | super(SAMixtureConv2d, self).__init__() 99 | if in_channels % groups != 0: 100 | raise ValueError('in_channels must be divisible by groups') 101 | if out_channels % groups != 0: 102 | raise ValueError('out_channels must be divisible by groups') 103 | 104 | self.in_height = in_height 105 | self.in_width = in_width 106 | self.in_channels = in_channels 107 | self.out_channels = out_channels 108 | self.kernel_size = _pair(kernel_size) 109 | self.stride = _pair(stride) 110 | self.padding = _pair(padding) 111 | self.groups = groups # multi-head count 112 | self.mix = mix # weight mixture 113 | 114 | if bias: 115 | self.bias = nn.Parameter(torch.Tensor(1, out_channels, 1, 1)) 116 | else: 117 | self.register_parameter('bias', None) 118 | 119 | # relative position offsets are shared between multi-heads 120 | self.rel_size = (out_channels // groups) // 2 121 | self.relative_x = nn.Parameter(torch.Tensor(self.rel_size, self.kernel_size[1])) 122 | self.relative_y = nn.Parameter(torch.Tensor((out_channels // groups) - self.rel_size, self.kernel_size[0])) 123 | 124 | self.weight_query = nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) 125 | self.weight_key = nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) 126 | self.weight_values = nn.ModuleList([nn.Conv2d(self.in_channels, self.out_channels, 1, groups=self.groups, bias=False) for _ in range(mix)]) 127 | 128 | self.emb_x = nn.Parameter(torch.Tensor(out_channels // groups, in_width + 2 * self.padding[1])) # fc/g, pw 129 | self.emb_y = nn.Parameter(torch.Tensor(out_channels // groups, in_height + 2 * self.padding[0])) # fc/g, ph 130 | self.emb_m = nn.Parameter(torch.Tensor(mix, out_channels // groups)) # m, fc/g 131 | 132 | self.softmax = nn.Softmax(dim=3) 133 | 134 | self.reset_parameters() 135 | 136 | def reset_parameters(self): 137 | init.kaiming_normal_(self.weight_query.weight, mode='fan_out', nonlinearity='relu') 138 | init.kaiming_normal_(self.weight_key.weight, mode='fan_out', nonlinearity='relu') 139 | for wv in self.weight_values: 140 | init.kaiming_normal_(wv.weight, mode='fan_out', nonlinearity='relu') 141 | 142 | if self.bias is not None: 143 | bound = 1 / math.sqrt(self.out_channels) 144 | init.uniform_(self.bias, -bound, bound) 145 | 146 | init.normal_(self.relative_x, 0, 1) 147 | init.normal_(self.relative_y, 0, 1) 148 | init.normal_(self.emb_x, 0, 1) 149 | init.normal_(self.emb_y, 0, 1) 150 | init.normal_(self.emb_m, 0, 1) 151 | 152 | def forward(self, x): 153 | b, c, h, w = x.size() 154 | kh, kw = self.kernel_size 155 | ph, pw = h + self.padding[0] * 2, w + self.padding[1] * 2 156 | 157 | fh = (ph - kh) // self.stride[0] + 1 158 | fw = (pw - kw) // self.stride[1] + 1 159 | 160 | px, py = self.padding 161 | x = F.pad(x, (py, py, px, px)) 162 | 163 | vq = self.weight_query(x) 164 | vk = self.weight_key(x) # b, fc, fh, fw 165 | 166 | # b, fc, fh, fw 167 | win_q = vq[:, :, (kh-1)//2:ph-(kh//2):self.stride[0], (kw-1)//2:pw-(kw//2):self.stride[1]] 168 | 169 | win_q_b = win_q.view(b, self.groups, -1, fh, fw) # b, g, fc/g, fh, fw 170 | 171 | win_q_x, win_q_y = win_q_b.split(self.rel_size, dim=2) # (b, g, x, fh, fw), (b, g, y, fh, fw) 172 | win_q_x = torch.einsum('bgxhw,xk->bhwk', (win_q_x, self.relative_x)) # b, fh, fw, kw 173 | win_q_y = torch.einsum('bgyhw,yk->bhwk', (win_q_y, self.relative_y)) # b, fh, fw, kh 174 | 175 | win_k = vk.unfold(2, kh, self.stride[0]).unfold(3, kw, self.stride[1]) # b, fc, fh, fw, kh, kw 176 | 177 | vx = (win_q.unsqueeze(4).unsqueeze(4) * win_k).sum(dim=1) # b, fh, fw, kh, kw 178 | vx = vx + win_q_x.unsqueeze(3) + win_q_y.unsqueeze(4) # add rel_x, rel_y 179 | vx = self.softmax(vx.view(b, fh, fw, -1)).view(b, 1, fh, fw, kh, kw) 180 | 181 | # spatially aware mixture embedding 182 | p_abm_x = torch.einsum('mc,cw->mw', (self.emb_m, self.emb_x)).unsqueeze(1) # m, 1, pw 183 | p_abm_y = torch.einsum('mc,ch->mh', (self.emb_m, self.emb_y)).unsqueeze(2) # m, ph, 1 184 | p_abm = F.softmax(p_abm_x + p_abm_y, dim=0) # m, ph, pw 185 | 186 | vv = torch.stack([weight_value(x) for weight_value in self.weight_values], dim=0) # m, b, fc, ph, pw 187 | vv = torch.einsum('mbchw,mhw->bchw', (vv, p_abm)) # b, fc, ph, pw 188 | 189 | win_v = vv.unfold(2, kh, self.stride[0]).unfold(3, kw, self.stride[1]) 190 | fin_v = torch.einsum('bchwkl->bchw', (vx * win_v, )) # (b, fc, fh, fw, kh, kw) -> (b, fc, fh, fw) 191 | 192 | if self.bias is not None: 193 | fin_v += self.bias 194 | 195 | return fin_v -------------------------------------------------------------------------------- /model/sa_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.models.resnet import Bottleneck, ResNet, conv1x1 6 | 7 | from .sa_layer import SelfAttentionConv2d, SAMixtureConv2d 8 | 9 | def sa_conv7x7(in_planes, out_planes, stride=1, groups=1, padding=3): 10 | """ 7x7 SA Convolution with padding """ 11 | return SelfAttentionConv2d(in_planes, out_planes, kernel_size=7, stride=stride, 12 | padding=padding, groups=groups, bias=False) 13 | 14 | def sa_stem4x4(in_height, in_width, in_planes, out_planes, stride=1, groups=1, padding=2, mix=4): 15 | """ 4x4 mixed SA Convolution for stem """ 16 | return SAMixtureConv2d(in_height, in_width, in_planes, out_planes, kernel_size=4, stride=stride, 17 | padding=padding, groups=groups, mix=mix, bias=False) 18 | 19 | class SelfAttentionBottleneck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, 23 | base_width=64, groups=1): 24 | super(SelfAttentionBottleneck, self).__init__() 25 | 26 | width = int(planes * (base_width / 64.)) 27 | 28 | self.conv1 = conv1x1(inplanes, width) 29 | self.bn1 = nn.BatchNorm2d(width) 30 | self.conv2 = sa_conv7x7(width, width, groups=groups) 31 | self.bn2 = nn.BatchNorm2d(width) 32 | self.conv3 = conv1x1(width, planes * self.expansion) 33 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | if stride >= 2: 39 | self.avg_pool = nn.AvgPool2d(stride, stride) 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | 50 | if self.stride >= 2: 51 | out = self.avg_pool(out) 52 | 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | class SAResNet(nn.Module): 68 | """ simpler version of torchvision official ResNet """ 69 | def __init__(self, block, layers, num_classes=1000, use_conv_stem=False, **kwargs): 70 | super(SAResNet, self).__init__() 71 | 72 | self.inplanes = 64 73 | self.head_count = 8 74 | 75 | if use_conv_stem: 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, 77 | padding=3, bias=False) 78 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 79 | 80 | else: 81 | self.conv1 = sa_stem4x4(kwargs['in_height'], kwargs['in_width'], 82 | 3, self.inplanes, groups=1, mix=4) 83 | self.maxpool = nn.MaxPool2d(kernel_size=4, stride=4) 84 | 85 | 86 | self.bn1 = nn.BatchNorm2d(self.inplanes) 87 | self.relu = nn.ReLU(inplace=True) 88 | 89 | self.layer1 = self._make_layer(block, 64, layers[0]) 90 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 91 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 92 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 93 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 94 | self.fc = nn.Linear(512 * block.expansion, num_classes) 95 | 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 99 | elif isinstance(m, nn.BatchNorm2d): 100 | nn.init.constant_(m.weight, 1) 101 | nn.init.constant_(m.bias, 0) 102 | 103 | def _make_layer(self, block, planes, blocks, stride=1): 104 | downsample = None 105 | 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | conv1x1(self.inplanes, planes * block.expansion, stride), 109 | nn.BatchNorm2d(planes * block.expansion), 110 | ) 111 | 112 | layers = [] 113 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.head_count)) 114 | self.inplanes = planes * block.expansion 115 | for _ in range(1, blocks): 116 | layers.append(block(self.inplanes, planes, groups=self.head_count)) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | x = self.conv1(x) 122 | x = self.bn1(x) 123 | x = self.relu(x) 124 | x = self.maxpool(x) 125 | 126 | x = self.layer1(x) 127 | x = self.layer2(x) 128 | x = self.layer3(x) 129 | x = self.layer4(x) 130 | 131 | x = self.avgpool(x) 132 | x = x.reshape(x.size(0), -1) 133 | x = self.fc(x) 134 | 135 | return x 136 | 137 | 138 | model_dict = { 139 | 'resnet26': (ResNet, Bottleneck, [1, 2, 4, 1]), 140 | 'resnet38': (ResNet, Bottleneck, [2, 3, 5, 2]), 141 | 'sa_resnet26': (SAResNet, SelfAttentionBottleneck, [1, 2, 4, 1]), 142 | 'sa_resnet38': (SAResNet, SelfAttentionBottleneck, [2, 3, 5, 2]), 143 | 'sa_resnet50': (SAResNet, SelfAttentionBottleneck, [3, 4, 6, 3]), 144 | 'sa_resnet101': (SAResNet, SelfAttentionBottleneck, [3, 4, 23, 3]), 145 | 'sa_resnet152': (SAResNet, SelfAttentionBottleneck, [3, 8, 36, 3]), 146 | 'cstem_sa_resnet26': (SAResNet, SelfAttentionBottleneck, [1, 2, 4, 1]), 147 | 'cstem_sa_resnet38': (SAResNet, SelfAttentionBottleneck, [2, 3, 5, 2]), 148 | 'cstem_sa_resnet50': (SAResNet, SelfAttentionBottleneck, [3, 4, 6, 3]), 149 | 'cstem_sa_resnet101': (SAResNet, SelfAttentionBottleneck, [3, 4, 23, 3]), 150 | 'cstem_sa_resnet152': (SAResNet, SelfAttentionBottleneck, [3, 8, 36, 3]), 151 | } 152 | 153 | model_names = list(model_dict.keys()) 154 | 155 | def get_model(args, **kwargs): 156 | arch = args.arch 157 | width = args.width 158 | 159 | model_fn, block, layers = model_dict[arch] 160 | use_conv_stem = arch.startswith('cstem') 161 | return model_fn(block, layers, use_conv_stem=use_conv_stem, in_height=width, in_width=width, **kwargs) 162 | --------------------------------------------------------------------------------