├── README.md
├── clsa
├── builder.py
├── clsa_augs.py
└── loader.py
├── detection
├── README.md
├── configs
│ ├── Base-RCNN-C4-BN.yaml
│ ├── coco_R_50_C4_2x.yaml
│ ├── coco_R_50_C4_2x_moco.yaml
│ ├── pascal_voc_R_50_C4_24k.yaml
│ └── pascal_voc_R_50_C4_24k_moco.yaml
├── convert-pretrain-to-detectron2.py
└── train_net.py
├── main_clsa.py
├── main_lincls.py
└── moco
├── __init__.py
├── builder.py
└── loader.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Unoffical implementation of Contrastive Learning with Stronger Augmentations
2 | WIP!!
3 |
4 | current results: (linear evaluation protocol on ImageNet)
5 |
6 | |Train epochs | Single | Mul-5 | MoCo-v2 | |
7 | |---|---|---|---|---|
8 | | 40 | 55.4% | 60.2% | 56.9% | |
9 | | 200 | 66.5% | 68.3% | 67.6% | |
10 | | | | | | |
11 |
12 |
13 | This is an unofficial PyTorch implementation of the CLSA paper: [Contrastive Learning with Stronger Augmentations](https://openreview.net/forum?id=KJSC_AsN14):
14 |
15 | Note: This implementation is most adopted from the offical moco's implementation from https://github.com/facebookresearch/moco
16 | This repo aims to be minimal modifications on that code.
17 |
18 |
19 |
20 | ### Preparation
21 | Note: This section is copied from moco's repo
22 |
23 | Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet).
24 |
25 |
26 |
27 | ### Unsupervised Training
28 |
29 | This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported.
30 |
31 | To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run:
32 | ```
33 | python main_clsa.py \
34 | -a resnet50 \
35 | --lr 0.03 \
36 | --batch-size 256 \
37 | --mlp --aug-plus --cos \
38 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
39 | [your imagenet-folder with train and val folders]
40 | ```
41 | This script uses all the default hyper-parameters as described in CLSA paper.
42 |
43 |
44 | ### Linear Classification
45 | Note: This section is copied from moco's repo
46 |
47 | With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run:
48 | ```
49 | python main_lincls.py \
50 | -a resnet50 \
51 | --lr 30.0 \
52 | --batch-size 256 \
53 | --pretrained [your checkpoint path]/checkpoint_0199.pth.tar \
54 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
55 | [your imagenet-folder with train and val folders]
56 | ```
57 |
58 | ### TODO:
59 | 1. ImageNet-1K CLSA-Single-200epoch pretraining: Running
60 | 2. ImageNet-1K CLSA-Mul-200epoch pretraining: Running
61 | 3. Evaluate CLSA-Single/-Mul on ImageNet Linear Protocal
62 | 4. Evaluate CLSA-Single/-Mul on VOC07 Det
63 |
64 |
65 |
--------------------------------------------------------------------------------
/clsa/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torchvision import models
5 |
6 | class CLSA(nn.Module):
7 | """
8 | Build a MoCo-like model with: a query encoder, a key encoder, and a queue acrroding these tow papers
9 | https://arxiv.org/abs/1911.05722
10 | https://openreview.net/forum?id=KJSC_AsN14
11 | """
12 | def __init__(self, base_encoder=models.resnet50, dim=2048, K=65536, m=0.999, T=0.2, mlp=True, ratio=1.0):
13 | """
14 | dim: feature dimension (default: 2048)
15 | K: queue size; number of negative keys (default: 65536)
16 | m: moco momentum of updating key encoder (default: 0.999)
17 | T: softmax temperature (default: 0.07)
18 | ratio: the coeffient for reweighting the ddm loss, i.e., beta in paper. (default: 1.0)
19 | """
20 | super(CLSA, self).__init__()
21 |
22 | self.K = K
23 | self.m = m
24 | self.T = T
25 | self.ratio = ratio
26 |
27 | # create the encoders
28 | # num_classes is the output fc dimension
29 | self.encoder_q = base_encoder(num_classes=dim)
30 | self.encoder_k = base_encoder(num_classes=dim)
31 |
32 | if mlp: # hack: brute-force replacement
33 | dim_mlp = self.encoder_q.fc.weight.shape[1]
34 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
35 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
36 |
37 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
38 | param_k.data.copy_(param_q.data) # initialize
39 | param_k.requires_grad = False # not update by gradient
40 |
41 | # create the queue
42 | self.register_buffer("queue", torch.randn(dim, K))
43 | self.queue = nn.functional.normalize(self.queue, dim=0)
44 |
45 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
46 |
47 | self.criterion = nn.CrossEntropyLoss()
48 |
49 | @torch.no_grad()
50 | def _momentum_update_key_encoder(self):
51 | """
52 | Momentum update of the key encoder
53 | """
54 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
55 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
56 |
57 | @torch.no_grad()
58 | def _dequeue_and_enqueue(self, keys):
59 | # gather keys before updating queue
60 | keys = concat_all_gather(keys)
61 |
62 | batch_size = keys.shape[0]
63 |
64 | ptr = int(self.queue_ptr)
65 | assert self.K % batch_size == 0 # for simplicity
66 |
67 | # replace the keys at ptr (dequeue and enqueue)
68 | self.queue[:, ptr:ptr + batch_size] = keys.T
69 | ptr = (ptr + batch_size) % self.K # move pointer
70 |
71 | self.queue_ptr[0] = ptr
72 |
73 | @torch.no_grad()
74 | def _batch_shuffle_ddp(self, x):
75 | """
76 | Batch shuffle, for making use of BatchNorm.
77 | *** Only support DistributedDataParallel (DDP) model. ***
78 | """
79 | # gather from all gpus
80 | batch_size_this = x.shape[0]
81 | x_gather = concat_all_gather(x)
82 | batch_size_all = x_gather.shape[0]
83 |
84 | num_gpus = batch_size_all // batch_size_this
85 |
86 | # random shuffle index
87 | idx_shuffle = torch.randperm(batch_size_all).cuda()
88 |
89 | # broadcast to all gpus
90 | torch.distributed.broadcast(idx_shuffle, src=0)
91 |
92 | # index for restoring
93 | idx_unshuffle = torch.argsort(idx_shuffle)
94 |
95 | # shuffled index for this gpu
96 | gpu_idx = torch.distributed.get_rank()
97 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
98 |
99 | return x_gather[idx_this], idx_unshuffle
100 |
101 | @torch.no_grad()
102 | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
103 | """
104 | Undo batch shuffle.
105 | *** Only support DistributedDataParallel (DDP) model. ***
106 | """
107 | # gather from all gpus
108 | batch_size_this = x.shape[0]
109 | x_gather = concat_all_gather(x)
110 | batch_size_all = x_gather.shape[0]
111 |
112 | num_gpus = batch_size_all // batch_size_this
113 |
114 | # restored index for this gpu
115 | gpu_idx = torch.distributed.get_rank()
116 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
117 |
118 | return x_gather[idx_this]
119 |
120 | def forward(self, im_q, im_k, img_stronger_aug_list):
121 | """
122 | Input:
123 | img_stronger_aug_list = [img_res_96, img_res_128, xxx]. img_res_96.shape = [N, 3, 96, 96]
124 | Output:
125 | loss_dict
126 | """
127 |
128 |
129 | # compute query features
130 | q = self.encoder_q(im_q) # queries: NxC
131 | q = nn.functional.normalize(q, dim=1)
132 |
133 | # compute key features
134 | with torch.no_grad(): # no gradient to keys
135 | self._momentum_update_key_encoder() # update the key encoder
136 |
137 | # shuffle for making use of BN
138 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
139 |
140 | k = self.encoder_k(im_k) # keys: NxC
141 | k = nn.functional.normalize(k, dim=1)
142 |
143 | # undo shuffle
144 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
145 |
146 | # compute logits
147 | # Einstein sum is more intuitive
148 | # positive logits: Nx1
149 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
150 | # negative logits: NxK
151 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
152 |
153 | # logits: Nx(1+K)
154 | logits = torch.cat([l_pos, l_neg], dim=1)
155 |
156 | # apply temperature
157 | logits /= self.T
158 |
159 | # labels: positive key indicators
160 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
161 |
162 | #losses = dict()
163 | loss_contrastive = self.criterion(logits, labels)
164 |
165 | # compute ddm loss below
166 | # get P(Zk, Zi')
167 | p_weak = nn.functional.softmax(logits, dim=-1)
168 | loss_ddm = 0
169 | for img_s in img_stronger_aug_list:
170 | q_s = self.encoder_q(img_s)
171 | q_s = nn.functional.normalize(q_s, dim=1)
172 | # compute logits using the same set of code above
173 |
174 | l_pos_stronger_aug = torch.einsum('nc,nc->n', [q_s, k]).unsqueeze(-1)
175 | # negative logits: NxK
176 | l_neg_stronger_aug = torch.einsum('nc,ck->nk', [q_s, self.queue.clone().detach()])
177 |
178 | # logits: Nx(1+K)
179 | logits_s = torch.cat([l_pos_stronger_aug, l_neg_stronger_aug], dim=1)
180 | logits_s /= self.T
181 |
182 | # compute nll loss below as -P(q, k) * log(P(q_s, k))
183 | log_p_s = nn.functional.log_softmax(logits_s, dim=-1)
184 |
185 | nll = -1.0 * torch.einsum('nk,nk->n', [p_weak, log_p_s])
186 | loss_ddm = loss_ddm + torch.mean(nll) # average over the batch dimension
187 |
188 | loss = loss_contrastive + self.ratio * loss_ddm
189 |
190 | #losses['loss'] = loss
191 |
192 | # dequeue and enqueue
193 | self._dequeue_and_enqueue(k)
194 |
195 | #return logits, labels
196 | return logits, labels, loss
197 |
198 |
199 | # utils
200 | @torch.no_grad()
201 | def concat_all_gather(tensor):
202 | """
203 | Performs all_gather operation on the provided tensors.
204 | *** Warning ***: torch.distributed.all_gather has no gradient.
205 | """
206 | tensors_gather = [torch.ones_like(tensor)
207 | for _ in range(torch.distributed.get_world_size())]
208 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
209 |
210 | output = torch.cat(tensors_gather, dim=0)
211 | return output
212 |
--------------------------------------------------------------------------------
/clsa/clsa_augs.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | import torch
8 | from torchvision.transforms.transforms import Compose
9 |
10 | random_mirror = True
11 |
12 |
13 | def ShearX(img, v): # [-0.3, 0.3]
14 | assert -0.3 <= v <= 0.3
15 | if random_mirror and random.random() > 0.5:
16 | v = -v
17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 |
19 |
20 | def ShearY(img, v): # [-0.3, 0.3]
21 | assert -0.3 <= v <= 0.3
22 | if random_mirror and random.random() > 0.5:
23 | v = -v
24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 |
26 |
27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.30, 0.30]
28 | assert -0.30 <= v <= 0.30
29 | if random_mirror and random.random() > 0.5:
30 | v = -v
31 | v = v * img.size[0]
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 |
35 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.30, 0.30]
36 | assert -0.30 <= v <= 0.30
37 | if random_mirror and random.random() > 0.5:
38 | v = -v
39 | v = v * img.size[1]
40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
41 |
42 |
43 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
44 | assert 0 <= v <= 10
45 | if random.random() > 0.5:
46 | v = -v
47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
48 |
49 |
50 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 | assert 0 <= v <= 10
52 | if random.random() > 0.5:
53 | v = -v
54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 |
56 |
57 | def Rotate(img, v): # [-30, 30]
58 | assert -30 <= v <= 30
59 | if random_mirror and random.random() > 0.5:
60 | v = -v
61 | return img.rotate(v)
62 |
63 |
64 | def AutoContrast(img, _):
65 | return PIL.ImageOps.autocontrast(img)
66 |
67 |
68 | def Invert(img, _):
69 | return PIL.ImageOps.invert(img)
70 |
71 |
72 | def Equalize(img, _):
73 | return PIL.ImageOps.equalize(img)
74 |
75 |
76 | def Flip(img, _): # not from the paper
77 | return PIL.ImageOps.mirror(img)
78 |
79 |
80 | def Solarize(img, v): # [0, 256]
81 | assert 0 <= v <= 256
82 | return PIL.ImageOps.solarize(img, v)
83 |
84 |
85 | def Posterize(img, v): # [4, 8]
86 | assert 4 <= v <= 8
87 | v = int(v)
88 | return PIL.ImageOps.posterize(img, v)
89 |
90 |
91 | def Posterize2(img, v): # [0, 4]
92 | assert 0 <= v <= 4
93 | v = int(v)
94 | return PIL.ImageOps.posterize(img, v)
95 |
96 | # for blow aug. The mag=1.0 gives the original image
97 | def Contrast(img, v): # [0.05,1.95]
98 | assert 0.05 <= v <= 1.95
99 | return PIL.ImageEnhance.Contrast(img).enhance(v)
100 |
101 |
102 | def Color(img, v): # [0.05,1.95]
103 | assert 0.05 <= v <= 1.95
104 | return PIL.ImageEnhance.Color(img).enhance(v)
105 |
106 |
107 | def Brightness(img, v): # [0.05,1.95]
108 | assert 0.05 <= v <= 1.95
109 | return PIL.ImageEnhance.Brightness(img).enhance(v)
110 |
111 |
112 | def Sharpness(img, v): # [0.05,1.95]
113 | assert 0.05 <= v <= 1.95
114 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
115 |
116 |
117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
118 | assert 0.0 <= v <= 0.2
119 | if v <= 0.:
120 | return img
121 |
122 | v = v * img.size[0]
123 | return CutoutAbs(img, v)
124 |
125 |
126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
127 | # assert 0 <= v <= 20
128 | if v < 0:
129 | return img
130 | w, h = img.size
131 | x0 = np.random.uniform(w)
132 | y0 = np.random.uniform(h)
133 |
134 | x0 = int(max(0, x0 - v / 2.))
135 | y0 = int(max(0, y0 - v / 2.))
136 | x1 = min(w, x0 + v)
137 | y1 = min(h, y0 + v)
138 |
139 | xy = (x0, y0, x1, y1)
140 | color = (125, 123, 114)
141 | # color = (0, 0, 0)
142 | img = img.copy()
143 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
144 | return img
145 |
146 |
147 | def SamplePairing(imgs): # [0, 0.4]
148 | def f(img1, v):
149 | i = np.random.choice(len(imgs))
150 | img2 = PIL.Image.fromarray(imgs[i])
151 | return PIL.Image.blend(img1, img2, v)
152 |
153 | return f
154 |
155 |
156 | def augment_list():
157 | # 14 augs and their magnitude range
158 | l = [
159 | (ShearX, -0.3, 0.3), # 0
160 | (ShearY, -0.3, 0.3), # 1
161 | (TranslateX, -0.3, 0.3), # 2
162 | (TranslateY, -0.3, 0.3), # 3
163 | (Rotate, -30, 30), # 4
164 | (AutoContrast, 0, 1), # 5
165 | (Invert, 0, 1), # 6
166 | (Equalize, 0, 1), # 7
167 | (Solarize, 0, 256), # 8
168 | (Posterize, 4, 8), # 9
169 | (Contrast, 0.05, 1.95), # 10
170 | (Color, 0.05, 1.95), # 11
171 | (Brightness, 0.05, 1.95), # 12
172 | (Sharpness, 0.05, 1.95), # 13
173 | ]
174 |
175 | return l
176 |
177 |
178 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
179 |
180 | def get_augment(name):
181 | return augment_dict[name]
182 |
183 | def apply_augment_with_rand_mag(img:PIL.Image, name:str) -> PIL.Image:
184 | augment_fn, low, high = get_augment(name)
185 | mag = np.random.uniform(low, high)
186 | return augment_fn(img.copy(), mag)
187 |
188 |
189 | class CLSAAug(object):
190 |
191 | def __init__(self, num_of_times=5):
192 | '''
193 | params: num_of_times: How many times the augment is repeated
194 | '''
195 | self.num_of_times = num_of_times
196 | self.aug_names = list(augment_dict.keys())
197 |
198 | print('Augmentation List:')
199 | for aug_name in self.aug_names:
200 | print('{} with magnitude of {} ~ {}'.format(aug_name, augment_dict[aug_name][1], augment_dict[aug_name][2]))
201 |
202 |
203 | def __call__(self, img):
204 | for i in range(self.num_of_times):
205 | if np.random.rand() > 0.5:
206 | aug_name = random.choice(self.aug_names)
207 | img = apply_augment_with_rand_mag(img, aug_name)
208 |
209 | return img
210 |
211 |
--------------------------------------------------------------------------------
/clsa/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from PIL import ImageFilter
3 | import random
4 | import torchvision.transforms as transforms
5 |
6 |
7 | class CALSMultiResolutionTransform(object):
8 | def __init__(self, base_transform, stronger_transfrom, num_res=5):
9 | '''
10 | Note: RandomResizedCrop should be includeed in stronger_transfrom
11 | '''
12 | resolutions = [96, 128, 160, 192, 224]
13 |
14 | self.res = resolutions[:num_res]
15 | self.resize_crop_ops = [transforms.RandomResizedCrop(res, scale=(0.2, 1.)) for res in self.res]
16 | self.num_res = num_res
17 |
18 | self.base_transform = base_transform
19 | self.stronger_transfrom = stronger_transfrom
20 |
21 | def __call__(self, x):
22 | q = self.base_transform(x)
23 | k = self.base_transform(x)
24 |
25 | q_stronger_augs = []
26 | for resize_crop_op in self.resize_crop_ops:
27 | q_s = self.stronger_transfrom(resize_crop_op(x))
28 | q_stronger_augs.append(q_s)
29 |
30 | return [q, k, q_stronger_augs]
31 |
32 |
33 |
34 | class GaussianBlur(object):
35 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
36 |
37 | def __init__(self, sigma=[.1, 2.]):
38 | self.sigma = sigma
39 |
40 | def __call__(self, x):
41 | sigma = random.uniform(self.sigma[0], self.sigma[1])
42 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
43 | return x
44 |
--------------------------------------------------------------------------------
/detection/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## MoCo: Transferring to Detection
3 |
4 | The `train_net.py` script reproduces the object detection experiments on Pascal VOC and COCO.
5 |
6 | ### Instruction
7 |
8 | 1. Install [detectron2](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md).
9 |
10 | 1. Convert a pre-trained MoCo model to detectron2's format:
11 | ```
12 | python3 convert-pretrain-to-detectron2.py input.pth.tar output.pkl
13 | ```
14 |
15 | 1. Put dataset under "./datasets" directory,
16 | following the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets)
17 | requried by detectron2.
18 |
19 | 1. Run training:
20 | ```
21 | python train_net.py --config-file configs/pascal_voc_R_50_C4_24k_moco.yaml \
22 | --num-gpus 8 MODEL.WEIGHTS ./output.pkl
23 | ```
24 |
25 | ### Results
26 |
27 | Below are the results on Pascal VOC 2007 test, fine-tuned on 2007+2012 trainval for 24k iterations using Faster R-CNN with a R50-C4 backbone:
28 |
29 |
30 |
31 |
32 | pretrain |
33 | AP50 |
34 | AP |
35 | AP75 |
36 |
37 | ImageNet-1M, supervised |
38 | 81.3 |
39 | 53.5 |
40 | 58.8 |
41 |
42 | ImageNet-1M, MoCo v1, 200ep |
43 | 81.5 |
44 | 55.9 |
45 | 62.6 |
46 |
47 |
48 | ImageNet-1M, MoCo v2, 200ep |
49 | 82.4 |
50 | 57.0 |
51 | 63.6 |
52 |
53 |
54 | ImageNet-1M, MoCo v2, 800ep |
55 | 82.5 |
56 | 57.4 |
57 | 64.0 |
58 |
59 |
60 |
61 | ***Note:*** These results are means of 5 trials. Variation on Pascal VOC is large: the std of AP50, AP, AP75 is expected to be 0.2, 0.2, 0.4 in most cases. We recommend to run 5 trials and compute means.
62 |
--------------------------------------------------------------------------------
/detection/configs/Base-RCNN-C4-BN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN"
3 | RPN:
4 | PRE_NMS_TOPK_TEST: 6000
5 | POST_NMS_TOPK_TEST: 1000
6 | ROI_HEADS:
7 | NAME: "Res5ROIHeadsExtraNorm"
8 | BACKBONE:
9 | FREEZE_AT: 0
10 | RESNETS:
11 | NORM: "SyncBN"
12 | TEST:
13 | PRECISE_BN:
14 | ENABLED: True
15 | SOLVER:
16 | IMS_PER_BATCH: 16
17 | BASE_LR: 0.02
18 |
--------------------------------------------------------------------------------
/detection/configs/coco_R_50_C4_2x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-C4-BN.yaml"
2 | MODEL:
3 | MASK_ON: True
4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
5 | INPUT:
6 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
7 | MIN_SIZE_TEST: 800
8 | DATASETS:
9 | TRAIN: ("coco_2017_train",)
10 | TEST: ("coco_2017_val",)
11 | SOLVER:
12 | STEPS: (120000, 160000)
13 | MAX_ITER: 180000
14 |
--------------------------------------------------------------------------------
/detection/configs/coco_R_50_C4_2x_moco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "coco_R_50_C4_2x.yaml"
2 | MODEL:
3 | PIXEL_MEAN: [123.675, 116.280, 103.530]
4 | PIXEL_STD: [58.395, 57.120, 57.375]
5 | WEIGHTS: "See Instructions"
6 | RESNETS:
7 | STRIDE_IN_1X1: False
8 | INPUT:
9 | FORMAT: "RGB"
10 |
--------------------------------------------------------------------------------
/detection/configs/pascal_voc_R_50_C4_24k.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "Base-RCNN-C4-BN.yaml"
2 | MODEL:
3 | MASK_ON: False
4 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
5 | ROI_HEADS:
6 | NUM_CLASSES: 20
7 | INPUT:
8 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
9 | MIN_SIZE_TEST: 800
10 | DATASETS:
11 | TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
12 | TEST: ('voc_2007_test',)
13 | SOLVER:
14 | STEPS: (18000, 22000)
15 | MAX_ITER: 24000
16 | WARMUP_ITERS: 100
17 |
--------------------------------------------------------------------------------
/detection/configs/pascal_voc_R_50_C4_24k_moco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "pascal_voc_R_50_C4_24k.yaml"
2 | MODEL:
3 | PIXEL_MEAN: [123.675, 116.280, 103.530]
4 | PIXEL_STD: [58.395, 57.120, 57.375]
5 | WEIGHTS: "See Instructions"
6 | RESNETS:
7 | STRIDE_IN_1X1: False
8 | INPUT:
9 | FORMAT: "RGB"
10 |
--------------------------------------------------------------------------------
/detection/convert-pretrain-to-detectron2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import pickle as pkl
5 | import sys
6 | import torch
7 |
8 | if __name__ == "__main__":
9 | input = sys.argv[1]
10 |
11 | obj = torch.load(input, map_location="cpu")
12 | obj = obj["state_dict"]
13 |
14 | newmodel = {}
15 | for k, v in obj.items():
16 | if not k.startswith("module.encoder_q."):
17 | continue
18 | old_k = k
19 | k = k.replace("module.encoder_q.", "")
20 | if "layer" not in k:
21 | k = "stem." + k
22 | for t in [1, 2, 3, 4]:
23 | k = k.replace("layer{}".format(t), "res{}".format(t + 1))
24 | for t in [1, 2, 3]:
25 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
26 | k = k.replace("downsample.0", "shortcut")
27 | k = k.replace("downsample.1", "shortcut.norm")
28 | print(old_k, "->", k)
29 | newmodel[k] = v.numpy()
30 |
31 | res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True}
32 |
33 | with open(sys.argv[2], "wb") as f:
34 | pkl.dump(res, f)
35 |
--------------------------------------------------------------------------------
/detection/train_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import os
5 |
6 | from detectron2.checkpoint import DetectionCheckpointer
7 | from detectron2.config import get_cfg
8 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
9 | from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator
10 | from detectron2.layers import get_norm
11 | from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
12 |
13 |
14 | @ROI_HEADS_REGISTRY.register()
15 | class Res5ROIHeadsExtraNorm(Res5ROIHeads):
16 | """
17 | As described in the MOCO paper, there is an extra BN layer
18 | following the res5 stage.
19 | """
20 | def _build_res5_block(self, cfg):
21 | seq, out_channels = super()._build_res5_block(cfg)
22 | norm = cfg.MODEL.RESNETS.NORM
23 | norm = get_norm(norm, out_channels)
24 | seq.add_module("norm", norm)
25 | return seq, out_channels
26 |
27 |
28 | class Trainer(DefaultTrainer):
29 | @classmethod
30 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
31 | if output_folder is None:
32 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
33 | if "coco" in dataset_name:
34 | return COCOEvaluator(dataset_name, cfg, True, output_folder)
35 | else:
36 | assert "voc" in dataset_name
37 | return PascalVOCDetectionEvaluator(dataset_name)
38 |
39 |
40 | def setup(args):
41 | cfg = get_cfg()
42 | cfg.merge_from_file(args.config_file)
43 | cfg.merge_from_list(args.opts)
44 | cfg.freeze()
45 | default_setup(cfg, args)
46 | return cfg
47 |
48 |
49 | def main(args):
50 | cfg = setup(args)
51 |
52 | if args.eval_only:
53 | model = Trainer.build_model(cfg)
54 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
55 | cfg.MODEL.WEIGHTS, resume=args.resume
56 | )
57 | res = Trainer.test(cfg, model)
58 | return res
59 |
60 | trainer = Trainer(cfg)
61 | trainer.resume_or_load(resume=args.resume)
62 | return trainer.train()
63 |
64 |
65 | if __name__ == "__main__":
66 | args = default_argument_parser().parse_args()
67 | print("Command Line Args:", args)
68 | launch(
69 | main,
70 | args.num_gpus,
71 | num_machines=args.num_machines,
72 | machine_rank=args.machine_rank,
73 | dist_url=args.dist_url,
74 | args=(args,),
75 | )
76 |
--------------------------------------------------------------------------------
/main_clsa.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | import argparse
4 | import builtins
5 | import math
6 | import os
7 | import random
8 | import shutil
9 | import time
10 | import warnings
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.parallel
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 | import torch.optim
18 | import torch.multiprocessing as mp
19 | import torch.utils.data
20 | import torch.utils.data.distributed
21 | import torchvision.transforms as transforms
22 | import torchvision.datasets as datasets
23 | import torchvision.models as models
24 |
25 | import moco.loader
26 | import moco.builder
27 |
28 | import clsa.clsa_augs
29 | import clsa.builder
30 | import clsa.loader
31 |
32 |
33 | model_names = sorted(name for name in models.__dict__
34 | if name.islower() and not name.startswith("__")
35 | and callable(models.__dict__[name]))
36 |
37 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
38 | parser.add_argument('data', metavar='DIR',
39 | help='path to dataset')
40 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
41 | choices=model_names,
42 | help='model architecture: ' +
43 | ' | '.join(model_names) +
44 | ' (default: resnet50)')
45 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
46 | help='number of data loading workers (default: 32)')
47 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
48 | help='number of total epochs to run')
49 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
50 | help='manual epoch number (useful on restarts)')
51 | parser.add_argument('-b', '--batch-size', default=256, type=int,
52 | metavar='N',
53 | help='mini-batch size (default: 256), this is the total '
54 | 'batch size of all GPUs on the current node when '
55 | 'using Data Parallel or Distributed Data Parallel')
56 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
57 | metavar='LR', help='initial learning rate', dest='lr')
58 | parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
59 | help='learning rate schedule (when to drop lr by 10x)')
60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
61 | help='momentum of SGD solver')
62 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
63 | metavar='W', help='weight decay (default: 1e-4)',
64 | dest='weight_decay')
65 | parser.add_argument('-p', '--print-freq', default=10, type=int,
66 | metavar='N', help='print frequency (default: 10)')
67 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
68 | help='path to latest checkpoint (default: none)')
69 | parser.add_argument('--world-size', default=-1, type=int,
70 | help='number of nodes for distributed training')
71 | parser.add_argument('--rank', default=-1, type=int,
72 | help='node rank for distributed training')
73 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
74 | help='url used to set up distributed training')
75 | parser.add_argument('--dist-backend', default='nccl', type=str,
76 | help='distributed backend')
77 | parser.add_argument('--seed', default=None, type=int,
78 | help='seed for initializing training. ')
79 | parser.add_argument('--gpu', default=None, type=int,
80 | help='GPU id to use.')
81 | parser.add_argument('--multiprocessing-distributed', action='store_true',
82 | help='Use multi-processing distributed training to launch '
83 | 'N processes per node, which has N GPUs. This is the '
84 | 'fastest way to use PyTorch for either single node or '
85 | 'multi node data parallel training')
86 |
87 | # moco specific configs:
88 | parser.add_argument('--moco-dim', default=2048, type=int,
89 | help='feature dimension (default: 128)')
90 | parser.add_argument('--moco-k', default=65536, type=int,
91 | help='queue size; number of negative keys (default: 65536)')
92 | parser.add_argument('--moco-m', default=0.999, type=float,
93 | help='moco momentum of updating key encoder (default: 0.999)')
94 | parser.add_argument('--moco-t', default=0.2, type=float,
95 | help='softmax temperature (default: 0.07)')
96 |
97 | # options for moco v2
98 | parser.add_argument('--mlp', action='store_true',
99 | help='use mlp head')
100 | parser.add_argument('--aug-plus', action='store_true',
101 | help='use moco v2 data augmentation')
102 | parser.add_argument('--cos', action='store_true',
103 | help='use cosine lr schedule')
104 |
105 | # additional hyper-param for clsa
106 |
107 | parser.add_argument('--ratio', default=1.0, type=float,
108 | help='the reweighing term for ddm loss')
109 | parser.add_argument('--num_res', default=1, type=int,
110 | help='The number of resolutions for stronger augs')
111 |
112 |
113 | def main():
114 | args = parser.parse_args()
115 |
116 | if args.seed is not None:
117 | random.seed(args.seed)
118 | torch.manual_seed(args.seed)
119 | cudnn.deterministic = True
120 | warnings.warn('You have chosen to seed training. '
121 | 'This will turn on the CUDNN deterministic setting, '
122 | 'which can slow down your training considerably! '
123 | 'You may see unexpected behavior when restarting '
124 | 'from checkpoints.')
125 |
126 | if args.gpu is not None:
127 | warnings.warn('You have chosen a specific GPU. This will completely '
128 | 'disable data parallelism.')
129 |
130 | if args.dist_url == "env://" and args.world_size == -1:
131 | args.world_size = int(os.environ["WORLD_SIZE"])
132 |
133 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
134 |
135 | ngpus_per_node = torch.cuda.device_count()
136 | if args.multiprocessing_distributed:
137 | # Since we have ngpus_per_node processes per node, the total world_size
138 | # needs to be adjusted accordingly
139 | args.world_size = ngpus_per_node * args.world_size
140 | # Use torch.multiprocessing.spawn to launch distributed processes: the
141 | # main_worker process function
142 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
143 | else:
144 | # Simply call main_worker function
145 | main_worker(args.gpu, ngpus_per_node, args)
146 |
147 |
148 | def main_worker(gpu, ngpus_per_node, args):
149 | args.gpu = gpu
150 |
151 | # suppress printing if not master
152 | if args.multiprocessing_distributed and args.gpu != 0:
153 | def print_pass(*args):
154 | pass
155 | builtins.print = print_pass
156 |
157 | if args.gpu is not None:
158 | print("Use GPU: {} for training".format(args.gpu))
159 |
160 | if args.distributed:
161 | if args.dist_url == "env://" and args.rank == -1:
162 | args.rank = int(os.environ["RANK"])
163 | if args.multiprocessing_distributed:
164 | # For multiprocessing distributed training, rank needs to be the
165 | # global rank among all the processes
166 | args.rank = args.rank * ngpus_per_node + gpu
167 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
168 | world_size=args.world_size, rank=args.rank)
169 | # create model
170 | print("=> creating model '{}'".format(args.arch))
171 |
172 | model = clsa.builder.CLSA(
173 | models.__dict__[args.arch],
174 | args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, args.ratio
175 | )
176 | print(model)
177 |
178 | if args.distributed:
179 | # For multiprocessing distributed, DistributedDataParallel constructor
180 | # should always set the single device scope, otherwise,
181 | # DistributedDataParallel will use all available devices.
182 | if args.gpu is not None:
183 | torch.cuda.set_device(args.gpu)
184 | model.cuda(args.gpu)
185 | # When using a single GPU per process and per
186 | # DistributedDataParallel, we need to divide the batch size
187 | # ourselves based on the total number of GPUs we have
188 | args.batch_size = int(args.batch_size / ngpus_per_node)
189 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
190 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
191 | else:
192 | model.cuda()
193 | # DistributedDataParallel will divide and allocate batch_size to all
194 | # available GPUs if device_ids are not set
195 | model = torch.nn.parallel.DistributedDataParallel(model)
196 | elif args.gpu is not None:
197 | torch.cuda.set_device(args.gpu)
198 | model = model.cuda(args.gpu)
199 | # comment out the following line for debugging
200 | raise NotImplementedError("Only DistributedDataParallel is supported.")
201 | else:
202 | # AllGather implementation (batch shuffle, queue update, etc.) in
203 | # this code only supports DistributedDataParallel.
204 | raise NotImplementedError("Only DistributedDataParallel is supported.")
205 |
206 | # define loss function (criterion) and optimizer
207 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
208 |
209 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
210 | momentum=args.momentum,
211 | weight_decay=args.weight_decay)
212 |
213 | # optionally resume from a checkpoint
214 | if args.resume:
215 | if os.path.isfile(args.resume):
216 | print("=> loading checkpoint '{}'".format(args.resume))
217 | if args.gpu is None:
218 | checkpoint = torch.load(args.resume)
219 | else:
220 | # Map model to be loaded to specified single gpu.
221 | loc = 'cuda:{}'.format(args.gpu)
222 | checkpoint = torch.load(args.resume, map_location=loc)
223 | args.start_epoch = checkpoint['epoch']
224 | model.load_state_dict(checkpoint['state_dict'])
225 | optimizer.load_state_dict(checkpoint['optimizer'])
226 | print("=> loaded checkpoint '{}' (epoch {})"
227 | .format(args.resume, checkpoint['epoch']))
228 | else:
229 | print("=> no checkpoint found at '{}'".format(args.resume))
230 |
231 | cudnn.benchmark = True
232 |
233 | # Data loading code
234 | traindir = os.path.join(args.data, 'train')
235 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
236 | std=[0.229, 0.224, 0.225])
237 | if args.aug_plus:
238 | # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
239 | augmentation = [
240 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
241 | transforms.RandomApply([
242 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
243 | ], p=0.8),
244 | transforms.RandomGrayscale(p=0.2),
245 | transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5),
246 | transforms.RandomHorizontalFlip(),
247 | transforms.ToTensor(),
248 | normalize
249 | ]
250 | else:
251 | # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
252 | augmentation = [
253 | transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
254 | transforms.RandomGrayscale(p=0.2),
255 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
256 | transforms.RandomHorizontalFlip(),
257 | transforms.ToTensor(),
258 | normalize
259 | ]
260 | # add two additional lines to construct CLSA data loader with mutli-resolutions
261 | augmentation = transforms.Compose(augmentation)
262 |
263 | stronger_aug = clsa.clsa_augs.CLSAAug(num_of_times=5) # num of repetive times for randaug
264 | stronger_aug = transforms.Compose([stronger_aug, transforms.ToTensor(), normalize])
265 | train_dataset = datasets.ImageFolder(
266 | traindir,
267 | clsa.loader.CALSMultiResolutionTransform(base_transform=augmentation,
268 | stronger_transfrom=stronger_aug, num_res=args.num_res))
269 |
270 | if args.distributed:
271 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
272 | else:
273 | train_sampler = None
274 |
275 | train_loader = torch.utils.data.DataLoader(
276 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
277 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
278 |
279 | for epoch in range(args.start_epoch, args.epochs):
280 | if args.distributed:
281 | train_sampler.set_epoch(epoch)
282 | adjust_learning_rate(optimizer, epoch, args)
283 |
284 | # train for one epoch
285 | train(train_loader, model, criterion, optimizer, epoch, args)
286 |
287 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
288 | and args.rank % ngpus_per_node == 0):
289 | save_checkpoint({
290 | 'epoch': epoch + 1,
291 | 'arch': args.arch,
292 | 'state_dict': model.state_dict(),
293 | 'optimizer' : optimizer.state_dict(),
294 | }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))
295 |
296 |
297 | def train(train_loader, model, criterion, optimizer, epoch, args):
298 | batch_time = AverageMeter('Time', ':6.3f')
299 | data_time = AverageMeter('Data', ':6.3f')
300 | losses = AverageMeter('Loss', ':.4e')
301 | top1 = AverageMeter('Acc@1', ':6.2f')
302 | top5 = AverageMeter('Acc@5', ':6.2f')
303 | progress = ProgressMeter(
304 | len(train_loader),
305 | [batch_time, data_time, losses, top1, top5],
306 | prefix="Epoch: [{}]".format(epoch))
307 |
308 | # switch to train mode
309 | model.train()
310 |
311 | end = time.time()
312 | for i, (images, _) in enumerate(train_loader):
313 | # measure data loading time
314 | data_time.update(time.time() - end)
315 |
316 | if args.gpu is not None:
317 | images[0] = images[0].cuda(args.gpu, non_blocking=True)
318 | images[1] = images[1].cuda(args.gpu, non_blocking=True)
319 |
320 | # compute output
321 | im_q, im_k, img_stronger_aug_list = images[0], images[1], images[2]
322 | output, target, loss = model(im_q, im_k, img_stronger_aug_list)
323 |
324 | #loss = criterion(output, target)
325 |
326 | # acc1/acc5 are (K+1)-way contrast classifier accuracy
327 | # measure accuracy and record loss
328 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
329 | losses.update(loss.item(), images[0].size(0))
330 | top1.update(acc1[0], images[0].size(0))
331 | top5.update(acc5[0], images[0].size(0))
332 |
333 | # compute gradient and do SGD step
334 | optimizer.zero_grad()
335 | loss.backward()
336 | optimizer.step()
337 |
338 | # measure elapsed time
339 | batch_time.update(time.time() - end)
340 | end = time.time()
341 |
342 | if i % args.print_freq == 0:
343 | progress.display(i)
344 |
345 |
346 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
347 | torch.save(state, filename)
348 | if is_best:
349 | shutil.copyfile(filename, 'model_best.pth.tar')
350 |
351 |
352 | class AverageMeter(object):
353 | """Computes and stores the average and current value"""
354 | def __init__(self, name, fmt=':f'):
355 | self.name = name
356 | self.fmt = fmt
357 | self.reset()
358 |
359 | def reset(self):
360 | self.val = 0
361 | self.avg = 0
362 | self.sum = 0
363 | self.count = 0
364 |
365 | def update(self, val, n=1):
366 | self.val = val
367 | self.sum += val * n
368 | self.count += n
369 | self.avg = self.sum / self.count
370 |
371 | def __str__(self):
372 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
373 | return fmtstr.format(**self.__dict__)
374 |
375 |
376 | class ProgressMeter(object):
377 | def __init__(self, num_batches, meters, prefix=""):
378 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
379 | self.meters = meters
380 | self.prefix = prefix
381 |
382 | def display(self, batch):
383 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
384 | entries += [str(meter) for meter in self.meters]
385 | print('\t'.join(entries))
386 |
387 | def _get_batch_fmtstr(self, num_batches):
388 | num_digits = len(str(num_batches // 1))
389 | fmt = '{:' + str(num_digits) + 'd}'
390 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
391 |
392 |
393 | def adjust_learning_rate(optimizer, epoch, args):
394 | """Decay the learning rate based on schedule"""
395 | lr = args.lr
396 | if args.cos: # cosine lr schedule
397 | lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
398 | else: # stepwise lr schedule
399 | for milestone in args.schedule:
400 | lr *= 0.1 if epoch >= milestone else 1.
401 | for param_group in optimizer.param_groups:
402 | param_group['lr'] = lr
403 |
404 |
405 | def accuracy(output, target, topk=(1,)):
406 | """Computes the accuracy over the k top predictions for the specified values of k"""
407 | with torch.no_grad():
408 | maxk = max(topk)
409 | batch_size = target.size(0)
410 |
411 | _, pred = output.topk(maxk, 1, True, True)
412 | pred = pred.t()
413 | correct = pred.eq(target.view(1, -1).expand_as(pred))
414 |
415 | res = []
416 | for k in topk:
417 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
418 | res.append(correct_k.mul_(100.0 / batch_size))
419 | return res
420 |
421 |
422 | if __name__ == '__main__':
423 | main()
424 |
--------------------------------------------------------------------------------
/main_lincls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | import argparse
4 | import builtins
5 | import os
6 | import random
7 | import shutil
8 | import time
9 | import warnings
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.backends.cudnn as cudnn
15 | import torch.distributed as dist
16 | import torch.optim
17 | import torch.multiprocessing as mp
18 | import torch.utils.data
19 | import torch.utils.data.distributed
20 | import torchvision.transforms as transforms
21 | import torchvision.datasets as datasets
22 | import torchvision.models as models
23 |
24 | model_names = sorted(name for name in models.__dict__
25 | if name.islower() and not name.startswith("__")
26 | and callable(models.__dict__[name]))
27 |
28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
29 | parser.add_argument('data', metavar='DIR',
30 | help='path to dataset')
31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
32 | choices=model_names,
33 | help='model architecture: ' +
34 | ' | '.join(model_names) +
35 | ' (default: resnet50)')
36 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
37 | help='number of data loading workers (default: 32)')
38 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
39 | help='number of total epochs to run')
40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
41 | help='manual epoch number (useful on restarts)')
42 | parser.add_argument('-b', '--batch-size', default=256, type=int,
43 | metavar='N',
44 | help='mini-batch size (default: 256), this is the total '
45 | 'batch size of all GPUs on the current node when '
46 | 'using Data Parallel or Distributed Data Parallel')
47 | parser.add_argument('--lr', '--learning-rate', default=30., type=float,
48 | metavar='LR', help='initial learning rate', dest='lr')
49 | parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
50 | help='learning rate schedule (when to drop lr by a ratio)')
51 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
52 | help='momentum')
53 | parser.add_argument('--wd', '--weight-decay', default=0., type=float,
54 | metavar='W', help='weight decay (default: 0.)',
55 | dest='weight_decay')
56 | parser.add_argument('-p', '--print-freq', default=10, type=int,
57 | metavar='N', help='print frequency (default: 10)')
58 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
59 | help='path to latest checkpoint (default: none)')
60 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
61 | help='evaluate model on validation set')
62 | parser.add_argument('--world-size', default=-1, type=int,
63 | help='number of nodes for distributed training')
64 | parser.add_argument('--rank', default=-1, type=int,
65 | help='node rank for distributed training')
66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
67 | help='url used to set up distributed training')
68 | parser.add_argument('--dist-backend', default='nccl', type=str,
69 | help='distributed backend')
70 | parser.add_argument('--seed', default=None, type=int,
71 | help='seed for initializing training. ')
72 | parser.add_argument('--gpu', default=None, type=int,
73 | help='GPU id to use.')
74 | parser.add_argument('--multiprocessing-distributed', action='store_true',
75 | help='Use multi-processing distributed training to launch '
76 | 'N processes per node, which has N GPUs. This is the '
77 | 'fastest way to use PyTorch for either single node or '
78 | 'multi node data parallel training')
79 |
80 | parser.add_argument('--pretrained', default='', type=str,
81 | help='path to moco pretrained checkpoint')
82 |
83 | best_acc1 = 0
84 |
85 |
86 | def main():
87 | args = parser.parse_args()
88 |
89 | if args.seed is not None:
90 | random.seed(args.seed)
91 | torch.manual_seed(args.seed)
92 | cudnn.deterministic = True
93 | warnings.warn('You have chosen to seed training. '
94 | 'This will turn on the CUDNN deterministic setting, '
95 | 'which can slow down your training considerably! '
96 | 'You may see unexpected behavior when restarting '
97 | 'from checkpoints.')
98 |
99 | if args.gpu is not None:
100 | warnings.warn('You have chosen a specific GPU. This will completely '
101 | 'disable data parallelism.')
102 |
103 | if args.dist_url == "env://" and args.world_size == -1:
104 | args.world_size = int(os.environ["WORLD_SIZE"])
105 |
106 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
107 |
108 | ngpus_per_node = torch.cuda.device_count()
109 | if args.multiprocessing_distributed:
110 | # Since we have ngpus_per_node processes per node, the total world_size
111 | # needs to be adjusted accordingly
112 | args.world_size = ngpus_per_node * args.world_size
113 | # Use torch.multiprocessing.spawn to launch distributed processes: the
114 | # main_worker process function
115 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
116 | else:
117 | # Simply call main_worker function
118 | main_worker(args.gpu, ngpus_per_node, args)
119 |
120 |
121 | def main_worker(gpu, ngpus_per_node, args):
122 | global best_acc1
123 | args.gpu = gpu
124 |
125 | # suppress printing if not master
126 | if args.multiprocessing_distributed and args.gpu != 0:
127 | def print_pass(*args):
128 | pass
129 | builtins.print = print_pass
130 |
131 | if args.gpu is not None:
132 | print("Use GPU: {} for training".format(args.gpu))
133 |
134 | if args.distributed:
135 | if args.dist_url == "env://" and args.rank == -1:
136 | args.rank = int(os.environ["RANK"])
137 | if args.multiprocessing_distributed:
138 | # For multiprocessing distributed training, rank needs to be the
139 | # global rank among all the processes
140 | args.rank = args.rank * ngpus_per_node + gpu
141 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
142 | world_size=args.world_size, rank=args.rank)
143 | # create model
144 | print("=> creating model '{}'".format(args.arch))
145 | model = models.__dict__[args.arch]()
146 |
147 | # freeze all layers but the last fc
148 | for name, param in model.named_parameters():
149 | if name not in ['fc.weight', 'fc.bias']:
150 | param.requires_grad = False
151 | # init the fc layer
152 | model.fc.weight.data.normal_(mean=0.0, std=0.01)
153 | model.fc.bias.data.zero_()
154 |
155 | # load from pre-trained, before DistributedDataParallel constructor
156 | if args.pretrained:
157 | if os.path.isfile(args.pretrained):
158 | print("=> loading checkpoint '{}'".format(args.pretrained))
159 | checkpoint = torch.load(args.pretrained, map_location="cpu")
160 |
161 | # rename moco pre-trained keys
162 | state_dict = checkpoint['state_dict']
163 | for k in list(state_dict.keys()):
164 | # retain only encoder_q up to before the embedding layer
165 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
166 | # remove prefix
167 | state_dict[k[len("module.encoder_q."):]] = state_dict[k]
168 | # delete renamed or unused k
169 | del state_dict[k]
170 |
171 | args.start_epoch = 0
172 | msg = model.load_state_dict(state_dict, strict=False)
173 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
174 |
175 | print("=> loaded pre-trained model '{}'".format(args.pretrained))
176 | else:
177 | print("=> no checkpoint found at '{}'".format(args.pretrained))
178 |
179 | if args.distributed:
180 | # For multiprocessing distributed, DistributedDataParallel constructor
181 | # should always set the single device scope, otherwise,
182 | # DistributedDataParallel will use all available devices.
183 | if args.gpu is not None:
184 | torch.cuda.set_device(args.gpu)
185 | model.cuda(args.gpu)
186 | # When using a single GPU per process and per
187 | # DistributedDataParallel, we need to divide the batch size
188 | # ourselves based on the total number of GPUs we have
189 | args.batch_size = int(args.batch_size / ngpus_per_node)
190 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
191 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
192 | else:
193 | model.cuda()
194 | # DistributedDataParallel will divide and allocate batch_size to all
195 | # available GPUs if device_ids are not set
196 | model = torch.nn.parallel.DistributedDataParallel(model)
197 | elif args.gpu is not None:
198 | torch.cuda.set_device(args.gpu)
199 | model = model.cuda(args.gpu)
200 | else:
201 | # DataParallel will divide and allocate batch_size to all available GPUs
202 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
203 | model.features = torch.nn.DataParallel(model.features)
204 | model.cuda()
205 | else:
206 | model = torch.nn.DataParallel(model).cuda()
207 |
208 | # define loss function (criterion) and optimizer
209 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
210 |
211 | # optimize only the linear classifier
212 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
213 | assert len(parameters) == 2 # fc.weight, fc.bias
214 | optimizer = torch.optim.SGD(parameters, args.lr,
215 | momentum=args.momentum,
216 | weight_decay=args.weight_decay)
217 |
218 | # optionally resume from a checkpoint
219 | if args.resume:
220 | if os.path.isfile(args.resume):
221 | print("=> loading checkpoint '{}'".format(args.resume))
222 | if args.gpu is None:
223 | checkpoint = torch.load(args.resume)
224 | else:
225 | # Map model to be loaded to specified single gpu.
226 | loc = 'cuda:{}'.format(args.gpu)
227 | checkpoint = torch.load(args.resume, map_location=loc)
228 | args.start_epoch = checkpoint['epoch']
229 | best_acc1 = checkpoint['best_acc1']
230 | if args.gpu is not None:
231 | # best_acc1 may be from a checkpoint from a different GPU
232 | best_acc1 = best_acc1.to(args.gpu)
233 | model.load_state_dict(checkpoint['state_dict'])
234 | optimizer.load_state_dict(checkpoint['optimizer'])
235 | print("=> loaded checkpoint '{}' (epoch {})"
236 | .format(args.resume, checkpoint['epoch']))
237 | else:
238 | print("=> no checkpoint found at '{}'".format(args.resume))
239 |
240 | cudnn.benchmark = True
241 |
242 | # Data loading code
243 | traindir = os.path.join(args.data, 'train')
244 | valdir = os.path.join(args.data, 'val')
245 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
246 | std=[0.229, 0.224, 0.225])
247 |
248 | train_dataset = datasets.ImageFolder(
249 | traindir,
250 | transforms.Compose([
251 | transforms.RandomResizedCrop(224),
252 | transforms.RandomHorizontalFlip(),
253 | transforms.ToTensor(),
254 | normalize,
255 | ]))
256 |
257 | if args.distributed:
258 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
259 | else:
260 | train_sampler = None
261 |
262 | train_loader = torch.utils.data.DataLoader(
263 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
264 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
265 |
266 | val_loader = torch.utils.data.DataLoader(
267 | datasets.ImageFolder(valdir, transforms.Compose([
268 | transforms.Resize(256),
269 | transforms.CenterCrop(224),
270 | transforms.ToTensor(),
271 | normalize,
272 | ])),
273 | batch_size=args.batch_size, shuffle=False,
274 | num_workers=args.workers, pin_memory=True)
275 |
276 | if args.evaluate:
277 | validate(val_loader, model, criterion, args)
278 | return
279 |
280 | for epoch in range(args.start_epoch, args.epochs):
281 | if args.distributed:
282 | train_sampler.set_epoch(epoch)
283 | adjust_learning_rate(optimizer, epoch, args)
284 |
285 | # train for one epoch
286 | train(train_loader, model, criterion, optimizer, epoch, args)
287 |
288 | # evaluate on validation set
289 | acc1 = validate(val_loader, model, criterion, args)
290 |
291 | # remember best acc@1 and save checkpoint
292 | is_best = acc1 > best_acc1
293 | best_acc1 = max(acc1, best_acc1)
294 |
295 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
296 | and args.rank % ngpus_per_node == 0):
297 | save_checkpoint({
298 | 'epoch': epoch + 1,
299 | 'arch': args.arch,
300 | 'state_dict': model.state_dict(),
301 | 'best_acc1': best_acc1,
302 | 'optimizer' : optimizer.state_dict(),
303 | }, is_best)
304 | if epoch == args.start_epoch:
305 | sanity_check(model.state_dict(), args.pretrained)
306 |
307 |
308 | def train(train_loader, model, criterion, optimizer, epoch, args):
309 | batch_time = AverageMeter('Time', ':6.3f')
310 | data_time = AverageMeter('Data', ':6.3f')
311 | losses = AverageMeter('Loss', ':.4e')
312 | top1 = AverageMeter('Acc@1', ':6.2f')
313 | top5 = AverageMeter('Acc@5', ':6.2f')
314 | progress = ProgressMeter(
315 | len(train_loader),
316 | [batch_time, data_time, losses, top1, top5],
317 | prefix="Epoch: [{}]".format(epoch))
318 |
319 | """
320 | Switch to eval mode:
321 | Under the protocol of linear classification on frozen features/models,
322 | it is not legitimate to change any part of the pre-trained model.
323 | BatchNorm in train mode may revise running mean/std (even if it receives
324 | no gradient), which are part of the model parameters too.
325 | """
326 | model.eval()
327 |
328 | end = time.time()
329 | for i, (images, target) in enumerate(train_loader):
330 | # measure data loading time
331 | data_time.update(time.time() - end)
332 |
333 | if args.gpu is not None:
334 | images = images.cuda(args.gpu, non_blocking=True)
335 | target = target.cuda(args.gpu, non_blocking=True)
336 |
337 | # compute output
338 | output = model(images)
339 | loss = criterion(output, target)
340 |
341 | # measure accuracy and record loss
342 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
343 | losses.update(loss.item(), images.size(0))
344 | top1.update(acc1[0], images.size(0))
345 | top5.update(acc5[0], images.size(0))
346 |
347 | # compute gradient and do SGD step
348 | optimizer.zero_grad()
349 | loss.backward()
350 | optimizer.step()
351 |
352 | # measure elapsed time
353 | batch_time.update(time.time() - end)
354 | end = time.time()
355 |
356 | if i % args.print_freq == 0:
357 | progress.display(i)
358 |
359 |
360 | def validate(val_loader, model, criterion, args):
361 | batch_time = AverageMeter('Time', ':6.3f')
362 | losses = AverageMeter('Loss', ':.4e')
363 | top1 = AverageMeter('Acc@1', ':6.2f')
364 | top5 = AverageMeter('Acc@5', ':6.2f')
365 | progress = ProgressMeter(
366 | len(val_loader),
367 | [batch_time, losses, top1, top5],
368 | prefix='Test: ')
369 |
370 | # switch to evaluate mode
371 | model.eval()
372 |
373 | with torch.no_grad():
374 | end = time.time()
375 | for i, (images, target) in enumerate(val_loader):
376 | if args.gpu is not None:
377 | images = images.cuda(args.gpu, non_blocking=True)
378 | target = target.cuda(args.gpu, non_blocking=True)
379 |
380 | # compute output
381 | output = model(images)
382 | loss = criterion(output, target)
383 |
384 | # measure accuracy and record loss
385 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
386 | losses.update(loss.item(), images.size(0))
387 | top1.update(acc1[0], images.size(0))
388 | top5.update(acc5[0], images.size(0))
389 |
390 | # measure elapsed time
391 | batch_time.update(time.time() - end)
392 | end = time.time()
393 |
394 | if i % args.print_freq == 0:
395 | progress.display(i)
396 |
397 | # TODO: this should also be done with the ProgressMeter
398 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
399 | .format(top1=top1, top5=top5))
400 |
401 | return top1.avg
402 |
403 |
404 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
405 | torch.save(state, filename)
406 | if is_best:
407 | shutil.copyfile(filename, 'model_best.pth.tar')
408 |
409 |
410 | def sanity_check(state_dict, pretrained_weights):
411 | """
412 | Linear classifier should not change any weights other than the linear layer.
413 | This sanity check asserts nothing wrong happens (e.g., BN stats updated).
414 | """
415 | print("=> loading '{}' for sanity check".format(pretrained_weights))
416 | checkpoint = torch.load(pretrained_weights, map_location="cpu")
417 | state_dict_pre = checkpoint['state_dict']
418 |
419 | for k in list(state_dict.keys()):
420 | # only ignore fc layer
421 | if 'fc.weight' in k or 'fc.bias' in k:
422 | continue
423 |
424 | # name in pretrained model
425 | k_pre = 'module.encoder_q.' + k[len('module.'):] \
426 | if k.startswith('module.') else 'module.encoder_q.' + k
427 |
428 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
429 | '{} is changed in linear classifier training.'.format(k)
430 |
431 | print("=> sanity check passed.")
432 |
433 |
434 | class AverageMeter(object):
435 | """Computes and stores the average and current value"""
436 | def __init__(self, name, fmt=':f'):
437 | self.name = name
438 | self.fmt = fmt
439 | self.reset()
440 |
441 | def reset(self):
442 | self.val = 0
443 | self.avg = 0
444 | self.sum = 0
445 | self.count = 0
446 |
447 | def update(self, val, n=1):
448 | self.val = val
449 | self.sum += val * n
450 | self.count += n
451 | self.avg = self.sum / self.count
452 |
453 | def __str__(self):
454 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
455 | return fmtstr.format(**self.__dict__)
456 |
457 |
458 | class ProgressMeter(object):
459 | def __init__(self, num_batches, meters, prefix=""):
460 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
461 | self.meters = meters
462 | self.prefix = prefix
463 |
464 | def display(self, batch):
465 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
466 | entries += [str(meter) for meter in self.meters]
467 | print('\t'.join(entries))
468 |
469 | def _get_batch_fmtstr(self, num_batches):
470 | num_digits = len(str(num_batches // 1))
471 | fmt = '{:' + str(num_digits) + 'd}'
472 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
473 |
474 |
475 | def adjust_learning_rate(optimizer, epoch, args):
476 | """Decay the learning rate based on schedule"""
477 | lr = args.lr
478 | for milestone in args.schedule:
479 | lr *= 0.1 if epoch >= milestone else 1.
480 | for param_group in optimizer.param_groups:
481 | param_group['lr'] = lr
482 |
483 |
484 | def accuracy(output, target, topk=(1,)):
485 | """Computes the accuracy over the k top predictions for the specified values of k"""
486 | with torch.no_grad():
487 | maxk = max(topk)
488 | batch_size = target.size(0)
489 |
490 | _, pred = output.topk(maxk, 1, True, True)
491 | pred = pred.t()
492 | correct = pred.eq(target.view(1, -1).expand_as(pred))
493 |
494 | res = []
495 | for k in topk:
496 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
497 | res.append(correct_k.mul_(100.0 / batch_size))
498 | return res
499 |
500 |
501 | if __name__ == '__main__':
502 | main()
503 |
--------------------------------------------------------------------------------
/moco/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/moco/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class MoCo(nn.Module):
7 | """
8 | Build a MoCo model with: a query encoder, a key encoder, and a queue
9 | https://arxiv.org/abs/1911.05722
10 | """
11 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
12 | """
13 | dim: feature dimension (default: 128)
14 | K: queue size; number of negative keys (default: 65536)
15 | m: moco momentum of updating key encoder (default: 0.999)
16 | T: softmax temperature (default: 0.07)
17 | """
18 | super(MoCo, self).__init__()
19 |
20 | self.K = K
21 | self.m = m
22 | self.T = T
23 |
24 | # create the encoders
25 | # num_classes is the output fc dimension
26 | self.encoder_q = base_encoder(num_classes=dim)
27 | self.encoder_k = base_encoder(num_classes=dim)
28 |
29 | if mlp: # hack: brute-force replacement
30 | dim_mlp = self.encoder_q.fc.weight.shape[1]
31 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
32 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
33 |
34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
35 | param_k.data.copy_(param_q.data) # initialize
36 | param_k.requires_grad = False # not update by gradient
37 |
38 | # create the queue
39 | self.register_buffer("queue", torch.randn(dim, K))
40 | self.queue = nn.functional.normalize(self.queue, dim=0)
41 |
42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
43 |
44 | @torch.no_grad()
45 | def _momentum_update_key_encoder(self):
46 | """
47 | Momentum update of the key encoder
48 | """
49 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
50 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
51 |
52 | @torch.no_grad()
53 | def _dequeue_and_enqueue(self, keys):
54 | # gather keys before updating queue
55 | keys = concat_all_gather(keys)
56 |
57 | batch_size = keys.shape[0]
58 |
59 | ptr = int(self.queue_ptr)
60 | assert self.K % batch_size == 0 # for simplicity
61 |
62 | # replace the keys at ptr (dequeue and enqueue)
63 | self.queue[:, ptr:ptr + batch_size] = keys.T
64 | ptr = (ptr + batch_size) % self.K # move pointer
65 |
66 | self.queue_ptr[0] = ptr
67 |
68 | @torch.no_grad()
69 | def _batch_shuffle_ddp(self, x):
70 | """
71 | Batch shuffle, for making use of BatchNorm.
72 | *** Only support DistributedDataParallel (DDP) model. ***
73 | """
74 | # gather from all gpus
75 | batch_size_this = x.shape[0]
76 | x_gather = concat_all_gather(x)
77 | batch_size_all = x_gather.shape[0]
78 |
79 | num_gpus = batch_size_all // batch_size_this
80 |
81 | # random shuffle index
82 | idx_shuffle = torch.randperm(batch_size_all).cuda()
83 |
84 | # broadcast to all gpus
85 | torch.distributed.broadcast(idx_shuffle, src=0)
86 |
87 | # index for restoring
88 | idx_unshuffle = torch.argsort(idx_shuffle)
89 |
90 | # shuffled index for this gpu
91 | gpu_idx = torch.distributed.get_rank()
92 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
93 |
94 | return x_gather[idx_this], idx_unshuffle
95 |
96 | @torch.no_grad()
97 | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
98 | """
99 | Undo batch shuffle.
100 | *** Only support DistributedDataParallel (DDP) model. ***
101 | """
102 | # gather from all gpus
103 | batch_size_this = x.shape[0]
104 | x_gather = concat_all_gather(x)
105 | batch_size_all = x_gather.shape[0]
106 |
107 | num_gpus = batch_size_all // batch_size_this
108 |
109 | # restored index for this gpu
110 | gpu_idx = torch.distributed.get_rank()
111 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
112 |
113 | return x_gather[idx_this]
114 |
115 | def forward(self, im_q, im_k):
116 | """
117 | Input:
118 | im_q: a batch of query images
119 | im_k: a batch of key images
120 | Output:
121 | logits, targets
122 | """
123 |
124 | # compute query features
125 | q = self.encoder_q(im_q) # queries: NxC
126 | q = nn.functional.normalize(q, dim=1)
127 |
128 | # compute key features
129 | with torch.no_grad(): # no gradient to keys
130 | self._momentum_update_key_encoder() # update the key encoder
131 |
132 | # shuffle for making use of BN
133 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
134 |
135 | k = self.encoder_k(im_k) # keys: NxC
136 | k = nn.functional.normalize(k, dim=1)
137 |
138 | # undo shuffle
139 | k = self._batch_unshuffle_ddp(k, idx_unshuffle)
140 |
141 | # compute logits
142 | # Einstein sum is more intuitive
143 | # positive logits: Nx1
144 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
145 | # negative logits: NxK
146 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
147 |
148 | # logits: Nx(1+K)
149 | logits = torch.cat([l_pos, l_neg], dim=1)
150 |
151 | # apply temperature
152 | logits /= self.T
153 |
154 | # labels: positive key indicators
155 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
156 |
157 | # dequeue and enqueue
158 | self._dequeue_and_enqueue(k)
159 |
160 | return logits, labels
161 |
162 |
163 | # utils
164 | @torch.no_grad()
165 | def concat_all_gather(tensor):
166 | """
167 | Performs all_gather operation on the provided tensors.
168 | *** Warning ***: torch.distributed.all_gather has no gradient.
169 | """
170 | tensors_gather = [torch.ones_like(tensor)
171 | for _ in range(torch.distributed.get_world_size())]
172 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
173 |
174 | output = torch.cat(tensors_gather, dim=0)
175 | return output
176 |
--------------------------------------------------------------------------------
/moco/loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from PIL import ImageFilter
3 | import random
4 | import torchvision.transforms as transforms
5 |
6 |
7 | class TwoCropsTransform:
8 | """Take two random crops of one image as the query and key."""
9 |
10 | def __init__(self, base_transform):
11 | self.base_transform = base_transform
12 |
13 | def __call__(self, x):
14 | q = self.base_transform(x)
15 | k = self.base_transform(x)
16 | return [q, k]
17 |
18 |
19 | class GaussianBlur(object):
20 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
21 |
22 | def __init__(self, sigma=[.1, 2.]):
23 | self.sigma = sigma
24 |
25 | def __call__(self, x):
26 | sigma = random.uniform(self.sigma[0], self.sigma[1])
27 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
28 | return x
--------------------------------------------------------------------------------