├── README.md ├── clsa ├── builder.py ├── clsa_augs.py └── loader.py ├── detection ├── README.md ├── configs │ ├── Base-RCNN-C4-BN.yaml │ ├── coco_R_50_C4_2x.yaml │ ├── coco_R_50_C4_2x_moco.yaml │ ├── pascal_voc_R_50_C4_24k.yaml │ └── pascal_voc_R_50_C4_24k_moco.yaml ├── convert-pretrain-to-detectron2.py └── train_net.py ├── main_clsa.py ├── main_lincls.py └── moco ├── __init__.py ├── builder.py └── loader.py /README.md: -------------------------------------------------------------------------------- 1 | ## Unoffical implementation of Contrastive Learning with Stronger Augmentations 2 | WIP!! 3 | 4 | current results: (linear evaluation protocol on ImageNet) 5 | 6 | |Train epochs | Single | Mul-5 | MoCo-v2 | | 7 | |---|---|---|---|---| 8 | | 40 | 55.4% | 60.2% | 56.9% | | 9 | | 200 | 66.5% | 68.3% | 67.6% | | 10 | | | | | | | 11 | 12 | 13 | This is an unofficial PyTorch implementation of the CLSA paper: [Contrastive Learning with Stronger Augmentations](https://openreview.net/forum?id=KJSC_AsN14): 14 | 15 | Note: This implementation is most adopted from the offical moco's implementation from https://github.com/facebookresearch/moco 16 | This repo aims to be minimal modifications on that code. 17 | 18 | 19 | 20 | ### Preparation 21 | Note: This section is copied from moco's repo 22 | 23 | Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 24 | 25 | 26 | 27 | ### Unsupervised Training 28 | 29 | This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. 30 | 31 | To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run: 32 | ``` 33 | python main_clsa.py \ 34 | -a resnet50 \ 35 | --lr 0.03 \ 36 | --batch-size 256 \ 37 | --mlp --aug-plus --cos \ 38 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ 39 | [your imagenet-folder with train and val folders] 40 | ``` 41 | This script uses all the default hyper-parameters as described in CLSA paper. 42 | 43 | 44 | ### Linear Classification 45 | Note: This section is copied from moco's repo 46 | 47 | With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run: 48 | ``` 49 | python main_lincls.py \ 50 | -a resnet50 \ 51 | --lr 30.0 \ 52 | --batch-size 256 \ 53 | --pretrained [your checkpoint path]/checkpoint_0199.pth.tar \ 54 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \ 55 | [your imagenet-folder with train and val folders] 56 | ``` 57 | 58 | ### TODO: 59 | 1. ImageNet-1K CLSA-Single-200epoch pretraining: Running 60 | 2. ImageNet-1K CLSA-Mul-200epoch pretraining: Running 61 | 3. Evaluate CLSA-Single/-Mul on ImageNet Linear Protocal 62 | 4. Evaluate CLSA-Single/-Mul on VOC07 Det 63 | 64 | 65 | -------------------------------------------------------------------------------- /clsa/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchvision import models 5 | 6 | class CLSA(nn.Module): 7 | """ 8 | Build a MoCo-like model with: a query encoder, a key encoder, and a queue acrroding these tow papers 9 | https://arxiv.org/abs/1911.05722 10 | https://openreview.net/forum?id=KJSC_AsN14 11 | """ 12 | def __init__(self, base_encoder=models.resnet50, dim=2048, K=65536, m=0.999, T=0.2, mlp=True, ratio=1.0): 13 | """ 14 | dim: feature dimension (default: 2048) 15 | K: queue size; number of negative keys (default: 65536) 16 | m: moco momentum of updating key encoder (default: 0.999) 17 | T: softmax temperature (default: 0.07) 18 | ratio: the coeffient for reweighting the ddm loss, i.e., beta in paper. (default: 1.0) 19 | """ 20 | super(CLSA, self).__init__() 21 | 22 | self.K = K 23 | self.m = m 24 | self.T = T 25 | self.ratio = ratio 26 | 27 | # create the encoders 28 | # num_classes is the output fc dimension 29 | self.encoder_q = base_encoder(num_classes=dim) 30 | self.encoder_k = base_encoder(num_classes=dim) 31 | 32 | if mlp: # hack: brute-force replacement 33 | dim_mlp = self.encoder_q.fc.weight.shape[1] 34 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 35 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 36 | 37 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 38 | param_k.data.copy_(param_q.data) # initialize 39 | param_k.requires_grad = False # not update by gradient 40 | 41 | # create the queue 42 | self.register_buffer("queue", torch.randn(dim, K)) 43 | self.queue = nn.functional.normalize(self.queue, dim=0) 44 | 45 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 46 | 47 | self.criterion = nn.CrossEntropyLoss() 48 | 49 | @torch.no_grad() 50 | def _momentum_update_key_encoder(self): 51 | """ 52 | Momentum update of the key encoder 53 | """ 54 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 55 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 56 | 57 | @torch.no_grad() 58 | def _dequeue_and_enqueue(self, keys): 59 | # gather keys before updating queue 60 | keys = concat_all_gather(keys) 61 | 62 | batch_size = keys.shape[0] 63 | 64 | ptr = int(self.queue_ptr) 65 | assert self.K % batch_size == 0 # for simplicity 66 | 67 | # replace the keys at ptr (dequeue and enqueue) 68 | self.queue[:, ptr:ptr + batch_size] = keys.T 69 | ptr = (ptr + batch_size) % self.K # move pointer 70 | 71 | self.queue_ptr[0] = ptr 72 | 73 | @torch.no_grad() 74 | def _batch_shuffle_ddp(self, x): 75 | """ 76 | Batch shuffle, for making use of BatchNorm. 77 | *** Only support DistributedDataParallel (DDP) model. *** 78 | """ 79 | # gather from all gpus 80 | batch_size_this = x.shape[0] 81 | x_gather = concat_all_gather(x) 82 | batch_size_all = x_gather.shape[0] 83 | 84 | num_gpus = batch_size_all // batch_size_this 85 | 86 | # random shuffle index 87 | idx_shuffle = torch.randperm(batch_size_all).cuda() 88 | 89 | # broadcast to all gpus 90 | torch.distributed.broadcast(idx_shuffle, src=0) 91 | 92 | # index for restoring 93 | idx_unshuffle = torch.argsort(idx_shuffle) 94 | 95 | # shuffled index for this gpu 96 | gpu_idx = torch.distributed.get_rank() 97 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 98 | 99 | return x_gather[idx_this], idx_unshuffle 100 | 101 | @torch.no_grad() 102 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 103 | """ 104 | Undo batch shuffle. 105 | *** Only support DistributedDataParallel (DDP) model. *** 106 | """ 107 | # gather from all gpus 108 | batch_size_this = x.shape[0] 109 | x_gather = concat_all_gather(x) 110 | batch_size_all = x_gather.shape[0] 111 | 112 | num_gpus = batch_size_all // batch_size_this 113 | 114 | # restored index for this gpu 115 | gpu_idx = torch.distributed.get_rank() 116 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 117 | 118 | return x_gather[idx_this] 119 | 120 | def forward(self, im_q, im_k, img_stronger_aug_list): 121 | """ 122 | Input: 123 | img_stronger_aug_list = [img_res_96, img_res_128, xxx]. img_res_96.shape = [N, 3, 96, 96] 124 | Output: 125 | loss_dict 126 | """ 127 | 128 | 129 | # compute query features 130 | q = self.encoder_q(im_q) # queries: NxC 131 | q = nn.functional.normalize(q, dim=1) 132 | 133 | # compute key features 134 | with torch.no_grad(): # no gradient to keys 135 | self._momentum_update_key_encoder() # update the key encoder 136 | 137 | # shuffle for making use of BN 138 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 139 | 140 | k = self.encoder_k(im_k) # keys: NxC 141 | k = nn.functional.normalize(k, dim=1) 142 | 143 | # undo shuffle 144 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 145 | 146 | # compute logits 147 | # Einstein sum is more intuitive 148 | # positive logits: Nx1 149 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 150 | # negative logits: NxK 151 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 152 | 153 | # logits: Nx(1+K) 154 | logits = torch.cat([l_pos, l_neg], dim=1) 155 | 156 | # apply temperature 157 | logits /= self.T 158 | 159 | # labels: positive key indicators 160 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 161 | 162 | #losses = dict() 163 | loss_contrastive = self.criterion(logits, labels) 164 | 165 | # compute ddm loss below 166 | # get P(Zk, Zi') 167 | p_weak = nn.functional.softmax(logits, dim=-1) 168 | loss_ddm = 0 169 | for img_s in img_stronger_aug_list: 170 | q_s = self.encoder_q(img_s) 171 | q_s = nn.functional.normalize(q_s, dim=1) 172 | # compute logits using the same set of code above 173 | 174 | l_pos_stronger_aug = torch.einsum('nc,nc->n', [q_s, k]).unsqueeze(-1) 175 | # negative logits: NxK 176 | l_neg_stronger_aug = torch.einsum('nc,ck->nk', [q_s, self.queue.clone().detach()]) 177 | 178 | # logits: Nx(1+K) 179 | logits_s = torch.cat([l_pos_stronger_aug, l_neg_stronger_aug], dim=1) 180 | logits_s /= self.T 181 | 182 | # compute nll loss below as -P(q, k) * log(P(q_s, k)) 183 | log_p_s = nn.functional.log_softmax(logits_s, dim=-1) 184 | 185 | nll = -1.0 * torch.einsum('nk,nk->n', [p_weak, log_p_s]) 186 | loss_ddm = loss_ddm + torch.mean(nll) # average over the batch dimension 187 | 188 | loss = loss_contrastive + self.ratio * loss_ddm 189 | 190 | #losses['loss'] = loss 191 | 192 | # dequeue and enqueue 193 | self._dequeue_and_enqueue(k) 194 | 195 | #return logits, labels 196 | return logits, labels, loss 197 | 198 | 199 | # utils 200 | @torch.no_grad() 201 | def concat_all_gather(tensor): 202 | """ 203 | Performs all_gather operation on the provided tensors. 204 | *** Warning ***: torch.distributed.all_gather has no gradient. 205 | """ 206 | tensors_gather = [torch.ones_like(tensor) 207 | for _ in range(torch.distributed.get_world_size())] 208 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 209 | 210 | output = torch.cat(tensors_gather, dim=0) 211 | return output 212 | -------------------------------------------------------------------------------- /clsa/clsa_augs.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from torchvision.transforms.transforms import Compose 9 | 10 | random_mirror = True 11 | 12 | 13 | def ShearX(img, v): # [-0.3, 0.3] 14 | assert -0.3 <= v <= 0.3 15 | if random_mirror and random.random() > 0.5: 16 | v = -v 17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 18 | 19 | 20 | def ShearY(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random_mirror and random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 25 | 26 | 27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.30, 0.30] 28 | assert -0.30 <= v <= 0.30 29 | if random_mirror and random.random() > 0.5: 30 | v = -v 31 | v = v * img.size[0] 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | 35 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.30, 0.30] 36 | assert -0.30 <= v <= 0.30 37 | if random_mirror and random.random() > 0.5: 38 | v = -v 39 | v = v * img.size[1] 40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 41 | 42 | 43 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 44 | assert 0 <= v <= 10 45 | if random.random() > 0.5: 46 | v = -v 47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 48 | 49 | 50 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v <= 10 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def Rotate(img, v): # [-30, 30] 58 | assert -30 <= v <= 30 59 | if random_mirror and random.random() > 0.5: 60 | v = -v 61 | return img.rotate(v) 62 | 63 | 64 | def AutoContrast(img, _): 65 | return PIL.ImageOps.autocontrast(img) 66 | 67 | 68 | def Invert(img, _): 69 | return PIL.ImageOps.invert(img) 70 | 71 | 72 | def Equalize(img, _): 73 | return PIL.ImageOps.equalize(img) 74 | 75 | 76 | def Flip(img, _): # not from the paper 77 | return PIL.ImageOps.mirror(img) 78 | 79 | 80 | def Solarize(img, v): # [0, 256] 81 | assert 0 <= v <= 256 82 | return PIL.ImageOps.solarize(img, v) 83 | 84 | 85 | def Posterize(img, v): # [4, 8] 86 | assert 4 <= v <= 8 87 | v = int(v) 88 | return PIL.ImageOps.posterize(img, v) 89 | 90 | 91 | def Posterize2(img, v): # [0, 4] 92 | assert 0 <= v <= 4 93 | v = int(v) 94 | return PIL.ImageOps.posterize(img, v) 95 | 96 | # for blow aug. The mag=1.0 gives the original image 97 | def Contrast(img, v): # [0.05,1.95] 98 | assert 0.05 <= v <= 1.95 99 | return PIL.ImageEnhance.Contrast(img).enhance(v) 100 | 101 | 102 | def Color(img, v): # [0.05,1.95] 103 | assert 0.05 <= v <= 1.95 104 | return PIL.ImageEnhance.Color(img).enhance(v) 105 | 106 | 107 | def Brightness(img, v): # [0.05,1.95] 108 | assert 0.05 <= v <= 1.95 109 | return PIL.ImageEnhance.Brightness(img).enhance(v) 110 | 111 | 112 | def Sharpness(img, v): # [0.05,1.95] 113 | assert 0.05 <= v <= 1.95 114 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 115 | 116 | 117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 118 | assert 0.0 <= v <= 0.2 119 | if v <= 0.: 120 | return img 121 | 122 | v = v * img.size[0] 123 | return CutoutAbs(img, v) 124 | 125 | 126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 127 | # assert 0 <= v <= 20 128 | if v < 0: 129 | return img 130 | w, h = img.size 131 | x0 = np.random.uniform(w) 132 | y0 = np.random.uniform(h) 133 | 134 | x0 = int(max(0, x0 - v / 2.)) 135 | y0 = int(max(0, y0 - v / 2.)) 136 | x1 = min(w, x0 + v) 137 | y1 = min(h, y0 + v) 138 | 139 | xy = (x0, y0, x1, y1) 140 | color = (125, 123, 114) 141 | # color = (0, 0, 0) 142 | img = img.copy() 143 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 144 | return img 145 | 146 | 147 | def SamplePairing(imgs): # [0, 0.4] 148 | def f(img1, v): 149 | i = np.random.choice(len(imgs)) 150 | img2 = PIL.Image.fromarray(imgs[i]) 151 | return PIL.Image.blend(img1, img2, v) 152 | 153 | return f 154 | 155 | 156 | def augment_list(): 157 | # 14 augs and their magnitude range 158 | l = [ 159 | (ShearX, -0.3, 0.3), # 0 160 | (ShearY, -0.3, 0.3), # 1 161 | (TranslateX, -0.3, 0.3), # 2 162 | (TranslateY, -0.3, 0.3), # 3 163 | (Rotate, -30, 30), # 4 164 | (AutoContrast, 0, 1), # 5 165 | (Invert, 0, 1), # 6 166 | (Equalize, 0, 1), # 7 167 | (Solarize, 0, 256), # 8 168 | (Posterize, 4, 8), # 9 169 | (Contrast, 0.05, 1.95), # 10 170 | (Color, 0.05, 1.95), # 11 171 | (Brightness, 0.05, 1.95), # 12 172 | (Sharpness, 0.05, 1.95), # 13 173 | ] 174 | 175 | return l 176 | 177 | 178 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 179 | 180 | def get_augment(name): 181 | return augment_dict[name] 182 | 183 | def apply_augment_with_rand_mag(img:PIL.Image, name:str) -> PIL.Image: 184 | augment_fn, low, high = get_augment(name) 185 | mag = np.random.uniform(low, high) 186 | return augment_fn(img.copy(), mag) 187 | 188 | 189 | class CLSAAug(object): 190 | 191 | def __init__(self, num_of_times=5): 192 | ''' 193 | params: num_of_times: How many times the augment is repeated 194 | ''' 195 | self.num_of_times = num_of_times 196 | self.aug_names = list(augment_dict.keys()) 197 | 198 | print('Augmentation List:') 199 | for aug_name in self.aug_names: 200 | print('{} with magnitude of {} ~ {}'.format(aug_name, augment_dict[aug_name][1], augment_dict[aug_name][2])) 201 | 202 | 203 | def __call__(self, img): 204 | for i in range(self.num_of_times): 205 | if np.random.rand() > 0.5: 206 | aug_name = random.choice(self.aug_names) 207 | img = apply_augment_with_rand_mag(img, aug_name) 208 | 209 | return img 210 | 211 | -------------------------------------------------------------------------------- /clsa/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class CALSMultiResolutionTransform(object): 8 | def __init__(self, base_transform, stronger_transfrom, num_res=5): 9 | ''' 10 | Note: RandomResizedCrop should be includeed in stronger_transfrom 11 | ''' 12 | resolutions = [96, 128, 160, 192, 224] 13 | 14 | self.res = resolutions[:num_res] 15 | self.resize_crop_ops = [transforms.RandomResizedCrop(res, scale=(0.2, 1.)) for res in self.res] 16 | self.num_res = num_res 17 | 18 | self.base_transform = base_transform 19 | self.stronger_transfrom = stronger_transfrom 20 | 21 | def __call__(self, x): 22 | q = self.base_transform(x) 23 | k = self.base_transform(x) 24 | 25 | q_stronger_augs = [] 26 | for resize_crop_op in self.resize_crop_ops: 27 | q_s = self.stronger_transfrom(resize_crop_op(x)) 28 | q_stronger_augs.append(q_s) 29 | 30 | return [q, k, q_stronger_augs] 31 | 32 | 33 | 34 | class GaussianBlur(object): 35 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 36 | 37 | def __init__(self, sigma=[.1, 2.]): 38 | self.sigma = sigma 39 | 40 | def __call__(self, x): 41 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 42 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 43 | return x 44 | -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## MoCo: Transferring to Detection 3 | 4 | The `train_net.py` script reproduces the object detection experiments on Pascal VOC and COCO. 5 | 6 | ### Instruction 7 | 8 | 1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). 9 | 10 | 1. Convert a pre-trained MoCo model to detectron2's format: 11 | ``` 12 | python3 convert-pretrain-to-detectron2.py input.pth.tar output.pkl 13 | ``` 14 | 15 | 1. Put dataset under "./datasets" directory, 16 | following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) 17 | requried by detectron2. 18 | 19 | 1. Run training: 20 | ``` 21 | python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_moco.yaml \ 22 | --num-gpus 8 MODEL.WEIGHTS ./output.pkl 23 | ``` 24 | 25 | ### Results 26 | 27 | Below are the results on Pascal VOC 2007 test, fine-tuned on 2007+2012 trainval for 24k iterations using Faster R-CNN with a R50-C4 backbone: 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 |
pretrainAP50APAP75
ImageNet-1M, supervised81.353.558.8
ImageNet-1M, MoCo v1, 200ep81.555.962.6
ImageNet-1M, MoCo v2, 200ep82.457.063.6
ImageNet-1M, MoCo v2, 800ep82.557.464.0
60 | 61 | ***Note:*** These results are means of 5 trials. Variation on Pascal VOC is large: the std of AP50, AP, AP75 is expected to be 0.2, 0.2, 0.4 in most cases. We recommend to run 5 trials and compute means. 62 | -------------------------------------------------------------------------------- /detection/configs/Base-RCNN-C4-BN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | RPN: 4 | PRE_NMS_TOPK_TEST: 6000 5 | POST_NMS_TOPK_TEST: 1000 6 | ROI_HEADS: 7 | NAME: "Res5ROIHeadsExtraNorm" 8 | BACKBONE: 9 | FREEZE_AT: 0 10 | RESNETS: 11 | NORM: "SyncBN" 12 | TEST: 13 | PRECISE_BN: 14 | ENABLED: True 15 | SOLVER: 16 | IMS_PER_BATCH: 16 17 | BASE_LR: 0.02 18 | -------------------------------------------------------------------------------- /detection/configs/coco_R_50_C4_2x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: True 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | INPUT: 6 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 7 | MIN_SIZE_TEST: 800 8 | DATASETS: 9 | TRAIN: ("coco_2017_train",) 10 | TEST: ("coco_2017_val",) 11 | SOLVER: 12 | STEPS: (120000, 160000) 13 | MAX_ITER: 180000 14 | -------------------------------------------------------------------------------- /detection/configs/coco_R_50_C4_2x_moco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "coco_R_50_C4_2x.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /detection/configs/pascal_voc_R_50_C4_24k.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-C4-BN.yaml" 2 | MODEL: 3 | MASK_ON: False 4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 5 | ROI_HEADS: 6 | NUM_CLASSES: 20 7 | INPUT: 8 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 9 | MIN_SIZE_TEST: 800 10 | DATASETS: 11 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') 12 | TEST: ('voc_2007_test',) 13 | SOLVER: 14 | STEPS: (18000, 22000) 15 | MAX_ITER: 24000 16 | WARMUP_ITERS: 100 17 | -------------------------------------------------------------------------------- /detection/configs/pascal_voc_R_50_C4_24k_moco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "pascal_voc_R_50_C4_24k.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | WEIGHTS: "See Instructions" 6 | RESNETS: 7 | STRIDE_IN_1X1: False 8 | INPUT: 9 | FORMAT: "RGB" 10 | -------------------------------------------------------------------------------- /detection/convert-pretrain-to-detectron2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | import torch 7 | 8 | if __name__ == "__main__": 9 | input = sys.argv[1] 10 | 11 | obj = torch.load(input, map_location="cpu") 12 | obj = obj["state_dict"] 13 | 14 | newmodel = {} 15 | for k, v in obj.items(): 16 | if not k.startswith("module.encoder_q."): 17 | continue 18 | old_k = k 19 | k = k.replace("module.encoder_q.", "") 20 | if "layer" not in k: 21 | k = "stem." + k 22 | for t in [1, 2, 3, 4]: 23 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 24 | for t in [1, 2, 3]: 25 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 26 | k = k.replace("downsample.0", "shortcut") 27 | k = k.replace("downsample.1", "shortcut.norm") 28 | print(old_k, "->", k) 29 | newmodel[k] = v.numpy() 30 | 31 | res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True} 32 | 33 | with open(sys.argv[2], "wb") as f: 34 | pkl.dump(res, f) 35 | -------------------------------------------------------------------------------- /detection/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import os 5 | 6 | from detectron2.checkpoint import DetectionCheckpointer 7 | from detectron2.config import get_cfg 8 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 9 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator 10 | from detectron2.layers import get_norm 11 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 12 | 13 | 14 | @ROI_HEADS_REGISTRY.register() 15 | class Res5ROIHeadsExtraNorm(Res5ROIHeads): 16 | """ 17 | As described in the MOCO paper, there is an extra BN layer 18 | following the res5 stage. 19 | """ 20 | def _build_res5_block(self, cfg): 21 | seq, out_channels = super()._build_res5_block(cfg) 22 | norm = cfg.MODEL.RESNETS.NORM 23 | norm = get_norm(norm, out_channels) 24 | seq.add_module("norm", norm) 25 | return seq, out_channels 26 | 27 | 28 | class Trainer(DefaultTrainer): 29 | @classmethod 30 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 31 | if output_folder is None: 32 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 33 | if "coco" in dataset_name: 34 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 35 | else: 36 | assert "voc" in dataset_name 37 | return PascalVOCDetectionEvaluator(dataset_name) 38 | 39 | 40 | def setup(args): 41 | cfg = get_cfg() 42 | cfg.merge_from_file(args.config_file) 43 | cfg.merge_from_list(args.opts) 44 | cfg.freeze() 45 | default_setup(cfg, args) 46 | return cfg 47 | 48 | 49 | def main(args): 50 | cfg = setup(args) 51 | 52 | if args.eval_only: 53 | model = Trainer.build_model(cfg) 54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 55 | cfg.MODEL.WEIGHTS, resume=args.resume 56 | ) 57 | res = Trainer.test(cfg, model) 58 | return res 59 | 60 | trainer = Trainer(cfg) 61 | trainer.resume_or_load(resume=args.resume) 62 | return trainer.train() 63 | 64 | 65 | if __name__ == "__main__": 66 | args = default_argument_parser().parse_args() 67 | print("Command Line Args:", args) 68 | launch( 69 | main, 70 | args.num_gpus, 71 | num_machines=args.num_machines, 72 | machine_rank=args.machine_rank, 73 | dist_url=args.dist_url, 74 | args=(args,), 75 | ) 76 | -------------------------------------------------------------------------------- /main_clsa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import argparse 4 | import builtins 5 | import math 6 | import os 7 | import random 8 | import shutil 9 | import time 10 | import warnings 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim 18 | import torch.multiprocessing as mp 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | import torchvision.datasets as datasets 23 | import torchvision.models as models 24 | 25 | import moco.loader 26 | import moco.builder 27 | 28 | import clsa.clsa_augs 29 | import clsa.builder 30 | import clsa.loader 31 | 32 | 33 | model_names = sorted(name for name in models.__dict__ 34 | if name.islower() and not name.startswith("__") 35 | and callable(models.__dict__[name])) 36 | 37 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 38 | parser.add_argument('data', metavar='DIR', 39 | help='path to dataset') 40 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 41 | choices=model_names, 42 | help='model architecture: ' + 43 | ' | '.join(model_names) + 44 | ' (default: resnet50)') 45 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 46 | help='number of data loading workers (default: 32)') 47 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 48 | help='number of total epochs to run') 49 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 50 | help='manual epoch number (useful on restarts)') 51 | parser.add_argument('-b', '--batch-size', default=256, type=int, 52 | metavar='N', 53 | help='mini-batch size (default: 256), this is the total ' 54 | 'batch size of all GPUs on the current node when ' 55 | 'using Data Parallel or Distributed Data Parallel') 56 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 57 | metavar='LR', help='initial learning rate', dest='lr') 58 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, 59 | help='learning rate schedule (when to drop lr by 10x)') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum of SGD solver') 62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 63 | metavar='W', help='weight decay (default: 1e-4)', 64 | dest='weight_decay') 65 | parser.add_argument('-p', '--print-freq', default=10, type=int, 66 | metavar='N', help='print frequency (default: 10)') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | parser.add_argument('--world-size', default=-1, type=int, 70 | help='number of nodes for distributed training') 71 | parser.add_argument('--rank', default=-1, type=int, 72 | help='node rank for distributed training') 73 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 74 | help='url used to set up distributed training') 75 | parser.add_argument('--dist-backend', default='nccl', type=str, 76 | help='distributed backend') 77 | parser.add_argument('--seed', default=None, type=int, 78 | help='seed for initializing training. ') 79 | parser.add_argument('--gpu', default=None, type=int, 80 | help='GPU id to use.') 81 | parser.add_argument('--multiprocessing-distributed', action='store_true', 82 | help='Use multi-processing distributed training to launch ' 83 | 'N processes per node, which has N GPUs. This is the ' 84 | 'fastest way to use PyTorch for either single node or ' 85 | 'multi node data parallel training') 86 | 87 | # moco specific configs: 88 | parser.add_argument('--moco-dim', default=2048, type=int, 89 | help='feature dimension (default: 128)') 90 | parser.add_argument('--moco-k', default=65536, type=int, 91 | help='queue size; number of negative keys (default: 65536)') 92 | parser.add_argument('--moco-m', default=0.999, type=float, 93 | help='moco momentum of updating key encoder (default: 0.999)') 94 | parser.add_argument('--moco-t', default=0.2, type=float, 95 | help='softmax temperature (default: 0.07)') 96 | 97 | # options for moco v2 98 | parser.add_argument('--mlp', action='store_true', 99 | help='use mlp head') 100 | parser.add_argument('--aug-plus', action='store_true', 101 | help='use moco v2 data augmentation') 102 | parser.add_argument('--cos', action='store_true', 103 | help='use cosine lr schedule') 104 | 105 | # additional hyper-param for clsa 106 | 107 | parser.add_argument('--ratio', default=1.0, type=float, 108 | help='the reweighing term for ddm loss') 109 | parser.add_argument('--num_res', default=1, type=int, 110 | help='The number of resolutions for stronger augs') 111 | 112 | 113 | def main(): 114 | args = parser.parse_args() 115 | 116 | if args.seed is not None: 117 | random.seed(args.seed) 118 | torch.manual_seed(args.seed) 119 | cudnn.deterministic = True 120 | warnings.warn('You have chosen to seed training. ' 121 | 'This will turn on the CUDNN deterministic setting, ' 122 | 'which can slow down your training considerably! ' 123 | 'You may see unexpected behavior when restarting ' 124 | 'from checkpoints.') 125 | 126 | if args.gpu is not None: 127 | warnings.warn('You have chosen a specific GPU. This will completely ' 128 | 'disable data parallelism.') 129 | 130 | if args.dist_url == "env://" and args.world_size == -1: 131 | args.world_size = int(os.environ["WORLD_SIZE"]) 132 | 133 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 134 | 135 | ngpus_per_node = torch.cuda.device_count() 136 | if args.multiprocessing_distributed: 137 | # Since we have ngpus_per_node processes per node, the total world_size 138 | # needs to be adjusted accordingly 139 | args.world_size = ngpus_per_node * args.world_size 140 | # Use torch.multiprocessing.spawn to launch distributed processes: the 141 | # main_worker process function 142 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 143 | else: 144 | # Simply call main_worker function 145 | main_worker(args.gpu, ngpus_per_node, args) 146 | 147 | 148 | def main_worker(gpu, ngpus_per_node, args): 149 | args.gpu = gpu 150 | 151 | # suppress printing if not master 152 | if args.multiprocessing_distributed and args.gpu != 0: 153 | def print_pass(*args): 154 | pass 155 | builtins.print = print_pass 156 | 157 | if args.gpu is not None: 158 | print("Use GPU: {} for training".format(args.gpu)) 159 | 160 | if args.distributed: 161 | if args.dist_url == "env://" and args.rank == -1: 162 | args.rank = int(os.environ["RANK"]) 163 | if args.multiprocessing_distributed: 164 | # For multiprocessing distributed training, rank needs to be the 165 | # global rank among all the processes 166 | args.rank = args.rank * ngpus_per_node + gpu 167 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 168 | world_size=args.world_size, rank=args.rank) 169 | # create model 170 | print("=> creating model '{}'".format(args.arch)) 171 | 172 | model = clsa.builder.CLSA( 173 | models.__dict__[args.arch], 174 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, args.ratio 175 | ) 176 | print(model) 177 | 178 | if args.distributed: 179 | # For multiprocessing distributed, DistributedDataParallel constructor 180 | # should always set the single device scope, otherwise, 181 | # DistributedDataParallel will use all available devices. 182 | if args.gpu is not None: 183 | torch.cuda.set_device(args.gpu) 184 | model.cuda(args.gpu) 185 | # When using a single GPU per process and per 186 | # DistributedDataParallel, we need to divide the batch size 187 | # ourselves based on the total number of GPUs we have 188 | args.batch_size = int(args.batch_size / ngpus_per_node) 189 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 190 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 191 | else: 192 | model.cuda() 193 | # DistributedDataParallel will divide and allocate batch_size to all 194 | # available GPUs if device_ids are not set 195 | model = torch.nn.parallel.DistributedDataParallel(model) 196 | elif args.gpu is not None: 197 | torch.cuda.set_device(args.gpu) 198 | model = model.cuda(args.gpu) 199 | # comment out the following line for debugging 200 | raise NotImplementedError("Only DistributedDataParallel is supported.") 201 | else: 202 | # AllGather implementation (batch shuffle, queue update, etc.) in 203 | # this code only supports DistributedDataParallel. 204 | raise NotImplementedError("Only DistributedDataParallel is supported.") 205 | 206 | # define loss function (criterion) and optimizer 207 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 208 | 209 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 210 | momentum=args.momentum, 211 | weight_decay=args.weight_decay) 212 | 213 | # optionally resume from a checkpoint 214 | if args.resume: 215 | if os.path.isfile(args.resume): 216 | print("=> loading checkpoint '{}'".format(args.resume)) 217 | if args.gpu is None: 218 | checkpoint = torch.load(args.resume) 219 | else: 220 | # Map model to be loaded to specified single gpu. 221 | loc = 'cuda:{}'.format(args.gpu) 222 | checkpoint = torch.load(args.resume, map_location=loc) 223 | args.start_epoch = checkpoint['epoch'] 224 | model.load_state_dict(checkpoint['state_dict']) 225 | optimizer.load_state_dict(checkpoint['optimizer']) 226 | print("=> loaded checkpoint '{}' (epoch {})" 227 | .format(args.resume, checkpoint['epoch'])) 228 | else: 229 | print("=> no checkpoint found at '{}'".format(args.resume)) 230 | 231 | cudnn.benchmark = True 232 | 233 | # Data loading code 234 | traindir = os.path.join(args.data, 'train') 235 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 236 | std=[0.229, 0.224, 0.225]) 237 | if args.aug_plus: 238 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 239 | augmentation = [ 240 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 241 | transforms.RandomApply([ 242 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 243 | ], p=0.8), 244 | transforms.RandomGrayscale(p=0.2), 245 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5), 246 | transforms.RandomHorizontalFlip(), 247 | transforms.ToTensor(), 248 | normalize 249 | ] 250 | else: 251 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978 252 | augmentation = [ 253 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)), 254 | transforms.RandomGrayscale(p=0.2), 255 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), 256 | transforms.RandomHorizontalFlip(), 257 | transforms.ToTensor(), 258 | normalize 259 | ] 260 | # add two additional lines to construct CLSA data loader with mutli-resolutions 261 | augmentation = transforms.Compose(augmentation) 262 | 263 | stronger_aug = clsa.clsa_augs.CLSAAug(num_of_times=5) # num of repetive times for randaug 264 | stronger_aug = transforms.Compose([stronger_aug, transforms.ToTensor(), normalize]) 265 | train_dataset = datasets.ImageFolder( 266 | traindir, 267 | clsa.loader.CALSMultiResolutionTransform(base_transform=augmentation, 268 | stronger_transfrom=stronger_aug, num_res=args.num_res)) 269 | 270 | if args.distributed: 271 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 272 | else: 273 | train_sampler = None 274 | 275 | train_loader = torch.utils.data.DataLoader( 276 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 277 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 278 | 279 | for epoch in range(args.start_epoch, args.epochs): 280 | if args.distributed: 281 | train_sampler.set_epoch(epoch) 282 | adjust_learning_rate(optimizer, epoch, args) 283 | 284 | # train for one epoch 285 | train(train_loader, model, criterion, optimizer, epoch, args) 286 | 287 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 288 | and args.rank % ngpus_per_node == 0): 289 | save_checkpoint({ 290 | 'epoch': epoch + 1, 291 | 'arch': args.arch, 292 | 'state_dict': model.state_dict(), 293 | 'optimizer' : optimizer.state_dict(), 294 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch)) 295 | 296 | 297 | def train(train_loader, model, criterion, optimizer, epoch, args): 298 | batch_time = AverageMeter('Time', ':6.3f') 299 | data_time = AverageMeter('Data', ':6.3f') 300 | losses = AverageMeter('Loss', ':.4e') 301 | top1 = AverageMeter('Acc@1', ':6.2f') 302 | top5 = AverageMeter('Acc@5', ':6.2f') 303 | progress = ProgressMeter( 304 | len(train_loader), 305 | [batch_time, data_time, losses, top1, top5], 306 | prefix="Epoch: [{}]".format(epoch)) 307 | 308 | # switch to train mode 309 | model.train() 310 | 311 | end = time.time() 312 | for i, (images, _) in enumerate(train_loader): 313 | # measure data loading time 314 | data_time.update(time.time() - end) 315 | 316 | if args.gpu is not None: 317 | images[0] = images[0].cuda(args.gpu, non_blocking=True) 318 | images[1] = images[1].cuda(args.gpu, non_blocking=True) 319 | 320 | # compute output 321 | im_q, im_k, img_stronger_aug_list = images[0], images[1], images[2] 322 | output, target, loss = model(im_q, im_k, img_stronger_aug_list) 323 | 324 | #loss = criterion(output, target) 325 | 326 | # acc1/acc5 are (K+1)-way contrast classifier accuracy 327 | # measure accuracy and record loss 328 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 329 | losses.update(loss.item(), images[0].size(0)) 330 | top1.update(acc1[0], images[0].size(0)) 331 | top5.update(acc5[0], images[0].size(0)) 332 | 333 | # compute gradient and do SGD step 334 | optimizer.zero_grad() 335 | loss.backward() 336 | optimizer.step() 337 | 338 | # measure elapsed time 339 | batch_time.update(time.time() - end) 340 | end = time.time() 341 | 342 | if i % args.print_freq == 0: 343 | progress.display(i) 344 | 345 | 346 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 347 | torch.save(state, filename) 348 | if is_best: 349 | shutil.copyfile(filename, 'model_best.pth.tar') 350 | 351 | 352 | class AverageMeter(object): 353 | """Computes and stores the average and current value""" 354 | def __init__(self, name, fmt=':f'): 355 | self.name = name 356 | self.fmt = fmt 357 | self.reset() 358 | 359 | def reset(self): 360 | self.val = 0 361 | self.avg = 0 362 | self.sum = 0 363 | self.count = 0 364 | 365 | def update(self, val, n=1): 366 | self.val = val 367 | self.sum += val * n 368 | self.count += n 369 | self.avg = self.sum / self.count 370 | 371 | def __str__(self): 372 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 373 | return fmtstr.format(**self.__dict__) 374 | 375 | 376 | class ProgressMeter(object): 377 | def __init__(self, num_batches, meters, prefix=""): 378 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 379 | self.meters = meters 380 | self.prefix = prefix 381 | 382 | def display(self, batch): 383 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 384 | entries += [str(meter) for meter in self.meters] 385 | print('\t'.join(entries)) 386 | 387 | def _get_batch_fmtstr(self, num_batches): 388 | num_digits = len(str(num_batches // 1)) 389 | fmt = '{:' + str(num_digits) + 'd}' 390 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 391 | 392 | 393 | def adjust_learning_rate(optimizer, epoch, args): 394 | """Decay the learning rate based on schedule""" 395 | lr = args.lr 396 | if args.cos: # cosine lr schedule 397 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 398 | else: # stepwise lr schedule 399 | for milestone in args.schedule: 400 | lr *= 0.1 if epoch >= milestone else 1. 401 | for param_group in optimizer.param_groups: 402 | param_group['lr'] = lr 403 | 404 | 405 | def accuracy(output, target, topk=(1,)): 406 | """Computes the accuracy over the k top predictions for the specified values of k""" 407 | with torch.no_grad(): 408 | maxk = max(topk) 409 | batch_size = target.size(0) 410 | 411 | _, pred = output.topk(maxk, 1, True, True) 412 | pred = pred.t() 413 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 414 | 415 | res = [] 416 | for k in topk: 417 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 418 | res.append(correct_k.mul_(100.0 / batch_size)) 419 | return res 420 | 421 | 422 | if __name__ == '__main__': 423 | main() 424 | -------------------------------------------------------------------------------- /main_lincls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import argparse 4 | import builtins 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.distributed as dist 16 | import torch.optim 17 | import torch.multiprocessing as mp 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.transforms as transforms 21 | import torchvision.datasets as datasets 22 | import torchvision.models as models 23 | 24 | model_names = sorted(name for name in models.__dict__ 25 | if name.islower() and not name.startswith("__") 26 | and callable(models.__dict__[name])) 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 29 | parser.add_argument('data', metavar='DIR', 30 | help='path to dataset') 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: resnet50)') 36 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 37 | help='number of data loading workers (default: 32)') 38 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 39 | help='number of total epochs to run') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--lr', '--learning-rate', default=30., type=float, 48 | metavar='LR', help='initial learning rate', dest='lr') 49 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int, 50 | help='learning rate schedule (when to drop lr by a ratio)') 51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 52 | help='momentum') 53 | parser.add_argument('--wd', '--weight-decay', default=0., type=float, 54 | metavar='W', help='weight decay (default: 0.)', 55 | dest='weight_decay') 56 | parser.add_argument('-p', '--print-freq', default=10, type=int, 57 | metavar='N', help='print frequency (default: 10)') 58 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 59 | help='path to latest checkpoint (default: none)') 60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 61 | help='evaluate model on validation set') 62 | parser.add_argument('--world-size', default=-1, type=int, 63 | help='number of nodes for distributed training') 64 | parser.add_argument('--rank', default=-1, type=int, 65 | help='node rank for distributed training') 66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 67 | help='url used to set up distributed training') 68 | parser.add_argument('--dist-backend', default='nccl', type=str, 69 | help='distributed backend') 70 | parser.add_argument('--seed', default=None, type=int, 71 | help='seed for initializing training. ') 72 | parser.add_argument('--gpu', default=None, type=int, 73 | help='GPU id to use.') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | 80 | parser.add_argument('--pretrained', default='', type=str, 81 | help='path to moco pretrained checkpoint') 82 | 83 | best_acc1 = 0 84 | 85 | 86 | def main(): 87 | args = parser.parse_args() 88 | 89 | if args.seed is not None: 90 | random.seed(args.seed) 91 | torch.manual_seed(args.seed) 92 | cudnn.deterministic = True 93 | warnings.warn('You have chosen to seed training. ' 94 | 'This will turn on the CUDNN deterministic setting, ' 95 | 'which can slow down your training considerably! ' 96 | 'You may see unexpected behavior when restarting ' 97 | 'from checkpoints.') 98 | 99 | if args.gpu is not None: 100 | warnings.warn('You have chosen a specific GPU. This will completely ' 101 | 'disable data parallelism.') 102 | 103 | if args.dist_url == "env://" and args.world_size == -1: 104 | args.world_size = int(os.environ["WORLD_SIZE"]) 105 | 106 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 107 | 108 | ngpus_per_node = torch.cuda.device_count() 109 | if args.multiprocessing_distributed: 110 | # Since we have ngpus_per_node processes per node, the total world_size 111 | # needs to be adjusted accordingly 112 | args.world_size = ngpus_per_node * args.world_size 113 | # Use torch.multiprocessing.spawn to launch distributed processes: the 114 | # main_worker process function 115 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 116 | else: 117 | # Simply call main_worker function 118 | main_worker(args.gpu, ngpus_per_node, args) 119 | 120 | 121 | def main_worker(gpu, ngpus_per_node, args): 122 | global best_acc1 123 | args.gpu = gpu 124 | 125 | # suppress printing if not master 126 | if args.multiprocessing_distributed and args.gpu != 0: 127 | def print_pass(*args): 128 | pass 129 | builtins.print = print_pass 130 | 131 | if args.gpu is not None: 132 | print("Use GPU: {} for training".format(args.gpu)) 133 | 134 | if args.distributed: 135 | if args.dist_url == "env://" and args.rank == -1: 136 | args.rank = int(os.environ["RANK"]) 137 | if args.multiprocessing_distributed: 138 | # For multiprocessing distributed training, rank needs to be the 139 | # global rank among all the processes 140 | args.rank = args.rank * ngpus_per_node + gpu 141 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 142 | world_size=args.world_size, rank=args.rank) 143 | # create model 144 | print("=> creating model '{}'".format(args.arch)) 145 | model = models.__dict__[args.arch]() 146 | 147 | # freeze all layers but the last fc 148 | for name, param in model.named_parameters(): 149 | if name not in ['fc.weight', 'fc.bias']: 150 | param.requires_grad = False 151 | # init the fc layer 152 | model.fc.weight.data.normal_(mean=0.0, std=0.01) 153 | model.fc.bias.data.zero_() 154 | 155 | # load from pre-trained, before DistributedDataParallel constructor 156 | if args.pretrained: 157 | if os.path.isfile(args.pretrained): 158 | print("=> loading checkpoint '{}'".format(args.pretrained)) 159 | checkpoint = torch.load(args.pretrained, map_location="cpu") 160 | 161 | # rename moco pre-trained keys 162 | state_dict = checkpoint['state_dict'] 163 | for k in list(state_dict.keys()): 164 | # retain only encoder_q up to before the embedding layer 165 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 166 | # remove prefix 167 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 168 | # delete renamed or unused k 169 | del state_dict[k] 170 | 171 | args.start_epoch = 0 172 | msg = model.load_state_dict(state_dict, strict=False) 173 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 174 | 175 | print("=> loaded pre-trained model '{}'".format(args.pretrained)) 176 | else: 177 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 178 | 179 | if args.distributed: 180 | # For multiprocessing distributed, DistributedDataParallel constructor 181 | # should always set the single device scope, otherwise, 182 | # DistributedDataParallel will use all available devices. 183 | if args.gpu is not None: 184 | torch.cuda.set_device(args.gpu) 185 | model.cuda(args.gpu) 186 | # When using a single GPU per process and per 187 | # DistributedDataParallel, we need to divide the batch size 188 | # ourselves based on the total number of GPUs we have 189 | args.batch_size = int(args.batch_size / ngpus_per_node) 190 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 191 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 192 | else: 193 | model.cuda() 194 | # DistributedDataParallel will divide and allocate batch_size to all 195 | # available GPUs if device_ids are not set 196 | model = torch.nn.parallel.DistributedDataParallel(model) 197 | elif args.gpu is not None: 198 | torch.cuda.set_device(args.gpu) 199 | model = model.cuda(args.gpu) 200 | else: 201 | # DataParallel will divide and allocate batch_size to all available GPUs 202 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 203 | model.features = torch.nn.DataParallel(model.features) 204 | model.cuda() 205 | else: 206 | model = torch.nn.DataParallel(model).cuda() 207 | 208 | # define loss function (criterion) and optimizer 209 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 210 | 211 | # optimize only the linear classifier 212 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 213 | assert len(parameters) == 2 # fc.weight, fc.bias 214 | optimizer = torch.optim.SGD(parameters, args.lr, 215 | momentum=args.momentum, 216 | weight_decay=args.weight_decay) 217 | 218 | # optionally resume from a checkpoint 219 | if args.resume: 220 | if os.path.isfile(args.resume): 221 | print("=> loading checkpoint '{}'".format(args.resume)) 222 | if args.gpu is None: 223 | checkpoint = torch.load(args.resume) 224 | else: 225 | # Map model to be loaded to specified single gpu. 226 | loc = 'cuda:{}'.format(args.gpu) 227 | checkpoint = torch.load(args.resume, map_location=loc) 228 | args.start_epoch = checkpoint['epoch'] 229 | best_acc1 = checkpoint['best_acc1'] 230 | if args.gpu is not None: 231 | # best_acc1 may be from a checkpoint from a different GPU 232 | best_acc1 = best_acc1.to(args.gpu) 233 | model.load_state_dict(checkpoint['state_dict']) 234 | optimizer.load_state_dict(checkpoint['optimizer']) 235 | print("=> loaded checkpoint '{}' (epoch {})" 236 | .format(args.resume, checkpoint['epoch'])) 237 | else: 238 | print("=> no checkpoint found at '{}'".format(args.resume)) 239 | 240 | cudnn.benchmark = True 241 | 242 | # Data loading code 243 | traindir = os.path.join(args.data, 'train') 244 | valdir = os.path.join(args.data, 'val') 245 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 246 | std=[0.229, 0.224, 0.225]) 247 | 248 | train_dataset = datasets.ImageFolder( 249 | traindir, 250 | transforms.Compose([ 251 | transforms.RandomResizedCrop(224), 252 | transforms.RandomHorizontalFlip(), 253 | transforms.ToTensor(), 254 | normalize, 255 | ])) 256 | 257 | if args.distributed: 258 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 259 | else: 260 | train_sampler = None 261 | 262 | train_loader = torch.utils.data.DataLoader( 263 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 264 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 265 | 266 | val_loader = torch.utils.data.DataLoader( 267 | datasets.ImageFolder(valdir, transforms.Compose([ 268 | transforms.Resize(256), 269 | transforms.CenterCrop(224), 270 | transforms.ToTensor(), 271 | normalize, 272 | ])), 273 | batch_size=args.batch_size, shuffle=False, 274 | num_workers=args.workers, pin_memory=True) 275 | 276 | if args.evaluate: 277 | validate(val_loader, model, criterion, args) 278 | return 279 | 280 | for epoch in range(args.start_epoch, args.epochs): 281 | if args.distributed: 282 | train_sampler.set_epoch(epoch) 283 | adjust_learning_rate(optimizer, epoch, args) 284 | 285 | # train for one epoch 286 | train(train_loader, model, criterion, optimizer, epoch, args) 287 | 288 | # evaluate on validation set 289 | acc1 = validate(val_loader, model, criterion, args) 290 | 291 | # remember best acc@1 and save checkpoint 292 | is_best = acc1 > best_acc1 293 | best_acc1 = max(acc1, best_acc1) 294 | 295 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 296 | and args.rank % ngpus_per_node == 0): 297 | save_checkpoint({ 298 | 'epoch': epoch + 1, 299 | 'arch': args.arch, 300 | 'state_dict': model.state_dict(), 301 | 'best_acc1': best_acc1, 302 | 'optimizer' : optimizer.state_dict(), 303 | }, is_best) 304 | if epoch == args.start_epoch: 305 | sanity_check(model.state_dict(), args.pretrained) 306 | 307 | 308 | def train(train_loader, model, criterion, optimizer, epoch, args): 309 | batch_time = AverageMeter('Time', ':6.3f') 310 | data_time = AverageMeter('Data', ':6.3f') 311 | losses = AverageMeter('Loss', ':.4e') 312 | top1 = AverageMeter('Acc@1', ':6.2f') 313 | top5 = AverageMeter('Acc@5', ':6.2f') 314 | progress = ProgressMeter( 315 | len(train_loader), 316 | [batch_time, data_time, losses, top1, top5], 317 | prefix="Epoch: [{}]".format(epoch)) 318 | 319 | """ 320 | Switch to eval mode: 321 | Under the protocol of linear classification on frozen features/models, 322 | it is not legitimate to change any part of the pre-trained model. 323 | BatchNorm in train mode may revise running mean/std (even if it receives 324 | no gradient), which are part of the model parameters too. 325 | """ 326 | model.eval() 327 | 328 | end = time.time() 329 | for i, (images, target) in enumerate(train_loader): 330 | # measure data loading time 331 | data_time.update(time.time() - end) 332 | 333 | if args.gpu is not None: 334 | images = images.cuda(args.gpu, non_blocking=True) 335 | target = target.cuda(args.gpu, non_blocking=True) 336 | 337 | # compute output 338 | output = model(images) 339 | loss = criterion(output, target) 340 | 341 | # measure accuracy and record loss 342 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 343 | losses.update(loss.item(), images.size(0)) 344 | top1.update(acc1[0], images.size(0)) 345 | top5.update(acc5[0], images.size(0)) 346 | 347 | # compute gradient and do SGD step 348 | optimizer.zero_grad() 349 | loss.backward() 350 | optimizer.step() 351 | 352 | # measure elapsed time 353 | batch_time.update(time.time() - end) 354 | end = time.time() 355 | 356 | if i % args.print_freq == 0: 357 | progress.display(i) 358 | 359 | 360 | def validate(val_loader, model, criterion, args): 361 | batch_time = AverageMeter('Time', ':6.3f') 362 | losses = AverageMeter('Loss', ':.4e') 363 | top1 = AverageMeter('Acc@1', ':6.2f') 364 | top5 = AverageMeter('Acc@5', ':6.2f') 365 | progress = ProgressMeter( 366 | len(val_loader), 367 | [batch_time, losses, top1, top5], 368 | prefix='Test: ') 369 | 370 | # switch to evaluate mode 371 | model.eval() 372 | 373 | with torch.no_grad(): 374 | end = time.time() 375 | for i, (images, target) in enumerate(val_loader): 376 | if args.gpu is not None: 377 | images = images.cuda(args.gpu, non_blocking=True) 378 | target = target.cuda(args.gpu, non_blocking=True) 379 | 380 | # compute output 381 | output = model(images) 382 | loss = criterion(output, target) 383 | 384 | # measure accuracy and record loss 385 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 386 | losses.update(loss.item(), images.size(0)) 387 | top1.update(acc1[0], images.size(0)) 388 | top5.update(acc5[0], images.size(0)) 389 | 390 | # measure elapsed time 391 | batch_time.update(time.time() - end) 392 | end = time.time() 393 | 394 | if i % args.print_freq == 0: 395 | progress.display(i) 396 | 397 | # TODO: this should also be done with the ProgressMeter 398 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 399 | .format(top1=top1, top5=top5)) 400 | 401 | return top1.avg 402 | 403 | 404 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 405 | torch.save(state, filename) 406 | if is_best: 407 | shutil.copyfile(filename, 'model_best.pth.tar') 408 | 409 | 410 | def sanity_check(state_dict, pretrained_weights): 411 | """ 412 | Linear classifier should not change any weights other than the linear layer. 413 | This sanity check asserts nothing wrong happens (e.g., BN stats updated). 414 | """ 415 | print("=> loading '{}' for sanity check".format(pretrained_weights)) 416 | checkpoint = torch.load(pretrained_weights, map_location="cpu") 417 | state_dict_pre = checkpoint['state_dict'] 418 | 419 | for k in list(state_dict.keys()): 420 | # only ignore fc layer 421 | if 'fc.weight' in k or 'fc.bias' in k: 422 | continue 423 | 424 | # name in pretrained model 425 | k_pre = 'module.encoder_q.' + k[len('module.'):] \ 426 | if k.startswith('module.') else 'module.encoder_q.' + k 427 | 428 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ 429 | '{} is changed in linear classifier training.'.format(k) 430 | 431 | print("=> sanity check passed.") 432 | 433 | 434 | class AverageMeter(object): 435 | """Computes and stores the average and current value""" 436 | def __init__(self, name, fmt=':f'): 437 | self.name = name 438 | self.fmt = fmt 439 | self.reset() 440 | 441 | def reset(self): 442 | self.val = 0 443 | self.avg = 0 444 | self.sum = 0 445 | self.count = 0 446 | 447 | def update(self, val, n=1): 448 | self.val = val 449 | self.sum += val * n 450 | self.count += n 451 | self.avg = self.sum / self.count 452 | 453 | def __str__(self): 454 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 455 | return fmtstr.format(**self.__dict__) 456 | 457 | 458 | class ProgressMeter(object): 459 | def __init__(self, num_batches, meters, prefix=""): 460 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 461 | self.meters = meters 462 | self.prefix = prefix 463 | 464 | def display(self, batch): 465 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 466 | entries += [str(meter) for meter in self.meters] 467 | print('\t'.join(entries)) 468 | 469 | def _get_batch_fmtstr(self, num_batches): 470 | num_digits = len(str(num_batches // 1)) 471 | fmt = '{:' + str(num_digits) + 'd}' 472 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 473 | 474 | 475 | def adjust_learning_rate(optimizer, epoch, args): 476 | """Decay the learning rate based on schedule""" 477 | lr = args.lr 478 | for milestone in args.schedule: 479 | lr *= 0.1 if epoch >= milestone else 1. 480 | for param_group in optimizer.param_groups: 481 | param_group['lr'] = lr 482 | 483 | 484 | def accuracy(output, target, topk=(1,)): 485 | """Computes the accuracy over the k top predictions for the specified values of k""" 486 | with torch.no_grad(): 487 | maxk = max(topk) 488 | batch_size = target.size(0) 489 | 490 | _, pred = output.topk(maxk, 1, True, True) 491 | pred = pred.t() 492 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 493 | 494 | res = [] 495 | for k in topk: 496 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 497 | res.append(correct_k.mul_(100.0 / batch_size)) 498 | return res 499 | 500 | 501 | if __name__ == '__main__': 502 | main() 503 | -------------------------------------------------------------------------------- /moco/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /moco/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MoCo(nn.Module): 7 | """ 8 | Build a MoCo model with: a query encoder, a key encoder, and a queue 9 | https://arxiv.org/abs/1911.05722 10 | """ 11 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 12 | """ 13 | dim: feature dimension (default: 128) 14 | K: queue size; number of negative keys (default: 65536) 15 | m: moco momentum of updating key encoder (default: 0.999) 16 | T: softmax temperature (default: 0.07) 17 | """ 18 | super(MoCo, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | self.T = T 23 | 24 | # create the encoders 25 | # num_classes is the output fc dimension 26 | self.encoder_q = base_encoder(num_classes=dim) 27 | self.encoder_k = base_encoder(num_classes=dim) 28 | 29 | if mlp: # hack: brute-force replacement 30 | dim_mlp = self.encoder_q.fc.weight.shape[1] 31 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 32 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) # initialize 36 | param_k.requires_grad = False # not update by gradient 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(dim, K)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | @torch.no_grad() 45 | def _momentum_update_key_encoder(self): 46 | """ 47 | Momentum update of the key encoder 48 | """ 49 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 50 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 51 | 52 | @torch.no_grad() 53 | def _dequeue_and_enqueue(self, keys): 54 | # gather keys before updating queue 55 | keys = concat_all_gather(keys) 56 | 57 | batch_size = keys.shape[0] 58 | 59 | ptr = int(self.queue_ptr) 60 | assert self.K % batch_size == 0 # for simplicity 61 | 62 | # replace the keys at ptr (dequeue and enqueue) 63 | self.queue[:, ptr:ptr + batch_size] = keys.T 64 | ptr = (ptr + batch_size) % self.K # move pointer 65 | 66 | self.queue_ptr[0] = ptr 67 | 68 | @torch.no_grad() 69 | def _batch_shuffle_ddp(self, x): 70 | """ 71 | Batch shuffle, for making use of BatchNorm. 72 | *** Only support DistributedDataParallel (DDP) model. *** 73 | """ 74 | # gather from all gpus 75 | batch_size_this = x.shape[0] 76 | x_gather = concat_all_gather(x) 77 | batch_size_all = x_gather.shape[0] 78 | 79 | num_gpus = batch_size_all // batch_size_this 80 | 81 | # random shuffle index 82 | idx_shuffle = torch.randperm(batch_size_all).cuda() 83 | 84 | # broadcast to all gpus 85 | torch.distributed.broadcast(idx_shuffle, src=0) 86 | 87 | # index for restoring 88 | idx_unshuffle = torch.argsort(idx_shuffle) 89 | 90 | # shuffled index for this gpu 91 | gpu_idx = torch.distributed.get_rank() 92 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 93 | 94 | return x_gather[idx_this], idx_unshuffle 95 | 96 | @torch.no_grad() 97 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 98 | """ 99 | Undo batch shuffle. 100 | *** Only support DistributedDataParallel (DDP) model. *** 101 | """ 102 | # gather from all gpus 103 | batch_size_this = x.shape[0] 104 | x_gather = concat_all_gather(x) 105 | batch_size_all = x_gather.shape[0] 106 | 107 | num_gpus = batch_size_all // batch_size_this 108 | 109 | # restored index for this gpu 110 | gpu_idx = torch.distributed.get_rank() 111 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 112 | 113 | return x_gather[idx_this] 114 | 115 | def forward(self, im_q, im_k): 116 | """ 117 | Input: 118 | im_q: a batch of query images 119 | im_k: a batch of key images 120 | Output: 121 | logits, targets 122 | """ 123 | 124 | # compute query features 125 | q = self.encoder_q(im_q) # queries: NxC 126 | q = nn.functional.normalize(q, dim=1) 127 | 128 | # compute key features 129 | with torch.no_grad(): # no gradient to keys 130 | self._momentum_update_key_encoder() # update the key encoder 131 | 132 | # shuffle for making use of BN 133 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 134 | 135 | k = self.encoder_k(im_k) # keys: NxC 136 | k = nn.functional.normalize(k, dim=1) 137 | 138 | # undo shuffle 139 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 140 | 141 | # compute logits 142 | # Einstein sum is more intuitive 143 | # positive logits: Nx1 144 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 145 | # negative logits: NxK 146 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 147 | 148 | # logits: Nx(1+K) 149 | logits = torch.cat([l_pos, l_neg], dim=1) 150 | 151 | # apply temperature 152 | logits /= self.T 153 | 154 | # labels: positive key indicators 155 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 156 | 157 | # dequeue and enqueue 158 | self._dequeue_and_enqueue(k) 159 | 160 | return logits, labels 161 | 162 | 163 | # utils 164 | @torch.no_grad() 165 | def concat_all_gather(tensor): 166 | """ 167 | Performs all_gather operation on the provided tensors. 168 | *** Warning ***: torch.distributed.all_gather has no gradient. 169 | """ 170 | tensors_gather = [torch.ones_like(tensor) 171 | for _ in range(torch.distributed.get_world_size())] 172 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 173 | 174 | output = torch.cat(tensors_gather, dim=0) 175 | return output 176 | -------------------------------------------------------------------------------- /moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class TwoCropsTransform: 8 | """Take two random crops of one image as the query and key.""" 9 | 10 | def __init__(self, base_transform): 11 | self.base_transform = base_transform 12 | 13 | def __call__(self, x): 14 | q = self.base_transform(x) 15 | k = self.base_transform(x) 16 | return [q, k] 17 | 18 | 19 | class GaussianBlur(object): 20 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 21 | 22 | def __init__(self, sigma=[.1, 2.]): 23 | self.sigma = sigma 24 | 25 | def __call__(self, x): 26 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 27 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 28 | return x --------------------------------------------------------------------------------