├── .gitignore ├── README.md ├── figures ├── imagenet.png ├── imagenet_curve.png ├── overview.png └── style_transfer.png ├── imagenet.py ├── models ├── recalibration_modules.py └── resnet.py └── utils ├── __init__.py ├── eval.py ├── logger.py ├── misc.py ├── progress ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.rst ├── demo.gif ├── progress │ ├── __init__.py │ ├── bar.py │ ├── counter.py │ ├── helpers.py │ └── spinner.py ├── setup.py └── test_progress.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Style-Based Recalibration Module 2 | The official PyTorch implementation of "SRM : A Style-based Recalibration Module for Convolutional Neural Networks" for ImageNet. 3 | SRM is a lightweight architectural unit that dynamically recalibrates feature responses based on style importance. 4 | 5 | ![](figures/overview.png) 6 | 7 | ## Overview of Results 8 | 9 | ### Training and validation curves on ImageNet with ResNet-50 10 | 11 | ![](figures/imagenet_curve.png) 12 | 13 | ### Top-1 and top-5 accuracy (%) on the ImageNet-1K validation set 14 | 15 | ![](figures/imagenet.png) 16 | 17 | ### Example results of style transfer 18 | 19 | ![](figures/style_transfer.png) 20 | 21 | ## Prerequisites 22 | - PyTorch 0.4.0+ 23 | - Python 3.6 24 | - CUDA 8.0+ 25 | 26 | ## Training Examples 27 | * Train **ResNet-50** 28 | ``` 29 | python imagenet.py --depth 50 --data /data/imagenet/ILSVRC2012 --gpu-id 0,1,2,3,4,5,6,7 --checkpoint resnet50/baseline 30 | ``` 31 | 32 | * Train **SRM-ResNet-50** 33 | ``` 34 | python imagenet.py --depth 50 --data /data/imagenet/ILSVRC2012 --gpu-id 0,1,2,3,4,5,6,7 --checkpoint resnet50/srm --recalibration-type srm 35 | ``` 36 | 37 | * Train **SE-ResNet-50** 38 | ``` 39 | python imagenet.py --depth 50 --data /data/imagenet/ILSVRC2012 --gpu-id 0,1,2,3,4,5,6,7 --checkpoint resnet50/se --recalibration-type se 40 | ``` 41 | 42 | * Train **GE-ResNet-50** 43 | ``` 44 | python imagenet.py --depth 50 --data /data/imagenet/ILSVRC2012 --gpu-id 0,1,2,3,4,5,6,7 --checkpoint resnet50/ge --recalibration-type ge 45 | ``` 46 | 47 | ## Acknowledgment 48 | This code is heavily borrowed from [pytorch-classification](https://github.com/bearpaw/pytorch-classification). 49 | 50 | ## Note 51 | * 28/05/2019: initial code for ImageNet is released 52 | -------------------------------------------------------------------------------- /figures/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjaelee410/style-based-recalibration-module/f6221c77e797c4530dddba03616153708d636bd1/figures/imagenet.png -------------------------------------------------------------------------------- /figures/imagenet_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjaelee410/style-based-recalibration-module/f6221c77e797c4530dddba03616153708d636bd1/figures/imagenet_curve.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjaelee410/style-based-recalibration-module/f6221c77e797c4530dddba03616153708d636bd1/figures/overview.png -------------------------------------------------------------------------------- /figures/style_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjaelee410/style-based-recalibration-module/f6221c77e797c4530dddba03616153708d636bd1/figures/style_transfer.png -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.utils.data as data 14 | import torch.optim as optim 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | 18 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p 19 | from models.resnet import resnet 20 | 21 | # Parse arguments 22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 23 | 24 | # Datasets 25 | parser.add_argument('-d', '--data', default='/data/imagenet/ILSVRC2012', type=str) 26 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 27 | help='number of data loading workers (default: 4)') 28 | # Optimization options 29 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 30 | help='number of total epochs to run') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('--train-batch', default=256, type=int, metavar='N', 34 | help='train batchsize (default: 256)') 35 | parser.add_argument('--test-batch', default=256, type=int, metavar='N', 36 | help='test batchsize (default: 256)') 37 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 38 | metavar='LR', help='initial learning rate') 39 | parser.add_argument('--drop', '--dropout', default=0, type=float, 40 | metavar='Dropout', help='Dropout ratio') 41 | parser.add_argument('--schedule', type=int, nargs='+', default=[31, 61], 42 | help='Decrease learning rate at these epochs.') 43 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 45 | help='momentum') 46 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 47 | metavar='W', help='weight decay (default: 1e-4)') 48 | # Checkpoints 49 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 50 | help='path to save checkpoint (default: checkpoint)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | # Miscs 56 | parser.add_argument('--manualSeed', type=int, help='manual seed') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | #Device options 60 | parser.add_argument('--gpu-id', default='0', type=str, 61 | help='id(s) for CUDA_VISIBLE_DEVICES') 62 | 63 | # Architecture 64 | parser.add_argument('--depth', type=int, default=50, help='Model depth.') 65 | 66 | # recalibration type 67 | parser.add_argument('--recalibration-type', type=str, metavar='recalibration', 68 | help='recalibration type {se, srm, ge} (default: None)') 69 | 70 | args = parser.parse_args() 71 | state = {k: v for k, v in args._get_kwargs()} 72 | 73 | # Use CUDA 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 75 | use_cuda = torch.cuda.is_available() 76 | 77 | # Random seed 78 | if args.manualSeed is None: 79 | args.manualSeed = random.randint(1, 10000) 80 | random.seed(args.manualSeed) 81 | torch.manual_seed(args.manualSeed) 82 | if use_cuda: 83 | torch.cuda.manual_seed_all(args.manualSeed) 84 | 85 | best_acc = 0 # best test accuracy 86 | 87 | def main(): 88 | print("Start time: "+time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) 89 | 90 | global best_acc 91 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 92 | 93 | if not os.path.isdir(args.checkpoint): 94 | mkdir_p(args.checkpoint) 95 | 96 | # Data loading code 97 | traindir = os.path.join(args.data, 'train') 98 | valdir = os.path.join(args.data, 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | 102 | train_loader = torch.utils.data.DataLoader( 103 | datasets.ImageFolder(traindir, transforms.Compose([ 104 | transforms.RandomResizedCrop(224), 105 | transforms.RandomHorizontalFlip(), 106 | transforms.ToTensor(), 107 | normalize, 108 | ])), 109 | batch_size=args.train_batch, shuffle=True, 110 | num_workers=args.workers, pin_memory=True) 111 | 112 | val_loader = torch.utils.data.DataLoader( 113 | datasets.ImageFolder(valdir, transforms.Compose([ 114 | transforms.Resize(256), 115 | transforms.CenterCrop(224), 116 | transforms.ToTensor(), 117 | normalize, 118 | ])), 119 | batch_size=args.test_batch, shuffle=False, 120 | num_workers=args.workers, pin_memory=True) 121 | 122 | # create model 123 | model = resnet( 124 | depth=args.depth, 125 | recalibration_type=args.recalibration_type, 126 | ) 127 | 128 | model = torch.nn.DataParallel(model).cuda() 129 | 130 | cudnn.benchmark = True 131 | print(model) 132 | print(' Total params: %.4fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 133 | 134 | # define loss function (criterion) and optimizer 135 | criterion = nn.CrossEntropyLoss().cuda() 136 | optimizer = set_optimizer(model, args) 137 | 138 | # Resume 139 | title = 'Resnet{}-{}'.format(args.depth, args.recalibration_type) 140 | if args.resume: 141 | # Load checkpoint. 142 | print('==> Resuming from checkpoint..') 143 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 144 | args.checkpoint = os.path.dirname(args.resume) 145 | checkpoint = torch.load(args.resume) 146 | best_acc = checkpoint['best_acc'] 147 | start_epoch = checkpoint['epoch'] 148 | model.load_state_dict(checkpoint['state_dict']) 149 | optimizer.load_state_dict(checkpoint['optimizer']) 150 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 151 | elif args.pretrained: 152 | # Load checkpoint. 153 | print('==> Start from pretrained checkpoint..') 154 | assert os.path.isfile(args.pretrained), 'Error: no checkpoint directory found!' 155 | args.checkpoint = os.path.dirname(args.pretrained) 156 | checkpoint = torch.load(args.pretrained) 157 | model.load_state_dict(checkpoint['state_dict']) 158 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 159 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 160 | 'Train Acc.5', 'Valid Acc.5', 'Train Time', 'Test Time']) 161 | else: 162 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 163 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.', 164 | 'Train Acc.5', 'Valid Acc.5', 'Train Time', 'Test Time']) 165 | 166 | 167 | if args.evaluate: 168 | print('\nEvaluation only') 169 | test_loss, test_acc, test_acc5, _ = test(val_loader, model, criterion, start_epoch, use_cuda) 170 | print(' Test Loss: %.8f, Test Acc: %.2f (Top-1), %.2f (Top-5)' % (test_loss, test_acc, test_acc5)) 171 | return 172 | 173 | # Train and val 174 | for epoch in range(start_epoch, args.epochs): 175 | adjust_learning_rate(optimizer, epoch) 176 | 177 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 178 | 179 | train_loss, train_acc, train_acc5, train_time = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 180 | test_loss, test_acc, test_acc5, test_time = test(val_loader, model, criterion, epoch, use_cuda) 181 | 182 | # append logger file 183 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc, train_acc5, 184 | test_acc5, train_time, test_time]) 185 | 186 | # save model 187 | is_best = test_acc > best_acc 188 | best_acc = max(test_acc, best_acc) 189 | save_checkpoint({ 190 | 'epoch': epoch + 1, 191 | 'state_dict': model.state_dict(), 192 | 'acc': test_acc, 193 | 'best_acc': best_acc, 194 | 'optimizer' : optimizer.state_dict(), 195 | }, is_best, checkpoint=args.checkpoint) 196 | 197 | logger.close() 198 | 199 | print('Best acc:') 200 | print(best_acc) 201 | 202 | print("Finish time: "+time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) 203 | 204 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 205 | # switch to train mode 206 | model.train() 207 | 208 | batch_time = AverageMeter() 209 | data_time = AverageMeter() 210 | losses = AverageMeter() 211 | top1 = AverageMeter() 212 | top5 = AverageMeter() 213 | end = time.time() 214 | 215 | bar = Bar('Processing', max=len(train_loader)) 216 | for batch_idx, (inputs, targets) in enumerate(train_loader): 217 | # measure data loading time 218 | data_time.update(time.time() - end) 219 | 220 | if use_cuda: 221 | inputs, targets = inputs.cuda(), targets.cuda(async=True) 222 | 223 | # compute output 224 | outputs = model(inputs) 225 | loss = criterion(outputs, targets) 226 | 227 | # measure accuracy and record loss 228 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 229 | losses.update(loss.item(), inputs.size(0)) 230 | top1.update(prec1.item(), inputs.size(0)) 231 | top5.update(prec5.item(), inputs.size(0)) 232 | 233 | # compute gradient and do SGD step 234 | optimizer.zero_grad() 235 | loss.backward() 236 | optimizer.step() 237 | 238 | # measure elapsed time 239 | batch_time.update(time.time() - end) 240 | end = time.time() 241 | 242 | # plot progress 243 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 244 | batch=batch_idx + 1, 245 | size=len(train_loader), 246 | data=data_time.val, 247 | bt=batch_time.val, 248 | total=bar.elapsed_td, 249 | eta=bar.eta_td, 250 | loss=losses.avg, 251 | top1=top1.avg, 252 | top5=top5.avg, 253 | ) 254 | bar.next() 255 | del inputs, targets, outputs, loss, prec1, prec5 256 | bar.finish() 257 | return (losses.avg, top1.avg, top5.avg, bar.elapsed/60.) 258 | 259 | def test(val_loader, model, criterion, epoch, use_cuda): 260 | global best_acc 261 | 262 | batch_time = AverageMeter() 263 | data_time = AverageMeter() 264 | losses = AverageMeter() 265 | top1 = AverageMeter() 266 | top5 = AverageMeter() 267 | 268 | # switch to evaluate mode 269 | model.eval() 270 | 271 | end = time.time() 272 | bar = Bar('Processing', max=len(val_loader)) 273 | for batch_idx, (inputs, targets) in enumerate(val_loader): 274 | # measure data loading time 275 | data_time.update(time.time() - end) 276 | 277 | if use_cuda: 278 | inputs, targets = inputs.cuda(), targets.cuda() 279 | 280 | # compute output 281 | with torch.no_grad(): 282 | outputs = model(inputs) 283 | loss = criterion(outputs, targets) 284 | 285 | # measure accuracy and record loss 286 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 287 | losses.update(loss.item(), inputs.size(0)) 288 | top1.update(prec1.item(), inputs.size(0)) 289 | top5.update(prec5.item(), inputs.size(0)) 290 | 291 | # measure elapsed time 292 | batch_time.update(time.time() - end) 293 | end = time.time() 294 | 295 | # plot progress 296 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 297 | batch=batch_idx + 1, 298 | size=len(val_loader), 299 | data=data_time.avg, 300 | bt=batch_time.avg, 301 | total=bar.elapsed_td, 302 | eta=bar.eta_td, 303 | loss=losses.avg, 304 | top1=top1.avg, 305 | top5=top5.avg, 306 | ) 307 | bar.next() 308 | bar.finish() 309 | return (losses.avg, top1.avg, top5.avg, bar.elapsed/60.) 310 | 311 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 312 | filepath = os.path.join(checkpoint, filename) 313 | torch.save(state, filepath) 314 | if is_best: 315 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 316 | 317 | def adjust_learning_rate(optimizer, epoch): 318 | global state 319 | if epoch in args.schedule: 320 | state['lr'] *= args.gamma 321 | for param_group in optimizer.param_groups: 322 | param_group['lr'] = state['lr'] 323 | 324 | def set_optimizer(model, args): 325 | params = [{'params': [p for p in model.parameters() if not getattr(p, 'srm_param', False)]}, 326 | {'params': [p for p in model.parameters() if getattr(p, 'srm_param', False)], 327 | 'lr': args.lr, 'weight_decay': 0}] 328 | 329 | optimizer = optim.SGD(params, 330 | lr=args.lr, 331 | momentum=args.momentum, 332 | weight_decay=args.weight_decay) 333 | return optimizer 334 | 335 | if __name__ == '__main__': 336 | main() 337 | -------------------------------------------------------------------------------- /models/recalibration_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import functools 3 | import math 4 | import torch 5 | from torch.nn.parameter import Parameter 6 | 7 | class SRMLayer(nn.Module): 8 | def __init__(self, channel): 9 | super(SRMLayer, self).__init__() 10 | 11 | self.cfc = Parameter(torch.Tensor(channel, 2)) 12 | self.cfc.data.fill_(0) 13 | 14 | self.bn = nn.BatchNorm2d(channel) 15 | self.activation = nn.Sigmoid() 16 | 17 | setattr(self.cfc, 'srm_param', True) 18 | setattr(self.bn.weight, 'srm_param', True) 19 | setattr(self.bn.bias, 'srm_param', True) 20 | 21 | def _style_pooling(self, x, eps=1e-5): 22 | N, C, _, _ = x.size() 23 | 24 | channel_mean = x.view(N, C, -1).mean(dim=2, keepdim=True) 25 | channel_var = x.view(N, C, -1).var(dim=2, keepdim=True) + eps 26 | channel_std = channel_var.sqrt() 27 | 28 | t = torch.cat((channel_mean, channel_std), dim=2) 29 | return t 30 | 31 | def _style_integration(self, t): 32 | z = t * self.cfc[None, :, :] # B x C x 2 33 | z = torch.sum(z, dim=2)[:, :, None, None] # B x C x 1 x 1 34 | 35 | z_hat = self.bn(z) 36 | g = self.activation(z_hat) 37 | 38 | return g 39 | 40 | def forward(self, x): 41 | # B x C x 2 42 | t = self._style_pooling(x) 43 | 44 | # B x C x 1 x 1 45 | g = self._style_integration(t) 46 | 47 | return x * g 48 | 49 | class SELayer(nn.Module): 50 | def __init__(self, channel, reduction=16): 51 | super(SELayer, self).__init__() 52 | self.avgpool = nn.AdaptiveAvgPool2d(1) 53 | self.activation = nn.Sigmoid() 54 | 55 | self.reduction = reduction 56 | 57 | self.fc = nn.Sequential( 58 | nn.Linear(channel, channel // self.reduction), 59 | nn.ReLU(inplace=True), 60 | nn.Linear(channel // self.reduction, channel), 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | 66 | avg_y = self.avgpool(x).view(b, c) 67 | 68 | gate = self.fc(avg_y).view(b, c, 1, 1) 69 | gate = self.activation(gate) 70 | 71 | return x * gate 72 | 73 | class GELayer(nn.Module): 74 | def __init__(self, channel, layer_idx): 75 | super(GELayer, self).__init__() 76 | 77 | # Kernel size w.r.t each layer for global depth-wise convolution 78 | kernel_size = [-1, 56, 28, 14, 7][layer_idx] 79 | 80 | self.conv = nn.Sequential( 81 | nn.Conv2d(channel, channel, kernel_size=kernel_size, groups=channel), 82 | nn.BatchNorm2d(channel), 83 | ) 84 | 85 | self.activation = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | b, c, _, _ = x.size() 89 | 90 | gate = self.conv(x) 91 | gate = self.activation(gate) 92 | 93 | return x * gate 94 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import functools 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None, rclb_layer=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | if rclb_layer == None: 25 | self.rclb = None 26 | else: 27 | if is_ge: 28 | self.rclb = rclb_layer(planes, layer_idx) 29 | else: 30 | self.rclb = rclb_layer(planes) 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | if self.rclb != None: 46 | out = self.rclb(out) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, rclb_layer=None, layer_idx=1, is_ge=False): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | if rclb_layer == None: 71 | self.rclb = None 72 | else: 73 | if is_ge: 74 | self.rclb = rclb_layer(planes * 4, layer_idx) 75 | else: 76 | self.rclb = rclb_layer(planes * 4) 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | if self.rclb != None: 96 | out = self.rclb(out) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, input_channels=3, num_classes=1000, recalibration_type=None): 107 | super(ResNet, self).__init__() 108 | 109 | 110 | self.is_ge = True if recalibration_type == 'ge' else False 111 | 112 | if recalibration_type == None: 113 | self.rclb_layer = None 114 | elif recalibration_type == 'srm': 115 | from .recalibration_modules import SRMLayer as rclb_layer 116 | self.rclb_layer = rclb_layer 117 | elif recalibration_type == 'se': 118 | from .recalibration_modules import SELayer as rclb_layer 119 | self.rclb_layer = rclb_layer 120 | elif recalibration_type == 'ge': 121 | from .recalibration_modules import GELayer as rclb_layer 122 | self.rclb_layer = rclb_layer 123 | else: 124 | raise NotImplementedError 125 | 126 | self.inplanes = 64 127 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0], layer_idx=1) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, layer_idx=2) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, layer_idx=3) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, layer_idx=4) 135 | self.avgpool = nn.AdaptiveAvgPool2d(1) 136 | self.fc = nn.Linear(512 * block.expansion, num_classes) 137 | 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 141 | m.weight.data.normal_(0, math.sqrt(2. / n)) 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | def _make_layer(self, block, planes, blocks, stride=1, layer_idx=1): 147 | downsample = None 148 | if stride != 1 or self.inplanes != planes * block.expansion: 149 | downsample = nn.Sequential( 150 | nn.Conv2d(self.inplanes, planes * block.expansion, 151 | kernel_size=1, stride=stride, bias=False), 152 | nn.BatchNorm2d(planes * block.expansion), 153 | ) 154 | 155 | layers = [] 156 | layers.append(block(self.inplanes, planes, stride, downsample, rclb_layer=self.rclb_layer, layer_idx=layer_idx, is_ge=self.is_ge)) 157 | self.inplanes = planes * block.expansion 158 | for i in range(1, blocks): 159 | layers.append(block(self.inplanes, planes, rclb_layer=self.rclb_layer, layer_idx=layer_idx, is_ge=self.is_ge)) 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = self.relu(x) 167 | x = self.maxpool(x) 168 | 169 | x = self.layer1(x) 170 | x = self.layer2(x) 171 | x = self.layer3(x) 172 | x = self.layer4(x) 173 | 174 | x = self.avgpool(x) 175 | x = x.view(x.size(0), -1) 176 | x = self.fc(x) 177 | 178 | return x 179 | 180 | def resnet(depth, **kwargs): 181 | if depth == 18: 182 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 183 | elif depth == 34: 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | elif depth == 50: 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | elif depth == 101: 188 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 189 | elif depth == 152: 190 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 191 | return model 192 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | #import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') 128 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/progress/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | build/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyunjaelee410/style-based-recalibration-module/f6221c77e797c4530dddba03616153708d636bd1/utils/progress/demo.gif -------------------------------------------------------------------------------- /utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | #import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() 111 | --------------------------------------------------------------------------------