├── .DS_Store ├── .gitignore ├── README.md ├── data └── ucf101.py ├── lib ├── .DS_Store ├── nn │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ ├── comm.cpython-37.pyc │ │ │ └── replicate.cpython-37.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── tests │ │ │ ├── test_numeric_batchnorm.py │ │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ └── parallel │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── data_parallel.cpython-37.pyc │ │ └── data_parallel.py └── radam.py ├── loss.py ├── models ├── .DS_Store ├── __init__.py ├── models.py └── resnet.py ├── pca.png ├── pca.py ├── train.py └── utils ├── .DS_Store └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | data/cifar-10-batches-py/ 4 | *.gz 5 | *.pth 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimCLR_pytorch 2 | PyTorch implementation of arxiv.org/pdf/2002.05709.pdf. 3 | 4 | This is a simple framework for contrastive learning of visual representations that comprises of the following modules: 5 | 1. Sequential data augmentations (Crop, flip, ..., color distortion, Gaussian blur) 6 | 2. A base encoder network (ResNet18 is used by default in this implementation) 7 | 3. A projection head network (just a fully connected layer at the end of the base encoder) 8 | 4. A contrastive loss implemented as SimLoss 9 | 10 | To train, run the following line with your hyperparameters: 11 | 12 | ``` 13 | python train.py --lr 0.01 --tau 0.1 --batch_size 32 14 | ``` 15 | 16 | and to use PCA to visualize features of layer before g(x), run the following: 17 | 18 | ``` 19 | python pca.py --ckpt /PATH/TO/YOUR/MODEL/WEIGHTS 20 | ``` 21 | 22 | The plot will be saved as a .png on in your project's directory. 23 | -------------------------------------------------------------------------------- /data/ucf101.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import math 6 | import functools 7 | import json 8 | import copy 9 | 10 | from utils.utils import load_value_file 11 | 12 | 13 | def pil_loader(path): 14 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 15 | with open(path, 'rb') as f: 16 | with Image.open(f) as img: 17 | return img.convert('RGB') 18 | 19 | 20 | def accimage_loader(path): 21 | try: 22 | import accimage 23 | return accimage.Image(path) 24 | except IOError: 25 | # Potentially a decoding problem, fall back to PIL.Image 26 | return pil_loader(path) 27 | 28 | 29 | def get_default_image_loader(): 30 | from torchvision import get_image_backend 31 | if get_image_backend() == 'accimage': 32 | return accimage_loader 33 | else: 34 | return pil_loader 35 | 36 | 37 | def video_loader(video_dir_path, frame_indices, image_loader): 38 | video = [] 39 | for i in frame_indices: 40 | image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i)) 41 | if os.path.exists(image_path): 42 | video.append(image_loader(image_path)) 43 | else: 44 | return video 45 | 46 | return video 47 | 48 | 49 | def get_default_video_loader(): 50 | image_loader = get_default_image_loader() 51 | return functools.partial(video_loader, image_loader=image_loader) 52 | 53 | 54 | def load_annotation_data(data_file_path): 55 | with open(data_file_path, 'r') as data_file: 56 | return json.load(data_file) 57 | 58 | 59 | def get_class_labels(data): 60 | class_labels_map = {} 61 | index = 0 62 | for class_label in data['labels']: 63 | class_labels_map[class_label] = index 64 | index += 1 65 | return class_labels_map 66 | 67 | 68 | def get_video_names_and_annotations(data, subset): 69 | video_names = [] 70 | annotations = [] 71 | 72 | for key, value in data['database'].items(): 73 | this_subset = value['subset'] 74 | if this_subset == subset: 75 | label = value['annotations']['label'] 76 | video_names.append('{}/{}'.format(label, key)) 77 | annotations.append(value['annotations']) 78 | 79 | return video_names, annotations 80 | 81 | 82 | def make_dataset(root_path, annotation_path, subset, n_samples_for_each_video, 83 | sample_duration): 84 | data = load_annotation_data(annotation_path) 85 | video_names, annotations = get_video_names_and_annotations(data, subset) 86 | class_to_idx = get_class_labels(data) 87 | idx_to_class = {} 88 | for name, label in class_to_idx.items(): 89 | idx_to_class[label] = name 90 | 91 | dataset = [] 92 | for i in range(len(video_names)): 93 | if i % 1000 == 0: 94 | print('dataset loading [{}/{}]'.format(i, len(video_names))) 95 | 96 | video_path = os.path.join(root_path, video_names[i]) 97 | if not os.path.exists(video_path): 98 | continue 99 | 100 | n_frames_file_path = os.path.join(video_path, 'n_frames') 101 | n_frames = int(load_value_file(n_frames_file_path)) 102 | if n_frames <= 0: 103 | continue 104 | 105 | begin_t = 1 106 | end_t = n_frames 107 | sample = { 108 | 'video': video_path, 109 | 'segment': [begin_t, end_t], 110 | 'n_frames': n_frames, 111 | 'video_id': video_names[i].split('/')[1] 112 | } 113 | if len(annotations) != 0: 114 | sample['label'] = class_to_idx[annotations[i]['label']] 115 | else: 116 | sample['label'] = -1 117 | 118 | if n_samples_for_each_video == 1: 119 | sample['frame_indices'] = list(range(1, n_frames + 1)) 120 | dataset.append(sample) 121 | else: 122 | if n_samples_for_each_video > 1: 123 | step = max(1, 124 | math.ceil((n_frames - 1 - sample_duration) / 125 | (n_samples_for_each_video - 1))) 126 | else: 127 | step = sample_duration 128 | for j in range(1, n_frames, step): 129 | sample_j = copy.deepcopy(sample) 130 | sample_j['frame_indices'] = list( 131 | range(j, min(n_frames + 1, j + sample_duration))) 132 | dataset.append(sample_j) 133 | 134 | return dataset, idx_to_class 135 | 136 | 137 | class UCF101(data.Dataset): 138 | """ 139 | Args: 140 | root (string): Root directory path. 141 | spatial_transform (callable, optional): A function/transform that takes in an PIL image 142 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 143 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 144 | and returns a transformed version 145 | target_transform (callable, optional): A function/transform that takes in the 146 | target and transforms it. 147 | loader (callable, optional): A function to load an video given its path and frame indices. 148 | Attributes: 149 | classes (list): List of the class names. 150 | class_to_idx (dict): Dict with items (class_name, class_index). 151 | imgs (list): List of (image path, class_index) tuples 152 | """ 153 | 154 | def __init__(self, 155 | root_path, 156 | annotation_path, 157 | subset, 158 | n_samples_for_each_video=1, 159 | spatial_transform=None, 160 | temporal_transform=None, 161 | target_transform=None, 162 | sample_duration=16, 163 | get_loader=get_default_video_loader): 164 | self.data, self.class_names = make_dataset( 165 | root_path, annotation_path, subset, n_samples_for_each_video, 166 | sample_duration) 167 | 168 | self.spatial_transform = spatial_transform 169 | self.temporal_transform = temporal_transform 170 | self.target_transform = target_transform 171 | self.loader = get_loader() 172 | 173 | def __getitem__(self, index): 174 | """ 175 | Args: 176 | index (int): Index 177 | Returns: 178 | tuple: (image, target) where target is class_index of the target class. 179 | """ 180 | path = self.data[index]['video'] 181 | 182 | frame_indices = self.data[index]['frame_indices'] 183 | if self.temporal_transform is not None: 184 | frame_indices = self.temporal_transform(frame_indices) 185 | clip = self.loader(path, frame_indices) 186 | if self.spatial_transform is not None: 187 | self.spatial_transform.randomize_parameters() 188 | clip = [self.spatial_transform(img) for img in clip] 189 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 190 | 191 | target = self.data[index] 192 | if self.target_transform is not None: 193 | target = self.target_transform(target) 194 | 195 | return clip, target 196 | 197 | def __len__(self): 198 | return len(self.data) 199 | -------------------------------------------------------------------------------- /lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/.DS_Store -------------------------------------------------------------------------------- /lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /lib/nn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/modules/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/modules/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /lib/nn/parallel/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/parallel/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /lib/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(RAdam, self).__init__(params, defaults) 11 | print("Initialized RAdam") 12 | def __setstate__(self, state): 13 | super(RAdam, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('RAdam does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 46 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 47 | 48 | state['step'] += 1 49 | buffered = self.buffer[int(state['step'] % 10)] 50 | if state['step'] == buffered[0]: 51 | N_sma, step_size = buffered[1], buffered[2] 52 | else: 53 | buffered[0] = state['step'] 54 | beta2_t = beta2 ** state['step'] 55 | N_sma_max = 2 / (1 - beta2) - 1 56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 57 | buffered[1] = N_sma 58 | 59 | # more conservative since it's an approximated value 60 | if N_sma >= 5: 61 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 62 | else: 63 | step_size = group['lr'] / (1 - beta1 ** state['step']) 64 | buffered[2] = step_size 65 | 66 | if group['weight_decay'] != 0: 67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 68 | 69 | # more conservative since it's an approximated value 70 | if N_sma >= 5: 71 | denom = exp_avg_sq.sqrt().add_(group['eps']) 72 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 73 | else: 74 | p_data_fp32.add_(-step_size, exp_avg) 75 | 76 | p.data.copy_(p_data_fp32) 77 | 78 | return loss 79 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | import numpy as np 5 | 6 | class SimLoss(nn.Module): 7 | def __init__(self, tau=1, eps=0.000001): 8 | super(SimLoss, self).__init__() 9 | self.tau = tau 10 | self.epsilon = eps 11 | 12 | def forward(self, batch): 13 | batch = self.reorder(batch) 14 | b = torch.mm(batch, batch.transpose(1,0)) 15 | norm = torch.norm(batch, p=2, dim=1).unsqueeze(1) 16 | norm = torch.mm(norm, norm.transpose(1,0)) 17 | den = (torch.ones(batch.shape[0]) - torch.eye(batch.shape[0])).cuda() * b / (norm + self.epsilon) # add eps for numerical stability 18 | den = torch.sum(torch.exp(den/self.tau), dim=1) 19 | 20 | num = torch.zeros(batch.shape[0]).float().cuda() 21 | for k in range(batch.shape[0]): 22 | i, j = (k//2)*2, (k//2)*2+1 23 | num[k] = b[i][j] / norm[i][j] 24 | loss = torch.sum(-torch.log(torch.exp(num/self.tau) / den))/batch.shape[0] 25 | 26 | return loss 27 | 28 | def reorder(self, batch): 29 | b = torch.zeros_like(batch) 30 | for i in range(batch.shape[0]//2): 31 | b[2*i] = batch[i] 32 | b[2*i+1] = batch[i+(batch.shape[0]//2)] 33 | return b 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/models/.DS_Store -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/models/__init__.py -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from . import resnet 8 | 9 | class Identity(nn.Module): 10 | def __init__(self): 11 | super(Identity, self).__init__() 12 | 13 | def forward(self, x): 14 | return x 15 | 16 | class ResidualBlock(nn.Module): 17 | def __init__(self, inchannel, outchannel, stride=1): 18 | super(ResidualBlock, self).__init__() 19 | self.left = nn.Sequential( 20 | nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False), 21 | nn.BatchNorm2d(outchannel), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False), 24 | nn.BatchNorm2d(outchannel) 25 | ) 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or inchannel != outchannel: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(outchannel) 31 | ) 32 | 33 | def forward(self, x): 34 | out = self.left(x) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | class ResNet(nn.Module): 40 | def __init__(self, ResidualBlock, num_classes=128, mode='train'): 41 | super(ResNet, self).__init__() 42 | self.inchannel = 64 43 | self.conv1 = nn.Sequential( 44 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 45 | nn.BatchNorm2d(64), 46 | nn.ReLU(inplace=True), 47 | ) 48 | self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1) 49 | self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2) 50 | self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2) 51 | self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2) 52 | self.fc = nn.Linear(512, num_classes) 53 | self.g = nn.Linear(num_classes, num_classes) 54 | self.mode = mode 55 | 56 | def make_layer(self, block, channels, num_blocks, stride): 57 | strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1] 58 | layers = [] 59 | for stride in strides: 60 | layers.append(block(self.inchannel, channels, stride)) 61 | self.inchannel = channels 62 | return nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | out = self.conv1(x) 66 | out = self.layer1(out) 67 | out = self.layer2(out) 68 | out = self.layer3(out) 69 | out = self.layer4(out) 70 | out = F.avg_pool2d(out, 4) 71 | out = out.view(out.size(0), -1) 72 | out = F.relu(self.fc(out)) 73 | if self.mode == 'train': 74 | out = F.relu(self.g(out)) 75 | return out 76 | 77 | 78 | def ResNet18(mode='train'): 79 | 80 | return ResNet(ResidualBlock, mode=mode) 81 | 82 | class ClassifierModule(nn.Module): 83 | def __init__(self, latent_dim=512, embedding_dim=16, num_class=10): 84 | super(ClassifierModule, self).__init__() 85 | self.num_class = num_class 86 | self.latent_dim = latent_dim 87 | self.embedding_dim = embedding_dim 88 | 89 | self.avgpool = nn.AvgPool2d(2, 2) 90 | self.model = resnet.resnet18(pretrained=False) 91 | self.encoder = nn.Sequential(*list(self.model.children())[:-2]) 92 | self.linear = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim // 3), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(self.latent_dim // 3, self.num_class)) 95 | def forward(self, x): 96 | out = self.encoder(x) 97 | out = torch.flatten(out, 1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | class ClassifierModuleDense(nn.Module): 103 | def __init__(self, latent_dim=8*8*10242, embedding_dim=120, num_class=2): 104 | super(ClassifierModuleDense, self).__init__() 105 | self.num_class = num_class 106 | self.latent_dim = latent_dim 107 | self.embedding_dim = embedding_dim 108 | 109 | self.avgpool = nn.AvgPool2d(2, 2) 110 | self.encoder = torchvision.models.densenet121(pretrained=False) 111 | self.layers = nn.Sequential(self.encoder.features.conv0, 112 | self.encoder.features.norm0, 113 | self.encoder.features.denseblock1, 114 | self.encoder.features.transition1, 115 | self.encoder.features.denseblock2, 116 | self.encoder.features.transition2, 117 | self.encoder.features.denseblock3, 118 | self.encoder.features.transition3, 119 | self.encoder.features.denseblock4, 120 | self.encoder.features.norm5) 121 | 122 | self.linear = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim // 4), 123 | nn.ReLU(inplace=True), 124 | nn.Linear(self.latent_dim // 4, self.num_class)) 125 | 126 | def forward(self, x): 127 | out = self.layers(x) 128 | out = self.avgpool(out) 129 | out = torch.flatten(out, 1) 130 | out = self.linear(out) 131 | return out 132 | 133 | class ConvLSTMCell(nn.Module): 134 | 135 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 136 | """ 137 | Initialize ConvLSTM cell. 138 | 139 | Parameters 140 | ---------- 141 | input_size: (int, int) 142 | Height and width of input tensor as (height, width). 143 | input_dim: int 144 | Number of channels of input tensor. 145 | hidden_dim: int 146 | Number of channels of hidden state. 147 | kernel_size: (int, int) 148 | Size of the convolutional kernel. 149 | bias: bool 150 | Whether or not to add the bias. 151 | """ 152 | 153 | super(ConvLSTMCell, self).__init__() 154 | 155 | self.height, self.width = input_size 156 | self.input_dim = input_dim 157 | self.hidden_dim = hidden_dim 158 | 159 | self.kernel_size = kernel_size 160 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 161 | self.bias = bias 162 | 163 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 164 | out_channels=4 * self.hidden_dim, 165 | kernel_size=self.kernel_size, 166 | padding=self.padding, 167 | bias=self.bias) 168 | 169 | def forward(self, input_tensor, cur_state): 170 | 171 | h_cur, c_cur = cur_state 172 | h_cur = h_cur.cuda() 173 | c_cur = c_cur.cuda() 174 | 175 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 176 | 177 | combined_conv = self.conv(combined) 178 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 179 | i = torch.sigmoid(cc_i) 180 | f = torch.sigmoid(cc_f) 181 | o = torch.sigmoid(cc_o) 182 | g = torch.tanh(cc_g) 183 | 184 | c_next = f * c_cur + i * g 185 | h_next = o * torch.tanh(c_next) 186 | 187 | return h_next, c_next 188 | 189 | def init_hidden(self, batch_size): 190 | return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)),#.cuda(), 191 | Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)))#.cuda()) 192 | 193 | 194 | class ConvLSTM(nn.Module): 195 | 196 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers, 197 | batch_first=False, bias=True, return_all_layers=False): 198 | super(ConvLSTM, self).__init__() 199 | 200 | self._check_kernel_size_consistency(kernel_size) 201 | 202 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 203 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 204 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 205 | if not len(kernel_size) == len(hidden_dim) == num_layers: 206 | raise ValueError('Inconsistent list length.') 207 | 208 | self.height, self.width = input_size 209 | 210 | self.input_dim = input_dim 211 | self.hidden_dim = hidden_dim 212 | self.kernel_size = kernel_size 213 | self.num_layers = num_layers 214 | self.batch_first = batch_first 215 | self.bias = bias 216 | self.return_all_layers = return_all_layers 217 | 218 | cell_list = [] 219 | for i in range(0, self.num_layers): 220 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1] 221 | 222 | cell_list.append(ConvLSTMCell(input_size=(self.height, self.width), 223 | input_dim=cur_input_dim, 224 | hidden_dim=self.hidden_dim[i], 225 | kernel_size=self.kernel_size[i], 226 | bias=self.bias)) 227 | 228 | self.cell_list = nn.ModuleList(cell_list) 229 | 230 | def forward(self, input_tensor, hidden_state=None): 231 | """ 232 | 233 | Parameters 234 | ---------- 235 | input_tensor: todo 236 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 237 | hidden_state: todo 238 | None. todo implement stateful 239 | 240 | Returns 241 | ------- 242 | last_state_list, layer_output 243 | """ 244 | if not self.batch_first: 245 | # (t, b, c, h, w) -> (b, t, c, h, w) 246 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 247 | 248 | # Implement stateful ConvLSTM 249 | if hidden_state is not None: 250 | raise NotImplementedError() 251 | else: 252 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 253 | 254 | layer_output_list = [] 255 | last_state_list = [] 256 | 257 | seq_len = input_tensor.size(1) 258 | cur_layer_input = input_tensor 259 | 260 | for layer_idx in range(self.num_layers): 261 | 262 | h, c = hidden_state[layer_idx] 263 | output_inner = [] 264 | for t in range(seq_len): 265 | 266 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 267 | cur_state=[h, c]) 268 | output_inner.append(h) 269 | 270 | layer_output = torch.stack(output_inner, dim=1) 271 | cur_layer_input = layer_output 272 | 273 | layer_output_list.append(layer_output) 274 | last_state_list.append([h, c]) 275 | 276 | if not self.return_all_layers: 277 | layer_output_list = layer_output_list[-1:] 278 | last_state_list = last_state_list[-1:] 279 | 280 | return layer_output_list, last_state_list 281 | 282 | def _init_hidden(self, batch_size): 283 | init_states = [] 284 | for i in range(self.num_layers): 285 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 286 | return init_states 287 | 288 | @staticmethod 289 | def _check_kernel_size_consistency(kernel_size): 290 | if not (isinstance(kernel_size, tuple) or 291 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 292 | raise ValueError('`kernel_size` must be tuple or list of tuples') 293 | 294 | @staticmethod 295 | def _extend_for_multilayer(param, num_layers): 296 | if not isinstance(param, list): 297 | param = [param] * num_layers 298 | return param 299 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = SynchronizedBatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = SynchronizedBatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = SynchronizedBatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = SynchronizedBatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 128 105 | super(ResNet, self).__init__() 106 | self.conv1 = conv3x3(3, 64, stride=2) 107 | self.bn1 = SynchronizedBatchNorm2d(64) 108 | self.relu1 = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(64, 64) 110 | self.bn2 = SynchronizedBatchNorm2d(64) 111 | self.relu2 = nn.ReLU(inplace=True) 112 | self.conv3 = conv3x3(64, 128) 113 | self.bn3 = SynchronizedBatchNorm2d(128) 114 | self.relu3 = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AvgPool2d(7, stride=1) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, SynchronizedBatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | SynchronizedBatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.relu1(self.bn1(self.conv1(x))) 151 | x = self.relu2(self.bn2(self.conv2(x))) 152 | x = self.relu3(self.bn3(self.conv3(x))) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | def resnet18(pretrained=False, **kwargs): 167 | """Constructs a ResNet-18 model. 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | ''' 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(load_url(model_urls['resnet34'])) 187 | return model 188 | ''' 189 | 190 | def resnet50(pretrained=False, **kwargs): 191 | """Constructs a ResNet-50 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 199 | return model 200 | 201 | 202 | def resnet101(pretrained=False, **kwargs): 203 | """Constructs a ResNet-101 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 211 | return model 212 | 213 | # def resnet152(pretrained=False, **kwargs): 214 | # """Constructs a ResNet-152 model. 215 | # 216 | # Args: 217 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | # """ 219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 220 | # if pretrained: 221 | # model.load_state_dict(load_url(model_urls['resnet152'])) 222 | # return model 223 | 224 | def load_url(url, model_dir='./pretrained', map_location=None): 225 | if not os.path.exists(model_dir): 226 | os.makedirs(model_dir) 227 | filename = url.split('/')[-1] 228 | cached_file = os.path.join(model_dir, filename) 229 | if not os.path.exists(cached_file): 230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 231 | urlretrieve(url, cached_file) 232 | return torch.load(cached_file, map_location=map_location) 233 | -------------------------------------------------------------------------------- /pca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/pca.png -------------------------------------------------------------------------------- /pca.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | import numpy as np 4 | import pandas as pd 5 | import argparse 6 | from sklearn.datasets import fetch_mldata 7 | from sklearn.decomposition import PCA 8 | from sklearn.manifold import TSNE 9 | 10 | import matplotlib.pyplot as plt 11 | from mpl_toolkits.mplot3d import Axes3D 12 | import seaborn as sns 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torchvision 18 | import torchvision.transforms as transforms 19 | import torch.optim as optim 20 | import torch.utils.data as data 21 | from models.models import ResNet18, ConvLSTM, ClassifierModule, ClassifierModuleDense 22 | from utils import utils 23 | from lib import radam 24 | from loss import SimLoss 25 | 26 | def dataloader(args): 27 | if args.dataset.lower() == 'cifar10': 28 | transform = transforms.Compose( 29 | [transforms.ToTensor()]) 30 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 31 | 32 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 33 | download=True, transform=transform) 34 | testloader = torch.utils.data.DataLoader(testset, batch_size=1, 35 | shuffle=False, num_workers=2) 36 | 37 | 38 | return testloader 39 | 40 | def test(net, testloader, args): 41 | net.eval() 42 | a = torch.empty(testloader.__len__(), 128) 43 | l = torch.empty(testloader.__len__()) 44 | with torch.no_grad(): 45 | for i, data in enumerate(testloader, 0): 46 | inputs, labels = data 47 | inputs = inputs.cuda() 48 | outputs = net(inputs) 49 | a[i] = outputs 50 | l[i] = labels 51 | 52 | print(a.shape) 53 | print(l.shape) 54 | 55 | feat_cols = [ 'pixel'+str(i) for i in range(a.shape[1]) ] 56 | df = pd.DataFrame(a,columns=feat_cols) 57 | df['y'] = l 58 | df['label'] = df['y'].apply(lambda i: str(i)) 59 | 60 | pca = PCA(n_components=3) 61 | pca_result = pca.fit_transform(df[feat_cols].values) 62 | df['pca-one'] = pca_result[:,0] 63 | df['pca-two'] = pca_result[:,1] 64 | df['pca-three'] = pca_result[:,2] 65 | 66 | rndperm = np.random.permutation(df.shape[0]) 67 | 68 | plt.figure(figsize=(16,10)) 69 | p = sns.scatterplot( 70 | x="pca-one", y="pca-two", 71 | hue="y", 72 | palette=sns.color_palette("hls", 10), 73 | data=df.loc[rndperm,:], 74 | legend="full", 75 | alpha=0.3 76 | ) 77 | p.get_figure().savefig("./pca.png") 78 | 79 | 80 | 81 | def checkpoint(net, args, epoch_num): 82 | print('Saving checkpoints...') 83 | 84 | suffix_latest = 'epoch_{}.pth'.format(epoch_num) 85 | dict_net = net.state_dict() 86 | torch.save(dict_net, 87 | '{}/resnet_{}'.format(args.ckpt, suffix_latest)) 88 | 89 | if __name__ == "__main__": 90 | 91 | parser = argparse.ArgumentParser() 92 | # optimization related arguments 93 | parser.add_argument('--dataset', default='cifar10') 94 | parser.add_argument('--tau', default=0.1, type=float) 95 | parser.add_argument('--ckpt', default="/home/rexma/Desktop/JesseSun/simclr") 96 | args = parser.parse_args() 97 | 98 | print("Input arguments:") 99 | for key, val in vars(args).items(): 100 | print("{:16} {}".format(key, val)) 101 | 102 | args.num_class = 10 if args.dataset.lower() == 'cifar10' else 1000 103 | 104 | testloader = dataloader(args) 105 | 106 | net = ResNet18(mode='test').cuda() 107 | 108 | net.load_state_dict( 109 | torch.load(args.ckpt, map_location=lambda storage, loc: storage), strict=False) 110 | print("Loaded pretrained weights.") 111 | 112 | test(net, testloader, args) 113 | 114 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torch.optim as optim 8 | import torch.utils.data as data 9 | from data.ucf101 import UCF101 10 | from models.models import ResNet18, ConvLSTM, ClassifierModule, ClassifierModuleDense 11 | from utils import utils 12 | from lib import radam 13 | from loss import SimLoss 14 | 15 | s=1 16 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 17 | data_augment = transforms.Compose([transforms.ToPILImage(), 18 | transforms.RandomResizedCrop(32), 19 | transforms.RandomHorizontalFlip(), 20 | transforms.RandomApply([color_jitter], p=0.8), 21 | transforms.RandomGrayscale(p=0.2), 22 | utils.GaussianBlur(), 23 | transforms.ToTensor()]) 24 | 25 | def dataloader(args): 26 | if args.dataset.lower() == 'cifar10': 27 | transform = transforms.Compose( 28 | [transforms.ToTensor()]) 29 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 30 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 31 | download=True, transform=transform) 32 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 33 | shuffle=True, num_workers=2) 34 | 35 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 36 | download=True, transform=transform) 37 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 38 | shuffle=False, num_workers=2) 39 | 40 | 41 | return trainloader, testloader 42 | 43 | def optimizer(net, args): 44 | assert args.optimizer.lower() in ["sgd", "adam", "radam"], "Invalid Optimizer" 45 | 46 | if args.optimizer.lower() == "sgd": 47 | return optim.SGD(net.parameters(), lr=args.lr, momentum=args.beta1, nesterov=args.nesterov) 48 | elif args.optimizer.lower() == "adam": 49 | return optim.Adam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 50 | elif args.optimizer.lower() == "radam": 51 | return radam.RAdam(net.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 52 | 53 | def test(net, epoch, criterion, testloader, args): 54 | net.eval() 55 | with torch.no_grad(): 56 | correct = 0 57 | for i, data in enumerate(testloader, 0): 58 | inputs, labels = data 59 | inputs = inputs.cuda() 60 | labels = labels.cuda() 61 | outputs = net(inputs) 62 | pred = F.softmax(outputs, 1) 63 | _, pred = torch.max(pred, 1) 64 | correct += torch.sum(pred==labels) 65 | print("Test set accuracy: " + str(float(correct)/ float(testloader.__len__()))) 66 | 67 | def train(net, epoch, criterion, optimizer, trainloader, args): 68 | loss_meter = utils.AverageMeter() 69 | net.train() 70 | 71 | for i, data in enumerate(trainloader, 0): 72 | inputs, labels = data 73 | optimizer.zero_grad() 74 | inputs = inputs.cuda() 75 | labels = labels.cuda() 76 | outputs = net(inputs) 77 | loss = criterion(outputs, labels) 78 | loss.backward() 79 | loss_meter.update(loss.item()) 80 | optimizer.step() 81 | 82 | #running_loss += loss.item() 83 | if i % 1 == 0 and i > 0: 84 | print('[Epoch %02d, Minibatch %05d] Loss: %.5f' % 85 | (epoch, i, loss_meter.average())) 86 | #running_loss = 0.0 87 | 88 | def SimCLR(net, epoch, criterion, optimizer, trainloader, args): 89 | loss_meter = utils.AverageMeter() 90 | net.train() 91 | 92 | for i, data in enumerate(trainloader, 0): 93 | b, _ = data 94 | optimizer.zero_grad() 95 | x_1 = torch.zeros_like(b).cuda() 96 | x_2 = torch.zeros_like(b).cuda() 97 | 98 | for idx, x in enumerate(b): 99 | x_1[idx] = data_augment(x) 100 | x_2[idx] = data_augment(x) 101 | #b = b.cuda() 102 | out_1 = net(x_1) 103 | out_2 = net(x_2) 104 | 105 | loss = criterion(torch.cat([out_1, out_2], dim=0)) 106 | loss.backward() 107 | loss_meter.update(loss.item()) 108 | optimizer.step() 109 | 110 | if i % 100 == 0 and i > 0: 111 | print('[Epoch %02d, Minibatch %05d] Loss: %.5f' % 112 | (epoch, i, loss_meter.average())) 113 | 114 | def checkpoint(net, args, epoch_num): 115 | print('Saving checkpoints...') 116 | 117 | suffix_latest = 'epoch_{}.pth'.format(epoch_num) 118 | dict_net = net.state_dict() 119 | torch.save(dict_net, 120 | '{}/resnet_{}'.format(args.ckpt, suffix_latest)) 121 | 122 | if __name__ == "__main__": 123 | 124 | parser = argparse.ArgumentParser() 125 | # optimization related arguments 126 | parser.add_argument('--batch_size', default=4, type=int, 127 | help='input batch size') 128 | parser.add_argument('--epoch', default=100, type=int, 129 | help='epochs to train for') 130 | parser.add_argument('--dataset', default='cifar10') 131 | parser.add_argument('--optimizer', default='sgd', help='optimizer') 132 | parser.add_argument('--lr', default=0.001, type=float, help='LR') 133 | parser.add_argument('--beta1', default=0.9, type=float, 134 | help='momentum for sgd, beta1 for adam') 135 | parser.add_argument('--beta2', default=0.999, type=float) 136 | parser.add_argument('--nesterov', default=False) 137 | parser.add_argument('--tau', default=0.1, type=float) 138 | parser.add_argument('--ckpt', default="/home/rexma/Desktop/JesseSun/simclr") 139 | args = parser.parse_args() 140 | 141 | print("Input arguments:") 142 | for key, val in vars(args).items(): 143 | print("{:16} {}".format(key, val)) 144 | 145 | args.num_class = 10 if args.dataset.lower() == 'cifar10' else 1000 146 | 147 | trainloader, testloader = dataloader(args) 148 | 149 | net = ResNet18().cuda() 150 | criterion = SimLoss(tau=args.tau).cuda() 151 | optimizer = optimizer(net, args) 152 | for epoch in range(1, args.epoch+1): 153 | SimCLR(net, epoch, criterion, optimizer, trainloader, args) 154 | #test(net, epoch, criterion, testloader, args) 155 | if epoch%5==0: 156 | checkpoint(net, args, epoch) 157 | 158 | print("Training completed!") 159 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjesse/pytorch-simclr/2f30807cda92cda8e2d189faf944ef93ae7e8efa/utils/.DS_Store -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import functools 4 | import fnmatch 5 | import numpy as np 6 | import cv2 7 | 8 | class GaussianBlur(object): 9 | 10 | def __init__(self, min=0.1, max=2.0, kernel_size=9): 11 | self.min = min 12 | self.max = max 13 | self.kernel_size = kernel_size 14 | 15 | def __call__(self, sample): 16 | sample = np.array(sample) 17 | 18 | # blur the image with a 50% chance 19 | prob = np.random.random_sample() 20 | 21 | if prob < 0.5: 22 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 23 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 24 | 25 | return sample 26 | 27 | def load_value_file(file_path): 28 | with open(file_path, 'r') as input_file: 29 | value = float(input_file.read().rstrip('\n\r')) 30 | 31 | return value 32 | 33 | def find_recursive(root_dir, ext='.jpg'): 34 | files = [] 35 | for root, dirnames, filenames in os.walk(root_dir): 36 | for filename in fnmatch.filter(filenames, '*' + ext): 37 | files.append(os.path.join(root, filename)) 38 | return files 39 | 40 | 41 | class AverageMeter(object): 42 | """Computes and stores the average and current value""" 43 | def __init__(self): 44 | self.initialized = False 45 | self.val = None 46 | self.avg = None 47 | self.sum = None 48 | self.count = None 49 | 50 | def initialize(self, val, weight): 51 | self.val = val 52 | self.avg = val 53 | self.sum = val * weight 54 | self.count = weight 55 | self.initialized = True 56 | 57 | def update(self, val, weight=1): 58 | if not self.initialized: 59 | self.initialize(val, weight) 60 | else: 61 | self.add(val, weight) 62 | 63 | def add(self, val, weight): 64 | self.val = val 65 | self.sum += val * weight 66 | self.count += weight 67 | self.avg = self.sum / self.count 68 | 69 | def value(self): 70 | return self.val 71 | 72 | def average(self): 73 | return self.avg 74 | 75 | 76 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 77 | ar = np.asanyarray(ar).flatten() 78 | 79 | optional_indices = return_index or return_inverse 80 | optional_returns = optional_indices or return_counts 81 | 82 | if ar.size == 0: 83 | if not optional_returns: 84 | ret = ar 85 | else: 86 | ret = (ar,) 87 | if return_index: 88 | ret += (np.empty(0, np.bool),) 89 | if return_inverse: 90 | ret += (np.empty(0, np.bool),) 91 | if return_counts: 92 | ret += (np.empty(0, np.intp),) 93 | return ret 94 | if optional_indices: 95 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 96 | aux = ar[perm] 97 | else: 98 | ar.sort() 99 | aux = ar 100 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 101 | 102 | if not optional_returns: 103 | ret = aux[flag] 104 | else: 105 | ret = (aux[flag],) 106 | if return_index: 107 | ret += (perm[flag],) 108 | if return_inverse: 109 | iflag = np.cumsum(flag) - 1 110 | inv_idx = np.empty(ar.shape, dtype=np.intp) 111 | inv_idx[perm] = iflag 112 | ret += (inv_idx,) 113 | if return_counts: 114 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 115 | ret += (np.diff(idx),) 116 | return ret 117 | 118 | 119 | def colorEncode(labelmap, colors, mode='BGR'): 120 | labelmap = labelmap.astype('int') 121 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 122 | dtype=np.uint8) 123 | for label in unique(labelmap): 124 | if label < 0: 125 | continue 126 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 127 | np.tile(colors[label], 128 | (labelmap.shape[0], labelmap.shape[1], 1)) 129 | 130 | if mode == 'BGR': 131 | return labelmap_rgb[:, :, ::-1] 132 | else: 133 | return labelmap_rgb 134 | 135 | 136 | def accuracy(preds, label): 137 | valid = (label >= 1) 138 | acc_sum = (valid * (preds == label)).sum() 139 | valid_sum = valid.sum() 140 | acc = float(acc_sum) / (valid_sum + 1e-10) 141 | return acc, valid_sum 142 | 143 | 144 | def intersectionAndUnion(imPred, imLab, numClass): 145 | imPred = np.asarray(imPred).copy() 146 | imLab = np.asarray(imLab).copy() 147 | 148 | imPred += 1 149 | imLab += 1 150 | # Remove classes from unlabeled pixels in gt image. 151 | # We should not penalize detections in unlabeled portions of the image. 152 | imPred = imPred * (imLab > 0) 153 | 154 | # Compute area intersection: 155 | intersection = imPred * (imPred == imLab) 156 | (area_intersection, _) = np.histogram( 157 | intersection, bins=numClass, range=(1, numClass)) 158 | 159 | # Compute area union: 160 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 161 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 162 | area_union = area_pred + area_lab - area_intersection 163 | #print("I: " + str(area_intersection)) 164 | #print("U: " + str(area_union)) 165 | return (area_intersection, area_union) 166 | 167 | 168 | class NotSupportedCliException(Exception): 169 | pass 170 | 171 | 172 | def process_range(xpu, inp): 173 | start, end = map(int, inp) 174 | if start > end: 175 | end, start = start, end 176 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 177 | 178 | 179 | REGEX = [ 180 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 181 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 182 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 183 | functools.partial(process_range, 'gpu')), 184 | (re.compile(r'^(\d+)-(\d+)$'), 185 | functools.partial(process_range, 'gpu')), 186 | ] 187 | 188 | 189 | def parse_devices(input_devices): 190 | 191 | """Parse user's devices input str to standard format. 192 | e.g. [gpu0, gpu1, ...] 193 | 194 | """ 195 | ret = [] 196 | for d in input_devices.split(','): 197 | for regex, func in REGEX: 198 | m = regex.match(d.lower().strip()) 199 | if m: 200 | tmp = func(m.groups()) 201 | # prevent duplicate 202 | for x in tmp: 203 | if x not in ret: 204 | ret.append(x) 205 | break 206 | else: 207 | raise NotSupportedCliException( 208 | 'Can not recognize device: "{}"'.format(d)) 209 | return ret 210 | ''' 211 | def torch_pixelwise_gradient(f, *varargs, **kwargs): 212 | N = len(f.shape) 213 | n = len(varargs) 214 | 215 | if n == 0: 216 | dx = torch.tensor([1.0])*N 217 | elif n == 1: 218 | dx = torch.tensor([varargs[0]])*N 219 | elif n == N: 220 | dx = torch.tensor(varargs) 221 | else: 222 | raise SyntaxError( 223 | "invalid number of arguments") 224 | 225 | edge_order = kwargs.pop('edge_order', 1) 226 | if kwargs: 227 | raise TypeError('"{}" are not valid keyword arguments.'.format( 228 | '", "'.join(kwargs.keys()))) 229 | if edge_order > 2: 230 | raise ValueError("'edge_order' greater than 2 not supported") 231 | 232 | outvals = [] 233 | 234 | slice1 = [slice(None)]*N 235 | slice2 = [slice(None)]*N 236 | slice3 = [slice(None)]*N 237 | slice4 = [slice(None)]*N 238 | 239 | y = f.float() 240 | 241 | for axis in range(N): 242 | if y.shape[axis] < 2: 243 | raise ValueError( 244 | "Shape of array too small to calculate a numerical gradient, " 245 | "at least two elements are required.") 246 | 247 | if y.shape[axis] == 2 or edge_order == 1: 248 | out = f.new(x.size()).float() 249 | 250 | uniform_spacing = 251 | ''' 252 | --------------------------------------------------------------------------------