├── LICENSE.md ├── README.md ├── args_file.py ├── div ├── download_from_url.py ├── pyconv.jpg └── pyconv.pdf ├── main.py ├── models ├── build_model.py ├── pyconvhgresnet.py ├── pyconvresnet.py └── resnet.py ├── requirements.txt └── utils.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pyramidal Convolution 2 | 3 | 4 | This is the PyTorch implementation of our paper ["Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition"](https://arxiv.org/pdf/2006.11538.pdf). 5 | (Note that this is the code for image recognition on ImageNet. For semantic image segmentation/parsing refer to this repository: https://github.com/iduta/pyconvsegnet) 6 | 7 | ![Pyramidal Convolution: PyConv](div/pyconv.jpg) 8 | 9 | The models trained on ImageNet can be found [here](https://drive.google.com/drive/folders/1DGTXansI_JbxJsS0cQzvfrLEVdJ6l8Oh?usp=sharing). 10 | 11 | 12 | PyConv is able to provide improved recognition capabilities over the baseline 13 | (see [the paper](https://arxiv.org/pdf/2006.11538.pdf) for details). 14 | 15 | The accuracy on ImageNet (using the default training settings): 16 | 17 | 18 | | Network | 50-layers |101-layers |152-layers | 19 | | :----- | :-----: | :-----: |:-----: | 20 | | ResNet | 76.12% ([model](https://drive.google.com/uc?export=download&id=176TS0b6O0NALBbfzpz4mM1b47s4dwSVH)) | 78.00% ([model](https://drive.google.com/uc?export=download&id=1bermctRPLs5DIsHB0c4iDIGHvjfERPLG)) | 78.45% ([model](https://drive.google.com/uc?export=download&id=1FAggTH4m7Kec8MyRe8dx-ugI_yh-nLzL))| 21 | | PyConvHGResNet | **78.48**% ([model](https://drive.google.com/uc?export=download&id=14x0uss32ASXr4FJTE7pip004XZpwNrZe))| **79.22**% ([model](https://drive.google.com/uc?export=download&id=1Fm48GfOfn2Ivf5nBiR1SMhp66k67ePRh))| **79.36**% ([model](https://drive.google.com/uc?export=download&id=1LRmdQWTceDkepnIxZ2mWbpEE2lFxy0QO))| 22 | | PyConvResNet | **77.88**% ([model](https://drive.google.com/uc?export=download&id=128iMzBnHQSPNehgb8nUF5cJyKBIB7do5))| **79.01**% ([model](https://drive.google.com/uc?export=download&id=1fn0eKdtGG7HA30O5SJ1XrmGR_FsQxTb1))| **79.52**% ([model](https://drive.google.com/uc?export=download&id=1zR6HOTaHB0t15n6Nh12adX86AhBMo46m))| 23 | 24 | 25 | 26 | 27 | 28 | The accuracy on ImageNet can be significantly improved using more complex training settings (for instance, using additional data augmentation (CutMix), increase bach size to 1024, learning rate of 0.4, cosine scheduler over 300 epochs and use mixed precision to speed-up training): 29 | 30 | 31 | | Network | test crop: 224×224 | test crop: 320×320 | | 32 | | :----- | :-----: | :-----: |:-----: | 33 | PyConvResNet-50 (+augment) | 79.44 | 80.59| ([model](https://drive.google.com/uc?export=download&id=19RFyaDnJ34IeqwS8QOX29JWH9I0r_ewM))| 34 | PyConvResNet-101 (+augment) | 80.58 | 81.49| ([model](https://drive.google.com/uc?export=download&id=12PXOwgIF4eiApxDL5QrAMnjbiX69YQOi))| 35 | 36 | 37 | ### Requirements 38 | 39 | Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 40 | 41 | A fast alternative (without the need to install PyTorch and other deep learning libraries) is to use [NVIDIA-Docker](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/pullcontainer.html#pullcontainer), 42 | we used [this container image](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel_19-05.html#rel_19-05). 43 | 44 | 45 | ### Training 46 | To train a model (for instance, PyConvResNet with 50 layers) using DataParallel run `main.py`; 47 | you need also to provide `result_path` (the directory path where to save the results 48 | and logs) and the `--data` (the path to the ImageNet dataset): 49 | ```bash 50 | result_path=/your/path/to/save/results/and/logs/ 51 | mkdir -p ${result_path} 52 | python main.py \ 53 | --data /your/path/to/ImageNet/dataset/ \ 54 | --result_path ${result_path} \ 55 | --arch pyconvresnet \ 56 | --model_depth 50 57 | ``` 58 | To train using Multi-processing Distributed Data Parallel Training follow the instructions in the 59 | [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 60 | 61 | 62 | ### Citation 63 | If you find our work useful, please consider citing: 64 | ``` 65 | @article{duta2020pyramidal, 66 | author = {Ionut Cosmin Duta and Li Liu and Fan Zhu and Ling Shao}, 67 | title = {Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition}, 68 | journal = {arXiv preprint arXiv:2006.11538}, 69 | year = {2020}, 70 | } 71 | ``` -------------------------------------------------------------------------------- /args_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training for PyConvResNets') 4 | parser.add_argument('--data', metavar='DIR', 5 | help='path to dataset') 6 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet', 7 | help='model architecture (default: resnet)') 8 | parser.add_argument('--result_path', default='results', type=str, 9 | help=' directory path where to save the results') 10 | parser.add_argument('--model_depth', default=50, type=int, 11 | help='depth of resnet (50 | 101 | 152)') 12 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 13 | help='number of data loading workers (default: 4)') 14 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 15 | help='number of total epochs to run') 16 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 17 | help='manual epoch number (useful on restarts)') 18 | parser.add_argument('-b', '--batch-size', default=256, type=int, 19 | metavar='N', 20 | help='mini-batch size (default: 256), this is the total ' 21 | 'batch size of all GPUs on the current node when ' 22 | 'using Data Parallel or Distributed Data Parallel') 23 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 24 | metavar='LR', help='initial learning rate', dest='lr') 25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 26 | help='momentum') 27 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 28 | metavar='W', help='weight decay (default: 1e-4)', 29 | dest='weight_decay') 30 | parser.add_argument('-p', '--print-freq', default=10, type=int, 31 | metavar='N', help='print frequency (default: 10)') 32 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 33 | help='path to latest checkpoint (default: none)') 34 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 35 | help='evaluate model on validation set') 36 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 37 | help='use pre-trained model') 38 | parser.add_argument('--world-size', default=-1, type=int, 39 | help='number of nodes for distributed training') 40 | parser.add_argument('--rank', default=-1, type=int, 41 | help='node rank for distributed training') 42 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 43 | help='url used to set up distributed training') 44 | parser.add_argument('--dist-backend', default='nccl', type=str, 45 | help='distributed backend') 46 | parser.add_argument('--seed', default=None, type=int, 47 | help='seed for initializing training. ') 48 | parser.add_argument('--gpu', default=None, type=int, 49 | help='GPU id to use.') 50 | parser.add_argument('--multiprocessing-distributed', action='store_true', 51 | help='Use multi-processing distributed training to launch ' 52 | 'N processes per node, which has N GPUs. This is the ' 53 | 'fastest way to use PyTorch for either single node or ' 54 | 'multi node data parallel training') 55 | parser.add_argument('--n_classes', default=1000, type=int, 56 | help='Number of classes. (default 1000) ') 57 | parser.add_argument('--lr_scheduler', default='MultiStepLR', type=str, 58 | help='The learning rate scheduler. Options: MultiStepLR') 59 | parser.add_argument('--lr_steps', default=[30, 60, 80], type=int, nargs="+", 60 | metavar='LRSteps', help='epochs to decay learning rate by 10') 61 | parser.add_argument('--lr_reduce_factor', default=0.1, type=float, 62 | help='Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.') 63 | parser.add_argument('--nesterov', action='store_true', default=False, help='Nesterov momentum') 64 | parser.add_argument('--zero_init_residual', action='store_true', 65 | help='If true, Zero-initialize the last BN in each residual branch,') 66 | parser.add_argument('--groups', default=None, type=int, 67 | help='the number of groups to split the spatial convolution') 68 | parser.add_argument('--train_crop_size', default=224, type=int, 69 | help='The crop size for training. default: 224') 70 | parser.add_argument('--val_resize', default=256, type=int, 71 | help='The value to resize the shorter size of the image (for validation). default: 256') 72 | parser.add_argument('--val_crop_size', default=224, type=int, 73 | help='The crop size for validation. default: 224') 74 | -------------------------------------------------------------------------------- /div/download_from_url.py: -------------------------------------------------------------------------------- 1 | #main source: https://pytorch.org/text/_modules/torchtext/utils.html#download_from_url 2 | import requests 3 | import os 4 | import re 5 | 6 | 7 | def download_from_url(url, path=None, root='.data', overwrite=False): 8 | 9 | 10 | def _process_response(r, root, filename): 11 | chunk_size = 16 * 1024 12 | total_size = int(r.headers.get('Content-length', 0)) 13 | if filename is None: 14 | d = r.headers['content-disposition'] 15 | filename = re.findall("filename=\"(.+)\"", d) 16 | if filename is None: 17 | raise RuntimeError("Filename could not be autodetected") 18 | filename = filename[0] 19 | path = os.path.join(root, filename) 20 | if os.path.exists(path): 21 | print('File %s already exists.' % path) 22 | if not overwrite: 23 | return path 24 | print('Overwriting file %s.' % path) 25 | print('Downloading file {} to {} ...'.format(filename, path)) 26 | 27 | with open(path, "wb") as file: 28 | for chunk in r.iter_content(chunk_size): 29 | if chunk: 30 | file.write(chunk) 31 | print('File {} downloaded.'.format(path)) 32 | return path 33 | 34 | if path is None: 35 | _, filename = os.path.split(url) 36 | else: 37 | root, filename = os.path.split(path) 38 | 39 | if not os.path.exists(root): 40 | raise RuntimeError( 41 | "Download directory {} does not exist. " 42 | "Did you create it?".format(root)) 43 | 44 | if 'drive.google.com' not in url: 45 | response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) 46 | return _process_response(response, root, filename) 47 | else: 48 | # google drive links get filename from google drive 49 | filename = None 50 | 51 | print('Downloading from Google Drive; may take a few minutes') 52 | confirm_token = None 53 | session = requests.Session() 54 | response = session.get(url, stream=True) 55 | for k, v in response.cookies.items(): 56 | if k.startswith("download_warning"): 57 | confirm_token = v 58 | 59 | if confirm_token: 60 | url = url + "&confirm=" + confirm_token 61 | response = session.get(url, stream=True) 62 | 63 | return _process_response(response, root, filename) 64 | -------------------------------------------------------------------------------- /div/pyconv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iduta/pyconv/d8b39cf43014b8fd277dcefc9eb7f8880511e977/div/pyconv.jpg -------------------------------------------------------------------------------- /div/pyconv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iduta/pyconv/d8b39cf43014b8fd277dcefc9eb7f8880511e977/div/pyconv.pdf -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import time 5 | import warnings 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.multiprocessing as mp 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torchvision.models as models 19 | 20 | import args_file 21 | import json 22 | from utils import Logger 23 | from models.build_model import build_model 24 | 25 | 26 | model_names = sorted(name for name in models.__dict__ 27 | if name.islower() and not name.startswith("__") 28 | and callable(models.__dict__[name])) 29 | 30 | 31 | 32 | best_acc1 = 0 33 | 34 | 35 | def main(): 36 | args = args_file.parser.parse_args() 37 | 38 | if args.seed is not None: 39 | random.seed(args.seed) 40 | torch.manual_seed(args.seed) 41 | cudnn.deterministic = True 42 | warnings.warn('You have chosen to seed training. ' 43 | 'This will turn on the CUDNN deterministic setting, ' 44 | 'which can slow down your training considerably! ' 45 | 'You may see unexpected behavior when restarting ' 46 | 'from checkpoints.') 47 | 48 | if args.gpu is not None: 49 | warnings.warn('You have chosen a specific GPU. This will completely ' 50 | 'disable data parallelism.') 51 | 52 | if args.dist_url == "env://" and args.world_size == -1: 53 | args.world_size = int(os.environ["WORLD_SIZE"]) 54 | 55 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 56 | 57 | ngpus_per_node = torch.cuda.device_count() 58 | if args.multiprocessing_distributed: 59 | # Since we have ngpus_per_node processes per node, the total world_size 60 | # needs to be adjusted accordingly 61 | args.world_size = ngpus_per_node * args.world_size 62 | # Use torch.multiprocessing.spawn to launch distributed processes: the 63 | # main_worker process function 64 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 65 | else: 66 | # Simply call main_worker function 67 | main_worker(args.gpu, ngpus_per_node, args) 68 | 69 | 70 | def main_worker(gpu, ngpus_per_node, args): 71 | global best_acc1 72 | args.gpu = gpu 73 | 74 | if not os.path.exists(args.result_path): 75 | os.makedirs(args.result_path) 76 | 77 | if not args.evaluate: 78 | if args.resume: 79 | with open(os.path.join(args.result_path, 'resume_args.json'), 'a') as f: 80 | json.dump(vars(args), f) 81 | else: 82 | with open(os.path.join(args.result_path, 'args.json'), 'w') as f: 83 | json.dump(vars(args), f) 84 | 85 | if args.gpu is not None: 86 | print("Use GPU: {} for training".format(args.gpu)) 87 | 88 | if args.distributed: 89 | if args.dist_url == "env://" and args.rank == -1: 90 | args.rank = int(os.environ["RANK"]) 91 | if args.multiprocessing_distributed: 92 | # For multiprocessing distributed training, rank needs to be the 93 | # global rank among all the processes 94 | args.rank = args.rank * ngpus_per_node + gpu 95 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 96 | world_size=args.world_size, rank=args.rank) 97 | # create model 98 | ''' 99 | if args.pretrained: 100 | print("=> using pre-trained model '{}'".format(args.arch)) 101 | model = models.__dict__[args.arch](pretrained=True) 102 | else: 103 | print("=> creating model '{}'".format(args.arch)) 104 | model = models.__dict__[args.arch]() 105 | ''' 106 | model = build_model(args) 107 | print(model) 108 | print(args) 109 | 110 | if args.distributed: 111 | # For multiprocessing distributed, DistributedDataParallel constructor 112 | # should always set the single device scope, otherwise, 113 | # DistributedDataParallel will use all available devices. 114 | if args.gpu is not None: 115 | torch.cuda.set_device(args.gpu) 116 | model.cuda(args.gpu) 117 | # When using a single GPU per process and per 118 | # DistributedDataParallel, we need to divide the batch size 119 | # ourselves based on the total number of GPUs we have 120 | args.batch_size = int(args.batch_size / ngpus_per_node) 121 | args.workers = int(args.workers / ngpus_per_node) 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 123 | else: 124 | model.cuda() 125 | # DistributedDataParallel will divide and allocate batch_size to all 126 | # available GPUs if device_ids are not set 127 | model = torch.nn.parallel.DistributedDataParallel(model) 128 | elif args.gpu is not None: 129 | torch.cuda.set_device(args.gpu) 130 | model = model.cuda(args.gpu) 131 | else: 132 | # DataParallel will divide and allocate batch_size to all available GPUs 133 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 134 | model.features = torch.nn.DataParallel(model.features) 135 | model.cuda() 136 | else: 137 | model = torch.nn.DataParallel(model).cuda() 138 | 139 | # define loss function (criterion) and optimizer 140 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 141 | 142 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 143 | momentum=args.momentum, 144 | weight_decay=args.weight_decay, 145 | nesterov=args.nesterov) 146 | 147 | if args.lr_scheduler == 'MultiStepLR': 148 | print("using MultiStepLR with steps: ", args.lr_steps) 149 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, gamma=args.lr_reduce_factor) 150 | 151 | # optionally resume from a checkpoint 152 | if args.resume: 153 | if os.path.isfile(args.resume): 154 | print("=> loading checkpoint '{}'".format(args.resume)) 155 | checkpoint = torch.load(args.resume) 156 | args.start_epoch = checkpoint['epoch'] 157 | best_acc1 = checkpoint['best_acc1'] 158 | if args.gpu is not None: 159 | # best_acc1 may be from a checkpoint from a different GPU 160 | best_acc1 = best_acc1.to(args.gpu) 161 | model.load_state_dict(checkpoint['state_dict']) 162 | optimizer.load_state_dict(checkpoint['optimizer']) 163 | print("=> loaded checkpoint '{}' (epoch {})" 164 | .format(args.resume, checkpoint['epoch'])) 165 | if args.lr_scheduler == 'MultiStepLR': 166 | print("usingMultiStepLR with steps: ", args.lr_steps) 167 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_steps, 168 | gamma=args.lr_reduce_factor, 169 | last_epoch=(checkpoint['epoch'] - 1)) 170 | print("last_epoch: ", scheduler.last_epoch) 171 | 172 | if args.start_epoch == 0: 173 | args.start_epoch = checkpoint['epoch'] 174 | else: 175 | print("=> no checkpoint found at '{}'".format(args.resume)) 176 | 177 | cudnn.benchmark = True 178 | 179 | # Data loading code 180 | traindir = os.path.join(args.data, 'train') 181 | valdir = os.path.join(args.data, 'val') 182 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 183 | std=[0.229, 0.224, 0.225]) 184 | 185 | train_trans = [ 186 | transforms.RandomResizedCrop(args.train_crop_size), # transforms.RandomResizedCrop(224) 187 | transforms.RandomHorizontalFlip(), 188 | transforms.ToTensor(), 189 | normalize, 190 | ] 191 | 192 | train_dataset = datasets.ImageFolder( 193 | traindir, 194 | transforms.Compose(train_trans)) 195 | 196 | if args.distributed: 197 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 198 | else: 199 | train_sampler = None 200 | 201 | train_loader = torch.utils.data.DataLoader( 202 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 203 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 204 | 205 | val_loader = torch.utils.data.DataLoader( 206 | datasets.ImageFolder(valdir, transforms.Compose([ 207 | transforms.Resize(args.val_resize),#transforms.Resize(256) 208 | transforms.CenterCrop(args.val_crop_size),#transforms.CenterCrop(224) 209 | transforms.ToTensor(), 210 | normalize, 211 | ])), 212 | batch_size=args.batch_size, shuffle=False, 213 | num_workers=args.workers, pin_memory=True) 214 | 215 | if args.evaluate: 216 | validate(val_loader, model, criterion, args) 217 | return 218 | 219 | if args.resume: 220 | mode = 'a' 221 | else: 222 | mode = 'w' 223 | 224 | train_logger = Logger(os.path.join(args.result_path, 'train.log'), 225 | ['epoch', 'loss', 'acc1', 'acc5', 'lr'], mode=mode) 226 | 227 | val_logger = Logger(os.path.join(args.result_path, 'val.log'), ['epoch', 'loss', 'acc1', 'acc5'], mode=mode) 228 | 229 | for epoch in range(args.start_epoch, args.epochs): 230 | if args.distributed: 231 | train_sampler.set_epoch(epoch) 232 | 233 | if args.lr_scheduler == 'MultiStepLR': 234 | scheduler.step() 235 | else: 236 | adjust_learning_rate(optimizer, epoch, args) 237 | 238 | # train for one epoch 239 | train_acc1, train_acc5, train_loss = \ 240 | train(train_loader, model, criterion, optimizer, epoch, args) 241 | 242 | train_logger.log({ 243 | 'epoch': epoch+1, 244 | 'loss': '{:.4f}'.format(train_loss), 245 | 'acc1': '{:.2f}'.format(train_acc1.item()), 246 | 'acc5': '{:.2f}'.format(train_acc5.item()), 247 | 'lr': '{:.6f}'.format(optimizer.param_groups[0]['lr']) 248 | }) 249 | 250 | # evaluate on validation set 251 | val_acc1, val_acc5, val_loss = validate(val_loader, model, criterion, args) 252 | 253 | val_logger.log({ 254 | 'epoch': epoch+1, 255 | 'loss': '{:.4f}'.format(val_loss), 256 | 'acc1': '{:.2f}'.format(val_acc1.item()), 257 | 'acc5': '{:.2f}'.format(val_acc5.item()) 258 | }) 259 | 260 | # remember best acc@1 and save checkpoint 261 | is_best = val_acc1 > best_acc1 262 | best_acc1 = max(val_acc1, best_acc1) 263 | 264 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 265 | and args.rank % ngpus_per_node == 0): 266 | save_checkpoint({ 267 | 'epoch': epoch + 1, 268 | 'arch': args.arch, 269 | 'state_dict': model.state_dict(), 270 | 'best_acc1': best_acc1, 271 | 'val_acc1': val_acc1, 272 | 'optimizer': optimizer.state_dict(), 273 | }, is_best, args.result_path + '/') 274 | 275 | 276 | def train(train_loader, model, criterion, optimizer, epoch, args): 277 | batch_time = AverageMeter('Time', ':6.3f') 278 | data_time = AverageMeter('Data', ':6.3f') 279 | losses = AverageMeter('Loss', ':.4e') 280 | top1 = AverageMeter('Acc@1', ':6.2f') 281 | top5 = AverageMeter('Acc@5', ':6.2f') 282 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 283 | top5, prefix="Epoch: [{}]".format(epoch+1)) 284 | 285 | # switch to train mode 286 | model.train() 287 | 288 | end = time.time() 289 | for i, (input, target) in enumerate(train_loader): 290 | # measure data loading time 291 | data_time.update(time.time() - end) 292 | 293 | if args.gpu is not None: 294 | input = input.cuda(args.gpu, non_blocking=True) 295 | target = target.cuda(args.gpu, non_blocking=True) 296 | 297 | # compute output 298 | output = model(input) 299 | loss = criterion(output, target) 300 | 301 | # measure accuracy and record loss 302 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 303 | losses.update(loss.item(), input.size(0)) 304 | top1.update(acc1[0], input.size(0)) 305 | top5.update(acc5[0], input.size(0)) 306 | 307 | # compute gradient and do SGD step 308 | optimizer.zero_grad() 309 | loss.backward() 310 | optimizer.step() 311 | 312 | # measure elapsed time 313 | batch_time.update(time.time() - end) 314 | end = time.time() 315 | 316 | if i % args.print_freq == 0: 317 | progress.print(i) 318 | 319 | return top1.avg, top5.avg, losses.avg 320 | 321 | 322 | def validate(val_loader, model, criterion, args): 323 | batch_time = AverageMeter('Time', ':6.3f') 324 | losses = AverageMeter('Loss', ':.4e') 325 | top1 = AverageMeter('Acc@1', ':6.2f') 326 | top5 = AverageMeter('Acc@5', ':6.2f') 327 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 328 | prefix='Test: ') 329 | 330 | # switch to evaluate mode 331 | model.eval() 332 | 333 | with torch.no_grad(): 334 | end = time.time() 335 | for i, (input, target) in enumerate(val_loader): 336 | if args.gpu is not None: 337 | input = input.cuda(args.gpu, non_blocking=True) 338 | target = target.cuda(args.gpu, non_blocking=True) 339 | 340 | # compute output 341 | output = model(input) 342 | loss = criterion(output, target) 343 | 344 | # measure accuracy and record loss 345 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 346 | losses.update(loss.item(), input.size(0)) 347 | top1.update(acc1[0], input.size(0)) 348 | top5.update(acc5[0], input.size(0)) 349 | 350 | # measure elapsed time 351 | batch_time.update(time.time() - end) 352 | end = time.time() 353 | 354 | if i % args.print_freq == 0: 355 | progress.print(i) 356 | 357 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 358 | .format(top1=top1, top5=top5)) 359 | 360 | return top1.avg, top5.avg, losses.avg 361 | 362 | 363 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 364 | 365 | torch.save(state, filename + 'checkpoint.pth.tar') 366 | if is_best: 367 | shutil.copyfile(filename + 'checkpoint.pth.tar', filename + 'model_best.pth.tar') 368 | 369 | 370 | class AverageMeter(object): 371 | """Computes and stores the average and current value""" 372 | def __init__(self, name, fmt=':f'): 373 | self.name = name 374 | self.fmt = fmt 375 | self.reset() 376 | 377 | def reset(self): 378 | self.val = 0 379 | self.avg = 0 380 | self.sum = 0 381 | self.count = 0 382 | 383 | def update(self, val, n=1): 384 | self.val = val 385 | self.sum += val * n 386 | self.count += n 387 | self.avg = self.sum / self.count 388 | 389 | def __str__(self): 390 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 391 | return fmtstr.format(**self.__dict__) 392 | 393 | 394 | class ProgressMeter(object): 395 | def __init__(self, num_batches, *meters, prefix=""): 396 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 397 | self.meters = meters 398 | self.prefix = prefix 399 | 400 | def print(self, batch): 401 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 402 | entries += [str(meter) for meter in self.meters] 403 | print('\t'.join(entries)) 404 | 405 | def _get_batch_fmtstr(self, num_batches): 406 | num_digits = len(str(num_batches // 1)) 407 | fmt = '{:' + str(num_digits) + 'd}' 408 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 409 | 410 | 411 | def adjust_learning_rate(optimizer, epoch, args): 412 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 413 | lr = args.lr * (0.1 ** (epoch // 30)) 414 | for param_group in optimizer.param_groups: 415 | param_group['lr'] = lr 416 | 417 | 418 | def accuracy(output, target, topk=(1,)): 419 | """Computes the accuracy over the k top predictions for the specified values of k""" 420 | with torch.no_grad(): 421 | maxk = max(topk) 422 | batch_size = target.size(0) 423 | 424 | _, pred = output.topk(maxk, 1, True, True) 425 | pred = pred.t() 426 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 427 | 428 | res = [] 429 | for k in topk: 430 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 431 | res.append(correct_k.mul_(100.0 / batch_size)) 432 | return res 433 | 434 | 435 | if __name__ == '__main__': 436 | main() -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | from models import resnet, pyconvresnet, pyconvhgresnet 2 | 3 | 4 | def build_model(args): 5 | 6 | if args.arch == 'pyconvhgresnet': 7 | assert args.model_depth in [50, 101, 152] 8 | 9 | if args.model_depth == 50: 10 | model = pyconvhgresnet.pyconvhgresnet50( 11 | pretrained=args.pretrained, 12 | num_classes=args.n_classes, 13 | zero_init_residual=args.zero_init_residual) 14 | elif args.model_depth == 101: 15 | model = pyconvhgresnet.pyconvhgresnet101( 16 | pretrained=args.pretrained, 17 | num_classes=args.n_classes, 18 | zero_init_residual=args.zero_init_residual) 19 | elif args.model_depth == 152: 20 | model = pyconvhgresnet.pyconvhgresnet152( 21 | pretrained=args.pretrained, 22 | num_classes=args.n_classes, 23 | zero_init_residual=args.zero_init_residual) 24 | 25 | if args.arch == 'pyconvresnet': 26 | assert args.model_depth in [50, 101, 152] 27 | 28 | if args.model_depth == 50: 29 | model = pyconvresnet.pyconvresnet50( 30 | pretrained=args.pretrained, 31 | num_classes=args.n_classes, 32 | zero_init_residual=args.zero_init_residual) 33 | elif args.model_depth == 101: 34 | model = pyconvresnet.pyconvresnet101( 35 | pretrained=args.pretrained, 36 | num_classes=args.n_classes, 37 | zero_init_residual=args.zero_init_residual) 38 | elif args.model_depth == 152: 39 | model = pyconvresnet.pyconvresnet152( 40 | pretrained=args.pretrained, 41 | num_classes=args.n_classes, 42 | zero_init_residual=args.zero_init_residual) 43 | 44 | if args.arch == 'resnet': 45 | assert args.model_depth in [18, 34, 50, 101, 152] 46 | 47 | if args.model_depth == 18: 48 | model = resnet.resnet18( 49 | pretrained=args.pretrained, 50 | num_classes=args.n_classes, 51 | zero_init_residual=args.zero_init_residual) 52 | elif args.model_depth == 34: 53 | model = resnet.resnet34( 54 | pretrained=args.pretrained, 55 | num_classes=args.n_classes, 56 | zero_init_residual=args.zero_init_residual) 57 | elif args.model_depth == 50: 58 | model = resnet.resnet50( 59 | pretrained=args.pretrained, 60 | num_classes=args.n_classes, 61 | zero_init_residual=args.zero_init_residual) 62 | elif args.model_depth == 101: 63 | model = resnet.resnet101( 64 | pretrained=args.pretrained, 65 | num_classes=args.n_classes, 66 | zero_init_residual=args.zero_init_residual) 67 | elif args.model_depth == 152: 68 | model = resnet.resnet152( 69 | pretrained=args.pretrained, 70 | num_classes=args.n_classes, 71 | zero_init_residual=args.zero_init_residual) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /models/pyconvhgresnet.py: -------------------------------------------------------------------------------- 1 | """ PyConv networks for image recognition as presented in our paper: 2 | Duta et al. "Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition" 3 | https://arxiv.org/pdf/2006.11538.pdf 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import os 8 | from div.download_from_url import download_from_url 9 | 10 | try: 11 | from torch.hub import _get_torch_home 12 | torch_cache_home = _get_torch_home() 13 | except ImportError: 14 | torch_cache_home = os.path.expanduser( 15 | os.getenv('TORCH_HOME', os.path.join( 16 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 17 | default_cache_path = os.path.join(torch_cache_home, 'pretrained') 18 | 19 | __all__ = ['PyConvHGResNet', 'pyconvhgresnet50', 'pyconvhgresnet101', 'pyconvhgresnet152'] 20 | 21 | 22 | model_urls = { 23 | 'pyconvhgresnet50': 'https://drive.google.com/uc?export=download&id=14x0uss32ASXr4FJTE7pip004XZpwNrZe', 24 | 'pyconvhgresnet101': 'https://drive.google.com/uc?export=download&id=1Fm48GfOfn2Ivf5nBiR1SMhp66k67ePRh', 25 | 'pyconvhgresnet152': 'https://drive.google.com/uc?export=download&id=1LRmdQWTceDkepnIxZ2mWbpEE2lFxy0QO', 26 | } 27 | 28 | 29 | class PyConv2d(nn.Module): 30 | """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes. 31 | 32 | Args: 33 | in_channels (int): Number of channels in the input image 34 | out_channels (list): Number of channels for each pyramid level produced by the convolution 35 | pyconv_kernels (list): Spatial size of the kernel for each pyramid level 36 | pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level 37 | stride (int or tuple, optional): Stride of the convolution. Default: 1 38 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 39 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False`` 40 | 41 | Example:: 42 | 43 | >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5 44 | >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) 45 | >>> input = torch.randn(4, 64, 56, 56) 46 | >>> output = m(input) 47 | 48 | >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7 49 | >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) 50 | >>> input = torch.randn(4, 64, 56, 56) 51 | >>> output = m(input) 52 | """ 53 | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False): 54 | super(PyConv2d, self).__init__() 55 | 56 | assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups) 57 | 58 | self.pyconv_levels = [None] * len(pyconv_kernels) 59 | for i in range(len(pyconv_kernels)): 60 | self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i], 61 | stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i], 62 | dilation=dilation, bias=bias) 63 | self.pyconv_levels = nn.ModuleList(self.pyconv_levels) 64 | 65 | def forward(self, x): 66 | out = [] 67 | for level in self.pyconv_levels: 68 | out.append(level(x)) 69 | 70 | return torch.cat(out, 1) 71 | 72 | 73 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1): 74 | """standard convolution with padding""" 75 | return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 76 | padding=padding, dilation=dilation, groups=groups, bias=False) 77 | 78 | 79 | def conv1x1(in_planes, out_planes, stride=1): 80 | """1x1 convolution""" 81 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 82 | 83 | 84 | class PyConv4(nn.Module): 85 | 86 | def __init__(self, inplans, planes, pyconv_kernels=[3, 5, 7, 9], stride=1, pyconv_groups=[1, 4, 8, 16]): 87 | super(PyConv4, self).__init__() 88 | self.conv2_1 = conv(inplans, planes//4, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0]//2, 89 | stride=stride, groups=pyconv_groups[0]) 90 | self.conv2_2 = conv(inplans, planes//4, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1]//2, 91 | stride=stride, groups=pyconv_groups[1]) 92 | self.conv2_3 = conv(inplans, planes//4, kernel_size=pyconv_kernels[2], padding=pyconv_kernels[2]//2, 93 | stride=stride, groups=pyconv_groups[2]) 94 | self.conv2_4 = conv(inplans, planes//4, kernel_size=pyconv_kernels[3], padding=pyconv_kernels[3]//2, 95 | stride=stride, groups=pyconv_groups[3]) 96 | 97 | def forward(self, x): 98 | return torch.cat((self.conv2_1(x), self.conv2_2(x), self.conv2_3(x), self.conv2_4(x)), dim=1) 99 | 100 | 101 | class PyConv3(nn.Module): 102 | 103 | def __init__(self, inplans, planes, pyconv_kernels=[3, 5, 7], stride=1, pyconv_groups=[1, 4, 8]): 104 | super(PyConv3, self).__init__() 105 | self.conv2_1 = conv(inplans, planes // 4, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0] // 2, 106 | stride=stride, groups=pyconv_groups[0]) 107 | self.conv2_2 = conv(inplans, planes // 4, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1] // 2, 108 | stride=stride, groups=pyconv_groups[1]) 109 | self.conv2_3 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[2], padding=pyconv_kernels[2] // 2, 110 | stride=stride, groups=pyconv_groups[2]) 111 | 112 | def forward(self, x): 113 | return torch.cat((self.conv2_1(x), self.conv2_2(x), self.conv2_3(x)), dim=1) 114 | 115 | 116 | class PyConv2(nn.Module): 117 | 118 | def __init__(self, inplans, planes,pyconv_kernels=[3, 5], stride=1, pyconv_groups=[1, 4]): 119 | super(PyConv2, self).__init__() 120 | self.conv2_1 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0] // 2, 121 | stride=stride, groups=pyconv_groups[0]) 122 | self.conv2_2 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1] // 2, 123 | stride=stride, groups=pyconv_groups[1]) 124 | 125 | def forward(self, x): 126 | return torch.cat((self.conv2_1(x), self.conv2_2(x)), dim=1) 127 | 128 | 129 | def get_pyconv(inplans, planes, pyconv_kernels, stride=1, pyconv_groups=[1]): 130 | if len(pyconv_kernels) == 1: 131 | return conv(inplans, planes, kernel_size=pyconv_kernels[0], stride=stride, groups=pyconv_groups[0]) 132 | elif len(pyconv_kernels) == 2: 133 | return PyConv2(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 134 | elif len(pyconv_kernels) == 3: 135 | return PyConv3(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 136 | elif len(pyconv_kernels) == 4: 137 | return PyConv4(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 138 | 139 | 140 | class PyConvBlock(nn.Module): 141 | expansion = 2 142 | 143 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, pyconv_groups=1, pyconv_kernels=1): 144 | super(PyConvBlock, self).__init__() 145 | if norm_layer is None: 146 | norm_layer = nn.BatchNorm2d 147 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 148 | self.conv1 = conv1x1(inplanes, planes) 149 | self.bn1 = norm_layer(planes) 150 | self.conv2 = get_pyconv(planes, planes, pyconv_kernels=pyconv_kernels, stride=stride, 151 | pyconv_groups=pyconv_groups) 152 | self.bn2 = norm_layer(planes) 153 | self.conv3 = conv1x1(planes, planes * self.expansion) 154 | self.bn3 = norm_layer(planes * self.expansion) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.downsample = downsample 157 | self.stride = stride 158 | 159 | def forward(self, x): 160 | identity = x 161 | 162 | out = self.conv1(x) 163 | out = self.bn1(out) 164 | out = self.relu(out) 165 | 166 | out = self.conv2(out) 167 | out = self.bn2(out) 168 | out = self.relu(out) 169 | 170 | out = self.conv3(out) 171 | out = self.bn3(out) 172 | 173 | if self.downsample is not None: 174 | identity = self.downsample(x) 175 | 176 | out += identity 177 | out = self.relu(out) 178 | 179 | return out 180 | 181 | 182 | class PyConvHGResNet(nn.Module): 183 | 184 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None, dropout_prob0=0.0): 185 | super(PyConvHGResNet, self).__init__() 186 | if norm_layer is None: 187 | norm_layer = nn.BatchNorm2d 188 | 189 | self.inplanes = 64 190 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 191 | self.bn1 = norm_layer(64) 192 | self.relu = nn.ReLU(inplace=True) 193 | 194 | self.layer1 = self._make_layer(block, 64*2, layers[0], stride=2, norm_layer=norm_layer, 195 | pyconv_kernels=[3, 5, 7, 9], pyconv_groups=[32, 32, 32, 32]) 196 | self.layer2 = self._make_layer(block, 128*2, layers[1], stride=2, norm_layer=norm_layer, 197 | pyconv_kernels=[3, 5, 7], pyconv_groups=[32, 64, 64]) 198 | self.layer3 = self._make_layer(block, 256*2, layers[2], stride=2, norm_layer=norm_layer, 199 | pyconv_kernels=[3, 5], pyconv_groups=[32, 64]) 200 | self.layer4 = self._make_layer(block, 512*2, layers[3], stride=2, norm_layer=norm_layer, 201 | pyconv_kernels=[3], pyconv_groups=[32]) 202 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 203 | 204 | if dropout_prob0 > 0.0: 205 | self.dp = nn.Dropout(dropout_prob0, inplace=True) 206 | print("Using Dropout with the prob to set to 0 of: ", dropout_prob0) 207 | else: 208 | self.dp = None 209 | 210 | self.fc = nn.Linear(512*2 * block.expansion, num_classes) 211 | 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 215 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 216 | nn.init.constant_(m.weight, 1) 217 | nn.init.constant_(m.bias, 0) 218 | 219 | # Zero-initialize the last BN in each residual branch, 220 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 221 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 222 | if zero_init_residual: 223 | for m in self.modules(): 224 | if isinstance(m, PyConvBlock): 225 | nn.init.constant_(m.bn3.weight, 0) 226 | 227 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, pyconv_kernels=[3], pyconv_groups=[1]): 228 | if norm_layer is None: 229 | norm_layer = nn.BatchNorm2d 230 | downsample = None 231 | if stride != 1 and self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 234 | conv1x1(self.inplanes, planes * block.expansion), 235 | norm_layer(planes * block.expansion), 236 | ) 237 | elif self.inplanes != planes * block.expansion: 238 | downsample = nn.Sequential( 239 | conv1x1(self.inplanes, planes * block.expansion), 240 | norm_layer(planes * block.expansion), 241 | ) 242 | elif stride != 1: 243 | downsample = nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) 244 | 245 | layers = [] 246 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, norm_layer=norm_layer, 247 | pyconv_kernels=pyconv_kernels, pyconv_groups=pyconv_groups)) 248 | self.inplanes = planes * block.expansion 249 | 250 | for _ in range(1, blocks): 251 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, 252 | pyconv_kernels=pyconv_kernels, pyconv_groups=pyconv_groups)) 253 | 254 | return nn.Sequential(*layers) 255 | 256 | def forward(self, x): 257 | x = self.conv1(x) 258 | x = self.bn1(x) 259 | x = self.relu(x) 260 | 261 | x = self.layer1(x) 262 | x = self.layer2(x) 263 | x = self.layer3(x) 264 | x = self.layer4(x) 265 | 266 | x = self.avgpool(x) 267 | x = x.view(x.size(0), -1) 268 | 269 | if self.dp is not None: 270 | x = self.dp(x) 271 | 272 | x = self.fc(x) 273 | 274 | return x 275 | 276 | 277 | def pyconvhgresnet50(pretrained=False, **kwargs): 278 | """Constructs a PyConvHGResNet-50 model. 279 | 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | """ 283 | model = PyConvHGResNet(PyConvBlock, [3, 4, 6, 3], **kwargs) 284 | if pretrained: 285 | os.makedirs(default_cache_path, exist_ok=True) 286 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvhgresnet50'], 287 | root=default_cache_path))) 288 | return model 289 | 290 | 291 | def pyconvhgresnet101(pretrained=False, **kwargs): 292 | """Constructs a PyConvHGResNet-101 model. 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | """ 297 | model = PyConvHGResNet(PyConvBlock, [3, 4, 23, 3], **kwargs) 298 | if pretrained: 299 | os.makedirs(default_cache_path, exist_ok=True) 300 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvhgresnet101'], 301 | root=default_cache_path))) 302 | return model 303 | 304 | 305 | def pyconvhgresnet152(pretrained=False, **kwargs): 306 | """Constructs a PyConvHGResNet-152 model. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = PyConvHGResNet(PyConvBlock, [3, 8, 36, 3], **kwargs) 312 | if pretrained: 313 | os.makedirs(default_cache_path, exist_ok=True) 314 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvhgresnet152'], 315 | root=default_cache_path))) 316 | return model 317 | -------------------------------------------------------------------------------- /models/pyconvresnet.py: -------------------------------------------------------------------------------- 1 | """ PyConv networks for image recognition as presented in our paper: 2 | Duta et al. "Pyramidal Convolution: Rethinking Convolutional Neural Networks for Visual Recognition" 3 | https://arxiv.org/pdf/2006.11538.pdf 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import os 8 | from div.download_from_url import download_from_url 9 | 10 | try: 11 | from torch.hub import _get_torch_home 12 | torch_cache_home = _get_torch_home() 13 | except ImportError: 14 | torch_cache_home = os.path.expanduser( 15 | os.getenv('TORCH_HOME', os.path.join( 16 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 17 | default_cache_path = os.path.join(torch_cache_home, 'pretrained') 18 | 19 | __all__ = ['PyConvResNet', 'pyconvresnet18', 'pyconvresnet34', 'pyconvresnet50', 'pyconvresnet101', 'pyconvresnet152'] 20 | 21 | 22 | model_urls = { 23 | 'pyconvresnet50': 'https://drive.google.com/uc?export=download&id=128iMzBnHQSPNehgb8nUF5cJyKBIB7do5', 24 | 'pyconvresnet101': 'https://drive.google.com/uc?export=download&id=1fn0eKdtGG7HA30O5SJ1XrmGR_FsQxTb1', 25 | 'pyconvresnet152': 'https://drive.google.com/uc?export=download&id=1zR6HOTaHB0t15n6Nh12adX86AhBMo46m', 26 | } 27 | 28 | 29 | class PyConv2d(nn.Module): 30 | """PyConv2d with padding (general case). Applies a 2D PyConv over an input signal composed of several input planes. 31 | 32 | Args: 33 | in_channels (int): Number of channels in the input image 34 | out_channels (list): Number of channels for each pyramid level produced by the convolution 35 | pyconv_kernels (list): Spatial size of the kernel for each pyramid level 36 | pyconv_groups (list): Number of blocked connections from input channels to output channels for each pyramid level 37 | stride (int or tuple, optional): Stride of the convolution. Default: 1 38 | dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 39 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False`` 40 | 41 | Example:: 42 | 43 | >>> # PyConv with two pyramid levels, kernels: 3x3, 5x5 44 | >>> m = PyConv2d(in_channels=64, out_channels=[32, 32], pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) 45 | >>> input = torch.randn(4, 64, 56, 56) 46 | >>> output = m(input) 47 | 48 | >>> # PyConv with three pyramid levels, kernels: 3x3, 5x5, 7x7 49 | >>> m = PyConv2d(in_channels=64, out_channels=[16, 16, 32], pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) 50 | >>> input = torch.randn(4, 64, 56, 56) 51 | >>> output = m(input) 52 | """ 53 | def __init__(self, in_channels, out_channels, pyconv_kernels, pyconv_groups, stride=1, dilation=1, bias=False): 54 | super(PyConv2d, self).__init__() 55 | 56 | assert len(out_channels) == len(pyconv_kernels) == len(pyconv_groups) 57 | 58 | self.pyconv_levels = [None] * len(pyconv_kernels) 59 | for i in range(len(pyconv_kernels)): 60 | self.pyconv_levels[i] = nn.Conv2d(in_channels, out_channels[i], kernel_size=pyconv_kernels[i], 61 | stride=stride, padding=pyconv_kernels[i] // 2, groups=pyconv_groups[i], 62 | dilation=dilation, bias=bias) 63 | self.pyconv_levels = nn.ModuleList(self.pyconv_levels) 64 | 65 | def forward(self, x): 66 | out = [] 67 | for level in self.pyconv_levels: 68 | out.append(level(x)) 69 | 70 | return torch.cat(out, 1) 71 | 72 | 73 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1): 74 | """standard convolution with padding""" 75 | return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 76 | padding=padding, dilation=dilation, groups=groups, bias=False) 77 | 78 | 79 | def conv1x1(in_planes, out_planes, stride=1): 80 | """1x1 convolution""" 81 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 82 | 83 | 84 | class PyConv4(nn.Module): 85 | 86 | def __init__(self, inplans, planes, pyconv_kernels=[3, 5, 7, 9], stride=1, pyconv_groups=[1, 4, 8, 16]): 87 | super(PyConv4, self).__init__() 88 | self.conv2_1 = conv(inplans, planes//4, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0]//2, 89 | stride=stride, groups=pyconv_groups[0]) 90 | self.conv2_2 = conv(inplans, planes//4, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1]//2, 91 | stride=stride, groups=pyconv_groups[1]) 92 | self.conv2_3 = conv(inplans, planes//4, kernel_size=pyconv_kernels[2], padding=pyconv_kernels[2]//2, 93 | stride=stride, groups=pyconv_groups[2]) 94 | self.conv2_4 = conv(inplans, planes//4, kernel_size=pyconv_kernels[3], padding=pyconv_kernels[3]//2, 95 | stride=stride, groups=pyconv_groups[3]) 96 | 97 | def forward(self, x): 98 | return torch.cat((self.conv2_1(x), self.conv2_2(x), self.conv2_3(x), self.conv2_4(x)), dim=1) 99 | 100 | 101 | class PyConv3(nn.Module): 102 | 103 | def __init__(self, inplans, planes, pyconv_kernels=[3, 5, 7], stride=1, pyconv_groups=[1, 4, 8]): 104 | super(PyConv3, self).__init__() 105 | self.conv2_1 = conv(inplans, planes // 4, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0] // 2, 106 | stride=stride, groups=pyconv_groups[0]) 107 | self.conv2_2 = conv(inplans, planes // 4, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1] // 2, 108 | stride=stride, groups=pyconv_groups[1]) 109 | self.conv2_3 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[2], padding=pyconv_kernels[2] // 2, 110 | stride=stride, groups=pyconv_groups[2]) 111 | 112 | def forward(self, x): 113 | return torch.cat((self.conv2_1(x), self.conv2_2(x), self.conv2_3(x)), dim=1) 114 | 115 | 116 | class PyConv2(nn.Module): 117 | 118 | def __init__(self, inplans, planes,pyconv_kernels=[3, 5], stride=1, pyconv_groups=[1, 4]): 119 | super(PyConv2, self).__init__() 120 | self.conv2_1 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[0], padding=pyconv_kernels[0] // 2, 121 | stride=stride, groups=pyconv_groups[0]) 122 | self.conv2_2 = conv(inplans, planes // 2, kernel_size=pyconv_kernels[1], padding=pyconv_kernels[1] // 2, 123 | stride=stride, groups=pyconv_groups[1]) 124 | 125 | def forward(self, x): 126 | return torch.cat((self.conv2_1(x), self.conv2_2(x)), dim=1) 127 | 128 | 129 | def get_pyconv(inplans, planes, pyconv_kernels, stride=1, pyconv_groups=[1]): 130 | if len(pyconv_kernels) == 1: 131 | return conv(inplans, planes, kernel_size=pyconv_kernels[0], stride=stride, groups=pyconv_groups[0]) 132 | elif len(pyconv_kernels) == 2: 133 | return PyConv2(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 134 | elif len(pyconv_kernels) == 3: 135 | return PyConv3(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 136 | elif len(pyconv_kernels) == 4: 137 | return PyConv4(inplans, planes, pyconv_kernels=pyconv_kernels, stride=stride, pyconv_groups=pyconv_groups) 138 | 139 | 140 | class PyConvBlock(nn.Module): 141 | expansion = 4 142 | 143 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, pyconv_groups=1, pyconv_kernels=1): 144 | super(PyConvBlock, self).__init__() 145 | if norm_layer is None: 146 | norm_layer = nn.BatchNorm2d 147 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 148 | self.conv1 = conv1x1(inplanes, planes) 149 | self.bn1 = norm_layer(planes) 150 | self.conv2 = get_pyconv(planes, planes, pyconv_kernels=pyconv_kernels, stride=stride, 151 | pyconv_groups=pyconv_groups) 152 | self.bn2 = norm_layer(planes) 153 | self.conv3 = conv1x1(planes, planes * self.expansion) 154 | self.bn3 = norm_layer(planes * self.expansion) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.downsample = downsample 157 | self.stride = stride 158 | 159 | def forward(self, x): 160 | identity = x 161 | 162 | out = self.conv1(x) 163 | out = self.bn1(out) 164 | out = self.relu(out) 165 | 166 | out = self.conv2(out) 167 | out = self.bn2(out) 168 | out = self.relu(out) 169 | 170 | out = self.conv3(out) 171 | out = self.bn3(out) 172 | 173 | if self.downsample is not None: 174 | identity = self.downsample(x) 175 | 176 | out += identity 177 | out = self.relu(out) 178 | 179 | return out 180 | 181 | 182 | class PyConvBasicBlock1(nn.Module): 183 | expansion = 1 184 | 185 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, pyconv_groups=1, pyconv_kernels=1): 186 | super(PyConvBasicBlock1, self).__init__() 187 | if norm_layer is None: 188 | norm_layer = nn.BatchNorm2d 189 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 190 | self.conv1 = get_pyconv(inplanes, planes, pyconv_kernels=pyconv_kernels, stride=stride, 191 | pyconv_groups=pyconv_groups) 192 | self.bn1 = norm_layer(planes) 193 | self.relu = nn.ReLU(inplace=True) 194 | self.conv2 = get_pyconv(planes, planes, pyconv_kernels=pyconv_kernels, stride=1, 195 | pyconv_groups=pyconv_groups) 196 | self.bn2 = norm_layer(planes) 197 | self.downsample = downsample 198 | self.stride = stride 199 | 200 | def forward(self, x): 201 | identity = x 202 | 203 | out = self.conv1(x) 204 | out = self.bn1(out) 205 | out = self.relu(out) 206 | 207 | out = self.conv2(out) 208 | out = self.bn2(out) 209 | 210 | if self.downsample is not None: 211 | identity = self.downsample(x) 212 | 213 | out += identity 214 | out = self.relu(out) 215 | 216 | return out 217 | 218 | 219 | class PyConvBasicBlock2(nn.Module): 220 | expansion = 1 221 | 222 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, pyconv_groups=1, pyconv_kernels=1): 223 | super(PyConvBasicBlock2, self).__init__() 224 | if norm_layer is None: 225 | norm_layer = nn.BatchNorm2d 226 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 227 | self.conv1 = get_pyconv(inplanes, planes, pyconv_kernels=pyconv_kernels, stride=stride, 228 | pyconv_groups=pyconv_groups) 229 | self.bn1 = norm_layer(planes) 230 | self.relu = nn.ReLU(inplace=True) 231 | self.conv2 = conv1x1(planes, planes * self.expansion) 232 | self.bn2 = norm_layer(planes) 233 | self.downsample = downsample 234 | self.stride = stride 235 | 236 | def forward(self, x): 237 | identity = x 238 | 239 | out = self.conv1(x) 240 | out = self.bn1(out) 241 | out = self.relu(out) 242 | 243 | out = self.conv2(out) 244 | out = self.bn2(out) 245 | 246 | if self.downsample is not None: 247 | identity = self.downsample(x) 248 | 249 | out += identity 250 | out = self.relu(out) 251 | 252 | return out 253 | 254 | 255 | class PyConvResNet(nn.Module): 256 | 257 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None, dropout_prob0=0.0): 258 | super(PyConvResNet, self).__init__() 259 | if norm_layer is None: 260 | norm_layer = nn.BatchNorm2d 261 | 262 | self.inplanes = 64 263 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 264 | self.bn1 = norm_layer(64) 265 | self.relu = nn.ReLU(inplace=True) 266 | 267 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2, norm_layer=norm_layer, 268 | pyconv_kernels=[3, 5, 7, 9], pyconv_groups=[1, 4, 8, 16]) 269 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer, 270 | pyconv_kernels=[3, 5, 7], pyconv_groups=[1, 4, 8]) 271 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer, 272 | pyconv_kernels=[3, 5], pyconv_groups=[1, 4]) 273 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer, 274 | pyconv_kernels=[3], pyconv_groups=[1]) 275 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 276 | 277 | if dropout_prob0 > 0.0: 278 | self.dp = nn.Dropout(dropout_prob0, inplace=True) 279 | print("Using Dropout with the prob to set to 0 of: ", dropout_prob0) 280 | else: 281 | self.dp = None 282 | 283 | self.fc = nn.Linear(512 * block.expansion, num_classes) 284 | 285 | for m in self.modules(): 286 | if isinstance(m, nn.Conv2d): 287 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 288 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 289 | nn.init.constant_(m.weight, 1) 290 | nn.init.constant_(m.bias, 0) 291 | 292 | # Zero-initialize the last BN in each residual branch, 293 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 294 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 295 | if zero_init_residual: 296 | for m in self.modules(): 297 | if isinstance(m, PyConvBlock): 298 | nn.init.constant_(m.bn3.weight, 0) 299 | 300 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None, pyconv_kernels=[3], pyconv_groups=[1]): 301 | if norm_layer is None: 302 | norm_layer = nn.BatchNorm2d 303 | downsample = None 304 | if stride != 1 and self.inplanes != planes * block.expansion: 305 | downsample = nn.Sequential( 306 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 307 | conv1x1(self.inplanes, planes * block.expansion), 308 | norm_layer(planes * block.expansion), 309 | ) 310 | elif self.inplanes != planes * block.expansion: 311 | downsample = nn.Sequential( 312 | conv1x1(self.inplanes, planes * block.expansion), 313 | norm_layer(planes * block.expansion), 314 | ) 315 | elif stride != 1: 316 | downsample = nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) 317 | 318 | layers = [] 319 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample, norm_layer=norm_layer, 320 | pyconv_kernels=pyconv_kernels, pyconv_groups=pyconv_groups)) 321 | self.inplanes = planes * block.expansion 322 | 323 | for _ in range(1, blocks): 324 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, 325 | pyconv_kernels=pyconv_kernels, pyconv_groups=pyconv_groups)) 326 | 327 | return nn.Sequential(*layers) 328 | 329 | def forward(self, x): 330 | x = self.conv1(x) 331 | x = self.bn1(x) 332 | x = self.relu(x) 333 | 334 | x = self.layer1(x) 335 | x = self.layer2(x) 336 | x = self.layer3(x) 337 | x = self.layer4(x) 338 | 339 | x = self.avgpool(x) 340 | x = x.view(x.size(0), -1) 341 | 342 | if self.dp is not None: 343 | x = self.dp(x) 344 | 345 | x = self.fc(x) 346 | 347 | return x 348 | 349 | 350 | def pyconvresnet18(pretrained=False, **kwargs): 351 | """Constructs a PyConvResNet-18 model. 352 | 353 | Args: 354 | pretrained (bool): If True, returns a model pre-trained on ImageNet 355 | """ 356 | #model = PyConvResNet(PyConvBasicBlock1, [2, 2, 2, 2], **kwargs) #params=11.21M GFLOPs 1.55 357 | model = PyConvResNet(PyConvBasicBlock2, [2, 2, 2, 2], **kwargs) #params=5.91M GFLOPs 0.88 358 | if pretrained: 359 | raise NotImplementedError("Not available the pretrained model yet!") 360 | 361 | return model 362 | 363 | 364 | def pyconvresnet34(pretrained=False, **kwargs): 365 | """Constructs a PyConvResNet-34 model. 366 | 367 | Args: 368 | pretrained (bool): If True, returns a model pre-trained on ImageNet 369 | """ 370 | #model = PyConvResNet(PyConvBasicBlock1, [3, 4, 6, 3], **kwargs) #params=20.44M GFLOPs 3.09 371 | model = PyConvResNet(PyConvBasicBlock2, [3, 4, 6, 3], **kwargs) #params=11.09M GFLOPs 1.75 372 | if pretrained: 373 | raise NotImplementedError("Not available the pretrained model yet!") 374 | 375 | return model 376 | 377 | 378 | def pyconvresnet50(pretrained=False, **kwargs): 379 | """Constructs a PyConvResNet-50 model. 380 | 381 | Args: 382 | pretrained (bool): If True, returns a model pre-trained on ImageNet 383 | """ 384 | model = PyConvResNet(PyConvBlock, [3, 4, 6, 3], **kwargs) 385 | if pretrained: 386 | os.makedirs(default_cache_path, exist_ok=True) 387 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvresnet50'], 388 | root=default_cache_path))) 389 | return model 390 | 391 | 392 | def pyconvresnet101(pretrained=False, **kwargs): 393 | """Constructs a PyConvResNet-101 model. 394 | 395 | Args: 396 | pretrained (bool): If True, returns a model pre-trained on ImageNet 397 | """ 398 | model = PyConvResNet(PyConvBlock, [3, 4, 23, 3], **kwargs) 399 | if pretrained: 400 | os.makedirs(default_cache_path, exist_ok=True) 401 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvresnet101'], 402 | root=default_cache_path))) 403 | return model 404 | 405 | 406 | def pyconvresnet152(pretrained=False, **kwargs): 407 | """Constructs a PyConvResNet-152 model. 408 | 409 | Args: 410 | pretrained (bool): If True, returns a model pre-trained on ImageNet 411 | """ 412 | model = PyConvResNet(PyConvBlock, [3, 8, 36, 3], **kwargs) 413 | if pretrained: 414 | os.makedirs(default_cache_path, exist_ok=True) 415 | model.load_state_dict(torch.load(download_from_url(model_urls['pyconvresnet152'], 416 | root=default_cache_path))) 417 | return model 418 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from div.download_from_url import download_from_url 5 | 6 | try: 7 | from torch.hub import _get_torch_home 8 | torch_cache_home = _get_torch_home() 9 | except ImportError: 10 | torch_cache_home = os.path.expanduser( 11 | os.getenv('TORCH_HOME', os.path.join( 12 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 13 | default_cache_path = os.path.join(torch_cache_home, 'pretrained') 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'Trained model not available yet!!', 20 | 'resnet34': 'Trained model not available yet!!', 21 | 'resnet50': 'https://drive.google.com/uc?export=download&id=176TS0b6O0NALBbfzpz4mM1b47s4dwSVH', 22 | 'resnet101': 'https://drive.google.com/uc?export=download&id=1bermctRPLs5DIsHB0c4iDIGHvjfERPLG', 23 | 'resnet152': 'https://drive.google.com/uc?export=download&id=1FAggTH4m7Kec8MyRe8dx-ugI_yh-nLzL', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): 42 | super(BasicBlock, self).__init__() 43 | if norm_layer is None: 44 | norm_layer = nn.BatchNorm2d 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): 77 | super(Bottleneck, self).__init__() 78 | if norm_layer is None: 79 | norm_layer = nn.BatchNorm2d 80 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv1x1(inplanes, planes) 82 | self.bn1 = norm_layer(planes) 83 | self.conv2 = conv3x3(planes, planes, stride) 84 | self.bn2 = norm_layer(planes) 85 | self.conv3 = conv1x1(planes, planes * self.expansion) 86 | self.bn3 = norm_layer(planes * self.expansion) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | 114 | class ResNet(nn.Module): 115 | 116 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None, dropout_prob0=0.0): 117 | super(ResNet, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | 121 | self.inplanes = 64 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 123 | bias=False) 124 | self.bn1 = norm_layer(64) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 127 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 130 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 131 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 132 | 133 | if dropout_prob0 > 0.0: 134 | self.dp = nn.Dropout(dropout_prob0, inplace=True) 135 | print("Using Dropout with the prob to set to 0 of: ", dropout_prob0) 136 | else: 137 | self.dp = None 138 | 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | # Zero-initialize the last BN in each residual branch, 149 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 150 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 151 | if zero_init_residual: 152 | for m in self.modules(): 153 | if isinstance(m, Bottleneck): 154 | nn.init.constant_(m.bn3.weight, 0) 155 | elif isinstance(m, BasicBlock): 156 | nn.init.constant_(m.bn2.weight, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 159 | if norm_layer is None: 160 | norm_layer = nn.BatchNorm2d 161 | downsample = None 162 | if stride != 1 or self.inplanes != planes * block.expansion: 163 | downsample = nn.Sequential( 164 | conv1x1(self.inplanes, planes * block.expansion, stride), 165 | norm_layer(planes * block.expansion), 166 | ) 167 | 168 | layers = [] 169 | layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) 170 | self.inplanes = planes * block.expansion 171 | for _ in range(1, blocks): 172 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) 173 | 174 | return nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | x = self.conv1(x) 178 | x = self.bn1(x) 179 | x = self.relu(x) 180 | x = self.maxpool(x) 181 | 182 | x = self.layer1(x) 183 | x = self.layer2(x) 184 | x = self.layer3(x) 185 | x = self.layer4(x) 186 | 187 | x = self.avgpool(x) 188 | x = x.view(x.size(0), -1) 189 | 190 | if self.dp is not None: 191 | x = self.dp(x) 192 | 193 | x = self.fc(x) 194 | 195 | return x 196 | 197 | 198 | def resnet18(pretrained=False, **kwargs): 199 | """Constructs a ResNet-18 model. 200 | 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 205 | if pretrained: 206 | os.makedirs(default_cache_path, exist_ok=True) 207 | model.load_state_dict(torch.load(download_from_url(model_urls['resnet18'], 208 | root=default_cache_path))) 209 | return model 210 | 211 | 212 | def resnet34(pretrained=False, **kwargs): 213 | """Constructs a ResNet-34 model. 214 | 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | """ 218 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 219 | if pretrained: 220 | os.makedirs(default_cache_path, exist_ok=True) 221 | model.load_state_dict(torch.load(download_from_url(model_urls['resnet34'], 222 | root=default_cache_path))) 223 | return model 224 | 225 | 226 | def resnet50(pretrained=False, **kwargs): 227 | """Constructs a ResNet-50 model. 228 | 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 233 | if pretrained: 234 | os.makedirs(default_cache_path, exist_ok=True) 235 | model.load_state_dict(torch.load(download_from_url(model_urls['resnet50'], 236 | root=default_cache_path))) 237 | return model 238 | 239 | 240 | def resnet101(pretrained=False, **kwargs): 241 | """Constructs a ResNet-101 model. 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | if pretrained: 248 | os.makedirs(default_cache_path, exist_ok=True) 249 | model.load_state_dict(torch.load(download_from_url(model_urls['resnet101'], 250 | root=default_cache_path))) 251 | return model 252 | 253 | 254 | def resnet152(pretrained=False, **kwargs): 255 | """Constructs a ResNet-152 model. 256 | 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 261 | if pretrained: 262 | os.makedirs(default_cache_path, exist_ok=True) 263 | model.load_state_dict(torch.load(download_from_url(model_urls['resnet152'], 264 | root=default_cache_path))) 265 | return model 266 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | 4 | class Logger(object): 5 | 6 | def __init__(self, path, header, mode='w'): 7 | self.log_file = open(path, mode=mode) 8 | self.logger = csv.writer(self.log_file, delimiter='\t') 9 | 10 | if mode != 'a': 11 | self.logger.writerow(header) 12 | 13 | self.header = header 14 | 15 | def __del(self): 16 | self.log_file.close() 17 | 18 | def log(self, values): 19 | write_values = [] 20 | for col in self.header: 21 | assert col in values 22 | write_values.append(values[col]) 23 | 24 | self.logger.writerow(write_values) 25 | self.log_file.flush() 26 | --------------------------------------------------------------------------------