├── .gitignore ├── img └── framework.png ├── script ├── train.sh └── command.sh ├── util ├── accuracy.py ├── torch_dist_sum.py ├── meter.py └── dist_init.py ├── network ├── backbone.py ├── head.py ├── ressl.py ├── ressl_multi.py └── resnet.py ├── README.md ├── data ├── imagenet.py ├── augmentation.py └── randaugment.py ├── ressl.py ├── ressl_multi.py └── linear_eval.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | .vscode/** 4 | checkpoints/* -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingkai-zheng/ReSSL/HEAD/img/framework.png -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ./script/command.sh ressl-200 8 1 "python -u ressl.py --backbone resnet50 --epochs 200 --lr 0.05 --t 0.04" 4 | # ./script/command.sh ressl-multi-200 8 1 "python -u ressl_multi.py --backbone resnet50 --epochs 200 --lr 0.05 --t 0.04" 5 | 6 | # ./script/command.sh eval_ressl-200 8 1 "python -u linear_eval.py --backbone resnet50 --checkpoint ressl-200.pth" 7 | # ./script/command.sh eval_ressl-multi-200 8 1 "python -u linear_eval.py --backbone resnet50 --checkpoint ressl-multi-200.pth" 8 | 9 | -------------------------------------------------------------------------------- /script/command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | job_name=$1 5 | train_gpu=$2 6 | num_node=$3 7 | command=$4 8 | total_process=$((train_gpu*num_node)) 9 | 10 | mkdir -p log 11 | 12 | now=$(date +"%Y%m%d_%H%M%S") 13 | 14 | # nohup 15 | GLOG_vmodule=MemcachedClient=-1 \ 16 | srun --partition=3dv-share \ 17 | --mpi=pmi2 -n$total_process \ 18 | --gres=gpu:$train_gpu \ 19 | --ntasks-per-node=$train_gpu \ 20 | --job-name=$job_name \ 21 | --kill-on-bad-exit=1 \ 22 | --cpus-per-task=5 \ 23 | $command 2>&1|tee -a log/$job_name.log & 24 | 25 | -------------------------------------------------------------------------------- /util/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the accuracy over the k top predictions for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | -------------------------------------------------------------------------------- /network/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from network.resnet import * 3 | from network.head import * 4 | 5 | 6 | backbone_dict = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | } 11 | 12 | dim_dict = { 13 | 'resnet18': 512, 14 | 'resnet34': 512, 15 | 'resnet50': 2048, 16 | } 17 | 18 | 19 | class BackBone(nn.Module): 20 | def __init__(self, backbone='resnet50', hidden_dim=4096, dim=512): 21 | super().__init__() 22 | dim_in = dim_dict[backbone] 23 | self.net = backbone_dict[backbone]() 24 | self.head = ProjectionHead(dim_in=dim_in, hidden_dim=hidden_dim, dim_out=dim) 25 | 26 | def forward(self, x): 27 | feat = self.net(x) 28 | embedding = self.head(feat) 29 | return F.normalize(embedding) 30 | 31 | -------------------------------------------------------------------------------- /util/torch_dist_sum.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import torch 12 | 13 | __all__ = ['torch_dist_sum'] 14 | 15 | def torch_dist_sum(gpu, *args): 16 | process_group = torch.distributed.group.WORLD 17 | tensor_args = [] 18 | pending_res = [] 19 | for arg in args: 20 | if isinstance(arg, torch.Tensor): 21 | tensor_arg = arg.clone().reshape(-1).detach().cuda(gpu) 22 | else: 23 | tensor_arg = torch.tensor(arg).reshape(-1).cuda(gpu) 24 | torch.distributed.all_reduce(tensor_arg, group=process_group) 25 | tensor_args.append(tensor_arg) 26 | 27 | return tensor_args 28 | -------------------------------------------------------------------------------- /network/head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class LinearHead(nn.Module): 7 | def __init__(self, net, dim_in=2048, dim_out=1000): 8 | super().__init__() 9 | self.net = net 10 | self.fc = nn.Linear(dim_in, dim_out) 11 | 12 | for param in self.net.parameters(): 13 | param.requires_grad = False 14 | 15 | self.fc.weight.data.normal_(mean=0.0, std=0.01) 16 | self.fc.bias.data.zero_() 17 | 18 | def forward(self, x): 19 | with torch.no_grad(): 20 | feat = self.net(x) 21 | return self.fc(feat) 22 | 23 | 24 | class ProjectionHead(nn.Module): 25 | def __init__(self, dim_in=2048, hidden_dim=4096, dim_out=512): 26 | super().__init__() 27 | 28 | self.linear1 = nn.Linear(dim_in, hidden_dim) 29 | self.relu1 = nn.ReLU(True) 30 | self.linear2 = nn.Linear(hidden_dim, dim_out) 31 | 32 | def forward(self, x): 33 | x = self.linear1(x) 34 | x = self.relu1(x) 35 | x = self.linear2(x) 36 | return x 37 | -------------------------------------------------------------------------------- /util/meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self, name, fmt=':f'): 5 | self.name = name 6 | self.fmt = fmt 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | def __str__(self): 22 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 23 | return fmtstr.format(**self.__dict__) 24 | 25 | 26 | class ProgressMeter(object): 27 | def __init__(self, num_batches, meters, prefix=""): 28 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 29 | self.meters = meters 30 | self.prefix = prefix 31 | 32 | def display(self, batch): 33 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 34 | entries += [str(meter) for meter in self.meters] 35 | print('\t'.join(entries)) 36 | 37 | def _get_batch_fmtstr(self, num_batches): 38 | num_digits = len(str(num_batches // 1)) 39 | fmt = '{:' + str(num_digits) + 'd}' 40 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 41 | 42 | -------------------------------------------------------------------------------- /util/dist_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def dist_init(port=23456): 5 | 6 | def init_parrots(host_addr, rank, local_rank, world_size, port): 7 | os.environ['MASTER_ADDR'] = str(host_addr) 8 | os.environ['MASTER_PORT'] = str(port) 9 | os.environ['WORLD_SIZE'] = str(world_size) 10 | os.environ['RANK'] = str(rank) 11 | torch.distributed.init_process_group(backend="nccl") 12 | torch.cuda.set_device(local_rank) 13 | 14 | def init(host_addr, rank, local_rank, world_size, port): 15 | host_addr_full = 'tcp://' + host_addr + ':' + str(port) 16 | torch.distributed.init_process_group("nccl", init_method=host_addr_full, 17 | rank=rank, world_size=world_size) 18 | torch.cuda.set_device(local_rank) 19 | assert torch.distributed.is_initialized() 20 | 21 | 22 | def parse_host_addr(s): 23 | if '[' in s: 24 | left_bracket = s.index('[') 25 | right_bracket = s.index(']') 26 | prefix = s[:left_bracket] 27 | first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0] 28 | return prefix + first_number 29 | else: 30 | return s 31 | 32 | rank = int(os.environ['SLURM_PROCID']) 33 | local_rank = int(os.environ['SLURM_LOCALID']) 34 | world_size = int(os.environ['SLURM_NTASKS']) 35 | 36 | ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST']) 37 | 38 | if torch.__version__ == 'parrots': 39 | init_parrots(ip, rank, local_rank, world_size, port) 40 | else: 41 | init(ip, rank, local_rank, world_size, port) 42 | 43 | return rank, local_rank, world_size 44 | 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReSSL: Relational Self-Supervised Learning with Weak Augmentation (NeurIPS 2021) 2 | 3 | This repository contains PyTorch evaluation code, training code and pretrained models for ReSSL. 4 | 5 | For details see [ReSSL: Relational Self-Supervised Learning with Weak Augmentation](https://proceedings.neurips.cc/paper/2021/file/14c4f36143b4b09cbc320d7c95a50ee7-Paper.pdf) by Mingkai Zheng, Shan You, Fei Wang, Chen Qian, Changshui Zhang, Xiaogang Wang and Chang Xu 6 | 7 | ![ReSSL](img/framework.png) 8 | 9 | 10 | ## Cifar10 / STL10 11 | This repository is based on ImageNet dataset, We also provide the training code and pretrained model for cifar10/100, STL10 and TinyImageNet, please download it from [this link](https://drive.google.com/file/d/1j2I1Lh9Dy7cHb6YO0PZ8HXDNewXrHO-j/view?usp=sharing). 12 | 13 | ## Reproducing 14 | 15 | To run the code, you probably need to change the Dataset setting (dataset/imagenet.py), and Pytorch DDP setting (util/dist_init.py) for your own server enviroments. 16 | 17 | The distribued training of this code is base on slurm enviroments, we have provide the training scrips under the script folder. 18 | 19 | 20 | We also provide the pretrained model for ResNet50 (single crop and 5 crops) 21 | 22 | | |Arch | BatchSize | Epochs | Crops | Linear Eval | Download | 23 | |----------|:----:|:---:|:---:|:---:|:---:|:---:| 24 | | ReSSL | ResNet50 | 256 | 200 | 1 | 69.9 % | [ressl-200.pth](https://drive.google.com/file/d/16Ib4rvEvB_rdQThPxkoOb9wvCALzPTZd/view?usp=sharing) | 25 | | ReSSL | ResNet50 | 256 | 200 | 5 | 74.7 % | [ressl-multi-200.pth](https://drive.google.com/file/d/1usvvFAw_1bOaiXBgxXG9kwOOPb0VAy0Y/view?usp=sharing) | 26 | 27 | If you want to test the pretained model, please download the weights from the link above, and move it to the checkpoints folder (create one if you don't have .checkpoints/ directory). The evaluation scripts also has been provided in script/train.sh 28 | 29 | 30 | ## Citation 31 | If you find that ReSSL interesting and help your research, please consider citing it: 32 | ``` 33 | 34 | @inproceedings{ 35 | zheng2021ressl, 36 | title={Re{SSL}: Relational Self-Supervised Learning with Weak Augmentation}, 37 | author={Mingkai Zheng and Shan You and Fei Wang and Chen Qian and Changshui Zhang and Xiaogang Wang and Chang Xu}, 38 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 39 | year={2021}, 40 | url={https://openreview.net/forum?id=ErivP29kYnx} 41 | } 42 | ``` 43 | 44 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import mc 6 | import io 7 | 8 | 9 | class DatasetCache(data.Dataset): 10 | def __init__(self): 11 | super().__init__() 12 | self.initialized = False 13 | 14 | 15 | def _init_memcached(self): 16 | if not self.initialized: 17 | server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" 18 | client_config_file = "/mnt/lustre/share/memcached_client/client.conf" 19 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) 20 | self.initialized = True 21 | 22 | def load_image(self, filename): 23 | self._init_memcached() 24 | value = mc.pyvector() 25 | self.mclient.Get(filename, value) 26 | value_str = mc.ConvertBuffer(value) 27 | 28 | buff = io.BytesIO(value_str) 29 | with Image.open(buff) as img: 30 | img = img.convert('RGB') 31 | return img 32 | 33 | 34 | 35 | class BaseDataset(DatasetCache): 36 | def __init__(self, mode='train', max_class=1000, aug=None): 37 | super().__init__() 38 | self.initialized = False 39 | 40 | 41 | prefix = '/mnt/lustreold/share/images/meta' 42 | image_folder_prefix = '/mnt/lustreold/share/images' 43 | if mode == 'train': 44 | image_list = os.path.join(prefix, 'train.txt') 45 | self.image_folder = os.path.join(image_folder_prefix, 'train') 46 | elif mode == 'test': 47 | image_list = os.path.join(prefix, 'test.txt') 48 | self.image_folder = os.path.join(image_folder_prefix, 'test') 49 | elif mode == 'val': 50 | image_list = os.path.join(prefix, 'val.txt') 51 | self.image_folder = os.path.join(image_folder_prefix, 'val') 52 | else: 53 | raise NotImplementedError('mode: ' + mode + ' does not exist please select from [train, test, eval]') 54 | 55 | 56 | self.samples = [] 57 | with open(image_list) as f: 58 | for line in f: 59 | name, label = line.split() 60 | label = int(label) 61 | if label < max_class: 62 | self.samples.append((label, name)) 63 | 64 | if aug is None: 65 | self.transform = transforms.Compose([ 66 | transforms.RandomResizedCrop(224), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]) 71 | ]) 72 | else: 73 | self.transform = aug 74 | 75 | 76 | 77 | 78 | class ImagenetContrastive(BaseDataset): 79 | def __init__(self, mode='train', max_class=1000, aug=None): 80 | super().__init__(mode, max_class, aug) 81 | 82 | def __len__(self): 83 | return self.samples.__len__() 84 | 85 | def __getitem__(self, index): 86 | _, name = self.samples[index] 87 | filename = os.path.join(self.image_folder, name) 88 | img = self.load_image(filename) 89 | if isinstance(self.transform, list): 90 | return self.transform[0](img), self.transform[1](img) 91 | return self.transform(img), self.transform(img) 92 | 93 | 94 | 95 | class Imagenet(BaseDataset): 96 | def __init__(self, mode='train', max_class=1000, aug=None): 97 | super().__init__(mode, max_class, aug) 98 | 99 | def __len__(self): 100 | return self.samples.__len__() 101 | 102 | def __getitem__(self, index): 103 | label, name = self.samples[index] 104 | filename = os.path.join(self.image_folder, name) 105 | img = self.load_image(filename) 106 | return self.transform(img), label 107 | 108 | -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | 4 | from PIL import ImageFilter, Image, ImageOps 5 | from data.randaugment import RandAugment 6 | 7 | class GaussianBlur(object): 8 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 9 | 10 | def __init__(self, sigma=[.1, 2.]): 11 | self.sigma = sigma 12 | 13 | def __call__(self, x): 14 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 15 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 16 | return x 17 | 18 | 19 | class Solarize(): 20 | def __init__(self, threshold=128): 21 | self.threshold = threshold 22 | def __call__(self, sample): 23 | return ImageOps.solarize(sample, self.threshold) 24 | 25 | 26 | 27 | moco_aug = transforms.Compose([ 28 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 29 | transforms.RandomApply([ 30 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 31 | ], p=0.8), 32 | transforms.RandomGrayscale(p=0.2), 33 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 37 | ]) 38 | 39 | target_aug = transforms.Compose([ 40 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 44 | ]) 45 | 46 | 47 | eval_aug = transforms.Compose([ 48 | transforms.Resize(256), 49 | transforms.CenterCrop(224), 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 52 | std=[0.229, 0.224, 0.225]), 53 | ]) 54 | 55 | class Multi_Transform(object): 56 | def __init__( 57 | self, 58 | size_crops=[224, 192, 160, 128, 96], 59 | nmb_crops=[1, 1, 1, 1, 1], 60 | min_scale_crops=[0.2, 0.172, 0.143, 0.114, 0.086], 61 | max_scale_crops=[1.0, 0.86, 0.715, 0.571, 0.429], 62 | init_size=224, 63 | strong=False): 64 | assert len(size_crops) == len(nmb_crops) 65 | assert len(min_scale_crops) == len(nmb_crops) 66 | assert len(max_scale_crops) == len(nmb_crops) 67 | trans=[] 68 | 69 | self.strong = strong 70 | 71 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 72 | 73 | #image_k 74 | weak = transforms.Compose([ 75 | transforms.RandomResizedCrop(init_size, scale=(0.2, 1.)), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | normalize 79 | ]) 80 | trans.append(weak) 81 | 82 | 83 | trans_weak=[] 84 | if strong: 85 | min_scale_crops=[0.08, 0.08, 0.08, 0.08, 0.08] 86 | jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2) 87 | else: 88 | jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 89 | 90 | 91 | for i in range(len(size_crops)): 92 | aug_list = [ 93 | transforms.RandomResizedCrop( 94 | size_crops[i], 95 | scale=(min_scale_crops[i], max_scale_crops[i]) 96 | ), 97 | transforms.RandomApply([jitter], p=0.8), 98 | transforms.RandomGrayscale(p=0.2), 99 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 100 | transforms.RandomHorizontalFlip(), 101 | ] 102 | 103 | if self.strong: 104 | aug_list.append(RandAugment(5, 10)) 105 | 106 | aug_list.extend([ 107 | transforms.ToTensor(), 108 | normalize 109 | ]) 110 | 111 | aug = transforms.Compose(aug_list) 112 | trans_weak.extend([aug]*nmb_crops[i]) 113 | 114 | trans.extend(trans_weak) 115 | self.trans=trans 116 | print("in total we have %d transforms"%(len(self.trans))) 117 | def __call__(self, x): 118 | multi_crops = list(map(lambda trans: trans(x), self.trans)) 119 | return multi_crops 120 | 121 | 122 | -------------------------------------------------------------------------------- /network/ressl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | from network.backbone import * 5 | 6 | 7 | class ReSSL(nn.Module): 8 | """ 9 | Build a MoCo model with: a query encoder, a key encoder, and a queue 10 | https://arxiv.org/abs/1911.05722 11 | """ 12 | def __init__(self, backbone='resnet50', dim=512, K=65536*2, m=0.999): 13 | """ 14 | dim: feature dimension (default: 512) 15 | K: queue size; number of negative keys (default: 65536*2) 16 | m: moco momentum of updating key encoder (default: 0.999) 17 | """ 18 | super(ReSSL, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | 23 | # create the encoders 24 | self.encoder_q = BackBone(backbone=backbone, dim=dim) 25 | self.encoder_k = BackBone(backbone=backbone, dim=dim) 26 | 27 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 28 | param_k.data.copy_(param_q.data) # initialize 29 | param_k.requires_grad = False # not update by gradient 30 | 31 | # create the queue 32 | self.register_buffer("queue", torch.randn(dim, K)) 33 | self.queue = nn.functional.normalize(self.queue, dim=0) 34 | 35 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 36 | 37 | @torch.no_grad() 38 | def _momentum_update_key_encoder(self): 39 | """ 40 | Momentum update of the key encoder 41 | """ 42 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 43 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 44 | 45 | @torch.no_grad() 46 | def _dequeue_and_enqueue(self, keys): 47 | # gather keys before updating queue 48 | keys = concat_all_gather(keys) 49 | 50 | batch_size = keys.shape[0] 51 | 52 | ptr = int(self.queue_ptr) 53 | assert self.K % batch_size == 0 # for simplicity 54 | 55 | # replace the keys at ptr (dequeue and enqueue) 56 | self.queue[:, ptr:ptr + batch_size] = keys.T 57 | ptr = (ptr + batch_size) % self.K # move pointer 58 | 59 | self.queue_ptr[0] = ptr 60 | 61 | @torch.no_grad() 62 | def _batch_shuffle_ddp(self, x): 63 | """ 64 | Batch shuffle, for making use of BatchNorm. 65 | *** Only support DistributedDataParallel (DDP) model. *** 66 | """ 67 | # gather from all gpus 68 | batch_size_this = x.shape[0] 69 | x_gather = concat_all_gather(x) 70 | batch_size_all = x_gather.shape[0] 71 | 72 | num_gpus = batch_size_all // batch_size_this 73 | 74 | # random shuffle index 75 | idx_shuffle = torch.randperm(batch_size_all).cuda() 76 | 77 | # broadcast to all gpus 78 | torch.distributed.broadcast(idx_shuffle, src=0) 79 | 80 | # index for restoring 81 | idx_unshuffle = torch.argsort(idx_shuffle) 82 | 83 | # shuffled index for this gpu 84 | gpu_idx = torch.distributed.get_rank() 85 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 86 | 87 | return x_gather[idx_this], idx_unshuffle 88 | 89 | @torch.no_grad() 90 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 91 | """ 92 | Undo batch shuffle. 93 | *** Only support DistributedDataParallel (DDP) model. *** 94 | """ 95 | # gather from all gpus 96 | batch_size_this = x.shape[0] 97 | x_gather = concat_all_gather(x) 98 | batch_size_all = x_gather.shape[0] 99 | 100 | num_gpus = batch_size_all // batch_size_this 101 | 102 | # restored index for this gpu 103 | gpu_idx = torch.distributed.get_rank() 104 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 105 | 106 | return x_gather[idx_this] 107 | 108 | def forward(self, im_q, im_k): 109 | """ 110 | Input: 111 | im_q: contrastive augmented image 112 | im_k: weak augmented image 113 | Output: 114 | logitsq, logitsk 115 | """ 116 | 117 | q = self.encoder_q(im_q) 118 | 119 | # compute key features 120 | with torch.no_grad(): # no gradient to keys 121 | self._momentum_update_key_encoder() # update the key encoder 122 | 123 | # shuffle for making use of BN 124 | im, idx_unshuffle = self._batch_shuffle_ddp(im_k) 125 | k = self.encoder_k(im) # keys: NxC 126 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 127 | 128 | logitsq = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 129 | logitsk = torch.einsum('nc,ck->nk', [k, self.queue.clone().detach()]) 130 | 131 | # dequeue and enqueue 132 | self._dequeue_and_enqueue(k) 133 | 134 | return logitsq, logitsk 135 | 136 | 137 | # utils 138 | @torch.no_grad() 139 | def concat_all_gather(tensor): 140 | """ 141 | Performs all_gather operation on the provided tensors. 142 | *** Warning ***: torch.distributed.all_gather has no gradient. 143 | """ 144 | tensors_gather = [torch.ones_like(tensor) 145 | for _ in range(torch.distributed.get_world_size())] 146 | torch.distributed.all_gather(tensors_gather, tensor) 147 | 148 | output = torch.cat(tensors_gather, dim=0) 149 | return output 150 | -------------------------------------------------------------------------------- /network/ressl_multi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | from network.backbone import * 5 | 6 | 7 | class ReSSL(nn.Module): 8 | """ 9 | Build a MoCo model with: a query encoder, a key encoder, and a queue 10 | https://arxiv.org/abs/1911.05722 11 | """ 12 | def __init__(self, backbone='resnet50', dim=512, K=65536*2, m=0.999): 13 | """ 14 | dim: feature dimension (default: 512) 15 | K: queue size; number of negative keys (default: 65536*2) 16 | m: moco momentum of updating key encoder (default: 0.999) 17 | """ 18 | super(ReSSL, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | 23 | # create the encoders 24 | # num_classes is the output fc dimension 25 | self.encoder_q = BackBone(backbone=backbone, dim=dim) 26 | self.encoder_k = BackBone(backbone=backbone, dim=dim) 27 | 28 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 29 | param_k.data.copy_(param_q.data) # initialize 30 | param_k.requires_grad = False # not update by gradient 31 | 32 | # create the queue 33 | self.register_buffer("queue", torch.randn(dim, K)) 34 | self.queue = nn.functional.normalize(self.queue, dim=0) 35 | 36 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 37 | 38 | @torch.no_grad() 39 | def _momentum_update_key_encoder(self): 40 | """ 41 | Momentum update of the key encoder 42 | """ 43 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 44 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 45 | 46 | @torch.no_grad() 47 | def _dequeue_and_enqueue(self, keys): 48 | # gather keys before updating queue 49 | keys = concat_all_gather(keys) 50 | 51 | batch_size = keys.shape[0] 52 | 53 | ptr = int(self.queue_ptr) 54 | assert self.K % batch_size == 0 # for simplicity 55 | 56 | # replace the keys at ptr (dequeue and enqueue) 57 | self.queue[:, ptr:ptr + batch_size] = keys.T 58 | ptr = (ptr + batch_size) % self.K # move pointer 59 | 60 | self.queue_ptr[0] = ptr 61 | 62 | @torch.no_grad() 63 | def _batch_shuffle_ddp(self, x): 64 | """ 65 | Batch shuffle, for making use of BatchNorm. 66 | *** Only support DistributedDataParallel (DDP) model. *** 67 | """ 68 | # gather from all gpus 69 | batch_size_this = x.shape[0] 70 | x_gather = concat_all_gather(x) 71 | batch_size_all = x_gather.shape[0] 72 | 73 | num_gpus = batch_size_all // batch_size_this 74 | 75 | # random shuffle index 76 | idx_shuffle = torch.randperm(batch_size_all).cuda() 77 | 78 | # broadcast to all gpus 79 | torch.distributed.broadcast(idx_shuffle, src=0) 80 | 81 | # index for restoring 82 | idx_unshuffle = torch.argsort(idx_shuffle) 83 | 84 | # shuffled index for this gpu 85 | gpu_idx = torch.distributed.get_rank() 86 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 87 | 88 | return x_gather[idx_this], idx_unshuffle 89 | 90 | @torch.no_grad() 91 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 92 | """ 93 | Undo batch shuffle. 94 | *** Only support DistributedDataParallel (DDP) model. *** 95 | """ 96 | # gather from all gpus 97 | batch_size_this = x.shape[0] 98 | x_gather = concat_all_gather(x) 99 | batch_size_all = x_gather.shape[0] 100 | 101 | num_gpus = batch_size_all // batch_size_this 102 | 103 | # restored index for this gpu 104 | gpu_idx = torch.distributed.get_rank() 105 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 106 | 107 | return x_gather[idx_this] 108 | 109 | def forward(self, im_q, im_k): 110 | """ 111 | Input: 112 | im_q: list of contrastive augmented image 113 | im_k: weak augmented image 114 | Output: 115 | logitsq, logitsk 116 | """ 117 | 118 | q_list = [self.encoder_q(im) for im in im_q] 119 | q = torch.cat(q_list) 120 | 121 | # compute key features 122 | with torch.no_grad(): # no gradient to keys 123 | self._momentum_update_key_encoder() # update the key encoder 124 | 125 | # shuffle for making use of BN 126 | im, idx_unshuffle = self._batch_shuffle_ddp(im_k) 127 | k = self.encoder_k(im) # keys: NxC 128 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 129 | 130 | logitsq = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 131 | logitsk = torch.einsum('nc,ck->nk', [k, self.queue.clone().detach()]).repeat(len(im_q), 1) 132 | 133 | self._dequeue_and_enqueue(k) 134 | return logitsq, logitsk 135 | 136 | 137 | # utils 138 | @torch.no_grad() 139 | def concat_all_gather(tensor): 140 | """ 141 | Performs all_gather operation on the provided tensors. 142 | *** Warning ***: torch.distributed.all_gather has no gradient. 143 | """ 144 | tensors_gather = [torch.ones_like(tensor) 145 | for _ in range(torch.distributed.get_world_size())] 146 | torch.distributed.all_gather(tensors_gather, tensor) 147 | 148 | output = torch.cat(tensors_gather, dim=0) 149 | return output 150 | -------------------------------------------------------------------------------- /ressl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.torch_dist_sum import * 3 | from data.imagenet import * 4 | from data.augmentation import * 5 | from util.meter import * 6 | from network.ressl import ReSSL 7 | import time 8 | import torch.nn as nn 9 | import argparse 10 | import math 11 | import torch.nn.functional as F 12 | import os 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--port', type=int, default=23457) 16 | parser.add_argument('--epochs', type=int, default=200) 17 | parser.add_argument('--lr', type=float, default=0.05) 18 | parser.add_argument('--t', type=float, default=0.04) 19 | parser.add_argument('--backbone', type=str, default='resnet50') 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | epochs = args.epochs 24 | warm_up = 5 25 | 26 | 27 | def adjust_learning_rate(optimizer, epoch, base_lr, i, iteration_per_epoch): 28 | T = epoch * iteration_per_epoch + i 29 | warmup_iters = warm_up * iteration_per_epoch 30 | total_iters = (epochs - warm_up) * iteration_per_epoch 31 | 32 | if epoch < warm_up: 33 | lr = base_lr * 1.0 * T / warmup_iters 34 | else: 35 | T = T - warmup_iters 36 | lr = 0.5 * base_lr * (1 + math.cos(1.0 * T / total_iters * math.pi)) 37 | 38 | for param_group in optimizer.param_groups: 39 | param_group['lr'] = lr 40 | 41 | 42 | def train(train_loader, model, local_rank, rank, criterion, optimizer, base_lr, epoch): 43 | batch_time = AverageMeter('Time', ':6.3f') 44 | data_time = AverageMeter('Data', ':6.3f') 45 | losses = AverageMeter('Loss', ':.4e') 46 | progress = ProgressMeter( 47 | len(train_loader), 48 | [batch_time, data_time, losses], 49 | prefix="Epoch: [{}]".format(epoch)) 50 | 51 | # switch to train mode 52 | model.train() 53 | 54 | iteration_per_epoch = len(train_loader) 55 | 56 | end = time.time() 57 | for i, (img1, img2) in enumerate(train_loader): 58 | adjust_learning_rate(optimizer, epoch, base_lr, i, iteration_per_epoch) 59 | # measure data loading time 60 | data_time.update(time.time() - end) 61 | 62 | if local_rank is not None: 63 | img1 = img1.cuda(local_rank, non_blocking=True) 64 | img2 = img2.cuda(local_rank, non_blocking=True) 65 | 66 | # compute output 67 | logitsq, ligitsk = model(im_q=img1, im_k=img2) 68 | loss = - torch.sum(F.softmax(ligitsk.detach() / args.t, dim=1) * F.log_softmax(logitsq / 0.1, dim=1), dim=1).mean() 69 | 70 | # acc1/acc5 are (K+1)-way contrast classifier accuracy 71 | # measure accuracy and record loss 72 | losses.update(loss.item(), img1.size(0)) 73 | 74 | # compute gradient and do SGD step 75 | optimizer.zero_grad() 76 | loss.backward() 77 | optimizer.step() 78 | 79 | # measure elapsed time 80 | batch_time.update(time.time() - end) 81 | end = time.time() 82 | 83 | if i % 20 == 0 and rank == 0: 84 | progress.display(i) 85 | 86 | 87 | def main(): 88 | from torch.nn.parallel import DistributedDataParallel 89 | from util.dist_init import dist_init 90 | 91 | rank, local_rank, world_size = dist_init(args.port) 92 | 93 | batch_size = 32 # single gpu 94 | num_workers = 8 95 | base_lr = args.lr 96 | 97 | model = ReSSL(backbone=args.backbone) 98 | model = DistributedDataParallel(model.to(local_rank), device_ids=[local_rank], output_device=local_rank) 99 | 100 | param_dict = {} 101 | for k, v in model.named_parameters(): 102 | param_dict[k] = v 103 | 104 | bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)] 105 | rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)] 106 | 107 | optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0,}, 108 | {'params': rest_params, 'weight_decay': 1e-4}], 109 | lr=base_lr, momentum=0.9, weight_decay=1e-4) 110 | 111 | torch.backends.cudnn.benchmark = True 112 | 113 | train_dataset = ImagenetContrastive(aug=[moco_aug, target_aug], max_class=1000) 114 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 115 | train_loader = torch.utils.data.DataLoader( 116 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), 117 | num_workers=num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) 118 | 119 | criterion = nn.CrossEntropyLoss().cuda(local_rank) 120 | 121 | if not os.path.exists('checkpoints'): 122 | os.makedirs('checkpoints') 123 | 124 | checkpoint_path = 'checkpoints/ressl-{}-{}.pth'.format(args.backbone, epochs) 125 | print('checkpoint_path:', checkpoint_path) 126 | if os.path.exists(checkpoint_path): 127 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 128 | model.load_state_dict(checkpoint['model']) 129 | optimizer.load_state_dict(checkpoint['optimizer']) 130 | start_epoch = checkpoint['epoch'] 131 | print(checkpoint_path, 'found, start from epoch', start_epoch) 132 | else: 133 | start_epoch = 0 134 | print(checkpoint_path, 'not found, start from epoch 0') 135 | 136 | 137 | model.train() 138 | for epoch in range(start_epoch, epochs): 139 | train_sampler.set_epoch(epoch) 140 | train(train_loader, model, local_rank, rank, criterion, optimizer, base_lr, epoch) 141 | 142 | if rank == 0: 143 | torch.save( 144 | { 145 | 'model': model.state_dict(), 146 | 'optimizer': optimizer.state_dict(), 147 | 'epoch': epoch + 1 148 | }, checkpoint_path) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /ressl_multi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.torch_dist_sum import * 3 | from data.imagenet import * 4 | from data.augmentation import * 5 | from util.meter import * 6 | from network.ressl_multi import ReSSL 7 | import time 8 | import torch.nn as nn 9 | import argparse 10 | import math 11 | import torch.nn.functional as F 12 | import os 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--port', type=int, default=23457) 16 | parser.add_argument('--epochs', type=int, default=200) 17 | parser.add_argument('--t', type=float, default=0.04) 18 | parser.add_argument('--lr', type=float, default=0.05) 19 | parser.add_argument('--backbone', type=str, default='resnet50') 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | epochs = args.epochs 24 | warm_up = 5 25 | 26 | 27 | def adjust_learning_rate(optimizer, epoch, base_lr, i, iteration_per_epoch): 28 | T = epoch * iteration_per_epoch + i 29 | warmup_iters = warm_up * iteration_per_epoch 30 | total_iters = (epochs - warm_up) * iteration_per_epoch 31 | 32 | if epoch < warm_up: 33 | lr = base_lr * 1.0 * T / warmup_iters 34 | else: 35 | T = T - warmup_iters 36 | lr = 0.5 * base_lr * (1 + math.cos(1.0 * T / total_iters * math.pi)) 37 | 38 | for param_group in optimizer.param_groups: 39 | param_group['lr'] = lr 40 | 41 | 42 | def train(train_loader, model, local_rank, rank, criterion, optimizer, base_lr, epoch): 43 | batch_time = AverageMeter('Time', ':6.3f') 44 | data_time = AverageMeter('Data', ':6.3f') 45 | losses = AverageMeter('Loss', ':.4e') 46 | progress = ProgressMeter( 47 | len(train_loader), 48 | [batch_time, data_time, losses], 49 | prefix="Epoch: [{}]".format(epoch)) 50 | 51 | # switch to train mode 52 | model.train() 53 | 54 | iteration_per_epoch = len(train_loader) 55 | 56 | end = time.time() 57 | for i, (img_list, _) in enumerate(train_loader): 58 | adjust_learning_rate(optimizer, epoch, base_lr, i, iteration_per_epoch) 59 | # measure data loading time 60 | data_time.update(time.time() - end) 61 | 62 | if local_rank is not None: 63 | img_list = [img.cuda(local_rank, non_blocking=True) for img in img_list] 64 | 65 | # compute output 66 | logitsq, ligitsk = model(im_q=img_list[1:], im_k=img_list[0]) 67 | loss = - torch.sum(F.softmax(ligitsk.detach() / args.t, dim=1) * F.log_softmax(logitsq / 0.1, dim=1), dim=1).mean() 68 | 69 | # acc1/acc5 are (K+1)-way contrast classifier accuracy 70 | # measure accuracy and record loss 71 | losses.update(loss.item(), logitsq.size(0)) 72 | 73 | # compute gradient and do SGD step 74 | optimizer.zero_grad() 75 | loss.backward() 76 | optimizer.step() 77 | 78 | # measure elapsed time 79 | batch_time.update(time.time() - end) 80 | end = time.time() 81 | 82 | if i % 20 == 0 and rank == 0: 83 | progress.display(i) 84 | 85 | 86 | def main(): 87 | from torch.nn.parallel import DistributedDataParallel 88 | import math 89 | from util.dist_init import dist_init 90 | 91 | rank, local_rank, world_size = dist_init(args.port) 92 | 93 | batch_size = 32 # single gpu 94 | num_workers = 8 95 | base_lr = args.lr 96 | 97 | model = ReSSL(backbone=args.backbone) 98 | model = DistributedDataParallel(model.to(local_rank), device_ids=[local_rank], output_device=local_rank) 99 | 100 | param_dict = {} 101 | for k, v in model.named_parameters(): 102 | param_dict[k] = v 103 | 104 | bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)] 105 | rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)] 106 | 107 | optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0,}, 108 | {'params': rest_params, 'weight_decay': 1e-4}], 109 | lr=base_lr, momentum=0.9, weight_decay=1e-4) 110 | 111 | torch.backends.cudnn.benchmark = True 112 | 113 | train_dataset = Imagenet(aug=Multi_Transform(nmb_crops=[1, 1, 1, 1, 1]), max_class=1000) 114 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 115 | train_loader = torch.utils.data.DataLoader( 116 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), 117 | num_workers=num_workers, pin_memory=True, sampler=train_sampler, drop_last=True) 118 | 119 | criterion = nn.CrossEntropyLoss().cuda(local_rank) 120 | 121 | if not os.path.exists('checkpoints'): 122 | os.makedirs('checkpoints') 123 | 124 | checkpoint_path = 'checkpoints/ressl-multi-{}-{}.pth'.format(args.backbone, epochs) 125 | print('checkpoint_path:', checkpoint_path) 126 | if os.path.exists(checkpoint_path): 127 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 128 | model.load_state_dict(checkpoint['model']) 129 | optimizer.load_state_dict(checkpoint['optimizer']) 130 | start_epoch = checkpoint['epoch'] 131 | print(checkpoint_path, 'found, start from epoch', start_epoch) 132 | else: 133 | start_epoch = 0 134 | print(checkpoint_path, 'not found, start from epoch 0') 135 | 136 | 137 | model.train() 138 | for epoch in range(start_epoch, epochs): 139 | train_sampler.set_epoch(epoch) 140 | train(train_loader, model, local_rank, rank, criterion, optimizer, base_lr, epoch) 141 | 142 | if rank == 0: 143 | torch.save( 144 | { 145 | 'model': model.state_dict(), 146 | 'optimizer': optimizer.state_dict(), 147 | 'epoch': epoch + 1 148 | }, checkpoint_path) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /linear_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.imagenet import * 3 | from data.augmentation import * 4 | from network.head import * 5 | from torch.nn.parallel import DistributedDataParallel 6 | import torch.nn.functional as F 7 | from util.meter import * 8 | import time 9 | from util.torch_dist_sum import * 10 | from util.dist_init import dist_init 11 | import argparse 12 | from network.resnet import * 13 | from network.backbone import backbone_dict, dim_dict 14 | 15 | 16 | 17 | def accuracy(output, target, topk=(1,)): 18 | with torch.no_grad(): 19 | maxk = max(topk) 20 | batch_size = target.size(0) 21 | 22 | _, pred = output.topk(maxk, 1, True, True) 23 | pred = pred.t() 24 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 25 | 26 | res = [] 27 | for k in topk: 28 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 29 | res.append(correct_k.mul_(100.0 / batch_size)) 30 | return res 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--port', type=int, default=23456) 35 | parser.add_argument('--backbone', type=str, default='resnet50') 36 | parser.add_argument('--checkpoint', type=str) 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | 41 | 42 | def main(): 43 | 44 | rank, local_rank, world_size = dist_init() 45 | 46 | epochs = 100 47 | batch_size = 32 48 | num_workers = 6 49 | 50 | 51 | pre_train = backbone_dict[args.backbone]() 52 | state_dict = torch.load('checkpoints/' + args.checkpoint, map_location='cpu')['model'] 53 | 54 | for k in list(state_dict.keys()): 55 | if not k.startswith('module.encoder_q.net.'): 56 | del state_dict[k] 57 | if k.startswith('module.encoder_q.net.'): 58 | state_dict[k[len("module.encoder_q.net."):]] = state_dict[k] 59 | del state_dict[k] 60 | 61 | pre_train.load_state_dict(state_dict) 62 | model = LinearHead(pre_train, dim_in=dim_dict[args.backbone]) 63 | model = DistributedDataParallel(model.to(local_rank), device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 64 | optimizer = torch.optim.SGD(model.module.fc.parameters(), lr=0.3, momentum=0.9, weight_decay=0) 65 | 66 | torch.backends.cudnn.benchmark = True 67 | 68 | train_dataset = Imagenet() 69 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 70 | train_loader = torch.utils.data.DataLoader( 71 | train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), 72 | num_workers=num_workers, pin_memory=True, sampler=train_sampler) 73 | 74 | 75 | test_dataset = Imagenet(mode='val', aug=eval_aug) 76 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) 77 | test_loader = torch.utils.data.DataLoader( 78 | test_dataset, batch_size=batch_size, shuffle=(test_sampler is None), 79 | num_workers=num_workers, pin_memory=True, sampler=test_sampler) 80 | 81 | 82 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*len(train_loader)) 83 | 84 | 85 | best_acc = 0 86 | best_acc5 = 0 87 | for epoch in range(epochs): 88 | # ---------------------- Train -------------------------- 89 | train_sampler.set_epoch(epoch) 90 | 91 | batch_time = AverageMeter('Time', ':6.3f') 92 | data_time = AverageMeter('Data', ':6.3f') 93 | losses = AverageMeter('Loss', ':.4e') 94 | progress = ProgressMeter( 95 | train_loader.__len__(), 96 | [batch_time, data_time, losses], 97 | prefix="Epoch: [{}]".format(epoch) 98 | ) 99 | end = time.time() 100 | 101 | model.eval() 102 | for i, (image, label) in enumerate(train_loader): 103 | data_time.update(time.time() - end) 104 | 105 | image = image.cuda(local_rank, non_blocking=True) 106 | label = label.cuda(local_rank, non_blocking=True) 107 | 108 | out = model(image) 109 | loss = F.cross_entropy(out, label) 110 | 111 | optimizer.zero_grad() 112 | loss.backward() 113 | optimizer.step() 114 | 115 | batch_time.update(time.time() - end) 116 | end = time.time() 117 | 118 | losses.update(loss.item()) 119 | 120 | if i % 20 == 0 and rank == 0: 121 | progress.display(i) 122 | 123 | scheduler.step() 124 | 125 | # ---------------------- Test -------------------------- 126 | model.eval() 127 | top1 = AverageMeter('Acc@1', ':6.2f') 128 | top5 = AverageMeter('Acc@5', ':6.2f') 129 | with torch.no_grad(): 130 | end = time.time() 131 | for i, (image, label) in enumerate(test_loader): 132 | 133 | image = image.cuda(local_rank, non_blocking=True) 134 | label = label.cuda(local_rank, non_blocking=True) 135 | 136 | # compute output 137 | output = model(image) 138 | 139 | # measure accuracy and record loss 140 | acc1, acc5 = accuracy(output, label, topk=(1, 5)) 141 | top1.update(acc1[0], image.size(0)) 142 | top5.update(acc5[0], image.size(0)) 143 | 144 | # measure elapsed time 145 | batch_time.update(time.time() - end) 146 | end = time.time() 147 | 148 | sum1, cnt1, sum5, cnt5 = torch_dist_sum(local_rank, top1.sum, top1.count, top5.sum, top5.count) 149 | top1_acc = sum(sum1.float()) / sum(cnt1.float()) 150 | top5_acc = sum(sum5.float()) / sum(cnt5.float()) 151 | 152 | best_acc = max(top1_acc, best_acc) 153 | best_acc5 = max(top5_acc, best_acc5) 154 | 155 | if rank == 0: 156 | print('Epoch:{} * Acc@1 {top1_acc:.3f} Acc@5 {top5_acc:.3f} Best_Acc@1 {best_acc:.3f} Best_Acc@5 {best_acc5:.3f}'.format(epoch, top1_acc=top1_acc, top5_acc=top5_acc, best_acc=best_acc, best_acc5=best_acc5)) 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /data/randaugment.py: -------------------------------------------------------------------------------- 1 | 2 | # code in this file is adpated from 3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 5 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 6 | import logging 7 | import random 8 | 9 | import numpy as np 10 | import PIL 11 | import PIL.ImageOps 12 | import PIL.ImageEnhance 13 | import PIL.ImageDraw 14 | from PIL import Image 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | PARAMETER_MAX = 10 19 | 20 | 21 | def AutoContrast(img, **kwarg): 22 | return PIL.ImageOps.autocontrast(img) 23 | 24 | 25 | def Brightness(img, v, max_v, bias=0): 26 | v = _float_parameter(v, max_v) + bias 27 | return PIL.ImageEnhance.Brightness(img).enhance(v) 28 | 29 | 30 | def Color(img, v, max_v, bias=0): 31 | v = _float_parameter(v, max_v) + bias 32 | return PIL.ImageEnhance.Color(img).enhance(v) 33 | 34 | 35 | def Contrast(img, v, max_v, bias=0): 36 | v = _float_parameter(v, max_v) + bias 37 | return PIL.ImageEnhance.Contrast(img).enhance(v) 38 | 39 | 40 | def Cutout(img, v, max_v, bias=0): 41 | if v == 0: 42 | return img 43 | v = _float_parameter(v, max_v) + bias 44 | v = int(v * min(img.size)) 45 | return CutoutAbs(img, v) 46 | 47 | 48 | def CutoutAbs(img, v, **kwarg): 49 | w, h = img.size 50 | x0 = np.random.uniform(0, w) 51 | y0 = np.random.uniform(0, h) 52 | x0 = int(max(0, x0 - v / 2.)) 53 | y0 = int(max(0, y0 - v / 2.)) 54 | x1 = int(min(w, x0 + v)) 55 | y1 = int(min(h, y0 + v)) 56 | xy = (x0, y0, x1, y1) 57 | # gray 58 | color = (127, 127, 127) 59 | img = img.copy() 60 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 61 | return img 62 | 63 | 64 | def Equalize(img, **kwarg): 65 | return PIL.ImageOps.equalize(img) 66 | 67 | 68 | def Identity(img, **kwarg): 69 | return img 70 | 71 | 72 | def Invert(img, **kwarg): 73 | return PIL.ImageOps.invert(img) 74 | 75 | 76 | def Posterize(img, v, max_v, bias=0): 77 | v = _int_parameter(v, max_v) + bias 78 | return PIL.ImageOps.posterize(img, v) 79 | 80 | 81 | def Rotate(img, v, max_v, bias=0): 82 | v = _int_parameter(v, max_v) + bias 83 | if random.random() < 0.5: 84 | v = -v 85 | return img.rotate(v) 86 | 87 | 88 | def Sharpness(img, v, max_v, bias=0): 89 | v = _float_parameter(v, max_v) + bias 90 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 91 | 92 | 93 | def ShearX(img, v, max_v, bias=0): 94 | v = _float_parameter(v, max_v) + bias 95 | if random.random() < 0.5: 96 | v = -v 97 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 98 | 99 | 100 | def ShearY(img, v, max_v, bias=0): 101 | v = _float_parameter(v, max_v) + bias 102 | if random.random() < 0.5: 103 | v = -v 104 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 105 | 106 | 107 | def Solarize(img, v, max_v, bias=0): 108 | v = _int_parameter(v, max_v) + bias 109 | return PIL.ImageOps.solarize(img, 256 - v) 110 | 111 | 112 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 113 | v = _int_parameter(v, max_v) + bias 114 | if random.random() < 0.5: 115 | v = -v 116 | img_np = np.array(img).astype(np.int) 117 | img_np = img_np + v 118 | img_np = np.clip(img_np, 0, 255) 119 | img_np = img_np.astype(np.uint8) 120 | img = Image.fromarray(img_np) 121 | return PIL.ImageOps.solarize(img, threshold) 122 | 123 | 124 | def TranslateX(img, v, max_v, bias=0): 125 | v = _float_parameter(v, max_v) + bias 126 | if random.random() < 0.5: 127 | v = -v 128 | v = int(v * img.size[0]) 129 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 130 | 131 | 132 | def TranslateY(img, v, max_v, bias=0): 133 | v = _float_parameter(v, max_v) + bias 134 | if random.random() < 0.5: 135 | v = -v 136 | v = int(v * img.size[1]) 137 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 138 | 139 | 140 | def _float_parameter(v, max_v): 141 | return float(v) * max_v / PARAMETER_MAX 142 | 143 | 144 | def _int_parameter(v, max_v): 145 | return int(v * max_v / PARAMETER_MAX) 146 | 147 | 148 | def fixmatch_augment_pool(): 149 | # FixMatch paper 150 | augs = [(AutoContrast, None, None), 151 | (Brightness, 0.9, 0.05), 152 | (Color, 0.9, 0.05), 153 | (Contrast, 0.9, 0.05), 154 | (Equalize, None, None), 155 | (Identity, None, None), 156 | (Posterize, 4, 4), 157 | (Rotate, 30, 0), 158 | (Sharpness, 0.9, 0.05), 159 | (ShearX, 0.3, 0), 160 | (ShearY, 0.3, 0), 161 | (Solarize, 256, 0), 162 | (TranslateX, 0.3, 0), 163 | (TranslateY, 0.3, 0)] 164 | return augs 165 | 166 | 167 | def my_augment_pool(): 168 | # Test 169 | augs = [(AutoContrast, None, None), 170 | (Brightness, 1.8, 0.1), 171 | (Color, 1.8, 0.1), 172 | (Contrast, 1.8, 0.1), 173 | (Cutout, 0.2, 0), 174 | (Equalize, None, None), 175 | (Invert, None, None), 176 | (Posterize, 4, 4), 177 | (Rotate, 30, 0), 178 | (Sharpness, 1.8, 0.1), 179 | (ShearX, 0.3, 0), 180 | (ShearY, 0.3, 0), 181 | (Solarize, 256, 0), 182 | (SolarizeAdd, 110, 0), 183 | (TranslateX, 0.45, 0), 184 | (TranslateY, 0.45, 0)] 185 | return augs 186 | 187 | 188 | class RandAugmentPC(object): 189 | def __init__(self, n, m): 190 | assert n >= 1 191 | assert 1 <= m <= 10 192 | self.n = n 193 | self.m = m 194 | self.augment_pool = my_augment_pool() 195 | 196 | def __call__(self, img): 197 | ops = random.choices(self.augment_pool, k=self.n) 198 | for op, max_v, bias in ops: 199 | prob = np.random.uniform(0.2, 0.8) 200 | if random.random() + prob >= 1: 201 | img = op(img, v=self.m, max_v=max_v, bias=bias) 202 | img = CutoutAbs(img, 16) 203 | return img 204 | 205 | 206 | class RandAugment(object): 207 | def __init__(self, n, m): 208 | assert n >= 0 209 | assert 1 <= m <= 10 210 | self.n = n 211 | self.m = m 212 | self.augment_pool = fixmatch_augment_pool() 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | return img 220 | 221 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 18 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 19 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 76 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 77 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 78 | # This variant is also known as ResNet V1.5 and improves accuracy according to 79 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 80 | 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(Bottleneck, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | width = int(planes * (base_width / 64.)) * groups 89 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 90 | self.conv1 = conv1x1(inplanes, width) 91 | self.bn1 = norm_layer(width) 92 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 93 | self.bn2 = norm_layer(width) 94 | self.conv3 = conv1x1(width, planes * self.expansion) 95 | self.bn3 = norm_layer(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | identity = self.downsample(x) 116 | 117 | out += identity 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | 123 | class ResNet(nn.Module): 124 | 125 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 126 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 127 | norm_layer=None): 128 | super(ResNet, self).__init__() 129 | if norm_layer is None: 130 | norm_layer = nn.BatchNorm2d 131 | self._norm_layer = norm_layer 132 | 133 | self.inplanes = 64 134 | self.dilation = 1 135 | if replace_stride_with_dilation is None: 136 | # each element in the tuple indicates if we should replace 137 | # the 2x2 stride with a dilated convolution instead 138 | replace_stride_with_dilation = [False, False, False] 139 | if len(replace_stride_with_dilation) != 3: 140 | raise ValueError("replace_stride_with_dilation should be None " 141 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 142 | self.groups = groups 143 | self.base_width = width_per_group 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = norm_layer(self.inplanes) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0]) 150 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 151 | dilate=replace_stride_with_dilation[0]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 153 | dilate=replace_stride_with_dilation[1]) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 155 | dilate=replace_stride_with_dilation[2]) 156 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 157 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 162 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | # Zero-initialize the last BN in each residual branch, 167 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 168 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 169 | if zero_init_residual: 170 | for m in self.modules(): 171 | if isinstance(m, Bottleneck): 172 | nn.init.constant_(m.bn3.weight, 0) 173 | elif isinstance(m, BasicBlock): 174 | nn.init.constant_(m.bn2.weight, 0) 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 177 | norm_layer = self._norm_layer 178 | downsample = None 179 | previous_dilation = self.dilation 180 | if dilate: 181 | self.dilation *= stride 182 | stride = 1 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | conv1x1(self.inplanes, planes * block.expansion, stride), 186 | norm_layer(planes * block.expansion), 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 191 | self.base_width, previous_dilation, norm_layer)) 192 | self.inplanes = planes * block.expansion 193 | for _ in range(1, blocks): 194 | layers.append(block(self.inplanes, planes, groups=self.groups, 195 | base_width=self.base_width, dilation=self.dilation, 196 | norm_layer=norm_layer)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def _forward_impl(self, x): 201 | # See note [TorchScript super()] 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = torch.flatten(x, 1) 214 | # x = self.fc(x) 215 | 216 | return x 217 | 218 | def forward(self, x): 219 | return self._forward_impl(x) 220 | 221 | 222 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 223 | model = ResNet(block, layers, **kwargs) 224 | # if pretrained: 225 | # state_dict = load_state_dict_from_url(model_urls[arch], 226 | # progress=progress) 227 | # model.load_state_dict(state_dict) 228 | return model 229 | 230 | 231 | def resnet18(pretrained=False, progress=True, **kwargs): 232 | r"""ResNet-18 model from 233 | `"Deep Residual Learning for Image Recognition" `_ 234 | 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | """ 239 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 240 | **kwargs) 241 | 242 | 243 | def resnet34(pretrained=False, progress=True, **kwargs): 244 | r"""ResNet-34 model from 245 | `"Deep Residual Learning for Image Recognition" `_ 246 | 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 252 | **kwargs) 253 | 254 | 255 | def resnet50(pretrained=False, progress=True, **kwargs): 256 | r"""ResNet-50 model from 257 | `"Deep Residual Learning for Image Recognition" `_ 258 | 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | progress (bool): If True, displays a progress bar of the download to stderr 262 | """ 263 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 264 | **kwargs) 265 | 266 | 267 | def resnet101(pretrained=False, progress=True, **kwargs): 268 | r"""ResNet-101 model from 269 | `"Deep Residual Learning for Image Recognition" `_ 270 | 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 276 | **kwargs) 277 | 278 | 279 | def resnet152(pretrained=False, progress=True, **kwargs): 280 | r"""ResNet-152 model from 281 | `"Deep Residual Learning for Image Recognition" `_ 282 | 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 288 | **kwargs) 289 | 290 | 291 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 292 | r"""ResNeXt-50 32x4d model from 293 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 294 | 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | kwargs['groups'] = 32 300 | kwargs['width_per_group'] = 4 301 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 302 | pretrained, progress, **kwargs) 303 | 304 | 305 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 306 | r"""ResNeXt-101 32x8d model from 307 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 308 | 309 | Args: 310 | pretrained (bool): If True, returns a model pre-trained on ImageNet 311 | progress (bool): If True, displays a progress bar of the download to stderr 312 | """ 313 | kwargs['groups'] = 32 314 | kwargs['width_per_group'] = 8 315 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 316 | pretrained, progress, **kwargs) 317 | 318 | 319 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 320 | r"""Wide ResNet-50-2 model from 321 | `"Wide Residual Networks" `_ 322 | 323 | The model is the same as ResNet except for the bottleneck number of channels 324 | which is twice larger in every block. The number of channels in outer 1x1 325 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 326 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 327 | 328 | Args: 329 | pretrained (bool): If True, returns a model pre-trained on ImageNet 330 | progress (bool): If True, displays a progress bar of the download to stderr 331 | """ 332 | kwargs['width_per_group'] = 64 * 2 333 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 334 | pretrained, progress, **kwargs) 335 | 336 | 337 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 338 | r"""Wide ResNet-101-2 model from 339 | `"Wide Residual Networks" `_ 340 | 341 | The model is the same as ResNet except for the bottleneck number of channels 342 | which is twice larger in every block. The number of channels in outer 1x1 343 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 344 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 345 | 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | kwargs['width_per_group'] = 64 * 2 351 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 352 | pretrained, progress, **kwargs) 353 | --------------------------------------------------------------------------------