├── .gitignore ├── README.md ├── cfgs ├── fishnet150.yaml └── fishnet99.yaml ├── checkpoints └── README.md ├── head_pic.jpg ├── logs └── README.md ├── main.py ├── models ├── __init__.py ├── fish_block.py ├── fishnet.py └── net_factory.py └── utils ├── __init__.py ├── data_aug.py └── profile.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | models.txt 3 | train.sh 4 | cfgs/local_test.yaml 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FishNet 2 | 3 | ![ ](head_pic.jpg) 4 | 5 | This repo holds the implementation code of the paper: 6 | 7 | [FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction](http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf) 8 | , Shuyang Sun, Jiangmiao Pang, Jianping Shi, Shuai Yi, Wanli Ouyang, NeurIPS 2018. 9 | 10 | FishNet was used as a key component 11 | for winning the 1st place in [COCO Detection Challenge 2018](http://cocodataset.org/#detection-leaderboard). 12 | 13 | Note that the results released here are a bit better than what we have reported in the paper. 14 | 15 | ### Prerequisites 16 | - Python 3.6.x 17 | - PyTorch 0.4.0+ 18 | 19 | ### Data Augmentation 20 | 21 | | Method | Settings | 22 | | ----- | -------- | 23 | | Random Flip | True | 24 | | Random Crop | 8% ~ 100% | 25 | | Aspect Ratio| 3/4 ~ 4/3 | 26 | | Random PCA Lighting | 0.1 | 27 | 28 | **Note**: We apply weight decay to all weights and biases instead of just the weights of the convolution layers. 29 | 30 | ### Training 31 | To train FishNet-150 with 8 GPUs and batch size 256, simply run 32 | ``` 33 | python main.py --config "cfgs/fishnet150.yaml" IMAGENET_ROOT_PATH 34 | ``` 35 | 36 | ### Models 37 | **Models trained without tricks** 38 | 39 | | Model | Params | FLOPs | Top-1 | Top-5 | Baidu Yun | Google Cloud | 40 | | ---------- | ------ | ----- | ------ | ----- | --------- | ------------ | 41 | | FishNet99 | 16.62M | 4.31G | 77.41% | 93.59% | [Download](https://pan.baidu.com/s/11U3sRod1VfbDBRbmXph6KA)| [Download](https://www.dropbox.com/s/hvojbdsad5ue7yb/fishnet99_ckpt.tar?dl=0) | 42 | | FishNet150 | 24.96M | 6.45G | 78.14% | 93.95% | [Download](https://pan.baidu.com/s/1uOEFsBHIdqpDLrbfCZJGUg)| [Download](https://www.dropbox.com/s/hjadcef18ln3o2v/fishnet150_ckpt.tar?dl=0) 43 | | FishNet201 | 44.58M | 10.58G| 78.76% | 94.39% | Available Soon | Available Soon | 44 | 45 | **Models trained with cosine lr schedule (200 epochs) and label smoothing** 46 | 47 | | Model | Params | FLOPs | Top-1 | Top-5 | Baidu Yun | Google Cloud | 48 | | ---------- | ------ | ----- | ------ | ----- | --------- | ------------ | 49 | | FishNet150 | 24.96M | 6.45G | 79.35% | 94.75% | [Download](https://pan.baidu.com/s/1pt31cp-xGcsRJKZAPcp4yQ) | [Download](https://www.dropbox.com/s/ajy9p6f97y45f1r/fishnet150_ckpt_welltrained.tar?dl=0) | 50 | | FishNet201 | 44.58M | 10.58G| 79.71% | 94.79% | [Download]() | [Download](https://www.dropbox.com/s/kvz2dmxe3fzn10m/fishnet201_ckpt_welltrain.tar?dl=0) | 51 | 52 | To load these models, e.g. FishNet150, you need to first construct your FishNet150 structure like: 53 | 54 | ``` 55 | from models.network_factory import fishnet150 56 | model = fishnet150() 57 | ``` 58 | 59 | and then you can load the weights from the pre-trained checkpoint by: 60 | ``` 61 | checkpoint = torch.load(model_path) # model_path: your checkpoint path, e.g. checkpoints/fishnet150.tar 62 | best_prec1 = checkpoint['best_prec1'] 63 | model.load_state_dict(checkpoint['state_dict']) 64 | optimizer.load_state_dict(checkpoint['optimizer']) 65 | ``` 66 | 67 | Note that you do **NOT** need to decompress the model using the ```tar``` command. 68 | The model you download from the cloud could be loaded directly. 69 | 70 | ### TODO: 71 | - [x] Update our arxiv paper. 72 | - [x] Release pre-train models. 73 | - [ ] Train the model with more training tricks. 74 | 75 | ### Citation 76 | 77 | If you find our research useful, please cite the paper: 78 | ``` 79 | @inproceedings{sun2018fishnet, 80 | title={FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction}, 81 | author={Sun, Shuyang and Pang, Jiangmiao and Shi, Jianping and Yi, Shuai and Ouyang, Wanli}, 82 | booktitle={Advances in Neural Information Processing Systems}, 83 | pages={760--770}, 84 | year={2018} 85 | } 86 | ``` 87 | 88 | ### Contact 89 | You can contact Shuyang Sun by sending email to kevin.sysun@gmail.com 90 | -------------------------------------------------------------------------------- /cfgs/fishnet150.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | arch: fishnet150 3 | workers: 1 4 | batch_size: 32 5 | 6 | epochs: 100 7 | policy: "step" 8 | base_lr: 0.1 9 | momentum: 0.9 10 | weight_decay: 0.0001 11 | 12 | print_freq: 10 13 | save_path: checkpoints/fishnet150_bs256 14 | -------------------------------------------------------------------------------- /cfgs/fishnet99.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | arch: fishnet99 3 | workers: 1 4 | batch_size: 32 5 | 6 | epochs: 100 7 | policy: "step" 8 | base_lr: 0.1 9 | momentum: 0.9 10 | weight_decay: 0.0001 11 | 12 | print_freq: 10 13 | save_path: checkpoints/fishnet99_bs256 -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | This is a place holder -------------------------------------------------------------------------------- /head_pic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevin-ssy/FishNet/b968f0244827e11201471edd8a979bd85027b991/head_pic.jpg -------------------------------------------------------------------------------- /logs/README.md: -------------------------------------------------------------------------------- 1 | This is a place holder -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import yaml 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | from utils.profile import count_params 13 | from utils.data_aug import ColorAugmentation 14 | import os 15 | from torch.autograd.variable import Variable 16 | import models 17 | 18 | 19 | model_names = sorted(name for name in models.__dict__ 20 | if name.islower() and not name.startswith("__") 21 | and callable(models.__dict__[name])) 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('data', metavar='DIR', help='path to dataset') 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 26 | choices=model_names, 27 | help='models architecture: ' + 28 | ' | '.join(model_names) + 29 | ' (default: resnet18)') 30 | parser.add_argument('--config', default='cfgs/local_test.yaml') 31 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 32 | help='number of data loading workers (default: 4)') 33 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 34 | help='number of total epochs to run') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('-b', '--batch-size', default=32, type=int, 38 | metavar='N', help='mini-batch size (default: 256)') 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 40 | metavar='LR', help='initial learning rate') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 44 | metavar='W', help='weight decay (default: 1e-4)') 45 | parser.add_argument('--print-freq', '-p', default=10, type=int, 46 | metavar='N', help='print frequency (default: 10)') 47 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 48 | help='path to latest checkpoint (default: none)') 49 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 50 | help='evaluate models on validation set') 51 | parser.add_argument('--train_image_list', default='', type=str, help='path to train image list') 52 | 53 | parser.add_argument('--input_size', default=224, type=int, help='img crop size') 54 | parser.add_argument('--image_size', default=256, type=int, help='ori img size') 55 | 56 | parser.add_argument('--model_name', default='', type=str, help='name of the models') 57 | 58 | best_prec1 = 0 59 | 60 | model_names = sorted(name for name in models.__dict__ 61 | if name.islower() and not name.startswith("__") 62 | and callable(models.__dict__[name])) 63 | 64 | USE_GPU = torch.cuda.is_available() 65 | 66 | 67 | def main(): 68 | global args, best_prec1, USE_GPU 69 | args = parser.parse_args() 70 | 71 | with open(args.config) as f: 72 | config = yaml.load(f) 73 | 74 | for k, v in config['common'].items(): 75 | setattr(args, k, v) 76 | 77 | # create models 78 | if args.input_size != 224 or args.image_size != 256: 79 | image_size = args.image_size 80 | input_size = args.input_size 81 | else: 82 | image_size = 256 83 | input_size = 224 84 | print("Input image size: {}, test size: {}".format(image_size, input_size)) 85 | 86 | if "model" in config.keys(): 87 | model = models.__dict__[args.arch](**config['model']) 88 | else: 89 | model = models.__dict__[args.arch]() 90 | 91 | if USE_GPU: 92 | model = model.cuda() 93 | model = torch.nn.DataParallel(model) 94 | 95 | count_params(model) 96 | 97 | # define loss function (criterion) and optimizer 98 | criterion = nn.CrossEntropyLoss() 99 | 100 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, 101 | weight_decay=args.weight_decay) 102 | 103 | # optionally resume from a checkpoint 104 | if args.resume: 105 | if os.path.isfile(args.resume): 106 | print("=> loading checkpoint '{}'".format(args.resume)) 107 | checkpoint = torch.load(args.resume) 108 | best_prec1 = checkpoint['best_prec1'] 109 | model.load_state_dict(checkpoint['state_dict']) 110 | optimizer.load_state_dict(checkpoint['optimizer']) 111 | else: 112 | print("=> no checkpoint found at '{}'".format(args.resume)) 113 | 114 | cudnn.benchmark = True 115 | 116 | # Data loading code 117 | traindir = os.path.join(args.data, 'train') 118 | valdir = os.path.join(args.data, 'val') 119 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 120 | std=[0.229, 0.224, 0.225]) 121 | img_size = args.input_size 122 | 123 | ratio = 224.0 / float(img_size) 124 | train_dataset = datasets.ImageFolder( 125 | traindir, 126 | transforms.Compose([ 127 | transforms.RandomResizedCrop(img_size), 128 | transforms.RandomHorizontalFlip(), 129 | transforms.ToTensor(), 130 | ColorAugmentation(), 131 | normalize, 132 | ])) 133 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 134 | transforms.Resize(int(256 * ratio)), 135 | transforms.CenterCrop(img_size), 136 | transforms.ToTensor(), 137 | normalize, 138 | ])) 139 | 140 | # if args.distributed: 141 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 142 | # val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 143 | # else: 144 | train_sampler = None 145 | val_sampler = None 146 | 147 | train_loader = torch.utils.data.DataLoader( 148 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 149 | num_workers=args.workers, pin_memory=(train_sampler is None), sampler=train_sampler) 150 | 151 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 152 | num_workers=args.workers, pin_memory=True, sampler=val_sampler) 153 | 154 | if args.evaluate: 155 | validate(val_loader, model, criterion) 156 | return 157 | 158 | for epoch in range(args.start_epoch, args.epochs): 159 | # if args.distributed: 160 | # train_sampler.set_epoch(epoch) 161 | adjust_learning_rate(optimizer, epoch) 162 | 163 | # train for one epoch 164 | train(train_loader, model, criterion, optimizer, epoch) 165 | 166 | # evaluate on validation set 167 | prec1 = validate(val_loader, model, criterion) 168 | 169 | # remember best prec@1 and save checkpoint 170 | is_best = prec1 > best_prec1 171 | best_prec1 = max(prec1, best_prec1) 172 | if not os.path.exists(args.save_path): 173 | os.mkdir(args.save_path) 174 | save_name = '{}/{}_{}_best.pth.tar'.format(args.save_path, args.model_name, epoch) if is_best else\ 175 | '{}/{}_{}.pth.tar'.format(args.save_path, args.model_name, epoch) 176 | save_checkpoint({ 177 | 'epoch': epoch + 1, 178 | 'arch': args.arch, 179 | 'state_dict': model.state_dict(), 180 | 'best_prec1': best_prec1, 181 | 'optimizer': optimizer.state_dict(), 182 | }, filename=save_name) 183 | 184 | 185 | def train(train_loader, model, criterion, optimizer, epoch): 186 | batch_time = AverageMeter() 187 | data_time = AverageMeter() 188 | losses = AverageMeter() 189 | top1 = AverageMeter() 190 | top5 = AverageMeter() 191 | 192 | # switch to train mode 193 | model.train() 194 | 195 | end = time.time() 196 | for i, (input, target) in enumerate(train_loader): 197 | # measure data loading time 198 | data_time.update(time.time() - end) 199 | 200 | # pytorch 0.4.0 compatible 201 | if '0.4.' in torch.__version__: 202 | if USE_GPU: 203 | input_var = torch.cuda.FloatTensor(input.cuda()) 204 | target_var = torch.cuda.LongTensor(target.cuda()) 205 | else: 206 | input_var = torch.FloatTensor(input) 207 | target_var = torch.LongTensor(target) 208 | else: # pytorch 0.3.1 or less compatible 209 | if USE_GPU: 210 | input = input.cuda() 211 | target = target.cuda(async=True) 212 | input_var = Variable(input) 213 | target_var = Variable(target) 214 | 215 | # compute output 216 | output = model(input_var) 217 | 218 | loss = criterion(output, target_var) 219 | prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5)) 220 | 221 | # measure accuracy and record loss 222 | reduced_prec1 = prec1.clone() 223 | reduced_prec5 = prec5.clone() 224 | 225 | top1.update(reduced_prec1[0]) 226 | top5.update(reduced_prec5[0]) 227 | 228 | reduced_loss = loss.data.clone() 229 | losses.update(reduced_loss) 230 | 231 | # compute gradient and do SGD step 232 | optimizer.zero_grad() 233 | loss.backward() 234 | # check whether the network is well connected 235 | optimizer.step() 236 | 237 | # measure elapsed time 238 | batch_time.update(time.time() - end) 239 | end = time.time() 240 | 241 | if i % args.print_freq == 0: 242 | with open('logs/{}_{}.log'.format(time_stp, args.arch), 'a+') as flog: 243 | line = 'Epoch: [{0}][{1}/{2}]\t ' \ 244 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 245 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 246 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' \ 247 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader), 248 | batch_time=batch_time, loss=losses, top1=top1, top5=top5) 249 | print(line) 250 | flog.write('{}\n'.format(line)) 251 | 252 | 253 | def validate(val_loader, model, criterion): 254 | global time_stp 255 | batch_time = AverageMeter() 256 | losses = AverageMeter() 257 | top1 = AverageMeter() 258 | top5 = AverageMeter() 259 | 260 | # switch to evaluate mode 261 | model.eval() 262 | 263 | end = time.time() 264 | for i, (input, target) in enumerate(val_loader): 265 | # pytorch 0.4.0 compatible 266 | if '0.4.' in torch.__version__: 267 | with torch.no_grad(): 268 | if USE_GPU: 269 | input_var = torch.cuda.FloatTensor(input.cuda()) 270 | target_var = torch.cuda.LongTensor(target.cuda()) 271 | else: 272 | input_var = torch.FloatTensor(input) 273 | target_var = torch.LongTensor(target) 274 | else: # pytorch 0.3.1 or less compatible 275 | if USE_GPU: 276 | input = input.cuda() 277 | target = target.cuda(async=True) 278 | input_var = Variable(input, volatile=True) 279 | target_var = Variable(target, volatile=True) 280 | 281 | # compute output 282 | output = model(input_var) 283 | loss = criterion(output, target_var) 284 | 285 | # measure accuracy and record loss 286 | prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5)) 287 | losses.update(loss.data, input.size(0)) 288 | top1.update(prec1[0], input.size(0)) 289 | top5.update(prec5[0], input.size(0)) 290 | 291 | # measure elapsed time 292 | batch_time.update(time.time() - end) 293 | end = time.time() 294 | if i % args.print_freq == 0: 295 | line = 'Test: [{0}/{1}]\t' \ 296 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 297 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ 298 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' \ 299 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(i, len(val_loader), batch_time=batch_time, 300 | loss=losses, top1=top1, top5=top5) 301 | 302 | with open('logs/{}_{}.log'.format(time_stp, args.arch), 'a+') as flog: 303 | flog.write('{}\n'.format(line)) 304 | print(line) 305 | 306 | return top1.avg 307 | 308 | 309 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 310 | torch.save(state, filename) 311 | 312 | 313 | class AverageMeter(object): 314 | """Computes and stores the average and current value""" 315 | 316 | def __init__(self): 317 | self.reset() 318 | 319 | def reset(self): 320 | self.val = 0 321 | self.avg = 0 322 | self.sum = 0 323 | self.count = 0 324 | 325 | def update(self, val, n=1): 326 | self.val = val 327 | self.sum += val * n 328 | self.count += n 329 | self.avg = self.sum / self.count 330 | 331 | 332 | def adjust_learning_rate(optimizer, epoch): 333 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 334 | lr = args.lr * (0.1 ** (epoch // 30)) 335 | for param_group in optimizer.param_groups: 336 | param_group['lr'] = lr 337 | 338 | 339 | def accuracy(output, target, topk=(1,)): 340 | """Computes the precision@k for the specified values of k""" 341 | maxk = max(topk) 342 | batch_size = target.size(0) 343 | 344 | _, pred = output.topk(maxk, 1, True, True) 345 | pred = pred.t() 346 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 347 | 348 | res = [] 349 | for k in topk: 350 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 351 | res.append(correct_k.mul_(100.0 / batch_size)) 352 | return res 353 | 354 | 355 | if __name__ == '__main__': 356 | time_stp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) 357 | main() 358 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .net_factory import * 2 | from torchvision.models import * 3 | -------------------------------------------------------------------------------- /models/fish_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Bottleneck(nn.Module): 5 | def __init__(self, inplanes, planes, stride=1, mode='NORM', k=1, dilation=1): 6 | """ 7 | Pre-act residual block, the middle transformations are bottle-necked 8 | :param inplanes: 9 | :param planes: 10 | :param stride: 11 | :param downsample: 12 | :param mode: NORM | UP 13 | :param k: times of additive 14 | """ 15 | 16 | super(Bottleneck, self).__init__() 17 | self.mode = mode 18 | self.relu = nn.ReLU(inplace=True) 19 | self.k = k 20 | 21 | btnk_ch = planes // 4 22 | self.bn1 = nn.BatchNorm2d(inplanes) 23 | self.conv1 = nn.Conv2d(inplanes, btnk_ch, kernel_size=1, bias=False) 24 | 25 | self.bn2 = nn.BatchNorm2d(btnk_ch) 26 | self.conv2 = nn.Conv2d(btnk_ch, btnk_ch, kernel_size=3, stride=stride, padding=dilation, 27 | dilation=dilation, bias=False) 28 | 29 | self.bn3 = nn.BatchNorm2d(btnk_ch) 30 | self.conv3 = nn.Conv2d(btnk_ch, planes, kernel_size=1, bias=False) 31 | 32 | if mode == 'UP': 33 | self.shortcut = None 34 | elif inplanes != planes or stride > 1: 35 | self.shortcut = nn.Sequential( 36 | nn.BatchNorm2d(inplanes), 37 | self.relu, 38 | nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 39 | ) 40 | else: 41 | self.shortcut = None 42 | 43 | def _pre_act_forward(self, x): 44 | residual = x 45 | 46 | out = self.bn1(x) 47 | out = self.relu(out) 48 | out = self.conv1(out) 49 | 50 | out = self.bn2(out) 51 | out = self.relu(out) 52 | out = self.conv2(out) 53 | 54 | out = self.bn3(out) 55 | out = self.relu(out) 56 | out = self.conv3(out) 57 | 58 | if self.mode == 'UP': 59 | residual = self.squeeze_idt(x) 60 | elif self.shortcut is not None: 61 | residual = self.shortcut(residual) 62 | 63 | out += residual 64 | 65 | return out 66 | 67 | def squeeze_idt(self, idt): 68 | n, c, h, w = idt.size() 69 | return idt.view(n, c // self.k, self.k, h, w).sum(2) 70 | 71 | def forward(self, x): 72 | out = self._pre_act_forward(x) 73 | return out 74 | -------------------------------------------------------------------------------- /models/fishnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | FishNet 3 | Author: Shuyang Sun 4 | ''' 5 | from __future__ import division 6 | import torch 7 | import math 8 | from .fish_block import * 9 | 10 | 11 | __all__ = ['fish'] 12 | 13 | 14 | class Fish(nn.Module): 15 | def __init__(self, block, num_cls=1000, num_down_sample=5, num_up_sample=3, trans_map=(2, 1, 0, 6, 5, 4), 16 | network_planes=None, num_res_blks=None, num_trans_blks=None): 17 | super(Fish, self).__init__() 18 | self.block = block 19 | self.trans_map = trans_map 20 | self.upsample = nn.Upsample(scale_factor=2) 21 | self.down_sample = nn.MaxPool2d(2, stride=2) 22 | self.num_cls = num_cls 23 | self.num_down = num_down_sample 24 | self.num_up = num_up_sample 25 | self.network_planes = network_planes[1:] 26 | self.depth = len(self.network_planes) 27 | self.num_trans_blks = num_trans_blks 28 | self.num_res_blks = num_res_blks 29 | self.fish = self._make_fish(network_planes[0]) 30 | 31 | def _make_score(self, in_ch, out_ch=1000, has_pool=False): 32 | bn = nn.BatchNorm2d(in_ch) 33 | relu = nn.ReLU(inplace=True) 34 | conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False) 35 | bn_out = nn.BatchNorm2d(in_ch // 2) 36 | conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu) 37 | if has_pool: 38 | fc = nn.Sequential( 39 | nn.AdaptiveAvgPool2d(1), 40 | nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)) 41 | else: 42 | fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True) 43 | return [conv, fc] 44 | 45 | def _make_se_block(self, in_ch, out_ch): 46 | bn = nn.BatchNorm2d(in_ch) 47 | sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1) 48 | ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1) 49 | return nn.Sequential(bn, 50 | nn.ReLU(inplace=True), 51 | nn.AdaptiveAvgPool2d(1), 52 | sq_conv, 53 | nn.ReLU(inplace=True), 54 | ex_conv, 55 | nn.Sigmoid()) 56 | 57 | def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1): 58 | layers = [] 59 | 60 | if is_up: 61 | layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k)) 62 | else: 63 | layers.append(self.block(inplanes, outplanes, stride=1)) 64 | for i in range(1, nstage): 65 | layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation)) 66 | return nn.Sequential(*layers) 67 | 68 | def _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True, 69 | has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs): 70 | sample_block = [] 71 | if has_score: 72 | sample_block.extend(self._make_score(outplanes, outplanes * 2, has_pool=False)) 73 | 74 | if no_sampling or is_down_sample: 75 | res_block = self._make_residual_block(inplanes, outplanes, n_blk, **kwargs) 76 | else: 77 | res_block = self._make_residual_block(inplanes, outplanes, n_blk, is_up=True, **kwargs) 78 | 79 | sample_block.append(res_block) 80 | 81 | if has_trans: 82 | trans_in_planes = self.in_planes if trans_planes == 0 else trans_planes 83 | sample_block.append(self._make_residual_block(trans_in_planes, trans_in_planes, num_trans)) 84 | 85 | if not no_sampling and is_down_sample: 86 | sample_block.append(self.down_sample) 87 | elif not no_sampling: # Up-Sample 88 | sample_block.append(self.upsample) 89 | 90 | return nn.ModuleList(sample_block) 91 | 92 | def _make_fish(self, in_planes): 93 | def get_trans_planes(index): 94 | map_id = self.trans_map[index-self.num_down-1] - 1 95 | p = in_planes if map_id == -1 else cated_planes[map_id] 96 | return p 97 | 98 | def get_trans_blk(index): 99 | return self.num_trans_blks[index-self.num_down-1] 100 | 101 | def get_cur_planes(index): 102 | return self.network_planes[index] 103 | 104 | def get_blk_num(index): 105 | return self.num_res_blks[index] 106 | 107 | cated_planes, fish = [in_planes] * self.depth, [] 108 | for i in range(self.depth): 109 | # even num for down-sample, odd for up-sample 110 | is_down, has_trans, no_sampling = i not in range(self.num_down, self.num_down+self.num_up+1),\ 111 | i > self.num_down, i == self.num_down 112 | cur_planes, trans_planes, cur_blocks, num_trans =\ 113 | get_cur_planes(i), get_trans_planes(i), get_blk_num(i), get_trans_blk(i) 114 | 115 | stg_args = [is_down, cated_planes[i - 1], cur_planes, cur_blocks] 116 | 117 | if is_down or no_sampling: 118 | k, dilation = 1, 1 119 | else: 120 | k, dilation = cated_planes[i - 1] // cur_planes, 2 ** (i-self.num_down-1) 121 | 122 | sample_block = self._make_stage(*stg_args, has_trans=has_trans, trans_planes=trans_planes, 123 | has_score=(i==self.num_down), num_trans=num_trans, k=k, dilation=dilation, 124 | no_sampling=no_sampling) 125 | if i == self.depth - 1: 126 | sample_block.extend(self._make_score(cur_planes + trans_planes, out_ch=self.num_cls, has_pool=True)) 127 | elif i == self.num_down: 128 | sample_block.append(nn.Sequential(self._make_se_block(cur_planes*2, cur_planes))) 129 | 130 | if i == self.num_down-1: 131 | cated_planes[i] = cur_planes * 2 132 | elif has_trans: 133 | cated_planes[i] = cur_planes + trans_planes 134 | else: 135 | cated_planes[i] = cur_planes 136 | fish.append(sample_block) 137 | return nn.ModuleList(fish) 138 | 139 | def _fish_forward(self, all_feat): 140 | def _concat(a, b): 141 | return torch.cat([a, b], dim=1) 142 | 143 | def stage_factory(*blks): 144 | def stage_forward(*inputs): 145 | if stg_id < self.num_down: # tail 146 | tail_blk = nn.Sequential(*blks[:2]) 147 | return tail_blk(*inputs) 148 | elif stg_id == self.num_down: 149 | score_blks = nn.Sequential(*blks[:2]) 150 | score_feat = score_blks(inputs[0]) 151 | att_feat = blks[3](score_feat) 152 | return blks[2](score_feat) * att_feat + att_feat 153 | else: # refine 154 | feat_trunk = blks[2](blks[0](inputs[0])) 155 | feat_branch = blks[1](inputs[1]) 156 | return _concat(feat_trunk, feat_branch) 157 | return stage_forward 158 | 159 | stg_id = 0 160 | # tail: 161 | while stg_id < self.depth: 162 | stg_blk = stage_factory(*self.fish[stg_id]) 163 | if stg_id <= self.num_down: 164 | in_feat = [all_feat[stg_id]] 165 | else: 166 | trans_id = self.trans_map[stg_id-self.num_down-1] 167 | in_feat = [all_feat[stg_id], all_feat[trans_id]] 168 | 169 | all_feat[stg_id + 1] = stg_blk(*in_feat) 170 | stg_id += 1 171 | # loop exit 172 | if stg_id == self.depth: 173 | score_feat = self.fish[self.depth-1][-2](all_feat[-1]) 174 | score = self.fish[self.depth-1][-1](score_feat) 175 | return score 176 | 177 | def forward(self, x): 178 | all_feat = [None] * (self.depth + 1) 179 | all_feat[0] = x 180 | return self._fish_forward(all_feat) 181 | 182 | 183 | class FishNet(nn.Module): 184 | def __init__(self, block, **kwargs): 185 | super(FishNet, self).__init__() 186 | 187 | inplanes = kwargs['network_planes'][0] 188 | # resolution: 224x224 189 | self.conv1 = self._conv_bn_relu(3, inplanes // 2, stride=2) 190 | self.conv2 = self._conv_bn_relu(inplanes // 2, inplanes // 2) 191 | self.conv3 = self._conv_bn_relu(inplanes // 2, inplanes) 192 | self.pool1 = nn.MaxPool2d(3, padding=1, stride=2) 193 | # construct fish, resolution 56x56 194 | self.fish = Fish(block, **kwargs) 195 | self._init_weights() 196 | 197 | def _conv_bn_relu(self, in_ch, out_ch, stride=1): 198 | return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=stride, bias=False), 199 | nn.BatchNorm2d(out_ch), 200 | nn.ReLU(inplace=True)) 201 | 202 | def _init_weights(self): 203 | for m in self.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 206 | m.weight.data.normal_(0, math.sqrt(2. / n)) 207 | elif isinstance(m, nn.BatchNorm2d): 208 | m.weight.data.fill_(1) 209 | m.bias.data.zero_() 210 | 211 | def forward(self, x): 212 | x = self.conv1(x) 213 | x = self.conv2(x) 214 | x = self.conv3(x) 215 | x = self.pool1(x) 216 | score = self.fish(x) 217 | # 1*1 output 218 | out = score.view(x.size(0), -1) 219 | 220 | return out 221 | 222 | 223 | def fish(**kwargs): 224 | return FishNet(Bottleneck, **kwargs) -------------------------------------------------------------------------------- /models/net_factory.py: -------------------------------------------------------------------------------- 1 | from .fishnet import fish 2 | 3 | 4 | def fishnet99(**kwargs): 5 | """ 6 | 7 | :return: 8 | """ 9 | net_cfg = { 10 | # input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14] 11 | # output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7] 12 | # | | | | | | | | | | | 13 | 'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600], 14 | 'num_res_blks': [2, 2, 6, 2, 1, 1, 1, 1, 2, 2], 15 | 'num_trans_blks': [1, 1, 1, 1, 1, 4], 16 | 'num_cls': 1000, 17 | 'num_down_sample': 3, 18 | 'num_up_sample': 3, 19 | } 20 | cfg = {**net_cfg, **kwargs} 21 | return fish(**cfg) 22 | 23 | 24 | def fishnet150(**kwargs): 25 | """ 26 | 27 | :return: 28 | """ 29 | net_cfg = { 30 | # input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14] 31 | # output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7] 32 | # | | | | | | | | | | | 33 | 'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600], 34 | 'num_res_blks': [2, 4, 8, 4, 2, 2, 2, 2, 2, 4], 35 | 'num_trans_blks': [2, 2, 2, 2, 2, 4], 36 | 'num_cls': 1000, 37 | 'num_down_sample': 3, 38 | 'num_up_sample': 3, 39 | } 40 | cfg = {**net_cfg, **kwargs} 41 | return fish(**cfg) 42 | 43 | 44 | def fishnet201(**kwargs): 45 | """ 46 | 47 | :return: 48 | """ 49 | net_cfg = { 50 | # input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14] 51 | # output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7] 52 | # | | | | | | | | | | | 53 | 'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600], 54 | 'num_res_blks': [3, 4, 12, 4, 2, 2, 2, 2, 3, 10], 55 | 'num_trans_blks': [2, 2, 2, 2, 2, 9], 56 | 'num_cls': 1000, 57 | 'num_down_sample': 3, 58 | 'num_up_sample': 3, 59 | } 60 | cfg = {**net_cfg, **kwargs} 61 | return fish(**cfg) 62 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .profile import * 2 | -------------------------------------------------------------------------------- /utils/data_aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ColorAugmentation(object): 5 | def __init__(self): 6 | self.eig_vec = torch.Tensor([ 7 | [0.4009, 0.7192, -0.5675], 8 | [-0.8140, -0.0045, -0.5808], 9 | [0.4203, -0.6948, -0.5836], 10 | ]) 11 | self.eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) 12 | 13 | def __call__(self, tensor): 14 | assert tensor.size(0) == 3 15 | alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 16 | quatity = torch.mm(self.eig_val * alpha, self.eig_vec) 17 | tensor = tensor + quatity.view(3, 1, 1) 18 | return tensor -------------------------------------------------------------------------------- /utils/profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.variable import Variable 3 | import numpy as np 4 | 5 | 6 | USE_GPU = torch.cuda.is_available() 7 | 8 | 9 | def calc_flops(model, input_size): 10 | global USE_GPU 11 | 12 | def conv_hook(self, input, output): 13 | batch_size, input_channels, input_height, input_width = input[0].size() 14 | output_channels, output_height, output_width = output[0].size() 15 | 16 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * ( 17 | 2 if multiply_adds else 1) 18 | bias_ops = 1 if self.bias is not None else 0 19 | 20 | params = output_channels * (kernel_ops + bias_ops) 21 | flops = batch_size * params * output_height * output_width 22 | 23 | list_conv.append(flops) 24 | 25 | def linear_hook(self, input, output): 26 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 27 | 28 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 29 | bias_ops = self.bias.nelement() 30 | 31 | flops = batch_size * (weight_ops + bias_ops) 32 | list_linear.append(flops) 33 | 34 | def bn_hook(self, input, output): 35 | list_bn.append(input[0].nelement()) 36 | 37 | def relu_hook(self, input, output): 38 | list_relu.append(input[0].nelement()) 39 | 40 | def pooling_hook(self, input, output): 41 | batch_size, input_channels, input_height, input_width = input[0].size() 42 | output_channels, output_height, output_width = output[0].size() 43 | 44 | kernel_ops = self.kernel_size * self.kernel_size 45 | bias_ops = 0 46 | params = output_channels * (kernel_ops + bias_ops) 47 | flops = batch_size * params * output_height * output_width 48 | 49 | list_pooling.append(flops) 50 | 51 | def foo(net): 52 | childrens = list(net.children()) 53 | if not childrens: 54 | if isinstance(net, torch.nn.Conv2d): 55 | net.register_forward_hook(conv_hook) 56 | if isinstance(net, torch.nn.Linear): 57 | net.register_forward_hook(linear_hook) 58 | if isinstance(net, torch.nn.BatchNorm2d): 59 | net.register_forward_hook(bn_hook) 60 | if isinstance(net, torch.nn.ReLU): 61 | net.register_forward_hook(relu_hook) 62 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 63 | net.register_forward_hook(pooling_hook) 64 | return 65 | for c in childrens: 66 | foo(c) 67 | 68 | multiply_adds = False 69 | list_conv, list_bn, list_relu, list_linear, list_pooling = [], [], [], [], [] 70 | foo(model) 71 | if '0.4.' in torch.__version__ or '1.0' in torch.__version__: 72 | if USE_GPU: 73 | input = torch.cuda.FloatTensor(torch.rand(2, 3, input_size, input_size).cuda()) 74 | else: 75 | input = torch.FloatTensor(torch.rand(2, 3, input_size, input_size)) 76 | else: 77 | input = Variable(torch.rand(2, 3, input_size, input_size), requires_grad=True) 78 | _ = model(input) 79 | 80 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) 81 | 82 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9 / 2)) 83 | 84 | 85 | def count_params(model, input_size=224): 86 | # param_sum = 0 87 | with open('models.txt', 'w') as fm: 88 | fm.write(str(model)) 89 | calc_flops(model, input_size) 90 | 91 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 92 | params = sum([np.prod(p.size()) for p in model_parameters]) 93 | 94 | print('The network has {} params.'.format(params)) 95 | 96 | --------------------------------------------------------------------------------