├── LICENSE ├── README.md ├── classify.py ├── data_transforms.py ├── datasets ├── cityscapes │ ├── README.md │ ├── create_lists.sh │ ├── drn-d-105.csv │ ├── info.json │ └── prepare_data.py └── compute_mean_std.py ├── doc └── drn_comp.png ├── drn.py ├── lib ├── Makefile ├── build.py ├── dense │ ├── __init__.py │ ├── batch_norm │ │ ├── __init__.py │ │ └── _batch_norm.so │ └── batchnormp_kernel.so ├── functions │ ├── __init__.py │ └── batchnormp.py ├── modules │ ├── __init__.py │ └── batchnormsync.py ├── src │ ├── batchnormp.c │ ├── batchnormp.h │ ├── batchnormp_cuda.c │ ├── batchnormp_cuda.h │ ├── batchnormp_cuda_kernel.cu │ ├── batchnormp_cuda_kernel.h │ └── generic │ │ └── batchnormp_cuda.cu └── test.py ├── requirements.txt └── segment.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Fisher Yu 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This code provides various models combining dilated convolutions with residual networks. Our models can achieve better performance with less parameters than ResNet on [image classification](#image-classification) and [semantic segmentation](#semantic-image-segmentataion). 4 | 5 | If you find this code useful for your publications, please consider citing 6 | 7 | ``` 8 | @inproceedings{Yu2017, 9 | title = {Dilated Residual Networks}, 10 | author = {Fisher Yu and Vladlen Koltun and Thomas Funkhouser}, 11 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 12 | year = {2017}, 13 | } 14 | 15 | @inproceedings{Yu2016, 16 | title = {Multi-scale context aggregation by dilated convolutions}, 17 | author = {Yu, Fisher and Koltun, Vladlen}, 18 | booktitle = {International Conference on Learning Representations (ICLR)}, 19 | year = {2016} 20 | } 21 | ``` 22 | 23 | ## Code Highlights 24 | 25 | - The pretrained model can be loaded using Pytorch model zoo api. [Example here](https://github.com/fyu/drn/blob/master/drn.py#L264). 26 | - Pytorch based image classification and semantic image segmentation. 27 | - BatchNorm synchronization across multipe GPUs. 28 | - High-resolution class activiation maps for state-of-the-art weakly supervised object localization. 29 | - [DRN-D-105](#semantic-image-segmentataion) gets 76.3% mIoU on Cityscapes with only fine training annotation and no context module. 30 | 31 | ## Image Classification 32 | 33 | Image classification is meant to be a controlled study to understand the role of high resolution feature maps in image classification and the class activations rising from it. Based on the investigation, we are able to design more efficient networks for learning high-resolution image representation. They have practical usage in semantic image segmentation, as detailed in [image segmentation section](#semantic-image-segmentataion). 34 | 35 | ### Models 36 | 37 | Comparison of classification error rate on ImageNet validation set and numbers of parameters. It is evaluated on single center 224x224 crop from resized images whose shorter side is 256-pixel long. 38 | 39 | | Name | Top-1 | Top-5 | Params | 40 | | --- | :---: | :---: | :---: | 41 | | ResNet-18 | 30.4% | 10.8% | 11.7M | 42 | | DRN-A-18 | 28.0% | 9.5% | 11.7M | 43 | | DRN-D-22 | 25.8% | 8.2% |16.4M | 44 | | DRN-C-26 | 24.9% | 7.6% |21.1M | 45 | | ResNet-34 | 27.7% | 8.7% | 21.8M | 46 | | DRN-A-34 | 24.8% | 7.5% | 21.8M| 47 | | DRN-D-38 | 23.8% | 6.9% |26.5M | 48 | | DRN-C-42 | 22.9% | 6.6% |31.2M | 49 | | ResNet-50 | 24.0% | 7.0% | 25.6M | 50 | | DRN-A-50 | 22.9% | 6.6% | 25.6M | 51 | | DRN-D-54 | 21.2% | 5.9% | 35.8M | 52 | | DRN-C-58 | 21.7% | 6.0% | 41.6M | 53 | | ResNet-101 | 22.4% | 6.2% | 44.5M | 54 | | DRN-D-105 | 20.6% | 5.5% | 54.8M | 55 | | ResNet-152 | 22.2% | 6.2% | 60.2M | 56 | 57 | The figure below groups the parameter and error rate comparison based on netwok structures. 58 | 59 | ![comparison](doc/drn_comp.png) 60 | 61 | 62 | ### Training and Testing 63 | 64 | The code is written in Python using [Pytorch](https://github.com/pytorch/pytorch). I started with code in [torchvision](https://github.com/pytorch/vision). Please check their license as well if copyright is your concern. Software dependency: 65 | 66 | * Python 3 67 | * Pillow 68 | * pytorch 69 | * torchvision 70 | 71 | **Note** If you want to train your own semantic segmentation model, make sure your Pytorch version is greater than [0.2.0](https://github.com/pytorch/pytorch/releases) or includes commit [78020a](https://github.com/pytorch/pytorch/pull/2077/commits/78020a52abb76fcb1c344b3c42fbe8610cc387e4). 72 | 73 | Go to [this page](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset) to prepare ImageNet 1K data. 74 | 75 | To test a model on ImageNet validation set: 76 | ``` 77 | python3 classify.py test --arch drn_c_26 -j 4 --pretrained 78 | ``` 79 | 80 | To train a new model: 81 | ``` 82 | python3 classify.py train --arch drn_c_26 -j 8 --epochs 120 83 | ``` 84 | 85 | Besides `drn_c_26`, we also provide `drn_c_42` and `drn_c_58`. They are in DRN-C family as described in [Dilated Residual Networks](https://umich.app.box.com/v/drn). DRN-D models are simplified versions of DRN-C. Their code names are `drn_d_22`, `drn_d_38`, `drn_d_54`, and `drn_d_105`. 86 | 87 | ## Semantic Image Segmentataion 88 | 89 | ### Models 90 | 91 | Comparison of mIoU on Cityscapes and numbers of parameters. 92 | 93 | | Name | mIoU | Params | 94 | | --- | :---: | :---: | 95 | | DRN-A-50 | 67.3% | 25.6M | 96 | | DRN-C-26 | 68.0% | 21.1M | 97 | | DRN-C-42 | 70.9% | 31.2M | 98 | | DRN-D-22 | 68.0% | 16.4M | 99 | | DRN-D-38 | 71.4% | 26.5M | 100 | | DRN-D-105* | 75.6% | 54.8M | 101 | 102 | *trained with poly learning rate, random scaling and rotations. 103 | 104 | DRN-D-105 gets 76.3% mIoU on Cityscapes testing set with multi-scale testing, poly learning rate and data augmentation with random rotation and scaling in training. Full results are [here](datasets/cityscapes/drn-d-105.csv). 105 | 106 | ### Prepare Data 107 | 108 | The segmentation image data folder is supposed to contain following image lists with names below: 109 | 110 | * train_images.txt 111 | * train_labels.txt 112 | * val_images.txt 113 | * val_labels.txt 114 | * test_images.txt 115 | 116 | The code will also look for `info.json` in the folder. It contains mean and std of the training images. For example, below is `info.json` used for training on Cityscapes. 117 | 118 | ``` 119 | { 120 | "mean": [ 121 | 0.290101, 122 | 0.328081, 123 | 0.286964 124 | ], 125 | "std": [ 126 | 0.182954, 127 | 0.186566, 128 | 0.184475 129 | ] 130 | } 131 | ``` 132 | 133 | Each line in the list is a path to an input image or its label map relative to the segmentation folder. 134 | 135 | For example, if the data folder is "/foo/bar" and train_images.txt in it contains 136 | ``` 137 | leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png 138 | leftImg8bit/train/aachen/aachen_000001_000019_leftImg8bit.png 139 | ``` 140 | and train_labels.txt contrains 141 | ``` 142 | gtFine/train/aachen/aachen_000000_000019_gtFine_trainIds.png 143 | gtFine/train/aachen/aachen_000001_000019_gtFine_trainIds.png 144 | ``` 145 | Then the first image path is expected at 146 | ``` 147 | /foo/bar/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png 148 | ``` 149 | and its label map is at 150 | ``` 151 | /foo/bar/gtFine/train/aachen/aachen_000000_000019_gtFine_trainIds.png 152 | ``` 153 | 154 | In training phase, both train_\* and val_\* are assumed to be in the data folder. In validation phase, only val_images.txt and val_labels.txt are needed. In testing phase, when there are no available labels, only test_images.txt is needed. `segment.py` has a command line option `--phase` and the corresponding acceptable arguments are `train`, `val`, and `test`. 155 | 156 | To set up Cityscapes data, please check this [document](datasets/cityscapes). 157 | 158 | ### Optimization Setup 159 | 160 | The current segmentation models are trained on basic data augmentation (random crops + flips). The learning rate is changed by steps, where it is decreased by a factor of 10 at each step. 161 | 162 | ### Training 163 | 164 | To train a new model, use 165 | ``` 166 | python3 segment.py train -d -c -s 896 \ 167 | --arch drn_d_22 --batch-size 32 --epochs 250 --lr 0.01 --momentum 0.9 \ 168 | --step 100 169 | ``` 170 | 171 | `category_number` is the number of categories in segmentation. It is 19 for Cityscapes and 11 for Camvid. The actual label maps should contain values in the range of `[0, category_number)`. Invalid pixels can be labeled as 255 and they will be ignored in training and evaluation. Depends on the batch size, lr and momentum can be 0.01/0.9 or 0.001/0.99. 172 | 173 | If you want to train drn_d_105 to achieve best results on cityscapes dataset, you need to turn on data augmentation and use poly learning rate: 174 | 175 | ``` 176 | python3 segment.py train -d -c 19 -s 840 --arch drn_d_105 --random-scale 2 --random-rotate 10 --batch-size 16 --epochs 500 --lr 0.01 --momentum 0.9 -j 16 --lr-mode poly --bn-sync 177 | ``` 178 | 179 | Note: 180 | 181 | - If you use 8 GPUs for 16 crops per batch, the memory for each GPU is more than 12GB. If you don't have enough GPU memory, you can try smaller batch size or crop size. Smaller crop size usually hurts the performance more. 182 | - Batch normalization synchronization across multiple GPUs is necessary to train very deep convolutional networks for semantic segmentation. We provide an implementation as a pytorch extenstion in `lib/`. However, it is not for the faint-hearted to build from scratch, although an Makefile is provided. So a built binary library for 64-bit Ubuntu is provided. It is tested on Ubuntu 16.04. Also remember to add `lib/` to your `PYTHONPATH`. 183 | 184 | ### Testing 185 | 186 | Evaluate models on testing set or any images without ground truth labels using our related pretrained model: 187 | ``` 188 | python3 segment.py test -d -c --arch drn_d_22 \ 189 | --pretrained --phase test --batch-size 1 190 | ``` 191 | 192 | You can download the pretrained DRN models on Cityscapes here: http://go.yf.io/drn-cityscapes-models. 193 | 194 | If you want to evaluate a checkpoint from your own training, use `--resume` instead of `--pretrained`: 195 | ``` 196 | python3 segment.py test -d -c --arch drn_d_22 \ 197 | --resume --phase test --batch-size 1 198 | ``` 199 | 200 | You can also turn on multi-scale testing for better results by adding `--ms`: 201 | 202 | ``` 203 | python3 segment.py test -d -c --arch drn_d_105 \ 204 | --resume --phase val --batch-size 1 --ms 205 | ``` 206 | -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import time 4 | 5 | import numpy as np 6 | import os 7 | from os.path import exists, split, join, splitext 8 | 9 | import sys 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | 19 | import drn as models 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='') 28 | parser.add_argument('cmd', choices=['train', 'test', 'map', 'locate']) 29 | parser.add_argument('data', metavar='DIR', 30 | help='path to dataset') 31 | parser.add_argument('--arch', '-a', metavar='ARCH', default='drn18', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: drn18)') 36 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 37 | help='number of data loading workers (default: 4)') 38 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 39 | help='number of total epochs to run') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', help='mini-batch size (default: 256)') 44 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 45 | metavar='LR', help='initial learning rate') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)') 50 | parser.add_argument('--print-freq', '-p', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--check-freq', default=10, type=int, 53 | metavar='N', help='checkpoint frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--lr-adjust', dest='lr_adjust', 61 | choices=['linear', 'step'], default='step') 62 | parser.add_argument('--crop-size', dest='crop_size', type=int, default=224) 63 | parser.add_argument('--scale-size', dest='scale_size', type=int, default=256) 64 | parser.add_argument('--step-ratio', dest='step_ratio', type=float, default=0.1) 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | def main(): 70 | print(' '.join(sys.argv)) 71 | args = parse_args() 72 | print(args) 73 | if args.cmd == 'train': 74 | run_training(args) 75 | elif args.cmd == 'test': 76 | test_model(args) 77 | 78 | 79 | def run_training(args): 80 | # create model 81 | model = models.__dict__[args.arch](args.pretrained) 82 | 83 | model = torch.nn.DataParallel(model).cuda() 84 | 85 | best_prec1 = 0 86 | 87 | # optionally resume from a checkpoint 88 | if args.resume: 89 | if os.path.isfile(args.resume): 90 | print("=> loading checkpoint '{}'".format(args.resume)) 91 | checkpoint = torch.load(args.resume) 92 | args.start_epoch = checkpoint['epoch'] 93 | best_prec1 = checkpoint['best_prec1'] 94 | model.load_state_dict(checkpoint['state_dict']) 95 | print("=> loaded checkpoint '{}' (epoch {})" 96 | .format(args.resume, checkpoint['epoch'])) 97 | else: 98 | print("=> no checkpoint found at '{}'".format(args.resume)) 99 | 100 | cudnn.benchmark = True 101 | 102 | # Data loading code 103 | traindir = os.path.join(args.data, 'train') 104 | valdir = os.path.join(args.data, 'val') 105 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 106 | std=[0.229, 0.224, 0.225]) 107 | 108 | train_loader = torch.utils.data.DataLoader( 109 | datasets.ImageFolder(traindir, transforms.Compose([ 110 | transforms.RandomSizedCrop(224), 111 | transforms.RandomHorizontalFlip(), 112 | transforms.ToTensor(), 113 | normalize, 114 | ])), 115 | batch_size=args.batch_size, shuffle=True, 116 | num_workers=args.workers, pin_memory=True) 117 | 118 | val_loader = torch.utils.data.DataLoader( 119 | datasets.ImageFolder(valdir, transforms.Compose([ 120 | transforms.Scale(256), 121 | transforms.CenterCrop(224), 122 | transforms.ToTensor(), 123 | normalize, 124 | ])), 125 | batch_size=args.batch_size, shuffle=False, 126 | num_workers=args.workers, pin_memory=True) 127 | 128 | # define loss function (criterion) and pptimizer 129 | criterion = nn.CrossEntropyLoss().cuda() 130 | 131 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 132 | momentum=args.momentum, 133 | weight_decay=args.weight_decay) 134 | 135 | for epoch in range(args.start_epoch, args.epochs): 136 | adjust_learning_rate(args, optimizer, epoch) 137 | 138 | # train for one epoch 139 | train(args, train_loader, model, criterion, optimizer, epoch) 140 | 141 | # evaluate on validation set 142 | prec1 = validate(args, val_loader, model, criterion) 143 | 144 | # remember best prec@1 and save checkpoint 145 | is_best = prec1 > best_prec1 146 | best_prec1 = max(prec1, best_prec1) 147 | 148 | checkpoint_path = 'checkpoint_latest.pth.tar' 149 | save_checkpoint({ 150 | 'epoch': epoch + 1, 151 | 'arch': args.arch, 152 | 'state_dict': model.state_dict(), 153 | 'best_prec1': best_prec1, 154 | }, is_best, filename=checkpoint_path) 155 | if (epoch + 1) % args.check_freq == 0: 156 | history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) 157 | shutil.copyfile(checkpoint_path, history_path) 158 | 159 | 160 | def test_model(args): 161 | # create model 162 | model = models.__dict__[args.arch](args.pretrained) 163 | 164 | model = torch.nn.DataParallel(model).cuda() 165 | 166 | if args.resume: 167 | if os.path.isfile(args.resume): 168 | print("=> loading checkpoint '{}'".format(args.resume)) 169 | checkpoint = torch.load(args.resume) 170 | args.start_epoch = checkpoint['epoch'] 171 | best_prec1 = checkpoint['best_prec1'] 172 | model.load_state_dict(checkpoint['state_dict']) 173 | print("=> loaded checkpoint '{}' (epoch {})" 174 | .format(args.resume, checkpoint['epoch'])) 175 | else: 176 | print("=> no checkpoint found at '{}'".format(args.resume)) 177 | 178 | cudnn.benchmark = True 179 | 180 | # Data loading code 181 | valdir = os.path.join(args.data, 'val') 182 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 183 | std=[0.229, 0.224, 0.225]) 184 | 185 | t = transforms.Compose([ 186 | transforms.Scale(args.scale_size), 187 | transforms.CenterCrop(args.crop_size), 188 | transforms.ToTensor(), 189 | normalize]) 190 | val_loader = torch.utils.data.DataLoader( 191 | datasets.ImageFolder(valdir, t), 192 | batch_size=args.batch_size, shuffle=False, 193 | num_workers=args.workers, pin_memory=True) 194 | 195 | criterion = nn.CrossEntropyLoss().cuda() 196 | 197 | validate(args, val_loader, model, criterion) 198 | 199 | 200 | def train(args, train_loader, model, criterion, optimizer, epoch): 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | top5 = AverageMeter() 206 | 207 | # switch to train mode 208 | model.train() 209 | 210 | end = time.time() 211 | for i, (input, target) in enumerate(train_loader): 212 | # measure data loading time 213 | data_time.update(time.time() - end) 214 | 215 | target = target.cuda(async=True) 216 | input_var = torch.autograd.Variable(input) 217 | target_var = torch.autograd.Variable(target) 218 | 219 | # compute output 220 | output = model(input_var) 221 | loss = criterion(output, target_var) 222 | 223 | # measure accuracy and record loss 224 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 225 | losses.update(loss.data[0], input.size(0)) 226 | top1.update(prec1[0], input.size(0)) 227 | top5.update(prec5[0], input.size(0)) 228 | 229 | # compute gradient and do SGD step 230 | optimizer.zero_grad() 231 | loss.backward() 232 | optimizer.step() 233 | 234 | # measure elapsed time 235 | batch_time.update(time.time() - end) 236 | end = time.time() 237 | 238 | if i % args.print_freq == 0: 239 | print('Epoch: [{0}][{1}/{2}]\t' 240 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 241 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 242 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 243 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 244 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 245 | epoch, i, len(train_loader), batch_time=batch_time, 246 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 247 | 248 | 249 | def validate(args, val_loader, model, criterion): 250 | batch_time = AverageMeter() 251 | losses = AverageMeter() 252 | top1 = AverageMeter() 253 | top5 = AverageMeter() 254 | 255 | # switch to evaluate mode 256 | model.eval() 257 | 258 | end = time.time() 259 | for i, (input, target) in enumerate(val_loader): 260 | target = target.cuda(async=True) 261 | input_var = torch.autograd.Variable(input, volatile=True) 262 | target_var = torch.autograd.Variable(target, volatile=True) 263 | 264 | # compute output 265 | output = model(input_var) 266 | loss = criterion(output, target_var) 267 | 268 | # measure accuracy and record loss 269 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 270 | losses.update(loss.data[0], input.size(0)) 271 | top1.update(prec1[0], input.size(0)) 272 | top5.update(prec5[0], input.size(0)) 273 | 274 | # measure elapsed time 275 | batch_time.update(time.time() - end) 276 | end = time.time() 277 | 278 | if i % args.print_freq == 0: 279 | print('Test: [{0}/{1}]\t' 280 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 281 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 282 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 283 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 284 | i, len(val_loader), batch_time=batch_time, loss=losses, 285 | top1=top1, top5=top5)) 286 | 287 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 288 | .format(top1=top1, top5=top5)) 289 | 290 | return top1.avg 291 | 292 | 293 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 294 | torch.save(state, filename) 295 | if is_best: 296 | shutil.copyfile(filename, 'model_best.pth.tar') 297 | 298 | 299 | class AverageMeter(object): 300 | """Computes and stores the average and current value""" 301 | def __init__(self): 302 | self.reset() 303 | 304 | def reset(self): 305 | self.val = 0 306 | self.avg = 0 307 | self.sum = 0 308 | self.count = 0 309 | 310 | def update(self, val, n=1): 311 | self.val = val 312 | self.sum += val * n 313 | self.count += n 314 | self.avg = self.sum / self.count 315 | 316 | 317 | def adjust_learning_rate(args, optimizer, epoch): 318 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 319 | lr = args.lr * (args.step_ratio ** (epoch // 30)) 320 | print('Epoch [{}] Learning rate: {}'.format(epoch, lr)) 321 | for param_group in optimizer.param_groups: 322 | param_group['lr'] = lr 323 | 324 | 325 | def accuracy(output, target, topk=(1,)): 326 | """Computes the precision@k for the specified values of k""" 327 | maxk = max(topk) 328 | batch_size = target.size(0) 329 | 330 | _, pred = output.topk(maxk, 1, True, True) 331 | pred = pred.t() 332 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 333 | 334 | res = [] 335 | for k in topk: 336 | correct_k = correct[:k].view(-1).float().sum(0) 337 | res.append(correct_k.mul_(100.0 / batch_size)) 338 | return res 339 | 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /data_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | 4 | import numpy as np 5 | from PIL import Image, ImageOps 6 | import torch 7 | 8 | 9 | class RandomCrop(object): 10 | def __init__(self, size): 11 | if isinstance(size, numbers.Number): 12 | self.size = (int(size), int(size)) 13 | else: 14 | self.size = size 15 | 16 | def __call__(self, image, label, *args): 17 | assert label is None or image.size == label.size, \ 18 | "image and label doesn't have the same size {} / {}".format( 19 | image.size, label.size) 20 | 21 | w, h = image.size 22 | tw, th = self.size 23 | top = bottom = left = right = 0 24 | if w < tw: 25 | left = (tw - w) // 2 26 | right = tw - w - left 27 | if h < th: 28 | top = (th - h) // 2 29 | bottom = th - h - top 30 | if left > 0 or right > 0 or top > 0 or bottom > 0: 31 | label = pad_image( 32 | 'constant', label, top, bottom, left, right, value=255) 33 | image = pad_image( 34 | 'reflection', image, top, bottom, left, right) 35 | w, h = image.size 36 | if w == tw and h == th: 37 | return (image, label, *args) 38 | 39 | x1 = random.randint(0, w - tw) 40 | y1 = random.randint(0, h - th) 41 | results = [image.crop((x1, y1, x1 + tw, y1 + th))] 42 | if label is not None: 43 | results.append(label.crop((x1, y1, x1 + tw, y1 + th))) 44 | results.extend(args) 45 | return results 46 | 47 | 48 | class RandomScale(object): 49 | def __init__(self, scale): 50 | if isinstance(scale, numbers.Number): 51 | scale = [1 / scale, scale] 52 | self.scale = scale 53 | 54 | def __call__(self, image, label): 55 | ratio = random.uniform(self.scale[0], self.scale[1]) 56 | w, h = image.size 57 | tw = int(ratio * w) 58 | th = int(ratio * h) 59 | if ratio == 1: 60 | return image, label 61 | elif ratio < 1: 62 | interpolation = Image.ANTIALIAS 63 | else: 64 | interpolation = Image.CUBIC 65 | return image.resize((tw, th), interpolation), \ 66 | label.resize((tw, th), Image.NEAREST) 67 | 68 | 69 | class RandomRotate(object): 70 | """Crops the given PIL.Image at a random location to have a region of 71 | the given size. size can be a tuple (target_height, target_width) 72 | or an integer, in which case the target will be of a square shape (size, size) 73 | """ 74 | 75 | def __init__(self, angle): 76 | self.angle = angle 77 | 78 | def __call__(self, image, label=None, *args): 79 | assert label is None or image.size == label.size 80 | 81 | w, h = image.size 82 | p = max((h, w)) 83 | angle = random.randint(0, self.angle * 2) - self.angle 84 | 85 | if label is not None: 86 | label = pad_image('constant', label, h, h, w, w, value=255) 87 | label = label.rotate(angle, resample=Image.NEAREST) 88 | label = label.crop((w, h, w + w, h + h)) 89 | 90 | image = pad_image('reflection', image, h, h, w, w) 91 | image = image.rotate(angle, resample=Image.BILINEAR) 92 | image = image.crop((w, h, w + w, h + h)) 93 | return image, label 94 | 95 | 96 | class RandomHorizontalFlip(object): 97 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 98 | """ 99 | 100 | def __call__(self, image, label): 101 | if random.random() < 0.5: 102 | results = [image.transpose(Image.FLIP_LEFT_RIGHT), 103 | label.transpose(Image.FLIP_LEFT_RIGHT)] 104 | else: 105 | results = [image, label] 106 | return results 107 | 108 | 109 | class Normalize(object): 110 | """Given mean: (R, G, B) and std: (R, G, B), 111 | will normalize each channel of the torch.*Tensor, i.e. 112 | channel = (channel - mean) / std 113 | """ 114 | 115 | def __init__(self, mean, std): 116 | self.mean = torch.FloatTensor(mean) 117 | self.std = torch.FloatTensor(std) 118 | 119 | def __call__(self, image, label=None): 120 | for t, m, s in zip(image, self.mean, self.std): 121 | t.sub_(m).div_(s) 122 | if label is None: 123 | return image, 124 | else: 125 | return image, label 126 | 127 | 128 | def pad_reflection(image, top, bottom, left, right): 129 | if top == 0 and bottom == 0 and left == 0 and right == 0: 130 | return image 131 | h, w = image.shape[:2] 132 | next_top = next_bottom = next_left = next_right = 0 133 | if top > h - 1: 134 | next_top = top - h + 1 135 | top = h - 1 136 | if bottom > h - 1: 137 | next_bottom = bottom - h + 1 138 | bottom = h - 1 139 | if left > w - 1: 140 | next_left = left - w + 1 141 | left = w - 1 142 | if right > w - 1: 143 | next_right = right - w + 1 144 | right = w - 1 145 | new_shape = list(image.shape) 146 | new_shape[0] += top + bottom 147 | new_shape[1] += left + right 148 | new_image = np.empty(new_shape, dtype=image.dtype) 149 | new_image[top:top+h, left:left+w] = image 150 | new_image[:top, left:left+w] = image[top:0:-1, :] 151 | new_image[top+h:, left:left+w] = image[-1:-bottom-1:-1, :] 152 | new_image[:, :left] = new_image[:, left*2:left:-1] 153 | new_image[:, left+w:] = new_image[:, -right-1:-right*2-1:-1] 154 | return pad_reflection(new_image, next_top, next_bottom, 155 | next_left, next_right) 156 | 157 | 158 | def pad_constant(image, top, bottom, left, right, value): 159 | if top == 0 and bottom == 0 and left == 0 and right == 0: 160 | return image 161 | h, w = image.shape[:2] 162 | new_shape = list(image.shape) 163 | new_shape[0] += top + bottom 164 | new_shape[1] += left + right 165 | new_image = np.empty(new_shape, dtype=image.dtype) 166 | new_image.fill(value) 167 | new_image[top:top+h, left:left+w] = image 168 | return new_image 169 | 170 | 171 | def pad_image(mode, image, top, bottom, left, right, value=0): 172 | if mode == 'reflection': 173 | return Image.fromarray( 174 | pad_reflection(np.asarray(image), top, bottom, left, right)) 175 | elif mode == 'constant': 176 | return Image.fromarray( 177 | pad_constant(np.asarray(image), top, bottom, left, right, value)) 178 | else: 179 | raise ValueError('Unknown mode {}'.format(mode)) 180 | 181 | 182 | class Pad(object): 183 | """Pads the given PIL.Image on all sides with the given "pad" value""" 184 | 185 | def __init__(self, padding, fill=0): 186 | assert isinstance(padding, numbers.Number) 187 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \ 188 | isinstance(fill, tuple) 189 | self.padding = padding 190 | self.fill = fill 191 | 192 | def __call__(self, image, label=None, *args): 193 | if label is not None: 194 | label = pad_image( 195 | 'constant', label, 196 | self.padding, self.padding, self.padding, self.padding, 197 | value=255) 198 | if self.fill == -1: 199 | image = pad_image( 200 | 'reflection', image, 201 | self.padding, self.padding, self.padding, self.padding) 202 | else: 203 | image = pad_image( 204 | 'constant', image, 205 | self.padding, self.padding, self.padding, self.padding, 206 | value=self.fill) 207 | return (image, label, *args) 208 | 209 | 210 | class PadImage(object): 211 | def __init__(self, padding, fill=0): 212 | assert isinstance(padding, numbers.Number) 213 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or \ 214 | isinstance(fill, tuple) 215 | self.padding = padding 216 | self.fill = fill 217 | 218 | def __call__(self, image, label=None, *args): 219 | if self.fill == -1: 220 | image = pad_image( 221 | 'reflection', image, 222 | self.padding, self.padding, self.padding, self.padding) 223 | else: 224 | image = ImageOps.expand(image, border=self.padding, fill=self.fill) 225 | return (image, label, *args) 226 | 227 | 228 | class ToTensor(object): 229 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 230 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 231 | """ 232 | 233 | def __call__(self, pic, label=None): 234 | if isinstance(pic, np.ndarray): 235 | # handle numpy array 236 | img = torch.from_numpy(pic) 237 | else: 238 | # handle PIL Image 239 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 240 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 241 | if pic.mode == 'YCbCr': 242 | nchannel = 3 243 | else: 244 | nchannel = len(pic.mode) 245 | img = img.view(pic.size[1], pic.size[0], nchannel) 246 | # put it from HWC to CHW format 247 | # yikes, this transpose takes 80% of the loading time/CPU 248 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 249 | img = img.float().div(255) 250 | if label is None: 251 | return img, 252 | else: 253 | return img, torch.LongTensor(np.array(label, dtype=np.int)) 254 | 255 | 256 | class Compose(object): 257 | """Composes several transforms together. 258 | """ 259 | 260 | def __init__(self, transforms): 261 | self.transforms = transforms 262 | 263 | def __call__(self, *args): 264 | for t in self.transforms: 265 | args = t(*args) 266 | return args 267 | -------------------------------------------------------------------------------- /datasets/cityscapes/README.md: -------------------------------------------------------------------------------- 1 | ## Prepare Cityscapes training data 2 | 3 | ### Step 1 4 | 5 | After you get a vanilla version of Cityscape data label maps, first convert the original segmentation label ids to one of 19 training ids: 6 | 7 | ``` 8 | python3 datasets/cityscapes/prepare_data.py /gtFine/ 9 | ``` 10 | 11 | ### Step 2 12 | 13 | - Run `create_lists.sh` in cityscape data folder, containing `gtFine` and `leftImg8bit` to create image and label lists. 14 | - Move [info.json](info.json) to the data folder. -------------------------------------------------------------------------------- /datasets/cityscapes/create_lists.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | find leftImg8bit/train -maxdepth 3 -name "*_leftImg8bit.png" | sort > train_images.txt 4 | find leftImg8bit/val -maxdepth 3 -name "*_leftImg8bit.png" | sort > val_images.txt 5 | find leftImg8bit/test -maxdepth 3 -name "*_leftImg8bit.png" | sort > test_images.txt 6 | 7 | find gtFine/train -maxdepth 3 -name "*_trainIds.png" | sort > train_labels.txt 8 | find gtFine/val -maxdepth 3 -name "*_trainIds.png" | sort > val_labels.txt 9 | 10 | -------------------------------------------------------------------------------- /datasets/cityscapes/drn-d-105.csv: -------------------------------------------------------------------------------- 1 | 2 | Challenge,Method 3 | pixellevel,drn 4 | 5 | "Class level" 6 | Metric,Average,road,sidewalk,building,wall,fence,pole,trafficlight,trafficsign,vegetation,terrain,sky,person,rider,car,truck,bus,train,motorcycle,bicycle 7 | IoU,76.3271,98.5788,86.0478,92.9234,48.0022,57.6084,67.0207,76.2925,79.7833,93.5752,71.9984,95.3354,86.4153,68.4009,95.7727,58.2886,68.8757,58.2999,69.7327,77.2638 8 | iIoU,54.956,,,,,,,,,,,,69.2543,49.5521,90.3489,35.8764,47.2958,37.416,46.5884,63.3161 9 | 10 | "Category level" 11 | Metric,Average,flat,nature,object,sky,construction,human,vehicle 12 | IoU,90.9445,98.6986,93.2919,73.7047,95.3354,93.3202,87.1861,95.0743 13 | iIoU,79.4787,,,,,,70.6883,88.2692 14 | -------------------------------------------------------------------------------- /datasets/cityscapes/info.json: -------------------------------------------------------------------------------- 1 | {"std": [0.1829540508368939, 0.18656561047509476, 0.18447508988480435], "mean": [0.29010095242892997, 0.32808144844279574, 0.28696394422942517]} -------------------------------------------------------------------------------- /datasets/cityscapes/prepare_data.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import os 3 | from os.path import join, split, exists 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | # a label and all meta information 10 | Label = namedtuple( 'Label' , [ 11 | 12 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 13 | # We use them to uniquely name a class 14 | 15 | 'id' , # An integer ID that is associated with this label. 16 | # The IDs are used to represent the label in ground truth images 17 | # An ID of -1 means that this label does not have an ID and thus 18 | # is ignored when creating ground truth images (e.g. license plate). 19 | # Do not modify these IDs, since exactly these IDs are expected by the 20 | # evaluation server. 21 | 22 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 23 | # ground truth images with train IDs, using the tools provided in the 24 | # 'preparation' folder. However, make sure to validate or submit results 25 | # to our evaluation server using the regular IDs above! 26 | # For trainIds, multiple labels might have the same ID. Then, these labels 27 | # are mapped to the same class in the ground truth images. For the inverse 28 | # mapping, we use the label that is defined first in the list below. 29 | # For example, mapping all void-type classes to the same ID in training, 30 | # might make sense for some approaches. 31 | # Max value is 255! 32 | 33 | 'category' , # The name of the category that this label belongs to 34 | 35 | 'categoryId' , # The ID of this category. Used to create ground truth images 36 | # on category level. 37 | 38 | 'hasInstances', # Whether this label distinguishes between single instances or not 39 | 40 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 41 | # during evaluations or not 42 | 43 | 'color' , # The color of this label 44 | ] ) 45 | 46 | 47 | labels = [ 48 | # name id trainId category catId hasInstances ignoreInEval color 49 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 50 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 51 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 52 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 53 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 54 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 55 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 56 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 57 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 58 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 59 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 60 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 61 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 62 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 63 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 64 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 65 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 66 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 67 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 68 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 69 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 70 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 71 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 72 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 73 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 74 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 75 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 76 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 77 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 78 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 79 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 80 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 81 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 82 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 83 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 84 | ] 85 | 86 | 87 | def label2id(image): 88 | array = np.array(image) 89 | out_array = np.empty(array.shape, dtype=array.dtype) 90 | for l in labels: 91 | if 0 <= l.trainId < 255: 92 | out_array[array == l.trainId] = l.id 93 | return Image.fromarray(out_array) 94 | 95 | 96 | def id2label(image): 97 | array = np.array(image) 98 | out_array = np.empty(array.shape, dtype=array.dtype) 99 | for l in labels: 100 | out_array[array == l.id] = l.trainId 101 | return Image.fromarray(out_array) 102 | 103 | 104 | def prepare_cityscape_submission(in_dir): 105 | our_dir = in_dir + '_id' 106 | for root, dirs, filenames in os.walk(in_dir): 107 | for name in filenames: 108 | in_path = join(root, name) 109 | out_path = join(root.replace(in_dir, our_dir), name) 110 | file_dir = split(out_path)[0] 111 | if not exists(file_dir): 112 | os.makedirs(file_dir) 113 | image = Image.open(in_path) 114 | id_map = label2id(image) 115 | print('Writing', out_path) 116 | id_map.save(out_path) 117 | 118 | 119 | def prepare_cityscape_training(in_dir): 120 | for root, dirs, filenames in os.walk(in_dir): 121 | for name in filenames: 122 | parts = name.split('_') 123 | if parts[-1] != 'labelIds.png': 124 | continue 125 | parts[-1] = 'trainIds.png' 126 | out_name = '_'.join(parts) 127 | in_path = join(root, name) 128 | out_path = join(root, out_name) 129 | image = Image.open(in_path) 130 | id_map = id2label(image) 131 | print('Writing', out_path) 132 | id_map.save(out_path) 133 | 134 | 135 | if __name__ == '__main__': 136 | prepare_cityscape_training(sys.argv[1]) 137 | -------------------------------------------------------------------------------- /datasets/compute_mean_std.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | from os import path as osp 6 | 7 | 8 | def compute_mean_std(data_dir, list_dir): 9 | image_list_path = osp.join(list_dir, 'train_images.txt') 10 | image_list = [line.strip() for line in open(image_list_path, 'r')] 11 | np.random.shuffle(image_list) 12 | pixels = [] 13 | for image_path in image_list[:500]: 14 | image = Image.open(osp.join(data_dir, image_path), 'r') 15 | pixels.append(np.asarray(image).reshape(-1, 3)) 16 | pixels = np.vstack(pixels) 17 | mean = np.mean(pixels, axis=0) / 255 18 | std = np.std(pixels, axis=0) / 255 19 | print(mean, std) 20 | info = {'mean': mean.tolist(), 'std': std.tolist()} 21 | with open(osp.join(data_dir, 'info.json'), 'w') as fp: 22 | json.dump(info, fp) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser( 27 | description='Compute mean and std of a dataset.') 28 | parser.add_argument('data_dir', default='./', required=True, 29 | help='data folder where train_images.txt resides.') 30 | parser.add_argument('list_dir', default=None, required=False, 31 | help='data folder where train_images.txt resides.') 32 | args = parser.parse_args() 33 | if args.list_dir is None: 34 | args.list_dir = args.data_dir 35 | return args 36 | 37 | 38 | def main(): 39 | args = parse_args() 40 | compute_mean_std(args.data_dir, args.list_dir) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /doc/drn_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/doc/drn_comp.png -------------------------------------------------------------------------------- /drn.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch.nn as nn 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | BatchNorm = nn.BatchNorm2d 8 | 9 | 10 | # __all__ = ['DRN', 'drn26', 'drn42', 'drn58'] 11 | 12 | 13 | webroot = 'http://dl.yf.io/drn/' 14 | 15 | model_urls = { 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 18 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 19 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 20 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 21 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 22 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 23 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=padding, bias=False, dilation=dilation) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, 36 | dilation=(1, 1), residual=True): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride, 39 | padding=dilation[0], dilation=dilation[0]) 40 | self.bn1 = BatchNorm(planes) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = conv3x3(planes, planes, 43 | padding=dilation[1], dilation=dilation[1]) 44 | self.bn2 = BatchNorm(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | self.residual = residual 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | if self.residual: 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, 72 | dilation=(1, 1), residual=True): 73 | super(Bottleneck, self).__init__() 74 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 75 | self.bn1 = BatchNorm(planes) 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 77 | padding=dilation[1], bias=False, 78 | dilation=dilation[1]) 79 | self.bn2 = BatchNorm(planes) 80 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 81 | self.bn3 = BatchNorm(planes * 4) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class DRN(nn.Module): 110 | 111 | def __init__(self, block, layers, num_classes=1000, 112 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 113 | out_map=False, out_middle=False, pool_size=28, arch='D'): 114 | super(DRN, self).__init__() 115 | self.inplanes = channels[0] 116 | self.out_map = out_map 117 | self.out_dim = channels[-1] 118 | self.out_middle = out_middle 119 | self.arch = arch 120 | 121 | if arch == 'C': 122 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 123 | padding=3, bias=False) 124 | self.bn1 = BatchNorm(channels[0]) 125 | self.relu = nn.ReLU(inplace=True) 126 | 127 | self.layer1 = self._make_layer( 128 | BasicBlock, channels[0], layers[0], stride=1) 129 | self.layer2 = self._make_layer( 130 | BasicBlock, channels[1], layers[1], stride=2) 131 | elif arch == 'D': 132 | self.layer0 = nn.Sequential( 133 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 134 | bias=False), 135 | BatchNorm(channels[0]), 136 | nn.ReLU(inplace=True) 137 | ) 138 | 139 | self.layer1 = self._make_conv_layers( 140 | channels[0], layers[0], stride=1) 141 | self.layer2 = self._make_conv_layers( 142 | channels[1], layers[1], stride=2) 143 | 144 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2) 145 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2) 146 | self.layer5 = self._make_layer(block, channels[4], layers[4], 147 | dilation=2, new_level=False) 148 | self.layer6 = None if layers[5] == 0 else \ 149 | self._make_layer(block, channels[5], layers[5], dilation=4, 150 | new_level=False) 151 | 152 | if arch == 'C': 153 | self.layer7 = None if layers[6] == 0 else \ 154 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 155 | new_level=False, residual=False) 156 | self.layer8 = None if layers[7] == 0 else \ 157 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 158 | new_level=False, residual=False) 159 | elif arch == 'D': 160 | self.layer7 = None if layers[6] == 0 else \ 161 | self._make_conv_layers(channels[6], layers[6], dilation=2) 162 | self.layer8 = None if layers[7] == 0 else \ 163 | self._make_conv_layers(channels[7], layers[7], dilation=1) 164 | 165 | if num_classes > 0: 166 | self.avgpool = nn.AvgPool2d(pool_size) 167 | self.fc = nn.Conv2d(self.out_dim, num_classes, kernel_size=1, 168 | stride=1, padding=0, bias=True) 169 | for m in self.modules(): 170 | if isinstance(m, nn.Conv2d): 171 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 172 | m.weight.data.normal_(0, math.sqrt(2. / n)) 173 | elif isinstance(m, BatchNorm): 174 | m.weight.data.fill_(1) 175 | m.bias.data.zero_() 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 178 | new_level=True, residual=True): 179 | assert dilation == 1 or dilation % 2 == 0 180 | downsample = None 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | downsample = nn.Sequential( 183 | nn.Conv2d(self.inplanes, planes * block.expansion, 184 | kernel_size=1, stride=stride, bias=False), 185 | BatchNorm(planes * block.expansion), 186 | ) 187 | 188 | layers = list() 189 | layers.append(block( 190 | self.inplanes, planes, stride, downsample, 191 | dilation=(1, 1) if dilation == 1 else ( 192 | dilation // 2 if new_level else dilation, dilation), 193 | residual=residual)) 194 | self.inplanes = planes * block.expansion 195 | for i in range(1, blocks): 196 | layers.append(block(self.inplanes, planes, residual=residual, 197 | dilation=(dilation, dilation))) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1): 202 | modules = [] 203 | for i in range(convs): 204 | modules.extend([ 205 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 206 | stride=stride if i == 0 else 1, 207 | padding=dilation, bias=False, dilation=dilation), 208 | BatchNorm(channels), 209 | nn.ReLU(inplace=True)]) 210 | self.inplanes = channels 211 | return nn.Sequential(*modules) 212 | 213 | def forward(self, x): 214 | y = list() 215 | 216 | if self.arch == 'C': 217 | x = self.conv1(x) 218 | x = self.bn1(x) 219 | x = self.relu(x) 220 | elif self.arch == 'D': 221 | x = self.layer0(x) 222 | 223 | x = self.layer1(x) 224 | y.append(x) 225 | x = self.layer2(x) 226 | y.append(x) 227 | 228 | x = self.layer3(x) 229 | y.append(x) 230 | 231 | x = self.layer4(x) 232 | y.append(x) 233 | 234 | x = self.layer5(x) 235 | y.append(x) 236 | 237 | if self.layer6 is not None: 238 | x = self.layer6(x) 239 | y.append(x) 240 | 241 | if self.layer7 is not None: 242 | x = self.layer7(x) 243 | y.append(x) 244 | 245 | if self.layer8 is not None: 246 | x = self.layer8(x) 247 | y.append(x) 248 | 249 | if self.out_map: 250 | x = self.fc(x) 251 | else: 252 | x = self.avgpool(x) 253 | x = self.fc(x) 254 | x = x.view(x.size(0), -1) 255 | 256 | if self.out_middle: 257 | return x, y 258 | else: 259 | return x 260 | 261 | 262 | class DRN_A(nn.Module): 263 | 264 | def __init__(self, block, layers, num_classes=1000): 265 | self.inplanes = 64 266 | super(DRN_A, self).__init__() 267 | self.out_dim = 512 * block.expansion 268 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 269 | bias=False) 270 | self.bn1 = nn.BatchNorm2d(64) 271 | self.relu = nn.ReLU(inplace=True) 272 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 273 | self.layer1 = self._make_layer(block, 64, layers[0]) 274 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 275 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 276 | dilation=2) 277 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 278 | dilation=4) 279 | self.avgpool = nn.AvgPool2d(28, stride=1) 280 | self.fc = nn.Linear(512 * block.expansion, num_classes) 281 | 282 | for m in self.modules(): 283 | if isinstance(m, nn.Conv2d): 284 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 285 | m.weight.data.normal_(0, math.sqrt(2. / n)) 286 | elif isinstance(m, BatchNorm): 287 | m.weight.data.fill_(1) 288 | m.bias.data.zero_() 289 | 290 | # for m in self.modules(): 291 | # if isinstance(m, nn.Conv2d): 292 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 293 | # elif isinstance(m, nn.BatchNorm2d): 294 | # nn.init.constant_(m.weight, 1) 295 | # nn.init.constant_(m.bias, 0) 296 | 297 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 298 | downsample = None 299 | if stride != 1 or self.inplanes != planes * block.expansion: 300 | downsample = nn.Sequential( 301 | nn.Conv2d(self.inplanes, planes * block.expansion, 302 | kernel_size=1, stride=stride, bias=False), 303 | nn.BatchNorm2d(planes * block.expansion), 304 | ) 305 | 306 | layers = [] 307 | layers.append(block(self.inplanes, planes, stride, downsample)) 308 | self.inplanes = planes * block.expansion 309 | for i in range(1, blocks): 310 | layers.append(block(self.inplanes, planes, 311 | dilation=(dilation, dilation))) 312 | 313 | return nn.Sequential(*layers) 314 | 315 | def forward(self, x): 316 | x = self.conv1(x) 317 | x = self.bn1(x) 318 | x = self.relu(x) 319 | x = self.maxpool(x) 320 | 321 | x = self.layer1(x) 322 | x = self.layer2(x) 323 | x = self.layer3(x) 324 | x = self.layer4(x) 325 | 326 | x = self.avgpool(x) 327 | x = x.view(x.size(0), -1) 328 | x = self.fc(x) 329 | 330 | return x 331 | 332 | 333 | def drn_a_50(pretrained=False, **kwargs): 334 | model = DRN_A(Bottleneck, [3, 4, 6, 3], **kwargs) 335 | if pretrained: 336 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 337 | return model 338 | 339 | 340 | def drn_c_26(pretrained=False, **kwargs): 341 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', **kwargs) 342 | if pretrained: 343 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-26'])) 344 | return model 345 | 346 | 347 | def drn_c_42(pretrained=False, **kwargs): 348 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs) 349 | if pretrained: 350 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-42'])) 351 | return model 352 | 353 | 354 | def drn_c_58(pretrained=False, **kwargs): 355 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs) 356 | if pretrained: 357 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-58'])) 358 | return model 359 | 360 | 361 | def drn_d_22(pretrained=False, **kwargs): 362 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', **kwargs) 363 | if pretrained: 364 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-22'])) 365 | return model 366 | 367 | 368 | def drn_d_24(pretrained=False, **kwargs): 369 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', **kwargs) 370 | if pretrained: 371 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-24'])) 372 | return model 373 | 374 | 375 | def drn_d_38(pretrained=False, **kwargs): 376 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs) 377 | if pretrained: 378 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-38'])) 379 | return model 380 | 381 | 382 | def drn_d_40(pretrained=False, **kwargs): 383 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs) 384 | if pretrained: 385 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-40'])) 386 | return model 387 | 388 | 389 | def drn_d_54(pretrained=False, **kwargs): 390 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs) 391 | if pretrained: 392 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-54'])) 393 | return model 394 | 395 | 396 | def drn_d_56(pretrained=False, **kwargs): 397 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs) 398 | if pretrained: 399 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-56'])) 400 | return model 401 | 402 | 403 | def drn_d_105(pretrained=False, **kwargs): 404 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', **kwargs) 405 | if pretrained: 406 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-105'])) 407 | return model 408 | 409 | 410 | def drn_d_107(pretrained=False, **kwargs): 411 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 2, 2], arch='D', **kwargs) 412 | if pretrained: 413 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-107'])) 414 | return model -------------------------------------------------------------------------------- /lib/Makefile: -------------------------------------------------------------------------------- 1 | PYTORCH_LIB_DIR := /home/fy/pytorch/torch/lib 2 | 3 | 4 | PYTHON := python3 5 | NVCC_COMPILE := nvcc -c -o 6 | RM_RF := rm -rf 7 | 8 | # Library compilation rules. 9 | NVCC_FLAGS := -x cu -Xcompiler -fPIC -shared 10 | 11 | # File structure. 12 | BUILD_DIR := dense 13 | INCLUDE_DIRS := TH THC THCUNN include include/TH 14 | TORCH_FFI_BUILD := build.py 15 | BN_KERNEL := $(BUILD_DIR)/batchnormp_kernel.so 16 | TORCH_FFI_TARGET := $(BUILD_DIR)/batch_norm/_batch_norm.so 17 | 18 | INCLUDE_FLAGS := $(foreach d, $(INCLUDE_DIRS), -I$(PYTORCH_LIB_DIR)/$d) 19 | 20 | all: $(TORCH_FFI_TARGET) 21 | 22 | $(TORCH_FFI_TARGET): $(BN_KERNEL) $(TORCH_FFI_BUILD) 23 | $(PYTHON) $(TORCH_FFI_BUILD) 24 | 25 | $(BUILD_DIR)/batchnormp_kernel.so: src/batchnormp_cuda_kernel.cu 26 | @mkdir -p $(BUILD_DIR) 27 | $(NVCC_COMPILE) $@ $? $(NVCC_FLAGS) $(INCLUDE_FLAGS) -Isrc 28 | 29 | clean: 30 | $(RM_RF) $(BUILD_DIR) -------------------------------------------------------------------------------- /lib/build.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | this_file = os.path.dirname(__file__) 7 | 8 | sources = ['src/batchnormp.c'] 9 | headers = ['src/batchnormp.h'] 10 | defines = [] 11 | with_cuda = False 12 | 13 | abs_path = os.path.dirname(os.path.realpath(__file__)) 14 | extra_objects = [os.path.join(abs_path, 'dense/batchnormp_kernel.so')] 15 | extra_objects += glob.glob('/usr/local/cuda/lib64/*.a') 16 | 17 | if torch.cuda.is_available(): 18 | print('Including CUDA code.') 19 | sources += ['src/batchnormp_cuda.c'] 20 | headers += ['src/batchnormp_cuda.h'] 21 | defines += [('WITH_CUDA', None)] 22 | with_cuda = True 23 | 24 | ffi = create_extension( 25 | 'dense.batch_norm', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda, 31 | extra_objects=extra_objects) 32 | 33 | if __name__ == '__main__': 34 | ffi.build() -------------------------------------------------------------------------------- /lib/dense/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/lib/dense/__init__.py -------------------------------------------------------------------------------- /lib/dense/batch_norm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._batch_norm import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | locals[symbol] = _wrap_function(fn, _ffi) 10 | __all__.append(symbol) 11 | 12 | _import_symbols(locals()) 13 | -------------------------------------------------------------------------------- /lib/dense/batch_norm/_batch_norm.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/lib/dense/batch_norm/_batch_norm.so -------------------------------------------------------------------------------- /lib/dense/batchnormp_kernel.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/lib/dense/batchnormp_kernel.so -------------------------------------------------------------------------------- /lib/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/lib/functions/__init__.py -------------------------------------------------------------------------------- /lib/functions/batchnormp.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torch.autograd import Function 7 | from dense import batch_norm 8 | 9 | from queue import Queue 10 | from threading import Condition 11 | 12 | cum_queue = Queue() 13 | broadcast_queue = Queue() 14 | broadcast_cv = Condition() 15 | 16 | 17 | class BatchNormPFunction(Function): 18 | def __init__(self, running_mean, running_var, training, 19 | cum_queue, broadcast_queue, device_ids, sync, 20 | eps=1e-5, momentum=0.1, affine=True): 21 | self.affine = affine 22 | self.eps = eps 23 | self.momentum = momentum 24 | self.running_mean = running_mean 25 | self.running_var = running_var 26 | self.mean = None 27 | self.var = None 28 | self.training = training 29 | self.cum_queue = cum_queue 30 | self.broadcast_queue = broadcast_queue 31 | self.device_ids = device_ids 32 | self.sync = sync 33 | 34 | def forward(self, input, weight, bias): 35 | output = input.new() 36 | self.save_for_backward(input, weight, bias) 37 | 38 | # input_t = input.transpose(0, 1).double() 39 | # input_size = input_t.size() 40 | batch_size = int(input.size(0)) 41 | # input_t.resize_(int(input_size[0]), int(np.prod(input_size[1:]))) 42 | # self.mean = input_t.mean(dim=1) 43 | 44 | device_ids = self.device_ids 45 | # print('device', input.get_device(), flush=True) 46 | if input.is_cuda: 47 | # self.mean.copy_(torch.from_numpy( 48 | # self.cum_mean(input.get_device(), 49 | # self.mean.cpu().numpy(), 50 | # batch_size))) 51 | # var = input_t - torch.unsqueeze(self.mean, 1) 52 | # var *= var 53 | # var = var.mean(dim=1) 54 | # total_var = self.cum_mean( 55 | # input.get_device(), var.cpu().numpy(), batch_size) 56 | # self.std = input_t.new().resize_as_(self.mean). \ 57 | # copy_(torch.from_numpy(total_var)).sqrt() 58 | 59 | mean_cuda = input.new().resize_(input.size(1)) 60 | var_cuda = input.new().resize_(input.size(1)) 61 | batch_norm.BatchNormalizationP_mean_cuda(input, mean_cuda) 62 | 63 | if len(device_ids) > 1 and self.sync and self.training: 64 | mean_cuda.copy_(torch.from_numpy(self.cum_mean( 65 | input.get_device(), mean_cuda.cpu().numpy(), batch_size))) 66 | batch_norm.BatchNormalizationP_var_cuda(input, mean_cuda, var_cuda) 67 | if len(device_ids) > 1 and self.sync and self.training: 68 | var_cuda.copy_(torch.from_numpy(self.cum_mean( 69 | input.get_device(), var_cuda.cpu().numpy(), batch_size))) 70 | else: 71 | # self.std = input_t.std(dim=1, unbiased=False) 72 | batch_norm.BatchNormalizationP_var_cuda(input, mean_cuda, var_cuda) 73 | self.mean = mean_cuda 74 | self.var = var_cuda 75 | 76 | if not input.is_cuda: 77 | self.std = input_t.std(dim=1, unbiased=False) 78 | batch_norm.BatchNormalizationP_forward( 79 | input, output, weight, bias, 80 | self.running_mean, self.running_var, self.mean, self.std, 81 | self.training, self.momentum, self.eps) 82 | else: 83 | batch_norm.BatchNormalizationP_forward_cuda( 84 | input, output, weight, bias, 85 | self.running_mean, self.running_var, self.mean, self.var, 86 | self.training, self.momentum, self.eps) 87 | return output 88 | 89 | def cum_mean(self, this_device, this_mean, batch_size): 90 | cum_queue.put((batch_size, this_mean)) 91 | total_mean = np.zeros(this_mean.shape, dtype=np.float64) 92 | total_batch_size = 0 93 | if this_device == self.device_ids[0]: 94 | for _ in self.device_ids: 95 | item = cum_queue.get() 96 | total_batch_size += item[0] 97 | total_mean += item[0] * item[1] 98 | cum_queue.task_done() 99 | total_mean /= total_batch_size 100 | broadcast_cv.acquire() 101 | for _ in range(len(self.device_ids) - 1): 102 | broadcast_queue.put(total_mean) 103 | broadcast_cv.notify_all() 104 | broadcast_cv.release() 105 | else: 106 | broadcast_cv.acquire() 107 | if broadcast_queue.qsize() == 0: 108 | broadcast_cv.wait() 109 | total_mean = broadcast_queue.get() 110 | broadcast_queue.task_done() 111 | broadcast_cv.release() 112 | # assert cum_queue.empty() 113 | broadcast_queue.join() 114 | return total_mean 115 | 116 | def backward(self, grad_output): 117 | input, weight, bias = self.saved_tensors 118 | grad_input = grad_output.new().resize_as_(input) 119 | grad_weight = grad_output.new().resize_as_(weight).zero_() 120 | grad_bias = grad_output.new().resize_as_(bias).zero_() 121 | if not grad_output.is_cuda: 122 | batch_norm.BatchNormalizationP_backward( 123 | input, grad_output, grad_input, grad_weight, grad_bias, 124 | weight, self.running_mean, self.running_var, self.mean, 125 | self.std, self.training, 1, self.eps) 126 | else: 127 | # grad_output_t = grad_output.transpose(0, 1).double() 128 | # batch_size = int(grad_output.size(0)) 129 | # grad_output_t.resize_(int(grad_output_t.size(0)), 130 | # int(np.prod(grad_output_t.size()[1:]))) 131 | # grad_output_mean = grad_output_t.mean(dim=1) 132 | # device_ids = self.device_ids 133 | # if len(device_ids) > 1 and self.sync: 134 | # grad_output_mean.copy_(torch.from_numpy( 135 | # self.cum_mean(grad_output.get_device(), 136 | # grad_output_mean.cpu().numpy(), 137 | # batch_size))) 138 | # grad_output_mean = grad_output_mean.float() 139 | # 140 | # input_t = input.transpose(0, 1).double() 141 | # input_size = input_t.size() 142 | # input_t.resize_(int(input_size[0]), int(np.prod(input_size[1:]))) 143 | # dotP = (input_t - torch.unsqueeze(self.mean.double(), 1)) * \ 144 | # grad_output_t 145 | # dotP = dotP.mean(dim=1) 146 | # if len(device_ids) > 1 and self.sync: 147 | # dotP.copy_(torch.from_numpy( 148 | # self.cum_mean(grad_output.get_device(), 149 | # dotP.cpu().numpy(), 150 | # batch_size))) 151 | # dotP = dotP.float() 152 | 153 | batch_size = int(grad_output.size(0)) 154 | grad_output_mean_cuda = grad_output.new().resize_(grad_output.size(1)) 155 | dotP_cuda = grad_output.new().resize_( 156 | grad_output.size(1)) 157 | batch_norm.BatchNormalizationP_mean_grad_cuda( 158 | input, grad_output, self.running_mean, 159 | self.mean, grad_output_mean_cuda, dotP_cuda, self.training 160 | ) 161 | if len(self.device_ids) > 1 and self.sync: 162 | grad_output_mean_cuda.copy_(torch.from_numpy( 163 | self.cum_mean(grad_output.get_device(), 164 | grad_output_mean_cuda.cpu().numpy(), 165 | batch_size))) 166 | dotP_cuda.copy_(torch.from_numpy( 167 | self.cum_mean(grad_output.get_device(), 168 | dotP_cuda.cpu().numpy(), 169 | batch_size))) 170 | 171 | # pdb.set_trace() 172 | 173 | batch_norm.BatchNormalizationP_backward_cuda( 174 | input, grad_output, grad_output_mean_cuda, dotP_cuda, 175 | grad_input, grad_weight, grad_bias, 176 | weight, self.running_mean, self.running_var, 177 | self.mean, self.var, self.training, 1, self.eps) 178 | return grad_input, grad_weight, grad_bias 179 | -------------------------------------------------------------------------------- /lib/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyu/drn/d75db2ee7070426db7a9264ee61cf489f8cf178c/lib/modules/__init__.py -------------------------------------------------------------------------------- /lib/modules/batchnormsync.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | 3 | import torch 4 | from torch.nn import Module 5 | from torch.nn.parameter import Parameter 6 | from functions.batchnormp import BatchNormPFunction 7 | 8 | 9 | class BatchNormSync(Module): 10 | 11 | sync = True 12 | checking_mode = False 13 | 14 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 15 | device_ids=None): 16 | super(BatchNormSync, self).__init__() 17 | self.num_features = num_features 18 | self.affine = affine 19 | self.eps = eps 20 | self.momentum = momentum 21 | if self.affine: 22 | self.weight = Parameter(torch.Tensor(num_features)) 23 | self.bias = Parameter(torch.Tensor(num_features)) 24 | else: 25 | self.register_parameter('weight', None) 26 | self.register_parameter('bias', None) 27 | self.register_buffer('running_mean', torch.zeros(num_features)) 28 | self.register_buffer('running_var', torch.ones(num_features)) 29 | self.mean = torch.zeros(num_features) 30 | self.std = torch.ones(num_features) 31 | self.reset_parameters() 32 | self.cum_queue = Queue() 33 | self.broadcast_queue = Queue() 34 | if device_ids is None: 35 | self.device_ids = list(range(torch.cuda.device_count())) 36 | else: 37 | self.device_ids = device_ids 38 | 39 | def reset_parameters(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | self.mean.zero_() 43 | self.std.fill_(1) 44 | if self.affine: 45 | if BatchNormSync.checking_mode: 46 | self.weight.data.fill_(1) 47 | else: 48 | self.weight.data.uniform_() 49 | self.bias.data.zero_() 50 | 51 | def forward(self, input): 52 | training = int(self.training) 53 | assert input.size(1) == self.num_features 54 | 55 | bn_func = BatchNormPFunction( 56 | self.running_mean, self.running_var, # self.mean, self.std, 57 | training, self.cum_queue, self.broadcast_queue, self.device_ids, 58 | BatchNormSync.sync, self.eps, self.momentum, self.affine) 59 | return bn_func(input, self.weight, self.bias) 60 | 61 | def __repr__(self): 62 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 63 | ' affine={affine})' 64 | .format(name=self.__class__.__name__, **self.__dict__)) -------------------------------------------------------------------------------- /lib/src/batchnormp.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "batchnormp.h" 3 | 4 | #define THNN_CHECK_SHAPE(I1, I2) \ 5 | if (I1 != NULL && I2 != NULL && !THFloatTensor_isSameSizeAs(I1, I2)) \ 6 | { \ 7 | THDescBuff s1 = THFloatTensor_sizeDesc(I1); \ 8 | THDescBuff s2 = THFloatTensor_sizeDesc(I2); \ 9 | THError(#I1 " and " #I2 " shapes do not match: " \ 10 | #I1 " %s, " #I2 " %s", s1.str, s2.str); \ 11 | } 12 | 13 | void BatchNormalizationP_forward( 14 | THFloatTensor *input, THFloatTensor *output, 15 | THFloatTensor *weight, THFloatTensor *bias, 16 | THFloatTensor *running_mean, THFloatTensor *running_var, 17 | THFloatTensor *save_mean, THFloatTensor *save_std, 18 | int train, double momentum, double eps) 19 | { 20 | THFloatTensor_resizeAs(output, input); 21 | int64_t nInput = THFloatTensor_size(input, 1); 22 | int64_t f; 23 | ptrdiff_t n = THFloatTensor_nElement(input) / nInput; 24 | 25 | #pragma omp parallel for 26 | for (f = 0; f < nInput; ++f) { 27 | THFloatTensor *in = THFloatTensor_newSelect(input, 1, f); 28 | THFloatTensor *out = THFloatTensor_newSelect(output, 1, f); 29 | 30 | float mean, invstd, std; 31 | 32 | if (train) { 33 | // compute mean per input 34 | // double sum = 0; 35 | // TH_TENSOR_APPLY(float, in, sum += *in_data;); 36 | // 37 | // mean = (float) sum / n; 38 | // THFloatTensor_set1d(save_mean, f, (float) mean); 39 | 40 | mean = THFloatTensor_get1d(save_mean, f); 41 | std = THFloatTensor_get1d(save_std, f); 42 | invstd = (float) (1 / (std + eps)); 43 | 44 | // compute variance per input 45 | // sum = 0; 46 | // TH_TENSOR_APPLY(float, in, 47 | // sum += (*in_data - mean) * (*in_data - mean);); 48 | // 49 | // if (sum == 0 && eps == 0.0) { 50 | // invstd = 0; 51 | // } else { 52 | // invstd = (float) (1 / sqrt(sum/n + eps)); 53 | // } 54 | // THFloatTensor_set1d(save_std, f, (float) invstd); 55 | 56 | // update running averages 57 | THFloatTensor_set1d(running_mean, f, 58 | (float) (momentum * mean + (1 - momentum) * THFloatTensor_get1d(running_mean, f))); 59 | 60 | double unbiased_var = std * n / (n - 1); 61 | THFloatTensor_set1d(running_var, f, 62 | (float) (momentum * unbiased_var + (1 - momentum) * THFloatTensor_get1d(running_var, f))); 63 | } else { 64 | mean = THFloatTensor_get1d(running_mean, f); 65 | invstd = 1 / sqrt(THFloatTensor_get1d(running_var, f) + eps); 66 | } 67 | 68 | // compute output 69 | float w = weight ? THFloatTensor_get1d(weight, f) : 1; 70 | float b = bias ? THFloatTensor_get1d(bias, f) : 0; 71 | 72 | TH_TENSOR_APPLY2(float, in, float, out, 73 | *out_data = (float) (((*in_data - mean) * invstd) * w + b);); 74 | 75 | THFloatTensor_free(out); 76 | THFloatTensor_free(in); 77 | } 78 | } 79 | 80 | void BatchNormalizationP_backward( 81 | THFloatTensor *input, THFloatTensor *gradOutput, THFloatTensor *gradInput, 82 | THFloatTensor *gradWeight, THFloatTensor *gradBias, THFloatTensor *weight, 83 | THFloatTensor *running_mean, THFloatTensor *running_var, 84 | THFloatTensor *save_mean, THFloatTensor *save_std, 85 | int train, double scale, double eps) 86 | { 87 | THNN_CHECK_SHAPE(input, gradOutput); 88 | int64_t nInput = THFloatTensor_size(input, 1); 89 | int64_t f; 90 | ptrdiff_t n = THFloatTensor_nElement(input) / nInput; 91 | 92 | #pragma omp parallel for 93 | for (f = 0; f < nInput; ++f) { 94 | THFloatTensor *in = THFloatTensor_newSelect(input, 1, f); 95 | THFloatTensor *gradOut = THFloatTensor_newSelect(gradOutput, 1, f); 96 | float w = weight ? THFloatTensor_get1d(weight, f) : 1; 97 | float mean, invstd; 98 | if (train) { 99 | mean = THFloatTensor_get1d(save_mean, f); 100 | invstd = 1 / (THFloatTensor_get1d(save_std, f) + eps); 101 | } else { 102 | mean = THFloatTensor_get1d(running_mean, f); 103 | invstd = 1 / sqrt(THFloatTensor_get1d(running_var, f) + eps); 104 | } 105 | 106 | // sum over all gradOutput in feature plane 107 | double sum = 0; 108 | TH_TENSOR_APPLY(float, gradOut, sum += *gradOut_data;); 109 | 110 | // dot product of the Q(X) and gradOuput 111 | double dotp = 0; 112 | TH_TENSOR_APPLY2(float, in, float, gradOut, 113 | dotp += (*in_data - mean) * (*gradOut_data);); 114 | 115 | if (gradInput) { 116 | THFloatTensor_resizeAs(gradInput, input); 117 | THFloatTensor *gradIn = THFloatTensor_newSelect(gradInput, 1, f); 118 | 119 | if (train) { 120 | // when in training mode 121 | // Q(X) = X - E[x] ; i.e. input centered to zero mean 122 | // Y = Q(X) / σ ; i.e. BN output before weight and bias 123 | // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w 124 | 125 | // projection of gradOutput on to output scaled by std 126 | float k = (float) dotp * invstd * invstd / n; 127 | TH_TENSOR_APPLY2(float, gradIn, float, in, 128 | *gradIn_data = (*in_data - mean) * k;); 129 | 130 | double gradMean = sum / n; 131 | TH_TENSOR_APPLY2(float, gradIn, float, gradOut, 132 | *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;); 133 | 134 | } else { 135 | // when in evaluation mode 136 | // Q(X) = X - running_mean ; i.e. input centered to zero mean 137 | // Y = Q(X) / running_std ; i.e. BN output before weight and bias 138 | // dL/dX = w / running_std 139 | TH_TENSOR_APPLY2(float, gradIn, float, gradOut, 140 | *gradIn_data = *gradOut_data * invstd * w;); 141 | } 142 | 143 | THFloatTensor_free(gradIn); 144 | } 145 | 146 | if (gradWeight) { 147 | float val = THFloatTensor_get1d(gradWeight, f); 148 | THFloatTensor_set1d(gradWeight, f, val + scale * dotp * invstd); 149 | } 150 | 151 | if (gradBias) { 152 | float val = THFloatTensor_get1d(gradBias, f); 153 | THFloatTensor_set1d(gradBias, f, val + scale * sum); 154 | } 155 | 156 | THFloatTensor_free(gradOut); 157 | THFloatTensor_free(in); 158 | } 159 | } -------------------------------------------------------------------------------- /lib/src/batchnormp.h: -------------------------------------------------------------------------------- 1 | // #include 2 | 3 | void BatchNormalizationP_forward( 4 | THFloatTensor *input, THFloatTensor *output, 5 | THFloatTensor *weight, THFloatTensor *bias, 6 | THFloatTensor *running_mean, THFloatTensor *running_var, 7 | THFloatTensor *save_mean, THFloatTensor *save_std, 8 | int train, double momentum, double eps); 9 | 10 | 11 | void BatchNormalizationP_backward( 12 | THFloatTensor *input, THFloatTensor *gradOutput, THFloatTensor *gradInput, 13 | THFloatTensor *gradWeight, THFloatTensor *gradBias, THFloatTensor *weight, 14 | THFloatTensor *running_mean, THFloatTensor *running_var, 15 | THFloatTensor *save_mean, THFloatTensor *save_std, 16 | int train, double scale, double eps); 17 | -------------------------------------------------------------------------------- /lib/src/batchnormp_cuda.c: -------------------------------------------------------------------------------- 1 | // #include "auto_gpu.h" 2 | #include 3 | 4 | #include "batchnormp_cuda_kernel.h" 5 | 6 | 7 | extern THCState *state; 8 | 9 | void BatchNormalizationP_forward_cuda( 10 | THCudaTensor *input, THCudaTensor *output, 11 | THCudaTensor *weight, THCudaTensor *bias, 12 | THCudaTensor *running_mean, THCudaTensor *running_var, 13 | THCudaTensor *save_mean, THCudaTensor *save_std, 14 | int train, double momentum, double eps) { 15 | THNN_CudaBatchNormalization_updateOutputhaha( 16 | state, input, output, weight, bias, running_mean, running_var, 17 | save_mean, save_std, train, momentum, eps); 18 | } 19 | 20 | void BatchNormalizationP_mean_cuda( 21 | THCudaTensor *input, THCudaTensor *save_mean) { 22 | THNN_CudaBatchNormalization_mean( 23 | state, input, save_mean); 24 | } 25 | 26 | 27 | void BatchNormalizationP_var_cuda( 28 | THCudaTensor *input, THCudaTensor *save_mean, THCudaTensor *save_var) { 29 | THNN_CudaBatchNormalization_var( 30 | state, input, save_mean, save_var); 31 | } 32 | 33 | 34 | void BatchNormalizationP_backward_cuda( 35 | THCudaTensor *input, THCudaTensor *gradOutput, 36 | THCudaTensor *gradOutputMean, THCudaTensor *dotP, 37 | THCudaTensor *gradInput, 38 | THCudaTensor *gradWeight, THCudaTensor *gradBias, THCudaTensor *weight, 39 | THCudaTensor *running_mean, THCudaTensor *running_var, 40 | THCudaTensor *save_mean, THCudaTensor *save_std, 41 | int train, double scale, double eps) { 42 | THNN_CudaBatchNormalization_backwardhaha( 43 | state, input, gradOutput, gradOutputMean, dotP, 44 | gradInput, gradWeight, gradBias, weight, 45 | running_mean, running_var, save_mean, save_std, train, scale, eps); 46 | } 47 | 48 | void BatchNormalizationP_mean_grad_cuda( 49 | THCudaTensor *input, THCudaTensor *gradOutput, 50 | THCudaTensor *runningMean, THCudaTensor *saveMean, 51 | THCudaTensor *gradOutputMean, THCudaTensor *dotP, int train) { 52 | THNN_CudaBatchNormalization_mean_grad( 53 | state, input, gradOutput, runningMean, saveMean, 54 | gradOutputMean, dotP, train); 55 | } -------------------------------------------------------------------------------- /lib/src/batchnormp_cuda.h: -------------------------------------------------------------------------------- 1 | void BatchNormalizationP_forward_cuda( 2 | THCudaTensor *input, THCudaTensor *output, 3 | THCudaTensor *weight, THCudaTensor *bias, 4 | THCudaTensor *running_mean, THCudaTensor *running_var, 5 | THCudaTensor *save_mean, THCudaTensor *save_std, 6 | int train, double momentum, double eps); 7 | 8 | 9 | void BatchNormalizationP_mean_cuda( 10 | THCudaTensor *input, THCudaTensor *save_mean); 11 | 12 | 13 | void BatchNormalizationP_var_cuda( 14 | THCudaTensor *input, THCudaTensor *save_mean, THCudaTensor *save_var); 15 | 16 | 17 | void BatchNormalizationP_backward_cuda( 18 | THCudaTensor *input, THCudaTensor *gradOutput, 19 | THCudaTensor *gradOutputMean, THCudaTensor *dotP, 20 | THCudaTensor *gradInput, 21 | THCudaTensor *gradWeight, THCudaTensor *gradBias, THCudaTensor *weight, 22 | THCudaTensor *running_mean, THCudaTensor *running_var, 23 | THCudaTensor *save_mean, THCudaTensor *save_std, 24 | int train, double scale, double eps); 25 | 26 | 27 | void BatchNormalizationP_mean_grad_cuda( 28 | THCudaTensor *input, THCudaTensor *gradOutput, 29 | THCudaTensor *runningMean, THCudaTensor *saveMean, 30 | THCudaTensor *gradOutputMean, THCudaTensor *dotP, int train); -------------------------------------------------------------------------------- /lib/src/batchnormp_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "THCUNN.h" 2 | #include "common.h" 3 | #include "THCHalf.h" 4 | #include "THCHalfAutoNumerics.cuh" 5 | 6 | #include "THCDeviceTensor.cuh" 7 | #include "THCDeviceTensorUtils.cuh" 8 | #include "THCDeviceUtils.cuh" 9 | const int WARP_SIZE = 32; 10 | 11 | // The maximum number of threads in a block 12 | const int MAX_BLOCK_SIZE = 512; 13 | 14 | // Number of threads in a block given an input size up to MAX_BLOCK_SIZE 15 | static int getNumThreads(int nElem) { 16 | int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; 17 | for (int i = 0; i != 5; ++i) { 18 | if (nElem <= threadSizes[i]) { 19 | return threadSizes[i]; 20 | } 21 | } 22 | return MAX_BLOCK_SIZE; 23 | } 24 | 25 | // Returns the index of the most significant 1 bit in `val`. 26 | __device__ __forceinline__ int getMSB(int val) { 27 | return 31 - __clz(val); 28 | } 29 | 30 | template 31 | struct Float2 { 32 | Acctype v1, v2; 33 | __device__ Float2() {} 34 | __device__ Float2(Dtype v1, Dtype v2) : v1(ScalarConvert::to(v1)), v2(ScalarConvert::to(v2)) {} 35 | __device__ Float2(Dtype v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} 36 | __device__ Float2(int v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} 37 | __device__ Float2& operator+=(const Float2& a) { 38 | v1 += a.v1; 39 | v2 += a.v2; 40 | return *this; 41 | } 42 | }; 43 | 44 | template 45 | struct SumOp { 46 | __device__ SumOp(const DeviceTensor3 t) : tensor(t) {} 47 | __device__ __forceinline__ Acctype operator()(int batch, int plane, int n) { 48 | return ScalarConvert::to(tensor[batch][plane][n]); 49 | } 50 | const DeviceTensor3 tensor; 51 | }; 52 | 53 | template 54 | struct VarOp { 55 | __device__ VarOp(Acctype m, const DeviceTensor3 t) : mean(m), tensor(t) {} 56 | __device__ __forceinline__ Acctype operator()(int batch, int plane, int n) { 57 | Dtype val = tensor[batch][plane][n]; 58 | return (val - mean) * (val - mean); 59 | } 60 | const Acctype mean; 61 | const DeviceTensor3 tensor; 62 | }; 63 | 64 | template 65 | struct GradOp { 66 | __device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g) 67 | : mean(m), input(i), gradOutput(g) {} 68 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { 69 | Dtype g = gradOutput[batch][plane][n]; 70 | Dtype c = ScalarConvert::to(input[batch][plane][n] - mean); 71 | return Float2(g, g * c); 72 | } 73 | const Acctype mean; 74 | const DeviceTensor3 input; 75 | const DeviceTensor3 gradOutput; 76 | }; 77 | 78 | // Sum across all threads within a warp 79 | template 80 | static __device__ __forceinline__ T warpSum(T val) { 81 | #if __CUDA_ARCH__ >= 300 82 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 83 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 84 | } 85 | #else 86 | __shared__ T values[MAX_BLOCK_SIZE]; 87 | values[threadIdx.x] = val; 88 | __threadfence_block(); 89 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 90 | for (int i = 1; i < WARP_SIZE; i++) { 91 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 92 | } 93 | #endif 94 | return val; 95 | } 96 | 97 | template 98 | static __device__ __forceinline__ Float2 warpSum(Float2 value) { 99 | value.v1 = warpSum(value.v1); 100 | value.v2 = warpSum(value.v2); 101 | return value; 102 | } 103 | 104 | // Sum across (batch, x/y/z) applying Op() pointwise 105 | template 106 | __device__ T reduce(Op op, DeviceTensor3 tensor, int plane) { 107 | T sum = (T)0; 108 | for (int batch = 0; batch < tensor.getSize(0); ++batch) { 109 | for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) { 110 | sum += op(batch, plane, x); 111 | } 112 | } 113 | 114 | // sum over NumThreads within a warp 115 | sum = warpSum(sum); 116 | 117 | // 'transpose', and reduce within warp again 118 | __shared__ T shared[32]; 119 | __syncthreads(); 120 | if (threadIdx.x % WARP_SIZE == 0) { 121 | shared[threadIdx.x / WARP_SIZE] = sum; 122 | } 123 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 124 | // zero out the other entries in shared 125 | shared[threadIdx.x] = (T)0; 126 | } 127 | __syncthreads(); 128 | if (threadIdx.x / WARP_SIZE == 0) { 129 | sum = warpSum(shared[threadIdx.x]); 130 | if (threadIdx.x == 0) { 131 | shared[0] = sum; 132 | } 133 | } 134 | __syncthreads(); 135 | 136 | // Everyone picks it up, should be broadcast into the whole gradInput 137 | return shared[0]; 138 | } 139 | 140 | template 141 | __global__ void BatchNormalizationUpdateOutputInference_kernel( 142 | const DeviceTensor3 input, 143 | DeviceTensor3 output, 144 | DeviceTensor1 runningMean, 145 | DeviceTensor1 runningVar, 146 | const DeviceTensor1 weight, 147 | const DeviceTensor1 bias, 148 | Acctype epsilon) { 149 | 150 | int plane = blockIdx.x; 151 | 152 | Acctype invstd = Acctype(1) / sqrt(runningVar[plane].ldg() + epsilon); 153 | Acctype mean = ScalarConvert::to(runningMean[plane].ldg()); 154 | Acctype gamma = weight.numElements() > 0 ? ScalarConvert::to(weight[plane].ldg()) : Acctype(1); 155 | Acctype beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane].ldg()) : Acctype(0); 156 | 157 | // Write normalized and update the output 158 | for (int batch = 0; batch < input.getSize(0); batch++) { 159 | for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { 160 | Dtype inp = input[batch][plane][x].ldg(); 161 | output[batch][plane][x] = ScalarConvert::to(gamma * (inp - mean) * invstd + beta); 162 | } 163 | } 164 | } 165 | 166 | template 167 | __global__ void BatchNormalizationMean_kernel( 168 | const DeviceTensor3 input, 169 | DeviceTensor1 out_mean) { 170 | int plane = blockIdx.x; 171 | int N = input.getSize(0) * input.getSize(2); 172 | 173 | Acctype norm = Acctype(1) / N; 174 | Acctype mean = reduce(SumOp(input), input, plane) * norm; 175 | if (threadIdx.x == 0) { 176 | out_mean[plane] = ScalarConvert::to(mean); 177 | } 178 | } 179 | 180 | 181 | template 182 | __global__ void BatchNormalizationVar_kernel( 183 | const DeviceTensor3 input, 184 | const DeviceTensor1 in_mean, 185 | DeviceTensor1 out_var) { 186 | int plane = blockIdx.x; 187 | int N = input.getSize(0) * input.getSize(2); 188 | 189 | Acctype norm = Acctype(1) / N; 190 | Acctype mean = ScalarConvert::to(in_mean[plane]); 191 | 192 | Acctype var = reduce(VarOp(mean, input), input, plane) * norm; 193 | if (threadIdx.x == 0) { 194 | out_var[plane] = ScalarConvert::to(var); 195 | } 196 | } 197 | 198 | template 199 | __global__ void BatchNormalizationUpdateOutput_kernelhaha( 200 | const DeviceTensor3 input, 201 | DeviceTensor3 output, 202 | const DeviceTensor1 weight, 203 | const DeviceTensor1 bias, 204 | const Acctype epsilon, 205 | const Acctype momentum, 206 | DeviceTensor1 runningMean, 207 | DeviceTensor1 runningVar, 208 | DeviceTensor1 saveMean, 209 | DeviceTensor1 saveVar) { 210 | 211 | 212 | int plane = blockIdx.x; 213 | int N = input.getSize(0) * input.getSize(2); 214 | 215 | 216 | // Compute the mean and variance across (batch, x/y/z) 217 | 218 | /* Acctype norm = Acctype(1) / N; 219 | Acctype mean = reduce(SumOp(input), input, plane) * norm; 220 | __syncthreads(); 221 | Acctype varN = reduce(VarOp(mean, input), input, plane); 222 | Acctype invStd = 0; 223 | if (varN != Acctype(0) || epsilon != Acctype(0)) { 224 | invStd = 1 / sqrt(varN * norm + epsilon); 225 | } */ 226 | 227 | Acctype mean = ScalarConvert::to(saveMean[plane]); 228 | Acctype var = ScalarConvert::to(saveVar[plane]); 229 | Acctype invStd = 1 / sqrt(var + epsilon); 230 | 231 | // Save the mean, variance, and moving averages 232 | if (threadIdx.x == 0) { 233 | // Momentum based writeback 234 | // Acctype unbiasedVar = varN / (N - 1); 235 | Acctype unbiasedVar = var * N / (N - 1); 236 | // saveMean[plane] = ScalarConvert::to(mean); 237 | // saveStd[plane] = ScalarConvert::to(invStd); 238 | runningMean[plane] = ScalarConvert::to((1 - momentum) * runningMean[plane] + momentum * mean); 239 | runningVar[plane] = ScalarConvert::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar); 240 | } 241 | 242 | // Write normalized and update the output 243 | Acctype gamma = weight.numElements() > 0 ? ScalarConvert::to(weight[plane]) : ScalarConvert::to(1); 244 | Acctype beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) : ScalarConvert::to(0); 245 | for (int batch = 0; batch < input.getSize(0); ++batch) { 246 | for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { 247 | Dtype inp = input[batch][plane][x].ldg(); 248 | output[batch][plane][x] = ScalarConvert::to(gamma * (inp - mean) * invStd + beta); 249 | } 250 | } 251 | } 252 | 253 | 254 | template 255 | __global__ void BatchNormalizationMeanGrad_kernel( 256 | const DeviceTensor3 input, 257 | const DeviceTensor3 gradOutput, 258 | const DeviceTensor1 runningMean, 259 | const DeviceTensor1 saveMean, 260 | DeviceTensor1 gradOutputMean_all, 261 | DeviceTensor1 dotP_all, 262 | bool train) { 263 | int plane = blockIdx.x; 264 | int N = gradOutput.getSize(0) * gradOutput.getSize(2); 265 | 266 | Acctype mean; 267 | if (train) { 268 | mean = ScalarConvert::to(saveMean[plane]); 269 | } else { 270 | mean = ScalarConvert::to(runningMean[plane]); 271 | } 272 | 273 | Acctype norm = Acctype(1) / N; 274 | GradOp g(mean, input, gradOutput); 275 | Float2 res = reduce, GradOp, DeviceTensor3>(g, gradOutput, plane); 276 | Acctype gradOutputMean = res.v1 * norm; 277 | Acctype dotP = res.v2 * norm; 278 | 279 | if (threadIdx.x == 0) { 280 | gradOutputMean_all[plane] = ScalarConvert::to(gradOutputMean); 281 | dotP_all[plane] = ScalarConvert::to(dotP); 282 | } 283 | } 284 | 285 | template 286 | __global__ void BatchNormalizationBackward_kernel( 287 | const DeviceTensor3 input, 288 | const DeviceTensor3 gradOutput, 289 | const DeviceTensor1 gradOutputMean, 290 | const DeviceTensor1 dotP_all, 291 | DeviceTensor3 gradInput, 292 | DeviceTensor1 gradWeight, 293 | DeviceTensor1 gradBias, 294 | const DeviceTensor1 weight, 295 | const DeviceTensor1 runningMean, 296 | const DeviceTensor1 runningVar, 297 | const DeviceTensor1 saveMean, 298 | const DeviceTensor1 saveVar, 299 | bool train, 300 | Acctype scale, 301 | double eps) { 302 | 303 | int plane = blockIdx.x; 304 | int N = gradOutput.getSize(0) * gradOutput.getSize(2); 305 | 306 | Acctype mean, stdVal; 307 | if (train) { 308 | mean = ScalarConvert::to(saveMean[plane]); 309 | stdVal = 1 / sqrt(ScalarConvert::to(saveVar[plane]) + eps); 310 | } else { 311 | mean = ScalarConvert::to(runningMean[plane]); 312 | stdVal = 1 / sqrt(runningVar[plane] + eps); 313 | } 314 | 315 | Acctype weightVal = weight.numElements() > 0 ? ScalarConvert::to(weight[plane]) : Acctype(1); 316 | // Acctype norm = Acctype(1) / N; 317 | 318 | // Compute two values across (batch, x/y/z) in one pass: 319 | // 1. Sum(gradOutput) 320 | // 2. DotProduct(input - mean, gradOutput) 321 | // GradOp g(mean, input, gradOutput); 322 | // Float2 res = reduce, GradOp, DeviceTensor3>(g, gradOutput, plane); 323 | // Acctype gradOutputSum = res.v1; 324 | Acctype gradOutputSum = ScalarConvert::to(gradOutputMean[plane]) * N; 325 | // Acctype dotP = res.v2; 326 | Acctype dotP = ScalarConvert::to(dotP_all[plane]); 327 | 328 | // Acctype gradMean = gradOutputSum * norm; 329 | Acctype gradMean = ScalarConvert::to(gradOutputMean[plane]); 330 | // Acctype projScale = dotP * norm * stdVal * stdVal; 331 | Acctype projScale = dotP * stdVal * stdVal; 332 | Acctype gradScale = stdVal * weightVal; 333 | 334 | if (gradInput.numElements() > 0) { 335 | for (int batch = 0; batch < gradOutput.getSize(0); ++batch) { 336 | for (int x = threadIdx.x; x < gradOutput.getSize(2); x += blockDim.x) { 337 | Dtype gradOut = gradOutput[batch][plane][x]; 338 | if (train) { 339 | Dtype inp = input[batch][plane][x]; 340 | Acctype proj = (inp - mean) * projScale; 341 | gradInput[batch][plane][x] = ScalarConvert::to((gradOut - proj - gradMean) * gradScale); 342 | } else { 343 | gradInput[batch][plane][x] = ScalarConvert::to(gradOut * gradScale); 344 | } 345 | } 346 | } 347 | } 348 | 349 | if (gradWeight.numElements() > 0) { 350 | if (threadIdx.x == 0) { 351 | gradWeight[plane] += ScalarConvert::to(scale * dotP * stdVal); 352 | } 353 | } 354 | 355 | if (gradBias.numElements() > 0) { 356 | if (threadIdx.x == 0) { 357 | gradBias[plane] += ScalarConvert::to(scale * gradOutputSum); 358 | } 359 | } 360 | } 361 | 362 | #include "generic/batchnormp_cuda.cu" 363 | #include "THCGenerateFloatTypes.h" -------------------------------------------------------------------------------- /lib/src/batchnormp_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void THNN_CudaBatchNormalization_updateOutputhaha( 4 | THCState *state, THCudaTensor *input_, THCudaTensor *output_, 5 | THCudaTensor *weight_, THCudaTensor *bias_, THCudaTensor *runningMean_, 6 | THCudaTensor *runningVar_, THCudaTensor *saveMean_, THCudaTensor *saveStd_, 7 | int train, double momentum, double eps); 8 | 9 | 10 | void THNN_CudaBatchNormalization_backwardhaha( 11 | THCState *state, THCudaTensor *input_, THCudaTensor *gradOutput_, 12 | THCudaTensor *gradOutputMean_, THCudaTensor *dotP, 13 | THCudaTensor *gradInput_, THCudaTensor *gradWeight_, THCudaTensor *gradBias_, 14 | THCudaTensor *weight_, THCudaTensor *runningMean_, THCudaTensor *runningVar_, 15 | THCudaTensor *saveMean_, THCudaTensor *saveStd_, int train, double scale, 16 | double eps); -------------------------------------------------------------------------------- /lib/src/generic/batchnormp_cuda.cu: -------------------------------------------------------------------------------- 1 | #ifndef THC_GENERIC_FILE 2 | #define THC_GENERIC_FILE "generic/batchnormp_cuda.cu" 3 | #else 4 | 5 | #define DeviceTensor3 THCDeviceTensor 6 | #define DeviceTensor1 THCDeviceTensor 7 | 8 | template 9 | static THCDeviceTensor devicetensor(THCState *state, THCTensor *t) { 10 | if (!t) { 11 | return THCDeviceTensor(); 12 | } 13 | 14 | int inDim = THCTensor_(nDimension)(state, t); 15 | if (inDim == Dim) { 16 | return toDeviceTensor(state, t); 17 | } 18 | 19 | // View in which the last dimensions are collapsed or expanded as needed 20 | THAssert(THCTensor_(isContiguous)(state, t)); 21 | int size[Dim]; 22 | for (int i = 0; i < Dim || i < inDim; ++i) { 23 | if (i < Dim && i < inDim) { 24 | size[i] = t->size[i]; 25 | } else if (i < Dim) { 26 | size[i] = 1; 27 | } else { 28 | size[Dim - 1] *= t->size[i]; 29 | } 30 | } 31 | return THCDeviceTensor(THCTensor_(data)(state, t), size); 32 | } 33 | 34 | extern "C" void THNN_(BatchNormalization_updateOutputhaha)( 35 | THCState *state, THCTensor *input_, THCTensor *output_, 36 | THCTensor *weight_, THCTensor *bias_, THCTensor *runningMean_, 37 | THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_, 38 | int train, double momentum, double eps); 39 | 40 | extern "C" void THNN_(BatchNormalization_mean)( 41 | THCState *state, THCTensor *input_, THCTensor *saveMean_); 42 | 43 | extern "C" void THNN_(BatchNormalization_var)( 44 | THCState *state, THCTensor *input_, THCTensor *saveMean_, 45 | THCTensor *saveVar_); 46 | 47 | 48 | void THNN_(BatchNormalization_mean)( 49 | THCState *state, THCTensor *input_, THCTensor *saveMean_) { 50 | DeviceTensor3 input = devicetensor<3>(state, input_); 51 | DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_); 52 | 53 | cudaStream_t s = THCState_getCurrentStream(state); 54 | cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state); 55 | 56 | dim3 blocks(input.getSize(1)); 57 | dim3 threads(getNumThreads(input.getSize(2))); 58 | BatchNormalizationMean_kernel <<>>( 59 | input, saveMean); 60 | THCudaCheck(cudaGetLastError()); 61 | } 62 | 63 | void THNN_(BatchNormalization_var)( 64 | THCState *state, THCTensor *input_, THCTensor *saveMean_, THCTensor *saveVar_) { 65 | DeviceTensor3 input = devicetensor<3>(state, input_); 66 | DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_); 67 | DeviceTensor1 saveVar = devicetensor<1>(state, saveVar_); 68 | 69 | cudaStream_t s = THCState_getCurrentStream(state); 70 | cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state); 71 | 72 | dim3 blocks(input.getSize(1)); 73 | dim3 threads(getNumThreads(input.getSize(2))); 74 | BatchNormalizationVar_kernel <<>>( 75 | input, saveMean, saveVar); 76 | THCudaCheck(cudaGetLastError()); 77 | } 78 | 79 | void THNN_(BatchNormalization_updateOutputhaha)( 80 | THCState *state, THCTensor *input_, THCTensor *output_, 81 | THCTensor *weight_, THCTensor *bias_, THCTensor *runningMean_, 82 | THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_, 83 | int train, double momentum, double eps) { 84 | 85 | THCTensor_(resizeAs)(state, output_, input_); 86 | DeviceTensor3 input = devicetensor<3>(state, input_); 87 | DeviceTensor3 output = devicetensor<3>(state, output_); 88 | DeviceTensor1 weight = devicetensor<1>(state, weight_); 89 | DeviceTensor1 bias = devicetensor<1>(state, bias_); 90 | DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_); 91 | DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_); 92 | DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_); 93 | DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_); 94 | 95 | cudaStream_t s = THCState_getCurrentStream(state); 96 | cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state); 97 | 98 | if (!train) { 99 | dim3 blocks(input.getSize(1)); 100 | dim3 threads(getNumThreads(input.getSize(2))); 101 | BatchNormalizationUpdateOutputInference_kernel <<>>( 102 | input, output, runningMean, runningVar, weight, bias, eps); 103 | } else { 104 | dim3 blocks(input.getSize(1)); 105 | dim3 threads(getNumThreads(input.getSize(2))); 106 | BatchNormalizationUpdateOutput_kernelhaha <<>>( 107 | input, output, weight, bias, eps, momentum, runningMean, runningVar, 108 | saveMean, saveStd); 109 | } 110 | THCudaCheck(cudaGetLastError()); 111 | } 112 | 113 | extern "C" void THNN_(BatchNormalization_backwardhaha)( 114 | THCState *state, THCTensor *input_, THCTensor *gradOutput_, 115 | THCTensor *gradOutputMean_, THCTensor *dotP, 116 | THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_, 117 | THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_, 118 | THCTensor *saveMean_, THCTensor *saveStd_, int train, double scale, double eps); 119 | 120 | 121 | extern "C" void THNN_(BatchNormalization_mean_grad)( 122 | THCState *state, THCTensor *input_, THCTensor *gradOutput_, 123 | THCTensor *runningMean_, THCTensor *saveMean_, 124 | THCTensor *gradOutputMean_, THCTensor *dotP_, int train); 125 | 126 | 127 | void THNN_(BatchNormalization_mean_grad)( 128 | THCState *state, THCTensor *input_, THCTensor *gradOutput_, 129 | THCTensor *runningMean_, THCTensor *saveMean_, 130 | THCTensor *gradOutputMean_, THCTensor *dotP_, int train) { 131 | 132 | THCUNN_check_shape(state, input_, gradOutput_); 133 | DeviceTensor3 input = devicetensor<3>(state, input_); 134 | DeviceTensor3 gradOutput = devicetensor<3>(state, gradOutput_); 135 | DeviceTensor1 gradOutputMean = devicetensor<1>(state, gradOutputMean_); 136 | DeviceTensor1 dotP = devicetensor<1>(state, dotP_); 137 | 138 | DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_); 139 | DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_); 140 | 141 | cudaStream_t s = THCState_getCurrentStream(state); 142 | 143 | dim3 blocks(gradOutput.getSize(1)); 144 | dim3 threads(getNumThreads(gradOutput.getSize(2))); 145 | BatchNormalizationMeanGrad_kernel <<>>( 146 | input, gradOutput, runningMean, saveMean, gradOutputMean, dotP, train); 147 | THCudaCheck(cudaGetLastError()); 148 | } 149 | 150 | 151 | void THNN_(BatchNormalization_backwardhaha)( 152 | THCState *state, THCTensor *input_, THCTensor *gradOutput_, 153 | THCTensor *gradOutputMean_, THCTensor *dotP_, 154 | THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_, 155 | THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_, 156 | THCTensor *saveMean_, THCTensor *saveStd_, int train, double scale, double eps) { 157 | 158 | THCUNN_check_shape(state, input_, gradOutput_); 159 | DeviceTensor3 input = devicetensor<3>(state, input_); 160 | DeviceTensor3 gradOutput = devicetensor<3>(state, gradOutput_); 161 | DeviceTensor1 gradOutputMean = devicetensor<1>(state, gradOutputMean_); 162 | DeviceTensor1 dotP = devicetensor<1>(state, dotP_); 163 | DeviceTensor3 gradInput = devicetensor<3>(state, gradInput_); 164 | DeviceTensor1 gradWeight = devicetensor<1>(state, gradWeight_); 165 | DeviceTensor1 gradBias = devicetensor<1>(state, gradBias_); 166 | DeviceTensor1 weight = devicetensor<1>(state, weight_); 167 | DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_); 168 | DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_); 169 | DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_); 170 | DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_); 171 | 172 | cudaStream_t s = THCState_getCurrentStream(state); 173 | 174 | dim3 blocks(gradOutput.getSize(1)); 175 | dim3 threads(getNumThreads(gradOutput.getSize(2))); 176 | BatchNormalizationBackward_kernel <<>>( 177 | input, gradOutput, gradOutputMean, dotP, gradInput, gradWeight, gradBias, weight, runningMean, runningVar, 178 | saveMean, saveStd, train, scale, eps); 179 | THCudaCheck(cudaGetLastError()); 180 | } 181 | 182 | #undef DeviceTensor3 183 | #undef DeviceTensor1 184 | 185 | #endif 186 | -------------------------------------------------------------------------------- /lib/test.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import time 3 | import logging 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.autograd import gradcheck 8 | 9 | from modules import batchnormsync 10 | 11 | FORMAT = "[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s" 12 | logging.basicConfig(format=FORMAT) 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.DEBUG) 15 | 16 | batchnormsync.BatchNormSync.checking_mode = True 17 | batchnormsync.BatchNormSync.sync = True 18 | 19 | cuda = True 20 | batch_size = 3 21 | input = torch.randn(3, 3, 2, 2).float() 22 | # input = torch.Tensor(range(60 * batch_size)).float().resize_(batch_size, 3, 2, 2) / 100 23 | bn = batchnormsync.BatchNormSync(3, eps=0, affine=True, 24 | device_ids=None) 25 | bn2 = torch.nn.BatchNorm2d(3, eps=0, affine=False) 26 | # bn.train() 27 | 28 | bn1 = batchnormsync.BatchNormSync(3, eps=0, affine=True, device_ids=[0]) 29 | 30 | bn1.train() 31 | 32 | if cuda: 33 | bn = torch.nn.DataParallel(bn) 34 | bn2 = torch.nn.DataParallel(bn2) 35 | 36 | bn = bn.cuda() 37 | bn1 = bn1.cuda() 38 | bn2 = bn2.cuda() 39 | input = input.cuda() 40 | 41 | 42 | inputs = (Variable(input, requires_grad=True),) 43 | # output = bn(inputs[0]) 44 | 45 | # output1 = bn1(inputs[0]) 46 | # output2 = bn2(inputs[0]) 47 | # print((output1 - output2).abs().max()) 48 | # print((output - output2).abs().max()) 49 | # test = gradcheck(bn, inputs, eps=1e-4, atol=1e-4, rtol=1e-8) 50 | for i in range(1000): 51 | logger.info(i) 52 | start_time = time.time() 53 | test = gradcheck(bn, inputs, eps=1e-4, atol=1e-2, rtol=1e-3) 54 | logger.info('%s %f', test, time.time() - start_time) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import json 6 | import logging 7 | import math 8 | import os 9 | from os.path import exists, join, split 10 | import threading 11 | 12 | import time 13 | 14 | import numpy as np 15 | import shutil 16 | 17 | import sys 18 | from PIL import Image 19 | import torch 20 | from torch import nn 21 | import torch.backends.cudnn as cudnn 22 | import torch.optim as optim 23 | from torchvision import datasets, transforms 24 | from torch.autograd import Variable 25 | 26 | import drn 27 | import data_transforms as transforms 28 | 29 | try: 30 | from modules import batchnormsync 31 | except ImportError: 32 | pass 33 | 34 | FORMAT = "[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s" 35 | logging.basicConfig(format=FORMAT) 36 | logger = logging.getLogger(__name__) 37 | logger.setLevel(logging.DEBUG) 38 | 39 | 40 | CITYSCAPE_PALETTE = np.asarray([ 41 | [128, 64, 128], 42 | [244, 35, 232], 43 | [70, 70, 70], 44 | [102, 102, 156], 45 | [190, 153, 153], 46 | [153, 153, 153], 47 | [250, 170, 30], 48 | [220, 220, 0], 49 | [107, 142, 35], 50 | [152, 251, 152], 51 | [70, 130, 180], 52 | [220, 20, 60], 53 | [255, 0, 0], 54 | [0, 0, 142], 55 | [0, 0, 70], 56 | [0, 60, 100], 57 | [0, 80, 100], 58 | [0, 0, 230], 59 | [119, 11, 32], 60 | [0, 0, 0]], dtype=np.uint8) 61 | 62 | 63 | TRIPLET_PALETTE = np.asarray([ 64 | [0, 0, 0, 255], 65 | [217, 83, 79, 255], 66 | [91, 192, 222, 255]], dtype=np.uint8) 67 | 68 | 69 | def fill_up_weights(up): 70 | w = up.weight.data 71 | f = math.ceil(w.size(2) / 2) 72 | c = (2 * f - 1 - f % 2) / (2. * f) 73 | for i in range(w.size(2)): 74 | for j in range(w.size(3)): 75 | w[0, 0, i, j] = \ 76 | (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 77 | for c in range(1, w.size(0)): 78 | w[c, 0, :, :] = w[0, 0, :, :] 79 | 80 | 81 | class DRNSeg(nn.Module): 82 | def __init__(self, model_name, classes, pretrained_model=None, 83 | pretrained=True, use_torch_up=False): 84 | super(DRNSeg, self).__init__() 85 | model = drn.__dict__.get(model_name)( 86 | pretrained=pretrained, num_classes=1000) 87 | pmodel = nn.DataParallel(model) 88 | if pretrained_model is not None: 89 | pmodel.load_state_dict(pretrained_model) 90 | self.base = nn.Sequential(*list(model.children())[:-2]) 91 | 92 | self.seg = nn.Conv2d(model.out_dim, classes, 93 | kernel_size=1, bias=True) 94 | self.softmax = nn.LogSoftmax() 95 | m = self.seg 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | m.bias.data.zero_() 99 | if use_torch_up: 100 | self.up = nn.UpsamplingBilinear2d(scale_factor=8) 101 | else: 102 | up = nn.ConvTranspose2d(classes, classes, 16, stride=8, padding=4, 103 | output_padding=0, groups=classes, 104 | bias=False) 105 | fill_up_weights(up) 106 | up.weight.requires_grad = False 107 | self.up = up 108 | 109 | def forward(self, x): 110 | x = self.base(x) 111 | x = self.seg(x) 112 | y = self.up(x) 113 | return self.softmax(y), x 114 | 115 | def optim_parameters(self, memo=None): 116 | for param in self.base.parameters(): 117 | yield param 118 | for param in self.seg.parameters(): 119 | yield param 120 | 121 | 122 | class SegList(torch.utils.data.Dataset): 123 | def __init__(self, data_dir, phase, transforms, list_dir=None, 124 | out_name=False): 125 | self.list_dir = data_dir if list_dir is None else list_dir 126 | self.data_dir = data_dir 127 | self.out_name = out_name 128 | self.phase = phase 129 | self.transforms = transforms 130 | self.image_list = None 131 | self.label_list = None 132 | self.bbox_list = None 133 | self.read_lists() 134 | 135 | def __getitem__(self, index): 136 | data = [Image.open(join(self.data_dir, self.image_list[index]))] 137 | if self.label_list is not None: 138 | data.append(Image.open( 139 | join(self.data_dir, self.label_list[index]))) 140 | data = list(self.transforms(*data)) 141 | if self.out_name: 142 | if self.label_list is None: 143 | data.append(data[0][0, :, :]) 144 | data.append(self.image_list[index]) 145 | return tuple(data) 146 | 147 | def __len__(self): 148 | return len(self.image_list) 149 | 150 | def read_lists(self): 151 | image_path = join(self.list_dir, self.phase + '_images.txt') 152 | label_path = join(self.list_dir, self.phase + '_labels.txt') 153 | assert exists(image_path) 154 | self.image_list = [line.strip() for line in open(image_path, 'r')] 155 | if exists(label_path): 156 | self.label_list = [line.strip() for line in open(label_path, 'r')] 157 | assert len(self.image_list) == len(self.label_list) 158 | 159 | 160 | class SegListMS(torch.utils.data.Dataset): 161 | def __init__(self, data_dir, phase, transforms, scales, list_dir=None): 162 | self.list_dir = data_dir if list_dir is None else list_dir 163 | self.data_dir = data_dir 164 | self.phase = phase 165 | self.transforms = transforms 166 | self.image_list = None 167 | self.label_list = None 168 | self.bbox_list = None 169 | self.read_lists() 170 | self.scales = scales 171 | 172 | def __getitem__(self, index): 173 | data = [Image.open(join(self.data_dir, self.image_list[index]))] 174 | w, h = data[0].size 175 | if self.label_list is not None: 176 | data.append(Image.open( 177 | join(self.data_dir, self.label_list[index]))) 178 | # data = list(self.transforms(*data)) 179 | out_data = list(self.transforms(*data)) 180 | ms_images = [self.transforms(data[0].resize((int(w * s), int(h * s)), 181 | Image.BICUBIC))[0] 182 | for s in self.scales] 183 | out_data.append(self.image_list[index]) 184 | out_data.extend(ms_images) 185 | return tuple(out_data) 186 | 187 | def __len__(self): 188 | return len(self.image_list) 189 | 190 | def read_lists(self): 191 | image_path = join(self.list_dir, self.phase + '_images.txt') 192 | label_path = join(self.list_dir, self.phase + '_labels.txt') 193 | assert exists(image_path) 194 | self.image_list = [line.strip() for line in open(image_path, 'r')] 195 | if exists(label_path): 196 | self.label_list = [line.strip() for line in open(label_path, 'r')] 197 | assert len(self.image_list) == len(self.label_list) 198 | 199 | 200 | def validate(val_loader, model, criterion, eval_score=None, print_freq=10): 201 | batch_time = AverageMeter() 202 | losses = AverageMeter() 203 | score = AverageMeter() 204 | 205 | # switch to evaluate mode 206 | model.eval() 207 | 208 | end = time.time() 209 | for i, (input, target) in enumerate(val_loader): 210 | if type(criterion) in [torch.nn.modules.loss.L1Loss, 211 | torch.nn.modules.loss.MSELoss]: 212 | target = target.float() 213 | input = input.cuda() 214 | target = target.cuda(async=True) 215 | input_var = torch.autograd.Variable(input, volatile=True) 216 | target_var = torch.autograd.Variable(target, volatile=True) 217 | 218 | # compute output 219 | output = model(input_var)[0] 220 | loss = criterion(output, target_var) 221 | 222 | # measure accuracy and record loss 223 | # prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 224 | losses.update(loss.data[0], input.size(0)) 225 | if eval_score is not None: 226 | score.update(eval_score(output, target_var), input.size(0)) 227 | 228 | # measure elapsed time 229 | batch_time.update(time.time() - end) 230 | end = time.time() 231 | 232 | if i % print_freq == 0: 233 | logger.info('Test: [{0}/{1}]\t' 234 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 235 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 236 | 'Score {score.val:.3f} ({score.avg:.3f})'.format( 237 | i, len(val_loader), batch_time=batch_time, loss=losses, 238 | score=score)) 239 | 240 | logger.info(' * Score {top1.avg:.3f}'.format(top1=score)) 241 | 242 | return score.avg 243 | 244 | 245 | class AverageMeter(object): 246 | """Computes and stores the average and current value""" 247 | def __init__(self): 248 | self.reset() 249 | 250 | def reset(self): 251 | self.val = 0 252 | self.avg = 0 253 | self.sum = 0 254 | self.count = 0 255 | 256 | def update(self, val, n=1): 257 | self.val = val 258 | self.sum += val * n 259 | self.count += n 260 | self.avg = self.sum / self.count 261 | 262 | 263 | def accuracy(output, target): 264 | """Computes the precision@k for the specified values of k""" 265 | # batch_size = target.size(0) * target.size(1) * target.size(2) 266 | _, pred = output.max(1) 267 | pred = pred.view(1, -1) 268 | target = target.view(1, -1) 269 | correct = pred.eq(target) 270 | correct = correct[target != 255] 271 | correct = correct.view(-1) 272 | score = correct.float().sum(0).mul(100.0 / correct.size(0)) 273 | return score.data[0] 274 | 275 | 276 | def train(train_loader, model, criterion, optimizer, epoch, 277 | eval_score=None, print_freq=10): 278 | batch_time = AverageMeter() 279 | data_time = AverageMeter() 280 | losses = AverageMeter() 281 | scores = AverageMeter() 282 | 283 | # switch to train mode 284 | model.train() 285 | 286 | end = time.time() 287 | 288 | for i, (input, target) in enumerate(train_loader): 289 | # measure data loading time 290 | data_time.update(time.time() - end) 291 | 292 | if type(criterion) in [torch.nn.modules.loss.L1Loss, 293 | torch.nn.modules.loss.MSELoss]: 294 | target = target.float() 295 | 296 | input = input.cuda() 297 | target = target.cuda(async=True) 298 | input_var = torch.autograd.Variable(input) 299 | target_var = torch.autograd.Variable(target) 300 | 301 | # compute output 302 | output = model(input_var)[0] 303 | loss = criterion(output, target_var) 304 | 305 | # measure accuracy and record loss 306 | # prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 307 | losses.update(loss.data[0], input.size(0)) 308 | if eval_score is not None: 309 | scores.update(eval_score(output, target_var), input.size(0)) 310 | 311 | # compute gradient and do SGD step 312 | optimizer.zero_grad() 313 | loss.backward() 314 | optimizer.step() 315 | 316 | # measure elapsed time 317 | batch_time.update(time.time() - end) 318 | end = time.time() 319 | 320 | if i % print_freq == 0: 321 | logger.info('Epoch: [{0}][{1}/{2}]\t' 322 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 323 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 324 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 325 | 'Score {top1.val:.3f} ({top1.avg:.3f})'.format( 326 | epoch, i, len(train_loader), batch_time=batch_time, 327 | data_time=data_time, loss=losses, top1=scores)) 328 | 329 | 330 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 331 | torch.save(state, filename) 332 | if is_best: 333 | shutil.copyfile(filename, 'model_best.pth.tar') 334 | 335 | 336 | def train_seg(args): 337 | batch_size = args.batch_size 338 | num_workers = args.workers 339 | crop_size = args.crop_size 340 | 341 | print(' '.join(sys.argv)) 342 | 343 | for k, v in args.__dict__.items(): 344 | print(k, ':', v) 345 | 346 | single_model = DRNSeg(args.arch, args.classes, None, 347 | pretrained=True) 348 | if args.pretrained: 349 | single_model.load_state_dict(torch.load(args.pretrained)) 350 | model = torch.nn.DataParallel(single_model).cuda() 351 | criterion = nn.NLLLoss2d(ignore_index=255) 352 | 353 | criterion.cuda() 354 | 355 | # Data loading code 356 | data_dir = args.data_dir 357 | info = json.load(open(join(data_dir, 'info.json'), 'r')) 358 | normalize = transforms.Normalize(mean=info['mean'], 359 | std=info['std']) 360 | t = [] 361 | if args.random_rotate > 0: 362 | t.append(transforms.RandomRotate(args.random_rotate)) 363 | if args.random_scale > 0: 364 | t.append(transforms.RandomScale(args.random_scale)) 365 | t.extend([transforms.RandomCrop(crop_size), 366 | transforms.RandomHorizontalFlip(), 367 | transforms.ToTensor(), 368 | normalize]) 369 | train_loader = torch.utils.data.DataLoader( 370 | SegList(data_dir, 'train', transforms.Compose(t), 371 | list_dir=args.list_dir), 372 | batch_size=batch_size, shuffle=True, num_workers=num_workers, 373 | pin_memory=True, drop_last=True 374 | ) 375 | val_loader = torch.utils.data.DataLoader( 376 | SegList(data_dir, 'val', transforms.Compose([ 377 | transforms.RandomCrop(crop_size), 378 | transforms.ToTensor(), 379 | normalize, 380 | ]), list_dir=args.list_dir), 381 | batch_size=batch_size, shuffle=False, num_workers=num_workers, 382 | pin_memory=True, drop_last=True 383 | ) 384 | 385 | # define loss function (criterion) and pptimizer 386 | optimizer = torch.optim.SGD(single_model.optim_parameters(), 387 | args.lr, 388 | momentum=args.momentum, 389 | weight_decay=args.weight_decay) 390 | 391 | cudnn.benchmark = True 392 | best_prec1 = 0 393 | start_epoch = 0 394 | 395 | # optionally resume from a checkpoint 396 | if args.resume: 397 | if os.path.isfile(args.resume): 398 | print("=> loading checkpoint '{}'".format(args.resume)) 399 | checkpoint = torch.load(args.resume) 400 | start_epoch = checkpoint['epoch'] 401 | best_prec1 = checkpoint['best_prec1'] 402 | model.load_state_dict(checkpoint['state_dict']) 403 | print("=> loaded checkpoint '{}' (epoch {})" 404 | .format(args.resume, checkpoint['epoch'])) 405 | else: 406 | print("=> no checkpoint found at '{}'".format(args.resume)) 407 | 408 | if args.evaluate: 409 | validate(val_loader, model, criterion, eval_score=accuracy) 410 | return 411 | 412 | for epoch in range(start_epoch, args.epochs): 413 | lr = adjust_learning_rate(args, optimizer, epoch) 414 | logger.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) 415 | # train for one epoch 416 | train(train_loader, model, criterion, optimizer, epoch, 417 | eval_score=accuracy) 418 | 419 | # evaluate on validation set 420 | prec1 = validate(val_loader, model, criterion, eval_score=accuracy) 421 | 422 | is_best = prec1 > best_prec1 423 | best_prec1 = max(prec1, best_prec1) 424 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_latest.pth.tar') 425 | save_checkpoint({ 426 | 'epoch': epoch + 1, 427 | 'arch': args.arch, 428 | 'state_dict': model.state_dict(), 429 | 'best_prec1': best_prec1, 430 | }, is_best, filename=checkpoint_path) 431 | if (epoch + 1) % args.save_iter == 0: 432 | history_path = os.path.join(args.save_path, 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)) 433 | shutil.copyfile(checkpoint_path, history_path) 434 | 435 | 436 | def adjust_learning_rate(args, optimizer, epoch): 437 | """ 438 | Sets the learning rate to the initial LR decayed by 10 every 30 epochs 439 | """ 440 | if args.lr_mode == 'step': 441 | lr = args.lr * (0.1 ** (epoch // args.step)) 442 | elif args.lr_mode == 'poly': 443 | lr = args.lr * (1 - epoch / args.epochs) ** 0.9 444 | else: 445 | raise ValueError('Unknown lr mode {}'.format(args.lr_mode)) 446 | 447 | for param_group in optimizer.param_groups: 448 | param_group['lr'] = lr 449 | return lr 450 | 451 | 452 | def fast_hist(pred, label, n): 453 | k = (label >= 0) & (label < n) 454 | return np.bincount( 455 | n * label[k].astype(int) + pred[k], minlength=n ** 2).reshape(n, n) 456 | 457 | 458 | def per_class_iu(hist): 459 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 460 | 461 | 462 | def save_output_images(predictions, filenames, output_dir): 463 | """ 464 | Saves a given (B x C x H x W) into an image file. 465 | If given a mini-batch tensor, will save the tensor as a grid of images. 466 | """ 467 | # pdb.set_trace() 468 | for ind in range(len(filenames)): 469 | im = Image.fromarray(predictions[ind].astype(np.uint8)) 470 | fn = os.path.join(output_dir, filenames[ind][:-4] + '.png') 471 | out_dir = split(fn)[0] 472 | if not exists(out_dir): 473 | os.makedirs(out_dir) 474 | im.save(fn) 475 | 476 | 477 | def save_colorful_images(predictions, filenames, output_dir, palettes): 478 | """ 479 | Saves a given (B x C x H x W) into an image file. 480 | If given a mini-batch tensor, will save the tensor as a grid of images. 481 | """ 482 | for ind in range(len(filenames)): 483 | im = Image.fromarray(palettes[predictions[ind].squeeze()]) 484 | fn = os.path.join(output_dir, filenames[ind][:-4] + '.png') 485 | out_dir = split(fn)[0] 486 | if not exists(out_dir): 487 | os.makedirs(out_dir) 488 | im.save(fn) 489 | 490 | 491 | def test(eval_data_loader, model, num_classes, 492 | output_dir='pred', has_gt=True, save_vis=False): 493 | model.eval() 494 | batch_time = AverageMeter() 495 | data_time = AverageMeter() 496 | end = time.time() 497 | hist = np.zeros((num_classes, num_classes)) 498 | for iter, (image, label, name) in enumerate(eval_data_loader): 499 | data_time.update(time.time() - end) 500 | image_var = Variable(image, requires_grad=False, volatile=True) 501 | final = model(image_var)[0] 502 | _, pred = torch.max(final, 1) 503 | pred = pred.cpu().data.numpy() 504 | batch_time.update(time.time() - end) 505 | if save_vis: 506 | save_output_images(pred, name, output_dir) 507 | save_colorful_images( 508 | pred, name, output_dir + '_color', 509 | TRIPLET_PALETTE if num_classes == 3 else CITYSCAPE_PALETTE) 510 | if has_gt: 511 | label = label.numpy() 512 | hist += fast_hist(pred.flatten(), label.flatten(), num_classes) 513 | logger.info('===> mAP {mAP:.3f}'.format( 514 | mAP=round(np.nanmean(per_class_iu(hist)) * 100, 2))) 515 | end = time.time() 516 | logger.info('Eval: [{0}/{1}]\t' 517 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 518 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 519 | .format(iter, len(eval_data_loader), batch_time=batch_time, 520 | data_time=data_time)) 521 | if has_gt: #val 522 | ious = per_class_iu(hist) * 100 523 | logger.info(' '.join('{:.03f}'.format(i) for i in ious)) 524 | return round(np.nanmean(ious), 2) 525 | 526 | 527 | def resize_4d_tensor(tensor, width, height): 528 | tensor_cpu = tensor.cpu().numpy() 529 | if tensor.size(2) == height and tensor.size(3) == width: 530 | return tensor_cpu 531 | out_size = (tensor.size(0), tensor.size(1), height, width) 532 | out = np.empty(out_size, dtype=np.float32) 533 | 534 | def resize_one(i, j): 535 | out[i, j] = np.array( 536 | Image.fromarray(tensor_cpu[i, j]).resize( 537 | (width, height), Image.BILINEAR)) 538 | 539 | def resize_channel(j): 540 | for i in range(tensor.size(0)): 541 | out[i, j] = np.array( 542 | Image.fromarray(tensor_cpu[i, j]).resize( 543 | (width, height), Image.BILINEAR)) 544 | 545 | # workers = [threading.Thread(target=resize_one, args=(i, j)) 546 | # for i in range(tensor.size(0)) for j in range(tensor.size(1))] 547 | 548 | workers = [threading.Thread(target=resize_channel, args=(j,)) 549 | for j in range(tensor.size(1))] 550 | for w in workers: 551 | w.start() 552 | for w in workers: 553 | w.join() 554 | # for i in range(tensor.size(0)): 555 | # for j in range(tensor.size(1)): 556 | # out[i, j] = np.array( 557 | # Image.fromarray(tensor_cpu[i, j]).resize( 558 | # (w, h), Image.BILINEAR)) 559 | # out = tensor.new().resize_(*out.shape).copy_(torch.from_numpy(out)) 560 | return out 561 | 562 | 563 | def test_ms(eval_data_loader, model, num_classes, scales, 564 | output_dir='pred', has_gt=True, save_vis=False): 565 | model.eval() 566 | batch_time = AverageMeter() 567 | data_time = AverageMeter() 568 | end = time.time() 569 | hist = np.zeros((num_classes, num_classes)) 570 | num_scales = len(scales) 571 | for iter, input_data in enumerate(eval_data_loader): 572 | data_time.update(time.time() - end) 573 | if has_gt: 574 | name = input_data[2] 575 | label = input_data[1] 576 | else: 577 | name = input_data[1] 578 | h, w = input_data[0].size()[2:4] 579 | images = [input_data[0]] 580 | images.extend(input_data[-num_scales:]) 581 | # pdb.set_trace() 582 | outputs = [] 583 | for image in images: 584 | image_var = Variable(image, requires_grad=False, volatile=True) 585 | final = model(image_var)[0] 586 | outputs.append(final.data) 587 | final = sum([resize_4d_tensor(out, w, h) for out in outputs]) 588 | # _, pred = torch.max(torch.from_numpy(final), 1) 589 | # pred = pred.cpu().numpy() 590 | pred = final.argmax(axis=1) 591 | batch_time.update(time.time() - end) 592 | if save_vis: 593 | save_output_images(pred, name, output_dir) 594 | save_colorful_images(pred, name, output_dir + '_color', 595 | CITYSCAPE_PALETTE) 596 | if has_gt: 597 | label = label.numpy() 598 | hist += fast_hist(pred.flatten(), label.flatten(), num_classes) 599 | logger.info('===> mAP {mAP:.3f}'.format( 600 | mAP=round(np.nanmean(per_class_iu(hist)) * 100, 2))) 601 | end = time.time() 602 | logger.info('Eval: [{0}/{1}]\t' 603 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 604 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 605 | .format(iter, len(eval_data_loader), batch_time=batch_time, 606 | data_time=data_time)) 607 | if has_gt: #val 608 | ious = per_class_iu(hist) * 100 609 | logger.info(' '.join('{:.03f}'.format(i) for i in ious)) 610 | return round(np.nanmean(ious), 2) 611 | 612 | 613 | def test_seg(args): 614 | batch_size = args.batch_size 615 | num_workers = args.workers 616 | phase = args.phase 617 | 618 | for k, v in args.__dict__.items(): 619 | print(k, ':', v) 620 | 621 | single_model = DRNSeg(args.arch, args.classes, pretrained_model=None, 622 | pretrained=False) 623 | if args.pretrained: 624 | single_model.load_state_dict(torch.load(args.pretrained)) 625 | model = torch.nn.DataParallel(single_model).cuda() 626 | 627 | data_dir = args.data_dir 628 | info = json.load(open(join(data_dir, 'info.json'), 'r')) 629 | normalize = transforms.Normalize(mean=info['mean'], std=info['std']) 630 | scales = [0.5, 0.75, 1.25, 1.5, 1.75] 631 | if args.ms: 632 | dataset = SegListMS(data_dir, phase, transforms.Compose([ 633 | transforms.ToTensor(), 634 | normalize, 635 | ]), scales, list_dir=args.list_dir) 636 | else: 637 | dataset = SegList(data_dir, phase, transforms.Compose([ 638 | transforms.ToTensor(), 639 | normalize, 640 | ]), list_dir=args.list_dir, out_name=True) 641 | test_loader = torch.utils.data.DataLoader( 642 | dataset, 643 | batch_size=batch_size, shuffle=False, num_workers=num_workers, 644 | pin_memory=False 645 | ) 646 | 647 | cudnn.benchmark = True 648 | 649 | # optionally resume from a checkpoint 650 | start_epoch = 0 651 | if args.resume: 652 | if os.path.isfile(args.resume): 653 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 654 | checkpoint = torch.load(args.resume) 655 | start_epoch = checkpoint['epoch'] 656 | best_prec1 = checkpoint['best_prec1'] 657 | model.load_state_dict(checkpoint['state_dict']) 658 | logger.info("=> loaded checkpoint '{}' (epoch {})" 659 | .format(args.resume, checkpoint['epoch'])) 660 | else: 661 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 662 | 663 | out_dir = '{}_{:03d}_{}'.format(args.arch, start_epoch, phase) 664 | if len(args.test_suffix) > 0: 665 | out_dir += '_' + args.test_suffix 666 | if args.ms: 667 | out_dir += '_ms' 668 | 669 | if args.ms: 670 | mAP = test_ms(test_loader, model, args.classes, save_vis=True, 671 | has_gt=phase != 'test' or args.with_gt, 672 | output_dir=out_dir, 673 | scales=scales) 674 | else: 675 | mAP = test(test_loader, model, args.classes, save_vis=True, 676 | has_gt=phase != 'test' or args.with_gt, output_dir=out_dir) 677 | logger.info('mAP: %f', mAP) 678 | 679 | 680 | def parse_args(): 681 | # Training settings 682 | parser = argparse.ArgumentParser(description='') 683 | parser.add_argument('cmd', choices=['train', 'test']) 684 | parser.add_argument('-d', '--data-dir', default=None, required=True) 685 | parser.add_argument('-l', '--list-dir', default=None, 686 | help='List dir to look for train_images.txt etc. ' 687 | 'It is the same with --data-dir if not set.') 688 | parser.add_argument('-c', '--classes', default=0, type=int) 689 | parser.add_argument('-s', '--crop-size', default=0, type=int) 690 | parser.add_argument('--step', type=int, default=200) 691 | parser.add_argument('--arch') 692 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 693 | help='input batch size for training (default: 64)') 694 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 695 | help='number of epochs to train (default: 10)') 696 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 697 | help='learning rate (default: 0.01)') 698 | parser.add_argument('--lr-mode', type=str, default='step') 699 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 700 | help='SGD momentum (default: 0.9)') 701 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 702 | metavar='W', help='weight decay (default: 1e-4)') 703 | parser.add_argument('-e', '--evaluate', dest='evaluate', 704 | action='store_true', 705 | help='evaluate model on validation set') 706 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 707 | help='path to latest checkpoint (default: none)') 708 | parser.add_argument('--pretrained', dest='pretrained', 709 | default='', type=str, metavar='PATH', 710 | help='use pre-trained model') 711 | parser.add_argument('--save_path', default='', type=str, metavar='PATH', 712 | help='output path for training checkpoints') 713 | parser.add_argument('--save_iter', default=1, type=int, 714 | help='number of training iterations between' 715 | 'checkpoint history saves') 716 | parser.add_argument('-j', '--workers', type=int, default=8) 717 | parser.add_argument('--load-release', dest='load_rel', default=None) 718 | parser.add_argument('--phase', default='val') 719 | parser.add_argument('--random-scale', default=0, type=float) 720 | parser.add_argument('--random-rotate', default=0, type=int) 721 | parser.add_argument('--bn-sync', action='store_true') 722 | parser.add_argument('--ms', action='store_true', 723 | help='Turn on multi-scale testing') 724 | parser.add_argument('--with-gt', action='store_true') 725 | parser.add_argument('--test-suffix', default='', type=str) 726 | args = parser.parse_args() 727 | 728 | assert args.classes > 0 729 | 730 | print(' '.join(sys.argv)) 731 | print(args) 732 | 733 | if args.bn_sync: 734 | drn.BatchNorm = batchnormsync.BatchNormSync 735 | 736 | return args 737 | 738 | 739 | def main(): 740 | args = parse_args() 741 | if args.cmd == 'train': 742 | train_seg(args) 743 | elif args.cmd == 'test': 744 | test_seg(args) 745 | 746 | 747 | if __name__ == '__main__': 748 | main() 749 | --------------------------------------------------------------------------------