├── .gitignore ├── README.md ├── create_imagenet_subset.py ├── deepul_helper ├── __init__.py ├── batch_norm.py ├── data.py ├── demos.py ├── lars.py ├── layer_norm.py ├── resnet.py ├── seg_model.py ├── tasks │ ├── __init__.py │ ├── context_encoder.py │ ├── cpc.py │ ├── rotation.py │ └── simclr.py ├── utils.py └── visualize.py ├── environment.yml ├── palette.pkl ├── run ├── run_cifar10_rotation.sh ├── run_cifar10_simclr.sh ├── run_imagenet100_cpc.sh ├── run_imagenet100_rotation.sh └── run_imagenet100_simclr.sh ├── sample_images ├── chrom_ab_demo.png ├── n01537544_19414.JPEG ├── n01768244_3034.JPEG ├── n03297495_2537.JPEG ├── n03297495_3735.JPEG ├── n03372029_42178.JPEG ├── n03372029_46468.JPEG ├── n03476684_24524.JPEG ├── n04372370_39950.JPEG ├── n11939491_52432.JPEG └── sample2.JPEG ├── setup.py ├── train_segmentation.py └── train_self_supervised_task.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | __pycache__ 3 | *.egg-info 4 | results 5 | *.txt 6 | logs/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the repo for CS294-158 self-supervised learning demos. 4 | 5 | # Setting Up 6 | 7 | The conda environment can be created using `environment.yml`. If you run into issues, you can create your own empty environment on Python 3.7.6 with the following packages (feel free to use a different cuda version): 8 | * conda install pytorch=1.4.0 torchvision=0.5.0 cudatooklkit=10.1 -c pytorch 9 | * pip install requests 10 | * pip install opencv-python 11 | * pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 12 | 13 | # Datasets 14 | 15 | The demos have been tested for CIFAR10 and ImageNet subsets. The scripts by default use "imagenet100" for ImageNet. You can recreate this dataset by downloading ImageNet into `data/imagenet`, and running `python create_imagenet_subset.py 100` to create a subsampled 100 class ImageNet. 16 | 17 | # Training 18 | 19 | You can execute the scripts in the `run` folder to train models on different self-supervised tasks. Note that different models may use a different number of GPUS (maximum 4). Running the scripts by default will use ALL gpus available. You can limit the number of GPUs used (e.g. 2) through CUDA_VISIBLE_DEVICES=0,1 ./run/run_cifar10_simclr.sh 20 | 21 | Contrastive Predictive Coding (CPC) is still a work in progress. 22 | -------------------------------------------------------------------------------- /create_imagenet_subset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | import os 4 | import os.path as osp 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from torchvision.datasets import ImageFolder 10 | 11 | 12 | n_classes = int(sys.argv[1]) 13 | print('Creating a subset of ImageNet with {} classes'.format(n_classes)) 14 | 15 | dset_dir = osp.join('data', 'imagenet') 16 | dset = ImageFolder(osp.join(dset_dir, 'train')) 17 | classes = dset.classes 18 | 19 | new_dset_dir = osp.join('data', 'imagenet{}'.format(n_classes)) 20 | classes_subset = np.random.choice(classes, size=n_classes, replace=False) 21 | 22 | os.makedirs(osp.join(new_dset_dir, 'train')) 23 | os.makedirs(osp.join(new_dset_dir, 'val')) 24 | 25 | for c in tqdm(classes_subset): 26 | src = osp.join(dset_dir, 'train', c) 27 | dst = osp.join(new_dset_dir, 'train', c) 28 | shutil.copytree(src, dst) 29 | 30 | src = osp.join(dset_dir, 'val', c) 31 | dst = osp.join(new_dset_dir, 'val', c) 32 | shutil.copytree(src, dst) 33 | -------------------------------------------------------------------------------- /deepul_helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/deepul_helper/__init__.py -------------------------------------------------------------------------------- /deepul_helper/batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | # Standard BatchNorm with custom toggles of both learned scale and bias parameters 7 | # (regular PyTorch batch norm toggles both or none, but not bias / scale individually) 8 | 9 | class _NormBase(nn.Module): 10 | """Common base of _InstanceNorm and _BatchNorm""" 11 | _version = 2 12 | __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', 13 | 'running_mean', 'running_var', 'num_batches_tracked', 14 | 'num_features', 'center', 'scale'] 15 | 16 | def __init__(self, num_features, eps=1e-5, momentum=0.1, center=True, scale=True, 17 | track_running_stats=True): 18 | super(_NormBase, self).__init__() 19 | self.num_features = num_features 20 | self.eps = eps 21 | self.momentum = momentum 22 | self.center = center 23 | self.scale = scale 24 | self.track_running_stats = track_running_stats 25 | if self.scale: 26 | self.weight = nn.Parameter(torch.Tensor(num_features)) 27 | else: 28 | self.register_parameter('weight', None) 29 | if self.center: 30 | self.bias = nn.Parameter(torch.Tensor(num_features)) 31 | else: 32 | self.register_parameter('bias', None) 33 | if self.track_running_stats: 34 | self.register_buffer('running_mean', torch.zeros(num_features)) 35 | self.register_buffer('running_var', torch.ones(num_features)) 36 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 37 | else: 38 | self.register_parameter('running_mean', None) 39 | self.register_parameter('running_var', None) 40 | self.register_parameter('num_batches_tracked', None) 41 | self.reset_parameters() 42 | 43 | def reset_running_stats(self): 44 | if self.track_running_stats: 45 | self.running_mean.zero_() 46 | self.running_var.fill_(1) 47 | self.num_batches_tracked.zero_() 48 | 49 | def reset_parameters(self): 50 | self.reset_running_stats() 51 | if self.scale: 52 | init.ones_(self.weight) 53 | if self.center: 54 | init.zeros_(self.bias) 55 | 56 | def _check_input_dim(self, input): 57 | raise NotImplementedError 58 | 59 | def extra_repr(self): 60 | return '{num_features}, eps={eps}, momentum={momentum}, center={center}, scale={scale}, ' \ 61 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 62 | 63 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 64 | missing_keys, unexpected_keys, error_msgs): 65 | version = local_metadata.get('version', None) 66 | 67 | if (version is None or version < 2) and self.track_running_stats: 68 | # at version 2: added num_batches_tracked buffer 69 | # this should have a default value of 0 70 | num_batches_tracked_key = prefix + 'num_batches_tracked' 71 | if num_batches_tracked_key not in state_dict: 72 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 73 | 74 | super(_NormBase, self)._load_from_state_dict( 75 | state_dict, prefix, local_metadata, strict, 76 | missing_keys, unexpected_keys, error_msgs) 77 | 78 | 79 | class _BatchNorm(_NormBase): 80 | 81 | def __init__(self, num_features, eps=1e-5, momentum=0.1, center=True, scale=True, 82 | track_running_stats=True): 83 | super(_BatchNorm, self).__init__( 84 | num_features, eps, momentum, center, scale, track_running_stats) 85 | 86 | def forward(self, input): 87 | self._check_input_dim(input) 88 | 89 | # exponential_average_factor is set to self.momentum 90 | # (when it is available) only so that if gets updated 91 | # in ONNX graph when this node is exported to ONNX. 92 | if self.momentum is None: 93 | exponential_average_factor = 0.0 94 | else: 95 | exponential_average_factor = self.momentum 96 | 97 | if self.training and self.track_running_stats: 98 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 99 | if self.num_batches_tracked is not None: 100 | self.num_batches_tracked = self.num_batches_tracked + 1 101 | if self.momentum is None: # use cumulative moving average 102 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 103 | else: # use exponential moving average 104 | exponential_average_factor = self.momentum 105 | 106 | return F.batch_norm( 107 | input, self.running_mean, self.running_var, self.weight, self.bias, 108 | self.training or not self.track_running_stats, 109 | exponential_average_factor, self.eps) 110 | 111 | 112 | class BatchNorm1d(_BatchNorm): 113 | r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 114 | inputs with optional additional channel dimension) as described in the paper 115 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 116 | 117 | .. math:: 118 | 119 | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 120 | 121 | The mean and standard-deviation are calculated per-dimension over 122 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 123 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 124 | to 1 and the elements of :math:`\beta` are set to 0. 125 | 126 | Also by default, during training this layer keeps running estimates of its 127 | computed mean and variance, which are then used for normalization during 128 | evaluation. The running estimates are kept with a default :attr:`momentum` 129 | of 0.1. 130 | 131 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 132 | keep running estimates, and batch statistics are instead used during 133 | evaluation time as well. 134 | 135 | .. note:: 136 | This :attr:`momentum` argument is different from one used in optimizer 137 | classes and the conventional notion of momentum. Mathematically, the 138 | update rule for running statistics here is 139 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 140 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 141 | new observed value. 142 | 143 | Because the Batch Normalization is done over the `C` dimension, computing statistics 144 | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 145 | 146 | Args: 147 | num_features: :math:`C` from an expected input of size 148 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 149 | eps: a value added to the denominator for numerical stability. 150 | Default: 1e-5 151 | momentum: the value used for the running_mean and running_var 152 | computation. Can be set to ``None`` for cumulative moving average 153 | (i.e. simple average). Default: 0.1 154 | center: a boolean value that when set to ``True``, this module has 155 | learnable center parameters. Default: ``True`` 156 | scale: a boolean value that when set to ``True``, this module has 157 | learnable scale parameters. Default: ``True`` 158 | track_running_stats: a boolean value that when set to ``True``, this 159 | module tracks the running mean and variance, and when set to ``False``, 160 | this module does not track such statistics and always uses batch 161 | statistics in both training and eval modes. Default: ``True`` 162 | 163 | Shape: 164 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 165 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 166 | 167 | Examples:: 168 | 169 | >>> # With Learnable Parameters 170 | >>> m = nn.BatchNorm1d(100) 171 | >>> # Without Learnable Parameters 172 | >>> m = nn.BatchNorm1d(100, center=False, scale=False) 173 | >>> input = torch.randn(20, 100) 174 | >>> output = m(input) 175 | 176 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 177 | https://arxiv.org/abs/1502.03167 178 | """ 179 | 180 | def _check_input_dim(self, input): 181 | if input.dim() != 2 and input.dim() != 3: 182 | raise ValueError('expected 2D or 3D input (got {}D input)' 183 | .format(input.dim())) 184 | 185 | 186 | class BatchNorm2d(_BatchNorm): 187 | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 188 | with additional channel dimension) as described in the paper 189 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 190 | 191 | .. math:: 192 | 193 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 194 | 195 | The mean and standard-deviation are calculated per-dimension over 196 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 197 | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set 198 | to 1 and the elements of :math:`\beta` are set to 0. 199 | 200 | Also by default, during training this layer keeps running estimates of its 201 | computed mean and variance, which are then used for normalization during 202 | evaluation. The running estimates are kept with a default :attr:`momentum` 203 | of 0.1. 204 | 205 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 206 | keep running estimates, and batch statistics are instead used during 207 | evaluation time as well. 208 | 209 | .. note:: 210 | This :attr:`momentum` argument is different from one used in optimizer 211 | classes and the conventional notion of momentum. Mathematically, the 212 | update rule for running statistics here is 213 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, 214 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 215 | new observed value. 216 | 217 | Because the Batch Normalization is done over the `C` dimension, computing statistics 218 | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 219 | 220 | Args: 221 | num_features: :math:`C` from an expected input of size 222 | :math:`(N, C, H, W)` 223 | eps: a value added to the denominator for numerical stability. 224 | Default: 1e-5 225 | momentum: the value used for the running_mean and running_var 226 | computation. Can be set to ``None`` for cumulative moving average 227 | (i.e. simple average). Default: 0.1 228 | center: a boolean value that when set to ``True``, this module has 229 | learnable center parameters. Default: ``True`` 230 | scale: a boolean value that when set to ``True``, this module has 231 | learnable scale parameters. Default: ``True`` 232 | track_running_stats: a boolean value that when set to ``True``, this 233 | module tracks the running mean and variance, and when set to ``False``, 234 | this module does not track such statistics and always uses batch 235 | statistics in both training and eval modes. Default: ``True`` 236 | 237 | Shape: 238 | - Input: :math:`(N, C, H, W)` 239 | - Output: :math:`(N, C, H, W)` (same shape as input) 240 | 241 | Examples:: 242 | 243 | >>> # With Learnable Parameters 244 | >>> m = nn.BatchNorm2d(100) 245 | >>> # Without Learnable Parameters 246 | >>> m = nn.BatchNorm2d(100, center=False, scale=False) 247 | >>> input = torch.randn(20, 100, 35, 45) 248 | >>> output = m(input) 249 | 250 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 251 | https://arxiv.org/abs/1502.03167 252 | """ 253 | 254 | def _check_input_dim(self, input): 255 | if input.dim() != 4: 256 | raise ValueError('expected 4D input (got {}D input)' 257 | .format(input.dim())) 258 | 259 | class SyncBatchNorm(_BatchNorm): 260 | r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs 261 | with additional channel dimension) as described in the paper 262 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 263 | 264 | .. math:: 265 | 266 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 267 | 268 | The mean and standard-deviation are calculated per-dimension over all 269 | mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` 270 | are learnable parameter vectors of size `C` (where `C` is the input size). 271 | By default, the elements of :math:`\gamma` are sampled from 272 | :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. 273 | 274 | Also by default, during training this layer keeps running estimates of its 275 | computed mean and variance, which are then used for normalization during 276 | evaluation. The running estimates are kept with a default :attr:`momentum` 277 | of 0.1. 278 | 279 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 280 | keep running estimates, and batch statistics are instead used during 281 | evaluation time as well. 282 | 283 | .. note:: 284 | This :attr:`momentum` argument is different from one used in optimizer 285 | classes and the conventional notion of momentum. Mathematically, the 286 | update rule for running statistics here is 287 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 288 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 289 | new observed value. 290 | 291 | Because the Batch Normalization is done over the `C` dimension, computing statistics 292 | on `(N, +)` slices, it's common terminology to call this Volumetric Batch Normalization 293 | or Spatio-temporal Batch Normalization. 294 | 295 | Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use 296 | torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping 297 | Network with DDP. 298 | 299 | Args: 300 | num_features: :math:`C` from an expected input of size 301 | :math:`(N, C, +)` 302 | eps: a value added to the denominator for numerical stability. 303 | Default: 1e-5 304 | momentum: the value used for the running_mean and running_var 305 | computation. Can be set to ``None`` for cumulative moving average 306 | (i.e. simple average). Default: 0.1 307 | center: a boolean value that when set to ``True``, this module has 308 | learnable center parameters. Default: ``True`` 309 | scale: a boolean value that when set to ``True``, this module has 310 | learnable scale parameters. Default: ``True`` 311 | track_running_stats: a boolean value that when set to ``True``, this 312 | module tracks the running mean and variance, and when set to ``False``, 313 | this module does not track such statistics and always uses batch 314 | statistics in both training and eval modes. Default: ``True`` 315 | process_group: synchronization of stats happen within each process group 316 | individually. Default behavior is synchronization across the whole 317 | world 318 | 319 | Shape: 320 | - Input: :math:`(N, C, +)` 321 | - Output: :math:`(N, C, +)` (same shape as input) 322 | 323 | Examples:: 324 | 325 | >>> # With Learnable Parameters 326 | >>> m = nn.SyncBatchNorm(100) 327 | >>> # creating process group (optional) 328 | >>> # process_ids is a list of int identifying rank ids. 329 | >>> process_group = torch.distributed.new_group(process_ids) 330 | >>> # Without Learnable Parameters 331 | >>> m = nn.BatchNorm3d(100, center=False, scale=False, process_group=process_group) 332 | >>> input = torch.randn(20, 100, 35, 45, 10) 333 | >>> output = m(input) 334 | 335 | >>> # network is nn.BatchNorm layer 336 | >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) 337 | >>> # only single gpu per process is currently supported 338 | >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( 339 | >>> sync_bn_network, 340 | >>> device_ids=[args.local_rank], 341 | >>> output_device=args.local_rank) 342 | 343 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 344 | https://arxiv.org/abs/1502.03167 345 | """ 346 | 347 | def __init__(self, num_features, eps=1e-5, momentum=0.1, center=True, scale=True, 348 | track_running_stats=True, process_group=None): 349 | super(SyncBatchNorm, self).__init__(num_features, eps, momentum, center, 350 | scale, track_running_stats) 351 | self.process_group = process_group 352 | # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used 353 | # under supported condition (single GPU per process) 354 | self.ddp_gpu_size = None 355 | 356 | def _check_input_dim(self, input): 357 | if input.dim() < 2: 358 | raise ValueError('expected at least 2D input (got {}D input)' 359 | .format(input.dim())) 360 | 361 | def _specify_ddp_gpu_num(self, gpu_size): 362 | if gpu_size > 1: 363 | raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process') 364 | self.ddp_gpu_size = gpu_size 365 | 366 | def forward(self, input): 367 | # currently only GPU input is supported 368 | if not input.is_cuda: 369 | raise ValueError('SyncBatchNorm expected input tensor to be on GPU') 370 | 371 | self._check_input_dim(input) 372 | 373 | # exponential_average_factor is set to self.momentum 374 | # (when it is available) only so that if gets updated 375 | # in ONNX graph when this node is exported to ONNX. 376 | if self.momentum is None: 377 | exponential_average_factor = 0.0 378 | else: 379 | exponential_average_factor = self.momentum 380 | 381 | if self.training and self.track_running_stats: 382 | self.num_batches_tracked = self.num_batches_tracked + 1 383 | if self.momentum is None: # use cumulative moving average 384 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 385 | else: # use exponential moving average 386 | exponential_average_factor = self.momentum 387 | 388 | need_sync = self.training or not self.track_running_stats 389 | if need_sync: 390 | process_group = torch.distributed.group.WORLD 391 | if self.process_group: 392 | process_group = self.process_group 393 | world_size = torch.distributed.get_world_size(process_group) 394 | need_sync = world_size > 1 395 | 396 | # fallback to framework BN when synchronization is not necessary 397 | if not need_sync: 398 | return F.batch_norm( 399 | input, self.running_mean, self.running_var, self.weight, self.bias, 400 | self.training or not self.track_running_stats, 401 | exponential_average_factor, self.eps) 402 | else: 403 | if not self.ddp_gpu_size: 404 | raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') 405 | 406 | return sync_batch_norm.apply( 407 | input, self.weight, self.bias, self.running_mean, self.running_var, 408 | self.eps, exponential_average_factor, process_group, world_size) 409 | 410 | @classmethod 411 | def convert_sync_batchnorm(cls, module, process_group=None): 412 | r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to 413 | `torch.nn.SyncBatchNorm` layer. 414 | 415 | Args: 416 | module (nn.Module): containing module 417 | process_group (optional): process group to scope synchronization, 418 | default is the whole world 419 | 420 | Returns: 421 | The original module with the converted `torch.nn.SyncBatchNorm` layer 422 | 423 | Example:: 424 | 425 | >>> # Network with nn.BatchNorm layer 426 | >>> module = torch.nn.Sequential( 427 | >>> torch.nn.Linear(20, 100), 428 | >>> torch.nn.BatchNorm1d(100) 429 | >>> ).cuda() 430 | >>> # creating process group (optional) 431 | >>> # process_ids is a list of int identifying rank ids. 432 | >>> process_group = torch.distributed.new_group(process_ids) 433 | >>> sync_bn_module = convert_sync_batchnorm(module, process_group) 434 | 435 | """ 436 | module_output = module 437 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 438 | module_output = torch.nn.SyncBatchNorm(module.num_features, 439 | module.eps, module.momentum, 440 | module.center, 441 | module.scale, 442 | module.track_running_stats, 443 | process_group) 444 | if module.scale: 445 | module_output.weight.data = module.weight.data.clone(memory_format=torch.preserve_format).detach() 446 | # keep requires_grad unchanged 447 | module_output.weight.requires_grad = module.weight.requires_grad 448 | if module.center: 449 | module_output.bias.data = module.bias.data.clone(memory_format=torch.preserve_format).detach() 450 | # keep requires_grad unchanged 451 | module_output.bias.requires_grad = module.bias.requires_grad 452 | module_output.running_mean = module.running_mean 453 | module_output.running_var = module.running_var 454 | module_output.num_batches_tracked = module.num_batches_tracked 455 | for name, child in module.named_children(): 456 | module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group)) 457 | del module 458 | return module_output -------------------------------------------------------------------------------- /deepul_helper/data.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | import torchvision.transforms.functional as F 8 | from torchvision import datasets 9 | from torchvision import transforms 10 | 11 | 12 | def get_transform(dataset, task, train=True): 13 | transform = None 14 | if task == 'context_encoder': 15 | if dataset == 'cifar10': 16 | transform = transforms.Compose([ 17 | transforms.Resize(128), 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) 20 | ]) 21 | elif 'imagenet' in dataset: 22 | transform = transforms.Compose([ 23 | transforms.Resize(350), 24 | transforms.RandomCrop(128), 25 | transforms.ToTensor(), 26 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 27 | ]) 28 | elif task == 'rotation': 29 | if dataset == 'cifar10': 30 | if train: 31 | transform = transforms.Compose([ 32 | transforms.RandomCrop(32, padding=4), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) 36 | ]) 37 | else: 38 | transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) 41 | ]) 42 | elif 'imagenet' in dataset: 43 | if train: 44 | transform = transforms.Compose([ 45 | transforms.RandomResizedCrop(224), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 49 | ]) 50 | else: 51 | transform = transforms.Compose([ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), 55 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | ]) 57 | elif task == 'cpc': 58 | if train: 59 | transform = transforms.Compose([ 60 | transforms.RandomResizedCrop(256), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 63 | transforms.RandomGrayscale(p=0.2), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 66 | ]) 67 | else: 68 | transform = transforms.Compose([ 69 | transforms.Resize(256), 70 | transforms.CenterCrop(256), 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 73 | ]) 74 | elif task == 'simclr': 75 | if dataset == 'cifar10': 76 | if train: 77 | transform = transforms.Compose([ 78 | transforms.RandomResizedCrop(32), 79 | transforms.RandomHorizontalFlip(p=0.5), 80 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 81 | transforms.RandomGrayscale(p=0.2), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 84 | ]) 85 | else: 86 | transform = transforms.Compose([ 87 | transforms.ToTensor(), 88 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 89 | ]) 90 | elif 'imagenet' in dataset: 91 | if train: 92 | transform = transforms.Compose([ 93 | transforms.RandomResizedCrop(128), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 96 | transforms.RandomGrayscale(p=0.2), 97 | GaussianBlur(kernel_size=11), 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 100 | ]) 101 | else: 102 | transform = transforms.Compose([ 103 | transforms.Resize(128), 104 | transforms.CenterCrop(128), 105 | transforms.ToTensor(), 106 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 107 | ]) 108 | transform = SimCLRDataTransform(transform) 109 | elif task == 'segmentation': 110 | if train: 111 | transform = MultipleCompose([ 112 | MultipleRandomResizedCrop(128), 113 | MultipleRandomHorizontalFlip(), 114 | RepeatTransform(transforms.ToTensor()), 115 | GroupTransform([ 116 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 117 | SegTargetTransform()]) 118 | ]) 119 | else: 120 | transform = MultipleCompose([ 121 | RepeatTransform(transforms.Resize(128)), 122 | RepeatTransform(transforms.CenterCrop(128)), 123 | RepeatTransform(transforms.ToTensor()), 124 | GroupTransform([ 125 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 126 | SegTargetTransform()]) 127 | ]) 128 | else: 129 | raise Exception('Invalid task:', task) 130 | 131 | return transform 132 | 133 | 134 | def get_datasets(dataset, task): 135 | if 'imagenet' in dataset: 136 | train_dir = osp.join('data', dataset, 'train') 137 | val_dir = osp.join('data', dataset, 'val') 138 | train_dataset = datasets.ImageFolder( 139 | train_dir, 140 | get_transform(dataset, task, train=True) 141 | ) 142 | 143 | val_dataset = datasets.ImageFolder( 144 | val_dir, 145 | get_transform(dataset, task, train=False) 146 | ) 147 | 148 | return train_dataset, val_dataset, len(train_dataset.classes) 149 | elif dataset == 'cifar10': 150 | train_dset = datasets.CIFAR10(osp.join('data', dataset), train=True, 151 | transform=get_transform(dataset, task, train=True), 152 | download=True) 153 | test_dset = datasets.CIFAR10(osp.join('data', dataset), train=False, 154 | transform=get_transform(dataset, task, train=False), 155 | download=True) 156 | return train_dset, test_dset, len(train_dset.classes) 157 | elif dataset == 'pascalvoc2012': 158 | train_dset = datasets.VOCSegmentation(osp.join('data', dataset), image_set='train', 159 | transforms=get_transform(dataset, task, train=True), 160 | download=True) 161 | test_dset = datasets.VOCSegmentation(osp.join('data', dataset), image_set='val', 162 | transforms=get_transform(dataset, task, train=False), 163 | download=True) 164 | return train_dset, test_dset, 21 165 | else: 166 | raise Exception('Invalid dataset:', dataset) 167 | 168 | 169 | # https://github.com/sthalles/SimCLR/blob/master/data_aug/gaussian_blur.py 170 | class GaussianBlur(object): 171 | # Implements Gaussian blur as described in the SimCLR paper 172 | def __init__(self, kernel_size, min=0.1, max=2.0): 173 | self.min = min 174 | self.max = max 175 | # kernel size is set to be 10% of the image height/width 176 | self.kernel_size = kernel_size 177 | 178 | def __call__(self, sample): 179 | sample = np.array(sample) 180 | 181 | # blur the image with a 50% chance 182 | prob = np.random.random_sample() 183 | 184 | if prob < 0.5: 185 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 186 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 187 | 188 | return sample 189 | 190 | 191 | class SimCLRDataTransform(object): 192 | def __init__(self, transform): 193 | self.transform = transform 194 | 195 | def __call__(self, sample): 196 | xi = self.transform(sample) 197 | xj = self.transform(sample) 198 | return xi, xj 199 | 200 | # Re-written torchvision transforms to support operations on multiple inputs 201 | # Needed to maintain consistency on random transforms with real images and their segmentations 202 | class MultipleCompose(object): 203 | def __init__(self, transforms): 204 | self.transforms = transforms 205 | 206 | def __call__(self, *inputs): 207 | for t in self.transforms: 208 | inputs = t(*inputs) 209 | return inputs 210 | 211 | 212 | class GroupTransform(object): 213 | """ Applies a list of transforms elementwise """ 214 | def __init__(self, transforms): 215 | self.transforms = transforms 216 | 217 | def __call__(self, *inputs): 218 | assert len(inputs) == len(self.transforms) 219 | outputs = [t(inp) for t, inp in zip(self.transforms, inputs)] 220 | return outputs 221 | 222 | class MultipleRandomResizedCrop(transforms.RandomResizedCrop): 223 | 224 | def __call__(self, *imgs): 225 | """ 226 | Args: 227 | imgs (List of PIL Image): Images to be cropped and resized. 228 | Assumes they are all the same size 229 | 230 | Returns: 231 | PIL Images: Randomly cropped and resized images. 232 | """ 233 | i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) 234 | return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 235 | for img in imgs] 236 | 237 | class MultipleRandomHorizontalFlip(transforms.RandomHorizontalFlip): 238 | def __call__(self, *imgs): 239 | if random.random() < self.p: 240 | return [F.hflip(img) for img in imgs] 241 | return imgs 242 | 243 | class RepeatTransform(object): 244 | def __init__(self, transform): 245 | self.transform = transform 246 | 247 | def __call__(self, *inputs): 248 | return [self.transform(inp) for inp in inputs] 249 | 250 | class SegTargetTransform(object): 251 | def __call__(self, target): 252 | target *= 255. 253 | target[target > 20] = 0 254 | return target.long() 255 | -------------------------------------------------------------------------------- /deepul_helper/demos.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | import torch.utils.data as data 9 | from torchvision import datasets 10 | from torchvision.utils import make_grid 11 | 12 | from deepul_helper.data import get_datasets 13 | from deepul_helper.tasks import * 14 | from deepul_helper.utils import accuracy, unnormalize, remove_module_state_dict, seg_idxs_to_color 15 | from deepul_helper.seg_model import SegmentationModel 16 | 17 | 18 | def load_model_and_data(task, dataset='cifar10'): 19 | train_dset, test_dset, n_classes = get_datasets(dataset, task) 20 | train_loader = data.DataLoader(train_dset, batch_size=128, num_workers=4, 21 | pin_memory=True, shuffle=True) 22 | test_loader = data.DataLoader(test_dset, batch_size=128, num_workers=4, 23 | pin_memory=True, shuffle=True) 24 | 25 | ckpt_pth = osp.join('results', f'{dataset}_{task}', 'model_best.pth.tar') 26 | ckpt = torch.load(ckpt_pth, map_location='cpu') 27 | 28 | if task == 'context_encoder': 29 | model = ContextEncoder(dataset, n_classes) 30 | elif task == 'rotation': 31 | model = RotationPrediction(dataset, n_classes) 32 | elif task == 'simclr': 33 | model = SimCLR(dataset, n_classes, None) 34 | model.load_state_dict(remove_module_state_dict(ckpt['state_dict'])) 35 | 36 | model.cuda() 37 | model.eval() 38 | 39 | linear_classifier = model.construct_classifier() 40 | linear_classifier.load_state_dict(remove_module_state_dict(ckpt['state_dict_linear'])) 41 | 42 | linear_classifier.cuda() 43 | linear_classifier.eval() 44 | 45 | return model, linear_classifier, train_loader, test_loader 46 | 47 | 48 | def evaluate_accuracy(model, linear_classifier, train_loader, test_loader): 49 | train_acc1, train_acc5 = evaluate_classifier(model, linear_classifier, train_loader) 50 | test_acc1, test_acc5 = evaluate_classifier(model, linear_classifier, test_loader) 51 | 52 | print('Train Set') 53 | print(f'Top 1 Accuracy: {train_acc1}, Top 5 Accuracy: {train_acc5}\n') 54 | print('Test Set') 55 | print(f'Top 1 Accuracy: {test_acc1}, Top 5 Accuracy: {test_acc5}\n') 56 | 57 | 58 | def evaluate_classifier(model, linear_classifier, loader): 59 | correct1, correct5 = 0, 0 60 | with torch.no_grad(): 61 | for images, target in loader: 62 | images = images_to_cuda(images) 63 | target = target.cuda(non_blocking=True) 64 | out, zs = model(images) 65 | 66 | logits = linear_classifier(zs) 67 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 68 | 69 | correct1 += acc1.item() * logits.shape[0] 70 | correct5 += acc5.item() * logits.shape[0] 71 | total = len(loader.dataset) 72 | 73 | return correct1 / total, correct5 / total 74 | 75 | 76 | def display_nearest_neighbors(task, model, loader, n_examples=4, k=16): 77 | with torch.no_grad(): 78 | all_images, all_zs = [], [] 79 | for i, (images, _) in enumerate(loader): 80 | images = images_to_cuda(images) 81 | if task == 'simclr': 82 | images = images[0] 83 | zs = model.encode(images) 84 | 85 | images = images.cpu() 86 | zs = zs.cpu() 87 | 88 | if i == 0: 89 | ref_zs = zs[:n_examples] 90 | ref_images = images[:n_examples] 91 | all_zs.append(zs[n_examples:]) 92 | all_images.append(images[n_examples:]) 93 | else: 94 | all_zs.append(zs) 95 | all_images.append(images) 96 | all_images = torch.cat(all_images, dim=0) 97 | all_zs = torch.cat(all_zs, dim=0) 98 | 99 | aa = (ref_zs ** 2).sum(dim=1).unsqueeze(dim=1) 100 | ab = torch.matmul(ref_zs, all_zs.t()) 101 | bb = (all_zs ** 2).sum(dim=1).unsqueeze(dim=0) 102 | dists = torch.sqrt(aa - 2 * ab + bb) 103 | 104 | idxs = torch.topk(dists, k, dim=1, largest=False)[1] 105 | sel_images = torch.index_select(all_images, 0, idxs.view(-1)) 106 | sel_images = unnormalize(sel_images.cpu(), 'cifar10') 107 | sel_images = sel_images.view(n_examples, k, *sel_images.shape[-3:]) 108 | 109 | ref_images = unnormalize(ref_images.cpu(), 'cifar10') 110 | ref_images = (ref_images.permute(0, 2, 3, 1) * 255.).numpy().astype('uint8') 111 | 112 | for i in range(n_examples): 113 | print(f'Image {i + 1}') 114 | plt.figure() 115 | plt.axis('off') 116 | plt.imshow(ref_images[i]) 117 | plt.show() 118 | 119 | grid_img = make_grid(sel_images[i], nrow=4) 120 | grid_img = (grid_img.permute(1, 2, 0) * 255.).numpy().astype('uint8') 121 | 122 | print(f'Top {k} Nearest Neighbors (in latent space)') 123 | plt.figure() 124 | plt.axis('off') 125 | plt.imshow(grid_img) 126 | plt.show() 127 | 128 | 129 | def images_to_cuda(images): 130 | if isinstance(images, (tuple, list)): 131 | images = [x.cuda(non_blocking=True) for x in images] 132 | else: 133 | images = images.cuda(non_blocking=True) 134 | return images 135 | 136 | 137 | def show_context_encoder_inpainting(): 138 | model, _, _, test_loader = load_model_and_data('context_encoder', 'cifar10') 139 | images = next(iter(test_loader))[0][:8] 140 | with torch.no_grad(): 141 | images = images.cuda(non_blocking=True) 142 | images_masked, images_recon = model.reconstruct(images) 143 | images_masked = unnormalize(images_masked.cpu(), 'cifar10') 144 | images_recon = unnormalize(images_recon.cpu(), 'cifar10') 145 | 146 | images = torch.stack((images_masked, images_recon), dim=1).flatten(end_dim=1) 147 | 148 | grid_img = make_grid(images, nrow=4) 149 | grid_img = (grid_img.permute(1, 2, 0) * 255.).numpy().astype('uint8') 150 | 151 | plt.figure() 152 | plt.axis('off') 153 | plt.imshow(grid_img) 154 | plt.show() 155 | 156 | 157 | def show_segmentation(): 158 | _, val_dset, n_classes = get_datasets('pascalvoc2012', 'segmentation') 159 | val_loader = data.DataLoader(val_dset, batch_size=128) 160 | 161 | pretrained_model = SimCLR('imagenet100', 100, None) 162 | ckpt = torch.load(osp.join('results', 'imagenet100_simclr', 'seg_model_best.pth.tar'), 163 | map_location='cpu') 164 | pretrained_model.load_state_dict(ckpt['pt_state_dict']) 165 | pretrained_model.cuda().eval() 166 | 167 | seg_model = SegmentationModel(n_classes) 168 | seg_model.load_state_dict(ckpt['state_dict']) 169 | seg_model.cuda().eval() 170 | 171 | images, target = next(iter(val_loader)) 172 | images, target = images[:12], target[:12] 173 | images = images.cuda(non_blocking=True) 174 | target = target.cuda(non_blocking=True).long().squeeze(1) 175 | features = pretrained_model.get_features(images) 176 | _, logits = seg_model(features, target) 177 | pred = torch.argmax(logits, dim=1) 178 | 179 | target = seg_idxs_to_color(target.cpu(), 'palette.pkl') 180 | pred = seg_idxs_to_color(pred.cpu(), 'palette.pkl') 181 | images = unnormalize(images.cpu(), 'imagenet') 182 | 183 | to_show = torch.stack((images, target, pred), dim=1).flatten(end_dim=1) 184 | to_show = make_grid(to_show, nrow=6, pad_value=1.) 185 | to_show = (to_show.permute(1, 2, 0) * 255.).numpy().astype('uint8') 186 | 187 | plt.figure(figsize=(12, 12)) 188 | plt.axis('off') 189 | plt.imshow(to_show) 190 | plt.show() 191 | 192 | -------------------------------------------------------------------------------- /deepul_helper/lars.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/noahgolmant/pytorch-lars 2 | Layer-wise adaptive rate scaling for SGD in PyTorch! """ 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | 7 | class LARS(Optimizer): 8 | r"""Implements layer-wise adaptive rate scaling for SGD. 9 | 10 | Args: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float): base learning rate (\gamma_0) 14 | momentum (float, optional): momentum factor (default: 0) ("m") 15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 16 | ("\beta") 17 | eta (float, optional): LARS coefficient 18 | 19 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 20 | Large Batch Training of Convolutional Networks: 21 | https://arxiv.org/abs/1708.03888 22 | 23 | Example: 24 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 25 | >>> optimizer.zero_grad() 26 | >>> loss_fn(model(input), target).backward() 27 | >>> optimizer.step() 28 | """ 29 | def __init__(self, params, lr=required, momentum=.9, 30 | weight_decay=.0005, eta=0.001): 31 | if lr is not required and lr < 0.0: 32 | raise ValueError("Invalid learning rate: {}".format(lr)) 33 | if momentum < 0.0: 34 | raise ValueError("Invalid momentum value: {}".format(momentum)) 35 | if weight_decay < 0.0: 36 | raise ValueError("Invalid weight_decay value: {}" 37 | .format(weight_decay)) 38 | if eta < 0.0: 39 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 40 | 41 | self.epoch = 0 42 | defaults = dict(lr=lr, momentum=momentum, 43 | weight_decay=weight_decay, 44 | eta=eta) 45 | super(LARS, self).__init__(params, defaults) 46 | 47 | def step(self, epoch=None, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | epoch: current epoch to calculate polynomial LR decay schedule. 54 | if None, uses self.epoch and increments it. 55 | """ 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | if epoch is None: 61 | epoch = self.epoch 62 | self.epoch += 1 63 | 64 | for group in self.param_groups: 65 | weight_decay = group['weight_decay'] 66 | momentum = group['momentum'] 67 | eta = group['eta'] 68 | lr = group['lr'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Compute local learning rate for this layer 81 | local_lr = eta * weight_norm / \ 82 | (grad_norm + weight_decay * weight_norm) 83 | 84 | # Update the momentum term 85 | actual_lr = local_lr * lr 86 | 87 | if 'momentum_buffer' not in param_state: 88 | buf = param_state['momentum_buffer'] = \ 89 | torch.zeros_like(p.data) 90 | else: 91 | buf = param_state['momentum_buffer'] 92 | buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data) 93 | p.data.add_(-buf) 94 | 95 | return loss 96 | -------------------------------------------------------------------------------- /deepul_helper/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numbers 3 | from torch.nn.parameter import Parameter 4 | from torch.nn import Module 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | 9 | class LayerNorm(Module): 10 | r"""Applies Layer Normalization over a mini-batch of inputs as described in 11 | the paper `Layer Normalization`_ . 12 | 13 | .. math:: 14 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 15 | 16 | The mean and standard-deviation are calculated separately over the last 17 | certain number dimensions which have to be of the shape specified by 18 | :attr:`normalized_shape`. 19 | :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of 20 | :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. 21 | 22 | .. note:: 23 | Unlike Batch Normalization and Instance Normalization, which applies 24 | scalar scale and bias for each entire channel/plane with the 25 | :attr:`affine` option, Layer Normalization applies per-element scale and 26 | bias with :attr:`elementwise_affine`. 27 | 28 | This layer uses statistics computed from input data in both training and 29 | evaluation modes. 30 | 31 | Args: 32 | normalized_shape (int or list or torch.Size): input shape from an expected input 33 | of size 34 | 35 | .. math:: 36 | [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] 37 | \times \ldots \times \text{normalized\_shape}[-1]] 38 | 39 | If a single integer is used, it is treated as a singleton list, and this module will 40 | normalize over the last dimension which is expected to be of that specific size. 41 | eps: a value added to the denominator for numerical stability. Default: 1e-5 42 | elementwise_affine: a boolean value that when set to ``True``, this module 43 | has learnable per-element affine parameters initialized to ones (for weights) 44 | and zeros (for biases). Default: ``True``. 45 | 46 | Shape: 47 | - Input: :math:`(N, *)` 48 | - Output: :math:`(N, *)` (same shape as input) 49 | 50 | Examples:: 51 | 52 | >>> input = torch.randn(20, 5, 10, 10) 53 | >>> # With Learnable Parameters 54 | >>> m = nn.LayerNorm(input.size()[1:]) 55 | >>> # Without Learnable Parameters 56 | >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False) 57 | >>> # Normalize over last two dimensions 58 | >>> m = nn.LayerNorm([10, 10]) 59 | >>> # Normalize over last dimension of size 10 60 | >>> m = nn.LayerNorm(10) 61 | >>> # Activating the module 62 | >>> output = m(input) 63 | 64 | .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 65 | """ 66 | __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] 67 | 68 | def __init__(self, normalized_shape, eps=1e-5, center=True, scale=True): 69 | super(LayerNorm, self).__init__() 70 | if isinstance(normalized_shape, numbers.Integral): 71 | normalized_shape = (normalized_shape,) 72 | self.normalized_shape = tuple(normalized_shape) 73 | self.eps = eps 74 | self.center = center 75 | self.scale = scale 76 | if self.center: 77 | self.bias = Parameter(torch.Tensor(*normalized_shape)) 78 | else: 79 | self.register_parameter('bias', None) 80 | if self.scale: 81 | self.weight = Parameter(torch.Tensor(*normalized_shape)) 82 | else: 83 | self.register_parameter('weight', None) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | if self.center: 88 | init.zeros_(self.bias) 89 | if self.scale: 90 | init.ones_(self.weight) 91 | 92 | def forward(self, input): 93 | return F.layer_norm( 94 | input, self.normalized_shape, self.weight, self.bias, self.eps) 95 | 96 | def extra_repr(self): 97 | return '{normalized_shape}, eps={eps}, ' \ 98 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 99 | -------------------------------------------------------------------------------- /deepul_helper/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | from .batch_norm import BatchNorm1d, BatchNorm2d 7 | from .layer_norm import LayerNorm 8 | 9 | class NormReLU(nn.Module): 10 | 11 | def __init__(self, input_size, relu=True, center=True, scale=True, norm_type='bn'): 12 | super().__init__() 13 | assert len(input_size) == 1 or len(input_size) == 3, f'Input size must be 1D or 3D {len(input_size)}' 14 | 15 | self.relu = relu 16 | if norm_type == 'bn': 17 | bn_cls = BatchNorm1d if len(input_size) == 1 else BatchNorm2d 18 | self.norm = bn_cls(input_size[0], center=center, scale=scale) 19 | elif norm_type == 'ln': 20 | self.norm = LayerNorm(input_size, center=center, scale=scale) 21 | else: 22 | self.bn = nn.Identity() 23 | 24 | def forward(self, x): 25 | x = self.norm(x) 26 | if self.relu: 27 | x = F.relu(x, inplace=True) 28 | return x 29 | 30 | 31 | def fixed_padding(inputs, kernel_size): 32 | pad_total = kernel_size - 1 33 | pad_beg = pad_total // 2 34 | pad_end = pad_total - pad_beg 35 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 36 | return padded_inputs 37 | 38 | 39 | class Conv2dFixedPad(nn.Module): 40 | 41 | def __init__(self, in_channels, out_channels, kernel_size, stride): 42 | super().__init__() 43 | 44 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 45 | padding=(kernel_size // 2 if stride == 1 else 0), bias=False) 46 | 47 | self.stride = stride 48 | self.kernel_size = kernel_size 49 | 50 | def forward(self, x): 51 | if self.stride > 1: 52 | x = fixed_padding(x, self.kernel_size) 53 | return self.conv(x) 54 | 55 | 56 | class ResidualBlock(nn.Module): 57 | 58 | def __init__(self, input_size, filters, stride, use_projection=False, norm_type='bn'): 59 | super().__init__() 60 | 61 | C, H, W = input_size 62 | if use_projection: 63 | self.proj_conv = Conv2dFixedPad(C, filters, kernel_size=1, stride=stride) 64 | self.proj_bnr = NormReLU((filters, H // stride, W // stride), 65 | relu=False, norm_type=norm_type) 66 | 67 | self.conv1 = Conv2dFixedPad(C, filters, kernel_size=3, stride=stride) 68 | self.bnr1 = NormReLU((filters, H // stride, C // stride), norm_type=norm_type) 69 | 70 | self.conv2 = Conv2dFixedPad(filters, filters, kernel_size=3, stride=1) 71 | self.bnr2 = NormReLU((filters, H // stride, W // stride), norm_type=norm_type) 72 | 73 | self.use_projection = use_projection 74 | 75 | def forward(self, x): 76 | shortcut = x 77 | if self.use_projection: 78 | shortcut = self.proj_bnr(self.proj_conv(x)) 79 | x = self.bnr1(self.conv1(x)) 80 | x = self.bnr2(self.conv2(x)) 81 | 82 | return F.relu(x + shortcut, inplace=True) 83 | 84 | 85 | class BottleneckBlock(nn.Module): 86 | 87 | def __init__(self, input_size, filters, stride, use_projection=False, norm_type='bn'): 88 | super().__init__() 89 | 90 | C, H, W = input_size 91 | if use_projection: 92 | filters_out = 4 * filters 93 | self.proj_conv = Conv2dFixedPad(C, filters_out, kernel_size=1, stride=stride) 94 | self.proj_bnr = NormReLU((filters_out, H // stride, W // stride), 95 | relu=False, norm_type=norm_type) 96 | 97 | self.conv1 = Conv2dFixedPad(C, filters, kernel_size=1, stride=1) 98 | self.bnr1 = NormReLU((filters, H, W), norm_type=norm_type) 99 | 100 | self.conv2 = Conv2dFixedPad(filters, filters, kernel_size=3, stride=stride) 101 | self.bnr2 = NormReLU((filters, H // stride, W // stride), norm_type=norm_type) 102 | 103 | self.conv3 = Conv2dFixedPad(filters, 4 * filters, kernel_size=1, stride=1) 104 | self.bnr3 = NormReLU((4 * filters, H // stride, W // stride), norm_type=norm_type) 105 | 106 | self.use_projection = use_projection 107 | 108 | def forward(self, x): 109 | shortcut = x 110 | if self.use_projection: 111 | shortcut = self.proj_bnr(self.proj_conv(x)) 112 | x = self.bnr1(self.conv1(x)) 113 | x = self.bnr2(self.conv2(x)) 114 | x = self.bnr3(self.conv3(x)) 115 | 116 | return F.relu(x + shortcut, inplace=True) 117 | 118 | 119 | class BlockGroup(nn.Module): 120 | 121 | def __init__(self, input_size, filters, block_fn, blocks, stride, norm_type='bn'): 122 | super().__init__() 123 | 124 | self.start_block = block_fn(input_size, filters, stride, 125 | use_projection=True, norm_type=norm_type) 126 | in_channels = filters * 4 if block_fn == BottleneckBlock else filters 127 | input_size = (4 * filters, input_size[1] // stride, input_size[2] // stride) 128 | 129 | self.blocks = [] 130 | for _ in range(1, blocks): 131 | self.blocks.append(block_fn(input_size, filters, 1, norm_type=norm_type)) 132 | self.blocks = nn.Sequential(*self.blocks) 133 | 134 | def forward(self, x): 135 | x = self.start_block(x) 136 | x = self.blocks(x) 137 | return x 138 | 139 | 140 | class ResNet(nn.Module): 141 | 142 | def __init__(self, input_size, block_fn, layers, width_multiplier, cifar_stem=False, 143 | norm_type='bn'): 144 | super().__init__() 145 | 146 | C, H, W = input_size 147 | if cifar_stem: 148 | self.stem = nn.Sequential( 149 | Conv2dFixedPad(C, 64 * width_multiplier, kernel_size=3, stride=1), 150 | NormReLU((64 * width_multiplier, H, W), norm_type=norm_type) 151 | ) 152 | else: 153 | self.stem = nn.Sequential( 154 | Conv2dFixedPad(C, 64 * width_multiplier, kernel_size=7, stride=2), 155 | NormReLU((64 * width_multiplier, H // 2, W // 2), norm_type=norm_type), 156 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 157 | ) 158 | H, W = H // 4, W // 4 159 | 160 | scalar = 4 if block_fn == BottleneckBlock else 1 161 | 162 | self.group1 = BlockGroup((64 * width_multiplier, H, W), 64 * width_multiplier, 163 | block_fn=block_fn, blocks=layers[0], stride=1, 164 | norm_type=norm_type) 165 | self.group2 = BlockGroup((64 * width_multiplier * scalar, H, W), 128 * width_multiplier, 166 | block_fn=block_fn, blocks=layers[1], stride=2, 167 | norm_type=norm_type) 168 | H, W = H // 2, W // 2 169 | self.group3 = BlockGroup((128 * width_multiplier * scalar, H, W), 256 * width_multiplier, 170 | block_fn=block_fn, blocks=layers[2], stride=2, 171 | norm_type=norm_type) 172 | H, W = H // 2, W // 2 173 | self.group4 = BlockGroup((256 * width_multiplier * scalar, H, W), 512 * width_multiplier, 174 | block_fn=block_fn, blocks=layers[3], stride=2, 175 | norm_type=norm_type) 176 | 177 | def forward(self, x): 178 | x = self.stem(x) 179 | x = self.group1(x) 180 | x = self.group2(x) 181 | x = self.group3(x) 182 | x = self.group4(x) 183 | x = torch.mean(x, dim=[2, 3]).squeeze() 184 | return x 185 | 186 | # For semantic segmentation architectures 187 | def get_features(self, x): 188 | features = [x] 189 | 190 | x = self.stem[1](self.stem[0](x)) 191 | features.append(x) 192 | 193 | x = self.group1(self.stem[2](x)) 194 | features.append(x) 195 | 196 | x = self.group2(x) 197 | features.append(x) 198 | 199 | x = self.group3(x) 200 | features.append(x) 201 | 202 | x = self.group4(x) 203 | features.append(x) 204 | 205 | return features 206 | 207 | def resnet_v1(input_size, resnet_depth, width_multiplier, cifar_stem=False, norm_type='bn'): 208 | model_params = { 209 | 18: {'block': ResidualBlock, 'layers': [2, 2, 2, 2]}, 210 | 34: {'block': ResidualBlock, 'layers': [3, 4, 6, 3]}, 211 | 50: {'block': BottleneckBlock, 'layers': [3, 4, 6, 3]}, 212 | 101: {'block': BottleneckBlock, 'layers': [3, 4, 23, 3]}, 213 | 152: {'block': BottleneckBlock, 'layers': [3, 8, 36, 3]}, 214 | 200: {'block': BottleneckBlock, 'layers': [3, 24, 36, 3]} 215 | } 216 | 217 | if resnet_depth not in model_params: 218 | raise ValueError('Not a valid resnet_depth:', resnet_depth) 219 | 220 | params = model_params[resnet_depth] 221 | return ResNet(input_size, params['block'], params['layers'], width_multiplier, 222 | cifar_stem=cifar_stem, norm_type=norm_type) 223 | -------------------------------------------------------------------------------- /deepul_helper/seg_model.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/qubvel/segmentation_models.pytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from deepul_helper.resnet import NormReLU 7 | 8 | class SegmentationModel(nn.Module): 9 | metrics = ['Loss'] 10 | metrics_fmt = [':.4e'] 11 | 12 | def __init__(self, n_classes): 13 | super().__init__() 14 | 15 | decoder_channels = (512, 256, 128, 64, 32) 16 | encoder_channels = (2048, 1024, 512, 256, 64) # Starting from head (resnet 50) 17 | 18 | # Construct decoder blocks 19 | in_channels = [encoder_channels[0]] + list(decoder_channels[:-1]) 20 | skip_channels = list(encoder_channels[1:]) + [0] 21 | out_channels = decoder_channels 22 | blocks = [ 23 | DecoderBlock(in_ch, skip_ch, out_ch) 24 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 25 | ] 26 | self.dec_blocks = nn.ModuleList(blocks) 27 | 28 | # Segmentation head for output prediction 29 | self.seg_head = nn.Conv2d(decoder_channels[-1], n_classes, kernel_size=3, padding=1) 30 | 31 | def forward(self, features, targets): 32 | features = features[1:] # remove first skip with same spatial resolution 33 | features = features[::-1] # reverse channels to start from head of encoder 34 | 35 | skips = features[1:] 36 | x = features[0] 37 | for i, decoder_block in enumerate(self.dec_blocks): 38 | skip = skips[i] if i < len(skips) else None 39 | x = decoder_block(x, skip) 40 | 41 | logits = self.seg_head(x) 42 | loss = F.cross_entropy(logits, targets) 43 | 44 | return dict(Loss=loss), logits 45 | 46 | 47 | class DecoderBlock(nn.Module): 48 | def __init__( 49 | self, 50 | in_channels, 51 | skip_channels, 52 | out_channels, 53 | ): 54 | super().__init__() 55 | self.conv1 = nn.Sequential( 56 | nn.Conv2d(in_channels + skip_channels, out_channels, 57 | kernel_size=3, padding=1), 58 | NormReLU((out_channels, None, None)), # only care about channel dim for BN 59 | ) 60 | self.conv2 = nn.Sequential( 61 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 62 | NormReLU((out_channels, None, None)) 63 | ) 64 | 65 | def forward(self, x, skip=None): 66 | x = F.interpolate(x, scale_factor=2, mode="nearest") 67 | if skip is not None: 68 | x = torch.cat([x, skip], dim=1) 69 | x = self.conv1(x) 70 | x = self.conv2(x) 71 | return x 72 | -------------------------------------------------------------------------------- /deepul_helper/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .context_encoder import ContextEncoder 2 | from .rotation import RotationPrediction 3 | from .cpc import CPC 4 | from .simclr import SimCLR -------------------------------------------------------------------------------- /deepul_helper/tasks/context_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ContextEncoder(nn.Module): 7 | metrics = ['Loss'] 8 | metrics_fmt = [':.4e'] 9 | 10 | def __init__(self, dataset, n_classes): 11 | super().__init__() 12 | input_channels = 3 13 | 14 | self.latent_dim = 4000 15 | 16 | # Encodes the masked image 17 | self.encoder = nn.Sequential( 18 | # 128 x 128 Input 19 | nn.Conv2d(input_channels, 64, 4, stride=2, padding=1), # 64 x 64 20 | nn.BatchNorm2d(64), 21 | nn.LeakyReLU(0.2, inplace=True), 22 | nn.Conv2d(64, 64, 4, stride=2, padding=1), # 32 x 32 23 | nn.BatchNorm2d(64), 24 | nn.LeakyReLU(0.2, inplace=True), 25 | nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16 x 16 26 | nn.BatchNorm2d(128), 27 | nn.LeakyReLU(0.2, inplace=True), 28 | nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8 x 8 29 | nn.BatchNorm2d(256), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4 x 4 32 | nn.BatchNorm2d(512), 33 | nn.LeakyReLU(0.2, inplace=True), 34 | nn.Conv2d(512, self.latent_dim, 4) # 1 x 1 35 | ) 36 | 37 | # Only reconstructs the masked part of the image and not the whole image 38 | self.decoder = nn.Sequential( 39 | nn.BatchNorm2d(self.latent_dim), 40 | nn.ReLU(inplace=True), 41 | nn.ConvTranspose2d(self.latent_dim, 512, 4, stride=1, padding=0), # 4 x 4 42 | nn.BatchNorm2d(512), 43 | nn.ReLU(inplace=True), 44 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8 x 8 45 | nn.BatchNorm2d(256), 46 | nn.ReLU(inplace=True), 47 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16 x 16 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(inplace=True), 50 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32 x 32 51 | nn.BatchNorm2d(64), 52 | nn.ReLU(inplace=True), 53 | nn.ConvTranspose2d(64, input_channels, 4, stride=2, padding=1), # 64 x 64 54 | nn.Tanh() 55 | ) 56 | 57 | self.dataset = dataset 58 | self.n_classes = n_classes 59 | 60 | def construct_classifier(self): 61 | classifier = nn.Sequential( 62 | nn.Flatten(), 63 | nn.BatchNorm1d(self.latent_dim, affine=False), 64 | nn.Linear(self.latent_dim, self.n_classes) 65 | ) 66 | return classifier 67 | 68 | def forward(self, images): 69 | # Extract a 64 x 64 center from 128 x 128 image 70 | images_center = images[:, :, 32:32+64, 32:32+64].clone() 71 | images_masked = images.clone() 72 | # Mask out a 64 x 64 center with slight overlap 73 | images_masked[:, 0, 32+4:32+64-4, 32+4:32+64-4] = 2 * 117.0/255.0 - 1.0 74 | images_masked[:, 1, 32+4:32+64-4, 32+4:32+64-4] = 2 * 104.0/255.0 - 1.0 75 | images_masked[:, 2, 32+4:32+64-4, 32+4:32+64-4] = 2 * 123.0/255.0 - 1.0 76 | 77 | z = self.encoder(images_masked) 78 | center_recon = self.decoder(z) 79 | 80 | return dict(Loss=F.mse_loss(center_recon, images_center)), torch.flatten(z, 1) 81 | 82 | def encode(self, images): 83 | images_masked = images.clone() 84 | images_masked[:, 0, 32+4:32+64-4, 32+4:32+64-4] = 2 * 117.0/255.0 - 1.0 85 | images_masked[:, 1, 32+4:32+64-4, 32+4:32+64-4] = 2 * 104.0/255.0 - 1.0 86 | images_masked[:, 2, 32+4:32+64-4, 32+4:32+64-4] = 2 * 123.0/255.0 - 1.0 87 | return self.encoder(images_masked).flatten(start_dim=1) 88 | 89 | def reconstruct(self, images): 90 | images_center = images[:, :, 32:32+64, 32:32+64].clone() 91 | images_masked = images.clone() 92 | images_masked[:, 0, 32+4:32+64 - 4, 32+4:32+64-4] = 2 * 117.0/255.0 - 1.0 93 | images_masked[:, 1, 32+4:32+64 - 4, 32+4:32+64-4] = 2 * 104.0/255.0 - 1.0 94 | images_masked[:, 2, 32+4:32+64 - 4, 32+4:32+64-4] = 2 * 123.0/255.0 - 1.0 95 | 96 | z = self.encoder(images_masked) 97 | center_recon = self.decoder(z) 98 | 99 | images_recon = images_masked.clone() 100 | images_recon[:, :, 32:32+64, 32:32+64] = center_recon 101 | return images_masked, images_recon -------------------------------------------------------------------------------- /deepul_helper/tasks/cpc.py: -------------------------------------------------------------------------------- 1 | # WIP 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from deepul_helper.resnet import resnet_v1 10 | from deepul_helper.batch_norm import BatchNorm1d 11 | 12 | 13 | class CPC(nn.Module): 14 | latent_dim = 2048 15 | metrics = ['Loss'] 16 | metrics_fmt = [':.4e'] 17 | 18 | def __init__(self, dataset, n_classes): 19 | super().__init__() 20 | self.target_dim = 64 21 | self.emb_scale = 0.1 22 | self.steps_to_ignore = 2 23 | self.steps_to_predict = 3 24 | self.n_classes = n_classes 25 | 26 | self.encoder = resnet_v1((3, 64, 64), 50, 1, cifar_stem=False, norm_type='ln') 27 | self.pixelcnn = PixelCNN() 28 | 29 | self.z2target = nn.Conv2d(self.latent_dim, self.target_dim, (1, 1)) 30 | self.ctx2pred = nn.ModuleList([nn.Conv2d(self.latent_dim, self.target_dim, (1, 1)) 31 | for i in range(self.steps_to_ignore, self.steps_to_ignore + self.steps_to_predict)]) 32 | 33 | def construct_classifier(self): 34 | return nn.Sequential(BatchNorm1d(self.latent_dim, center=False), nn.Linear(self.latent_dim, self.n_classes)) 35 | 36 | def forward(self, images): 37 | batch_size = images.shape[0] 38 | patches = images_to_cpc_patches(images).detach() # (N*49, C, 64, 64) 39 | rnd = np.random.randint(low=0, high=16, size=(batch_size * 49,)) 40 | for i in range(batch_size * 49): 41 | r, c = rnd[i] // 4, rnd[i] % 4 42 | patches[i, :, :r] = -1. 43 | patches[i, :, :, :c] = -1. 44 | patches[i, :, r + 60:] = -1. 45 | patches[i, :, :, c + 60:] = -1. 46 | 47 | latents = self.encoder(patches) # (N*49, latent_dim) 48 | 49 | latents = latents.view(batch_size, 7, 7, -1).permute(0, 3, 1, 2).contiguous() # (N, latent_dim, 7, 7) 50 | context = self.pixelcnn(latents) # (N, latent_dim, 7, 7) 51 | 52 | col_dim, row_dim = 7, 7 53 | targets = self.z2target(latents).permute(0, 2, 3, 1).contiguous().view(-1, self.target_dim) # (N*49, 64) 54 | 55 | loss = 0. 56 | for i in range(self.steps_to_ignore, self.steps_to_ignore + self.steps_to_predict): 57 | col_dim_i = col_dim - i - 1 58 | total_elements = batch_size * col_dim_i * row_dim 59 | 60 | preds_i = self.ctx2pred[i - self.steps_to_ignore](context) # (N, 64, 7, 7) 61 | preds_i = preds_i[:, :, :-(i+1), :] * self.emb_scale # (N, 64, H, 7) 62 | preds_i = preds_i.permute(0, 2, 3, 1).contiguous() # (N, H, 7, 64) 63 | preds_i = preds_i.view(-1, self.target_dim) # (N*H*7, 64) 64 | 65 | logits = torch.matmul(preds_i, targets.t()) # (N*H*7, N*49) 66 | 67 | b = np.arange(total_elements) // (col_dim_i * row_dim) 68 | col = np.arange(total_elements) % (col_dim_i * row_dim) 69 | labels = b * col_dim * row_dim + (i + 1) * row_dim + col 70 | labels = torch.LongTensor(labels).to(logits.get_device()) 71 | 72 | loss = loss + F.cross_entropy(logits, labels) 73 | 74 | return dict(Loss=loss), latents.mean(dim=[2, 3]) 75 | 76 | def encode(self, images): 77 | batch_size = images.shape[0] 78 | patches = images_to_cpc_patches(images) # (N*49, C, 64, 64) 79 | latents = self.encoder(patches) # (N*49, latent_dim) 80 | latents = latents.view(batch_size, 7, 7, -1) # (N, 7, 7, latent_dim) 81 | return latents.mean(dim=[1, 2]) 82 | 83 | 84 | class PixelCNN(nn.Module): 85 | """Following PixelCN architecture in A.2 of 86 | https://arxiv.org/pdf/1905.09272.pdf""" 87 | 88 | def __init__(self): 89 | super().__init__() 90 | latent_dim = 2048 91 | 92 | self.net = nn.ModuleList() 93 | for _ in range(5): 94 | block = nn.Sequential( 95 | nn.Conv2d(latent_dim, 256, (1, 1)), 96 | nn.ReLU(), 97 | nn.ZeroPad2d((1, 1, 0, 0)), 98 | nn.Conv2d(256, 256, (1, 3)), 99 | nn.ReLU(), 100 | nn.ZeroPad2d((0, 0, 1, 0)), 101 | nn.Conv2d(256, 256, (2, 1)), 102 | nn.ReLU(), 103 | nn.Conv2d(256, latent_dim, (1, 1)) 104 | ) 105 | self.net.append(block) 106 | 107 | def forward(self, x): 108 | for i, block in enumerate(self.net): 109 | x = block(x) + x 110 | x = F.relu(x) 111 | return x 112 | 113 | 114 | def images_to_cpc_patches(images): 115 | """Converts (N, C, 256, 256) tensors to (N*49, C, 64, 64) patches 116 | for CPC training""" 117 | all_image_patches = [] 118 | for r in range(7): 119 | for c in range(7): 120 | batch_patch = images[:, :, r*32:r*32+64, c*32:c*32+64] 121 | all_image_patches.append(batch_patch) 122 | # (N, 49, C, 64, 64) 123 | image_patches_tensor = torch.stack(all_image_patches, dim=1) 124 | return image_patches_tensor.view(-1, *image_patches_tensor.shape[-3:]) 125 | 126 | # def extract_image_patches(x, kernel, stride=1, dilation=1): 127 | # # Do TF 'SAME' Padding 128 | # b,c,h,w = x.shape 129 | # h2 = math.ceil(h / stride) 130 | # w2 = math.ceil(w / stride) 131 | # pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h 132 | # pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w 133 | # x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2)) 134 | # # Extract patches 135 | # patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) 136 | # patches = patches.permute(0,4,5,1,2,3).contiguous() 137 | # return patches.view(b,-1,patches.shape[-2], patches.shape[-1]) 138 | -------------------------------------------------------------------------------- /deepul_helper/tasks/rotation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class RotationPrediction(nn.Module): 9 | metrics = ['Loss', 'Acc1'] 10 | metrics_fmt = [':.4e', ':6.2f'] 11 | 12 | def __init__(self, dataset, n_classes): 13 | super().__init__() 14 | if dataset == 'cifar10': 15 | self.model = NetworkInNetwork() 16 | self.latent_dim = 192 * 8 * 8 17 | self.feat_layer = 'conv2' 18 | elif 'imagenet' in dataset: 19 | self.model = AlexNet() 20 | self.latent_dim = 256 * 13 * 13 21 | self.feat_layer = 'conv5' 22 | else: 23 | raise Exception('Unsupported dataset:', dataset) 24 | self.dataset = dataset 25 | self.n_classes = n_classes 26 | 27 | def construct_classifier(self): 28 | if self.dataset == 'cifar10': 29 | classifier = nn.Sequential( 30 | nn.BatchNorm1d(self.latent_dim, affine=False), 31 | nn.Linear(self.latent_dim, self.n_classes) 32 | ) 33 | elif 'imagenet' in self.dataset: 34 | classifier = nn.Sequential( 35 | nn.AdaptiveMaxPool2d((6, 6)), 36 | nn.BatchNorm2d(256, affine=False), 37 | nn.Flatten(), 38 | nn.Linear(256 * 6 * 6, self.n_classes) 39 | ) 40 | else: 41 | raise Exception('Unsupported dataset:', dataset) 42 | return classifier 43 | 44 | def forward(self, images): 45 | batch_size = images.shape[0] 46 | images, targets = self._preprocess(images) 47 | targets = targets.to(images.get_device()) 48 | 49 | logits, zs = self.model(images, out_feat_keys=('classifier', self.feat_layer)) 50 | loss = F.cross_entropy(logits, targets) 51 | 52 | pred = logits.argmax(dim=-1) 53 | correct = pred.eq(targets).float().sum() 54 | acc = correct / targets.shape[0] * 100. 55 | 56 | zs = zs[:batch_size] 57 | if self.dataset == 'cifar10': 58 | zs = zs.flatten(start_dim=1) 59 | 60 | return dict(Loss=loss, Acc1=acc), zs[:batch_size] 61 | 62 | def encode(self, images, flatten=True): 63 | zs = self.model(images, out_feat_keys=(self.feat_layer,)) 64 | return zs.flatten(start_dim=1) 65 | 66 | def _preprocess(self, images): 67 | batch_size = images.shape[0] 68 | images_90 = torch.flip(images.transpose(2, 3), (2,)) 69 | images_180 = torch.flip(images, (2, 3)) 70 | images_270 = torch.flip(images, (2,)).transpose(2, 3) 71 | images_batch = torch.cat((images, images_90, images_180, images_270), dim=0) 72 | targets = torch.arange(4).long().repeat(batch_size) 73 | targets = targets.view(batch_size, 4).transpose(0, 1) 74 | targets = targets.contiguous().view(-1) 75 | return images_batch, targets 76 | 77 | 78 | 79 | # Code borrowed from https://github.com/gidariss/FeatureLearningRotNet 80 | 81 | # NetworkInNetwork 82 | class BasicBlock(nn.Module): 83 | def __init__(self, in_planes, out_planes, kernel_size): 84 | super(BasicBlock, self).__init__() 85 | padding = (kernel_size-1) // 2 86 | self.layers = nn.Sequential() 87 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes, 88 | kernel_size=kernel_size, stride=1, padding=padding, bias=False)) 89 | self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes)) 90 | self.layers.add_module('ReLU', nn.ReLU(inplace=True)) 91 | 92 | def forward(self, x): 93 | return self.layers(x) 94 | 95 | 96 | class GlobalAveragePooling(nn.Module): 97 | def __init__(self): 98 | super(GlobalAveragePooling, self).__init__() 99 | 100 | def forward(self, feat): 101 | num_channels = feat.size(1) 102 | return F.avg_pool2d(feat, (feat.size(2), feat.size(3))).view(-1, num_channels) 103 | 104 | 105 | class NetworkInNetwork(nn.Module): 106 | def __init__(self): 107 | super(NetworkInNetwork, self).__init__() 108 | 109 | num_classes = 4 110 | num_inchannels = 3 111 | num_stages = 4 112 | use_avg_on_conv3 = False 113 | 114 | 115 | nChannels = 192 116 | nChannels2 = 160 117 | nChannels3 = 96 118 | 119 | blocks = [nn.Sequential() for i in range(num_stages)] 120 | # 1st block 121 | blocks[0].add_module('Block1_ConvB1', BasicBlock(num_inchannels, nChannels, 5)) 122 | blocks[0].add_module('Block1_ConvB2', BasicBlock(nChannels, nChannels2, 1)) 123 | blocks[0].add_module('Block1_ConvB3', BasicBlock(nChannels2, nChannels3, 1)) 124 | blocks[0].add_module('Block1_MaxPool', nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) 125 | 126 | # 2nd block 127 | blocks[1].add_module('Block2_ConvB1', BasicBlock(nChannels3, nChannels, 5)) 128 | blocks[1].add_module('Block2_ConvB2', BasicBlock(nChannels, nChannels, 1)) 129 | blocks[1].add_module('Block2_ConvB3', BasicBlock(nChannels, nChannels, 1)) 130 | blocks[1].add_module('Block2_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 131 | 132 | # 3rd block 133 | blocks[2].add_module('Block3_ConvB1', BasicBlock(nChannels, nChannels, 3)) 134 | blocks[2].add_module('Block3_ConvB2', BasicBlock(nChannels, nChannels, 1)) 135 | blocks[2].add_module('Block3_ConvB3', BasicBlock(nChannels, nChannels, 1)) 136 | 137 | if num_stages > 3 and use_avg_on_conv3: 138 | blocks[2].add_module('Block3_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 139 | for s in range(3, num_stages): 140 | blocks[s].add_module('Block'+str(s+1)+'_ConvB1', BasicBlock(nChannels, nChannels, 3)) 141 | blocks[s].add_module('Block'+str(s+1)+'_ConvB2', BasicBlock(nChannels, nChannels, 1)) 142 | blocks[s].add_module('Block'+str(s+1)+'_ConvB3', BasicBlock(nChannels, nChannels, 1)) 143 | 144 | # global average pooling and classifier 145 | blocks.append(nn.Sequential()) 146 | blocks[-1].add_module('GlobalAveragePooling', GlobalAveragePooling()) 147 | blocks[-1].add_module('Classifier', nn.Linear(nChannels, num_classes)) 148 | 149 | self._feature_blocks = nn.ModuleList(blocks) 150 | self.all_feat_names = ['conv'+str(s+1) for s in range(num_stages)] + ['classifier',] 151 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 152 | 153 | def _parse_out_keys_arg(self, out_feat_keys): 154 | 155 | # By default return the features of the last layer / module. 156 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 157 | 158 | if len(out_feat_keys) == 0: 159 | raise ValueError('Empty list of output feature keys.') 160 | for f, key in enumerate(out_feat_keys): 161 | if key not in self.all_feat_names: 162 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 163 | elif key in out_feat_keys[:f]: 164 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 165 | 166 | # Find the highest output feature in `out_feat_keys 167 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 168 | 169 | return out_feat_keys, max_out_feat 170 | 171 | def forward(self, x, out_feat_keys=None): 172 | """Forward an image `x` through the network and return the asked output features. 173 | Args: 174 | x: input image. 175 | out_feat_keys: a list/tuple with the feature names of the features 176 | that the function should return. By default the last feature of 177 | the network is returned. 178 | Return: 179 | out_feats: If multiple output features were asked then `out_feats` 180 | is a list with the asked output features placed in the same 181 | order as in `out_feat_keys`. If a single output feature was 182 | asked then `out_feats` is that output feature (and not a list). 183 | """ 184 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 185 | out_feats = [None] * len(out_feat_keys) 186 | 187 | feat = x 188 | for f in range(max_out_feat+1): 189 | feat = self._feature_blocks[f](feat) 190 | key = self.all_feat_names[f] 191 | if key in out_feat_keys: 192 | out_feats[out_feat_keys.index(key)] = feat 193 | 194 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 195 | return out_feats 196 | 197 | 198 | def weight_initialization(self): 199 | for m in self.modules(): 200 | if isinstance(m, nn.Conv2d): 201 | if m.weight.requires_grad: 202 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 203 | m.weight.data.normal_(0, math.sqrt(2. / n)) 204 | elif isinstance(m, nn.BatchNorm2d): 205 | if m.weight.requires_grad: 206 | m.weight.data.fill_(1) 207 | if m.bias.requires_grad: 208 | m.bias.data.zero_() 209 | elif isinstance(m, nn.Linear): 210 | if m.bias.requires_grad: 211 | m.bias.data.zero_() 212 | 213 | 214 | # AlexNet 215 | class AlexNet(nn.Module): 216 | def __init__(self): 217 | super(AlexNet, self).__init__() 218 | num_classes = 4 219 | 220 | conv1 = nn.Sequential( 221 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 222 | nn.BatchNorm2d(64), 223 | nn.ReLU(inplace=True), 224 | ) 225 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 226 | conv2 = nn.Sequential( 227 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 228 | nn.BatchNorm2d(192), 229 | nn.ReLU(inplace=True), 230 | ) 231 | pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 232 | conv3 = nn.Sequential( 233 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 234 | nn.BatchNorm2d(384), 235 | nn.ReLU(inplace=True), 236 | ) 237 | conv4 = nn.Sequential( 238 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 239 | nn.BatchNorm2d(256), 240 | nn.ReLU(inplace=True), 241 | ) 242 | conv5 = nn.Sequential( 243 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 244 | nn.BatchNorm2d(256), 245 | nn.ReLU(inplace=True), 246 | ) 247 | pool5 = nn.MaxPool2d(kernel_size=3, stride=2) 248 | 249 | num_pool5_feats = 6 * 6 * 256 250 | fc_block = nn.Sequential( 251 | nn.Flatten(), 252 | nn.Linear(num_pool5_feats, 4096, bias=False), 253 | nn.BatchNorm1d(4096), 254 | nn.ReLU(inplace=True), 255 | nn.Linear(4096, 4096, bias=False), 256 | nn.BatchNorm1d(4096), 257 | nn.ReLU(inplace=True), 258 | ) 259 | classifier = nn.Sequential( 260 | nn.Linear(4096, num_classes), 261 | ) 262 | 263 | self._feature_blocks = nn.ModuleList([ 264 | conv1, 265 | pool1, 266 | conv2, 267 | pool2, 268 | conv3, 269 | conv4, 270 | conv5, 271 | pool5, 272 | fc_block, 273 | classifier, 274 | ]) 275 | self.all_feat_names = [ 276 | 'conv1', 277 | 'pool1', 278 | 'conv2', 279 | 'pool2', 280 | 'conv3', 281 | 'conv4', 282 | 'conv5', 283 | 'pool5', 284 | 'fc_block', 285 | 'classifier', 286 | ] 287 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 288 | 289 | def _parse_out_keys_arg(self, out_feat_keys): 290 | 291 | # By default return the features of the last layer / module. 292 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 293 | 294 | if len(out_feat_keys) == 0: 295 | raise ValueError('Empty list of output feature keys.') 296 | for f, key in enumerate(out_feat_keys): 297 | if key not in self.all_feat_names: 298 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 299 | elif key in out_feat_keys[:f]: 300 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 301 | 302 | # Find the highest output feature in `out_feat_keys 303 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 304 | 305 | return out_feat_keys, max_out_feat 306 | 307 | def forward(self, x, out_feat_keys=None): 308 | """Forward an image `x` through the network and return the asked output features. 309 | Args: 310 | x: input image. 311 | out_feat_keys: a list/tuple with the feature names of the features 312 | that the function should return. By default the last feature of 313 | the network is returned. 314 | Return: 315 | out_feats: If multiple output features were asked then `out_feats` 316 | is a list with the asked output features placed in the same 317 | order as in `out_feat_keys`. If a single output feature was 318 | asked then `out_feats` is that output feature (and not a list). 319 | """ 320 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 321 | out_feats = [None] * len(out_feat_keys) 322 | 323 | feat = x 324 | for f in range(max_out_feat+1): 325 | feat = self._feature_blocks[f](feat) 326 | key = self.all_feat_names[f] 327 | if key in out_feat_keys: 328 | out_feats[out_feat_keys.index(key)] = feat 329 | 330 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 331 | return out_feats 332 | 333 | def get_L1filters(self): 334 | convlayer = self._feature_blocks[0][0] 335 | batchnorm = self._feature_blocks[0][1] 336 | filters = convlayer.weight.data 337 | scalars = (batchnorm.weight.data / torch.sqrt(batchnorm.running_var + 1e-05)) 338 | filters = (filters * scalars.view(-1, 1, 1, 1).expand_as(filters)).cpu().clone() 339 | 340 | return filters 341 | -------------------------------------------------------------------------------- /deepul_helper/tasks/simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from deepul_helper.resnet import resnet_v1 6 | from deepul_helper.batch_norm import SyncBatchNorm, BatchNorm1d 7 | 8 | # Some code adapted from https://github.com/sthalles/SimCLR 9 | class SimCLR(nn.Module): 10 | metrics = ['Loss'] 11 | metrics_fmt = [':.4e'] 12 | 13 | def __init__(self, dataset, n_classes, dist=None): 14 | super().__init__() 15 | self.temperature = 0.5 16 | self.projection_dim = 128 17 | 18 | if dataset == 'cifar10': 19 | resnet = resnet_v1((3, 32, 32), 50, 1, cifar_stem=True) 20 | resnet = SyncBatchNorm.convert_sync_batchnorm(resnet) 21 | self.resnet = resnet 22 | self.latent_dim = 2048 23 | elif 'imagenet' in dataset: 24 | resnet = resnet_v1((3, 128, 128), 50, 1, cifar_stem=False) 25 | if dist is not None: 26 | resnet = nn.SyncBatchNorm.convert_sync_batchnorm(resnet) 27 | self.resnet = resnet 28 | self.latent_dim = 2048 29 | 30 | self.proj = nn.Sequential( 31 | nn.Linear(self.latent_dim, self.projection_dim, bias=False), 32 | BatchNorm1d(self.projection_dim), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(self.projection_dim, self.projection_dim, bias=False), 35 | BatchNorm1d(self.projection_dim, center=False) 36 | ) 37 | 38 | self.dataset = dataset 39 | self.n_classes = n_classes 40 | self.dist = dist 41 | 42 | def construct_classifier(self): 43 | return nn.Sequential(nn.Linear(self.latent_dim, self.n_classes)) 44 | 45 | def forward(self, images): 46 | n = images[0].shape[0] 47 | xi, xj = images 48 | hi, hj = self.encode(xi), self.encode(xj) # (N, latent_dim) 49 | zi, zj = self.proj(hi), self.proj(hj) # (N, projection_dim) 50 | zi, zj = F.normalize(zi), F.normalize(zj) 51 | 52 | # Each training example has 2N - 2 negative samples 53 | # 2N total samples, but exclude the current and positive sample 54 | 55 | if self.dist is None: 56 | zis = [zi] 57 | zjs = [zj] 58 | else: 59 | zis = [torch.zeros_like(zi) for _ in range(self.dist.get_world_size())] 60 | zjs = [torch.zeros_like(zj) for _ in range(self.dist.get_world_size())] 61 | 62 | self.dist.all_gather(zis, zi) 63 | self.dist.all_gather(zjs, zj) 64 | 65 | z1 = torch.cat((zi, zj), dim=0) # (2N, projection_dim) 66 | z2 = torch.cat(zis + zjs, dim=0) # (2N * n_gpus, projection_dim) 67 | 68 | sim_matrix = torch.mm(z1, z2.t()) # (2N, 2N * n_gpus) 69 | sim_matrix = sim_matrix / self.temperature 70 | # Mask out same-sample terms 71 | n_gpus = 1 if self.dist is None else self.dist.get_world_size() 72 | rank = 0 if self.dist is None else self.dist.get_rank() 73 | sim_matrix[torch.arange(n), torch.arange(rank*n, (rank+1)*n)] = -float('inf') 74 | sim_matrix[torch.arange(n, 2*n), torch.arange((n_gpus+rank)*n, (n_gpus+rank+1)*n)] = -float('inf') 75 | 76 | targets = torch.cat((torch.arange((n_gpus+rank)*n, (n_gpus+rank+1)*n), 77 | torch.arange(rank*n, (rank+1)*n)), dim=0) 78 | targets = targets.to(sim_matrix.get_device()).long() 79 | 80 | loss = F.cross_entropy(sim_matrix, targets, reduction='sum') 81 | loss = loss / n 82 | return dict(Loss=loss), hi 83 | 84 | def encode(self, images): 85 | return self.resnet(images) 86 | 87 | def get_features(self, images): 88 | return self.resnet.get_features(images) 89 | 90 | -------------------------------------------------------------------------------- /deepul_helper/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import math 3 | import pickle 4 | from collections import OrderedDict, Counter 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def quantize(img, n_bits): 10 | n_colors = 2 ** n_bits 11 | # Quantize to integers from 0, ..., n_colors - 1 12 | img = torch.clamp(torch.floor((img * n_colors)), max=n_colors - 1) 13 | img /= n_colors - 1 # Scale to [0, 1] 14 | return img 15 | 16 | 17 | def remove_module_state_dict(state_dict): 18 | """Clean state_dict keys if original state dict was saved from DistributedDataParallel 19 | and loaded without""" 20 | new_state_dict = OrderedDict() 21 | for k, v in state_dict.items(): 22 | name = k[7:] 23 | new_state_dict[name] = v 24 | return new_state_dict 25 | 26 | 27 | def seg_idxs_to_color(segs, palette_fname='palette.pkl'): 28 | B, H, W = segs.shape 29 | 30 | with open(palette_fname, 'rb') as f: 31 | palette = pickle.load(f) 32 | palette = torch.FloatTensor(palette).view(256, 3) 33 | imgs = torch.index_select(palette, 0, segs.view(-1)).view(B, H, W, 3).permute(0, 3, 1, 2) / 255. 34 | return imgs 35 | 36 | 37 | def unnormalize(images, dataset): 38 | if dataset == 'cifar10': 39 | mu = [0.4914, 0.4822, 0.4465] 40 | stddev = [0.2023, 0.1994, 0.2010] 41 | else: 42 | mu = [0.485, 0.456, 0.406] 43 | stddev = [0.229, 0.224, 0.225] 44 | 45 | mu = torch.FloatTensor(mu).view(1, 3, 1, 1) 46 | stddev = torch.FloatTensor(stddev).view(1, 3, 1, 1) 47 | return images * stddev + mu 48 | 49 | 50 | def accuracy(output, target, topk=(1,)): 51 | with torch.no_grad(): 52 | maxk = max(topk) 53 | batch_size = target.size(0) 54 | 55 | _, pred = output.topk(maxk, 1, True, True) 56 | pred = pred.t() 57 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 58 | 59 | res = [] 60 | for k in topk: 61 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 62 | res.append(correct_k.mul_(100.0 / batch_size)) 63 | return res 64 | 65 | class AverageMeter(object): 66 | """Computes and stores the average and current value""" 67 | def __init__(self, name, fmt=':f'): 68 | self.name = name 69 | self.fmt = fmt 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | def __str__(self): 85 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 86 | return fmtstr.format(**self.__dict__) 87 | 88 | 89 | class ProgressMeter(object): 90 | def __init__(self, num_batches, meters, prefix=""): 91 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 92 | self.meters = meters 93 | self.prefix = prefix 94 | 95 | def display(self, batch): 96 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 97 | entries += [str(meter) for meter in self.meters] 98 | print('\t'.join(entries)) 99 | 100 | def _get_batch_fmtstr(self, num_batches): 101 | num_digits = len(str(num_batches // 1)) 102 | fmt = '{:' + str(num_digits) + 'd}' 103 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 104 | -------------------------------------------------------------------------------- /deepul_helper/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def plot_hist(data, bins=10, xlabel='x', ylabel='Probability', title='', density=None): 9 | bins = np.concatenate((np.arange(bins) - 0.5, [bins - 1 + 0.5])) 10 | 11 | plt.figure() 12 | plt.hist(data, bins=bins, density=True) 13 | 14 | if density: 15 | plt.plot(density[0], density[1], label='distribution') 16 | plt.legend() 17 | 18 | plt.xlabel(xlabel) 19 | plt.ylabel(ylabel) 20 | plt.title(title) 21 | plt.show() 22 | 23 | 24 | def plot_2d_dist(dist, title='Learned Distribution'): 25 | plt.figure() 26 | plt.imshow(dist) 27 | plt.title(title) 28 | plt.xlabel('x1') 29 | plt.ylabel('x0') 30 | plt.show() 31 | 32 | 33 | def plot_train_curves(epochs, train_losses, test_losses, title=''): 34 | x = np.linspace(0, epochs, len(train_losses)) 35 | plt.figure() 36 | plt.plot(x, train_losses, label='train_loss') 37 | if test_losses: 38 | plt.plot(x, test_losses, label='test_loss') 39 | plt.xlabel('Epoch') 40 | plt.ylabel('Loss') 41 | plt.title(title) 42 | plt.legend() 43 | plt.show() 44 | 45 | 46 | def plot_scatter_2d(points, title='', labels=None): 47 | plt.figure() 48 | if labels is not None: 49 | plt.scatter(points[:, 0], points[:, 1], c=labels, 50 | cmap=mpl.colors.ListedColormap(['red', 'blue', 'green', 'purple'])) 51 | else: 52 | plt.scatter(points[:, 0], points[:, 1]) 53 | plt.title(title) 54 | plt.show() 55 | 56 | 57 | def visualize_batch(batch_tensor, nrow=8, title='', figsize=None): 58 | batch_tensor = batch_tensor.clamp(min=0, max=1) 59 | grid_img = make_grid(batch_tensor, nrow=nrow) 60 | plt.figure(figsize=figsize) 61 | plt.title(title) 62 | plt.imshow(grid_img.permute(1, 2, 0)) 63 | plt.axis('off') 64 | plt.show() -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ssl 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.1=cpu_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.1.1=0 10 | - certifi=2020.4.5.1=py37_0 11 | - cffi=1.14.0=py37h2e261b9_0 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - freetype=2.9.1=h8a8886c_1 14 | - intel-openmp=2020.1=217 15 | - jpeg=9b=h024ee3a_2 16 | - ld_impl_linux-64=2.33.1=h53a641e_7 17 | - libedit=3.1.20181209=hc058e9b_0 18 | - libffi=3.2.1=hd88cf55_4 19 | - libgcc-ng=9.1.0=hdf63c60_0 20 | - libgfortran-ng=7.3.0=hdf63c60_0 21 | - libpng=1.6.37=hbc83047_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - libtiff=4.1.0=h2733197_0 24 | - mkl=2020.1=217 25 | - mkl-service=2.3.0=py37he904b0f_0 26 | - mkl_fft=1.0.15=py37ha843d7b_0 27 | - mkl_random=1.1.0=py37hd6b4f25_0 28 | - ncurses=6.2=he6710b0_1 29 | - ninja=1.9.0=py37hfd86e86_0 30 | - numpy=1.18.1=py37h4f9e942_0 31 | - numpy-base=1.18.1=py37hde5b4d6_1 32 | - olefile=0.46=py37_0 33 | - openssl=1.1.1g=h7b6447c_0 34 | - pillow=7.1.2=py37hb39fc2d_0 35 | - pip=20.0.2=py37_3 36 | - pycparser=2.20=py_0 37 | - python=3.7.6=h0371630_2 38 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 39 | - readline=7.0=h7b6447c_5 40 | - setuptools=46.4.0=py37_0 41 | - six=1.14.0=py37_0 42 | - sqlite=3.31.1=h62c20be_1 43 | - tk=8.6.8=hbc83047_0 44 | - torchvision=0.5.0=py37_cu101 45 | - wheel=0.34.2=py37_0 46 | - xz=5.2.5=h7b6447c_0 47 | - zlib=1.2.11=h7b6447c_3 48 | - zstd=1.3.7=h0b5b093_0 49 | - pip: 50 | - chardet==3.0.4 51 | - idna==2.9 52 | - opencv-python==4.2.0.34 53 | - requests==2.23.0 54 | - urllib3==1.25.9 55 | - warmup-scheduler==0.3.2 56 | prefix: /home/wilson/miniconda3/envs/ssl 57 | 58 | -------------------------------------------------------------------------------- /palette.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/palette.pkl -------------------------------------------------------------------------------- /run/run_cifar10_rotation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 1 GPU 4 | python train_self_supervised_task.py -d cifar10 -t rotation 5 | -------------------------------------------------------------------------------- /run/run_cifar10_simclr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Trained on 4 Titan Xps. Turn down batch size to use less GPU memory. If your batch size is <= 256, then set -u 0 (no warmup) 4 | python train_self_supervised_task.py -d cifar10 -t simclr -b 512 -e 1000 -o lars --lr 1.0 -w 1e-6 -u 10 5 | -------------------------------------------------------------------------------- /run/run_imagenet100_cpc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # WIP 4 | python train_self_supervised_task.py -d imagenet100 -t cpc -b 64 -e 300 --lr 1e-3 --o adam 5 | -------------------------------------------------------------------------------- /run/run_imagenet100_rotation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 1-2 GPUs 4 | python train_self_supervised_task.py -d imagenet100 -t rotation 5 | -------------------------------------------------------------------------------- /run/run_imagenet100_simclr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Trained on 4 Titan Xps. Turn down batch size to use less GPU memory. If your batch size is <= 256, then set -u 0 (no warmup) 4 | python train_self_supervised_task.py -d imagenet100 -t simclr -b 512 -e 300 -o lars --lr 0.3 -w 1e-6 -u 10 5 | -------------------------------------------------------------------------------- /sample_images/chrom_ab_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/chrom_ab_demo.png -------------------------------------------------------------------------------- /sample_images/n01537544_19414.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n01537544_19414.JPEG -------------------------------------------------------------------------------- /sample_images/n01768244_3034.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n01768244_3034.JPEG -------------------------------------------------------------------------------- /sample_images/n03297495_2537.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n03297495_2537.JPEG -------------------------------------------------------------------------------- /sample_images/n03297495_3735.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n03297495_3735.JPEG -------------------------------------------------------------------------------- /sample_images/n03372029_42178.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n03372029_42178.JPEG -------------------------------------------------------------------------------- /sample_images/n03372029_46468.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n03372029_46468.JPEG -------------------------------------------------------------------------------- /sample_images/n03476684_24524.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n03476684_24524.JPEG -------------------------------------------------------------------------------- /sample_images/n04372370_39950.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n04372370_39950.JPEG -------------------------------------------------------------------------------- /sample_images/n11939491_52432.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/n11939491_52432.JPEG -------------------------------------------------------------------------------- /sample_images/sample2.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilson1yan/cs294-158-ssl/d67490a74ba40d3c7b14a9f54fdaaf4cb3d434ea/sample_images/sample2.JPEG -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='deepul_helper', 6 | version='0.1.0', 7 | packages=find_packages(), 8 | license='MIT License', 9 | ) -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import shutil 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | from torchvision.utils import save_image 13 | 14 | from deepul_helper.utils import AverageMeter, ProgressMeter, remove_module_state_dict, seg_idxs_to_color, unnormalize 15 | from deepul_helper.data import get_datasets 16 | from deepul_helper.seg_model import SegmentationModel 17 | from deepul_helper.tasks import * 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | # Currently only works for SimCLR 22 | parser.add_argument('-d', '--dataset', type=str, default='pascalvoc2012', 23 | help='default: pascalvoc2012') 24 | parser.add_argument('-t', '--pretrained_dir', type=str, default='results/imagenet100_simclr', 25 | help='directory of the pretrained model (default: results/imagenet100_simclr)') 26 | 27 | # Training parameters 28 | parser.add_argument('-b', '--batch_size', type=int, default=16, help='default: 128') 29 | parser.add_argument('-e', '--epochs', type=int, default=1000, help='default: 200') 30 | parser.add_argument('-o', '--optimizer', type=str, default='adam', help='sgd|adam (default: adam)') 31 | parser.add_argument('--lr', type=float, default=1e-4, help='default: 1e-3') 32 | parser.add_argument('-m', '--momentum', type=float, default=0.9, help='default: 0.9') 33 | parser.add_argument('-w', '--weight_decay', type=float, default=5e-4, help='default: 5e-4') 34 | parser.add_argument('-i', '--log_interval', type=int, default=10, help='default: 10') 35 | parser.add_argument('-f', '--fine_tuning', action='store_true', help='fine-tune the pretrained model') 36 | 37 | best_loss = float('inf') 38 | 39 | def main(): 40 | global best_loss 41 | 42 | args = parser.parse_args() 43 | assert osp.exists(args.pretrained_dir) 44 | 45 | args.seg_dir = osp.join(args.pretrained_dir, 'segmentation') 46 | if not osp.exists(args.seg_dir): 47 | os.makedirs(args.seg_dir) 48 | 49 | train_dataset, val_dataset, n_classes = get_datasets(args.dataset, 'segmentation') 50 | train_loader = torch.utils.data.DataLoader( 51 | train_dataset, batch_size=args.batch_size, num_workers=16, 52 | pin_memory=True 53 | ) 54 | 55 | val_loader = torch.utils.data.DataLoader( 56 | val_dataset, batch_size=args.batch_size, num_workers=16, 57 | pin_memory=True 58 | ) 59 | 60 | # Currently only supports using SimCLR 61 | pretrained_model = SimCLR('imagenet100', 100, None) 62 | ckpt = torch.load(osp.join(args.pretrained_dir, 'model_best.pth.tar'), map_location='cpu') 63 | state_dict = remove_module_state_dict(ckpt['state_dict']) 64 | pretrained_model.load_state_dict(state_dict) 65 | pretrained_model.cuda() 66 | if not args.fine_tuning: 67 | pretrained_model.eval() 68 | print(f"Loaded pretrained model at Epoch {ckpt['epoch']} Acc {ckpt['best_acc']:.2f}") 69 | 70 | model = SegmentationModel(n_classes) 71 | 72 | args.metrics = model.metrics 73 | args.metrics_fmt = model.metrics_fmt 74 | 75 | torch.backends.cudnn.benchmark = True 76 | model.cuda() 77 | 78 | params = list(model.parameters()) 79 | if args.fine_tuning: 80 | params += list(pretrained_model.parameters()) 81 | if args.optimizer == 'sgd': 82 | optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, 83 | weight_decay=args.weight_decay, nesterov=True) 84 | elif args.optimizer == 'adam': 85 | optimizer = torch.optim.Adam(params, lr=args.lr, betas=(args.momentum, 0.999), 86 | weight_decay=args.weight_decay) 87 | else: 88 | raise Exception('Unsupported optimizer', args.optimizer) 89 | 90 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 0, -1) 91 | 92 | for epoch in range(args.epochs): 93 | train(train_loader, pretrained_model, model, optimizer, epoch, args) 94 | val_loss, val_acc, val_miou = validate(val_loader, pretrained_model, model, args, dist) 95 | 96 | scheduler.step() 97 | 98 | is_best = val_loss < best_loss 99 | best_loss = min(val_loss, best_loss) 100 | save_checkpoint({ 101 | 'epoch': epoch + 1, 102 | 'state_dict': model.state_dict(), 103 | 'optimizer': optimizer.state_dict(), 104 | 'scheduler': scheduler.state_dict(), 105 | 'pt_state_dict': pretrained_model.state_dict(), 106 | 'best_loss': best_loss, 107 | 'best_acc': val_acc, 108 | 'best_miou': val_miou 109 | }, is_best, args) 110 | 111 | # Save segmentation samples to visualize 112 | if epoch % 10 == 0: 113 | with torch.no_grad(): 114 | images, target = next(iter(val_loader)) 115 | images, target = images[:33], target[:33] 116 | images = images.cuda(non_blocking=True) 117 | target = target.cuda(non_blocking=True).long().squeeze(1) 118 | features = pretrained_model.get_features(images) 119 | _, logits = model(features, target) 120 | pred = torch.argmax(logits, dim=1) 121 | 122 | target = seg_idxs_to_color(target.cpu()) 123 | pred = seg_idxs_to_color(pred.cpu()) 124 | images = unnormalize(images.cpu(), 'imagenet') 125 | 126 | to_save = torch.stack((images, target, pred), dim=1).flatten(end_dim=1) 127 | save_image(to_save, osp.join(args.seg_dir, f'epoch{epoch}.png'), nrow=10, pad_value=1.) 128 | 129 | 130 | def train(train_loader, pretrained_model, model, optimizer, epoch, args): 131 | batch_time = AverageMeter('Time', ':6.3f') 132 | data_time = AverageMeter('Data', ':6.3f') 133 | top1 = AverageMeter('PixelAcc@1', ':6.2f') 134 | top3 = AverageMeter('PixelAcc@3', ':6.2f') 135 | miou = AverageMeter('mIOU', ':6.2f') 136 | avg_meters = {k: AverageMeter(k, fmt) 137 | for k, fmt in zip(args.metrics, args.metrics_fmt)} 138 | progress = ProgressMeter( 139 | len(train_loader), 140 | [batch_time, data_time, top1, top3, miou] + list(avg_meters.values()), 141 | prefix="Epoch: [{}]".format(epoch) 142 | ) 143 | 144 | # switch to train mode 145 | model.train() 146 | if args.fine_tuning: 147 | pretrained_model.train() 148 | 149 | end = time.time() 150 | for i, (images, target) in enumerate(train_loader): 151 | # measure data loading time 152 | data_time.update(time.time() - end) 153 | 154 | # compute loss 155 | bs = images.shape[0] 156 | images = images.cuda(non_blocking=True) 157 | target = target.cuda(non_blocking=True).squeeze(1).long() 158 | 159 | features = pretrained_model.get_features(images) 160 | if not args.fine_tuning: 161 | features = [f.detach() for f in features] 162 | 163 | out, logits = model(features, target) 164 | for k, v in out.items(): 165 | avg_meters[k].update(v.item(), bs) 166 | 167 | # compute gradient and optimizer step for ssl task 168 | optimizer.zero_grad() 169 | out['Loss'].backward() 170 | optimizer.step() 171 | 172 | miou.update(compute_mIOU(logits, target), bs) 173 | acc1, acc3 = accuracy(logits, target, topk=(1, 3)) 174 | top1.update(acc1[0], bs) 175 | top3.update(acc3[0], bs) 176 | 177 | # measure elapsed time 178 | batch_time.update(time.time() - end) 179 | end = time.time() 180 | 181 | if i % args.log_interval == 0: 182 | progress.display(i) 183 | 184 | 185 | def validate(val_loader, pretrained_model, model, args, dist): 186 | batch_time = AverageMeter('Time', ':6.3f') 187 | data_time = AverageMeter('Data', ':6.3f') 188 | top1 = AverageMeter('PixelAcc@1', ':6.2f') 189 | top3 = AverageMeter('PixelAcc@3', ':6.2f') 190 | miou = AverageMeter('mIOU', ':6.2f') 191 | avg_meters = {k: AverageMeter(k, fmt) 192 | for k, fmt in zip(args.metrics, args.metrics_fmt)} 193 | progress = ProgressMeter( 194 | len(val_loader), 195 | [batch_time, data_time, top1, top3, miou] + list(avg_meters.values()), 196 | prefix="Test: " 197 | ) 198 | 199 | # switch to evaluate mode 200 | model.eval() 201 | pretrained_model.eval() 202 | 203 | with torch.no_grad(): 204 | end = time.time() 205 | for i, (images, target) in enumerate(val_loader): 206 | # compute and measure loss 207 | bs = images.shape[0] 208 | images = images.cuda(non_blocking=True) 209 | target = target.cuda(non_blocking=True).squeeze(1).long() 210 | 211 | features = pretrained_model.get_features(images) 212 | out, logits = model(features, target) 213 | for k, v in out.items(): 214 | avg_meters[k].update(v.item(), bs) 215 | 216 | miou.update(compute_mIOU(logits, target), bs) 217 | acc1, acc3 = accuracy(logits, target, topk=(1, 3)) 218 | top1.update(acc1[0], bs) 219 | top3.update(acc3[0], bs) 220 | 221 | # measure elapsed time 222 | batch_time.update(time.time() - end) 223 | end = time.time() 224 | 225 | if i % args.log_interval == 0: 226 | progress.display(i) 227 | 228 | print_str = f' * PixelAcc@1 {top1.avg:.3f} PixelAcc@3 {top3.avg:.3f} mIOU {miou.avg:.3f}' 229 | for k, v in avg_meters.items(): 230 | print_str += f' {k} {v.avg:.3f}' 231 | print(print_str) 232 | 233 | return avg_meters['Loss'].avg, top1.avg, miou.avg 234 | 235 | 236 | def save_checkpoint(state, is_best, args, filename='seg_checkpoint.pth.tar'): 237 | filename = osp.join(args.pretrained_dir, filename) 238 | torch.save(state, filename) 239 | if is_best: 240 | shutil.copyfile(filename, osp.join(args.pretrained_dir, 'seg_model_best.pth.tar')) 241 | 242 | 243 | def accuracy(logits, target, topk=(1,)): 244 | # Assumes logits (B, n_classes, H, W), target (B, H, W) 245 | B, n_classes, H, W = logits.shape 246 | logits = logits.permute(0, 2, 3, 1).contiguous().view(-1, n_classes) 247 | target = target.view(-1) 248 | with torch.no_grad(): 249 | maxk = max(topk) 250 | batch_size = target.size(0) 251 | 252 | _, pred = logits.topk(maxk, 1, True, True) 253 | pred = pred.t() 254 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 255 | 256 | res = [] 257 | for k in topk: 258 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 259 | res.append(correct_k.mul_(100.0 / (B * H * W))) 260 | return res 261 | 262 | def compute_mIOU(logits, target): 263 | # Assumes logits (B, n_classes, H, W), target (B, H, W) 264 | n_classes = logits.shape[1] 265 | pred = torch.argmax(logits, dim=1) 266 | 267 | # Ignore background class 0 268 | intersection = pred * (pred == target) 269 | area_intersection = torch.histc(intersection, bins=n_classes - 1, min=1, max=n_classes-1) 270 | 271 | area_pred = torch.histc(pred, bins=n_classes - 1, min=1, max=n_classes - 1) 272 | area_target = torch.histc(target, bins=n_classes - 1, min=1, max=n_classes - 1) 273 | area_union = area_pred + area_target - area_intersection 274 | 275 | return torch.mean(area_intersection / (area_union + 1e-10)) * 100. 276 | 277 | 278 | if __name__ == '__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /train_self_supervised_task.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import shutil 6 | from warmup_scheduler import GradualWarmupScheduler 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import torch.optim.lr_scheduler as lr_scheduler 14 | 15 | from deepul_helper.tasks import * 16 | from deepul_helper.utils import AverageMeter, ProgressMeter, accuracy 17 | from deepul_helper.data import get_datasets 18 | from deepul_helper.lars import LARS 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-d', '--dataset', type=str, default='cifar10', help='cifar10|imagenet* (default: cifar10)') 23 | parser.add_argument('-t', '--task', type=str, default='rotation', 24 | help='context_encoder|rotation|cpc|simclr (default: rotation)') 25 | 26 | # Training parameters 27 | parser.add_argument('-b', '--batch_size', type=int, default=128, help='batch size total for all gpus (default: 128)') 28 | parser.add_argument('-e', '--epochs', type=int, default=200, help='default: 200') 29 | parser.add_argument('-o', '--optimizer', type=str, default='sgd', help='sgd|lars|adam (default: sgd)') 30 | parser.add_argument('--lr', type=float, default=0.1, help='default: 0.1') 31 | parser.add_argument('-m', '--momentum', type=float, default=0.9, help='default: 0.9') 32 | parser.add_argument('-w', '--weight_decay', type=float, default=5e-4, help='default: 5e-4') 33 | parser.add_argument('-u', '--warmup_epochs', type=int, default=0, 34 | help='# of warmup epochs. If > 0, then the scheduler warmups from lr * batch_size / 256.') 35 | 36 | parser.add_argument('-p', '--port', type=int, default=23456, help='tcp port for distributed trainign (default: 23456)') 37 | parser.add_argument('-i', '--log_interval', type=int, default=10, help='default: 10') 38 | 39 | 40 | best_loss = float('inf') 41 | best_acc = 0.0 42 | 43 | def main(): 44 | args = parser.parse_args() 45 | assert args.task in ['context_encoder', 'rotation', 'cpc', 'simclr'] 46 | 47 | args.output_dir = osp.join('results', f"{args.dataset}_{args.task}") 48 | if not osp.exists(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | 51 | ngpus = torch.cuda.device_count() 52 | mp.spawn(main_worker, nprocs=ngpus, args=(ngpus, args), join=True) 53 | 54 | 55 | def main_worker(gpu, ngpus, args): 56 | global best_loss 57 | 58 | print(f'Starting process on GPU: {gpu}') 59 | dist.init_process_group(backend='nccl', init_method=f'tcp://localhost:{args.port}', 60 | world_size=ngpus, rank=gpu) 61 | total_batch_size = args.batch_size 62 | args.batch_size = args.batch_size // ngpus 63 | 64 | train_dataset, val_dataset, n_classes = get_datasets(args.dataset, args.task) 65 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 66 | train_loader = torch.utils.data.DataLoader( 67 | train_dataset, batch_size=args.batch_size, num_workers=16, 68 | pin_memory=True, sampler=train_sampler, drop_last=True 69 | ) 70 | 71 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 72 | val_loader = torch.utils.data.DataLoader( 73 | val_dataset, batch_size=args.batch_size, num_workers=16, 74 | pin_memory=True, drop_last=True, sampler=val_sampler 75 | ) 76 | 77 | if args.task == 'context_encoder': 78 | model = ContextEncoder(args.dataset, n_classes) 79 | elif args.task == 'rotation': 80 | model = RotationPrediction(args.dataset, n_classes) 81 | elif args.task == 'cpc': 82 | model = CPC(args.dataset, n_classes) 83 | elif args.task == 'simclr': 84 | model = SimCLR(args.dataset, n_classes, dist) 85 | else: 86 | raise Exception('Invalid task:', args.task) 87 | args.metrics = model.metrics 88 | args.metrics_fmt = model.metrics_fmt 89 | 90 | torch.backends.cudnn.benchmark = True 91 | torch.cuda.set_device(gpu) 92 | model.cuda(gpu) 93 | 94 | args.gpu = gpu 95 | 96 | linear_classifier = model.construct_classifier().cuda(gpu) 97 | linear_classifier = torch.nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[gpu], find_unused_parameters=True) 98 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True) 99 | 100 | if args.optimizer == 'sgd': 101 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, 102 | weight_decay=args.weight_decay, nesterov=True) 103 | optimizer_linear = torch.optim.SGD(linear_classifier.parameters(), lr=args.lr, 104 | momentum=args.momentum, nesterov=True) 105 | elif args.optimizer == 'adam': 106 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, 0.999), 107 | weight_decay=args.weight_decay) 108 | optimizer_linear = torch.optim.Adam(linear_classifier.parameters(), lr=args.lr, 109 | betas=(args.momentum, 0.999)) 110 | elif args.optimizer == 'lars': 111 | optimizer = LARS(model.parameters(), lr=args.lr, momentum=args.momentum, 112 | weight_decay=args.weight_decay) 113 | optimizer_linear = LARS(linear_classifier.parameters(), lr=args.lr, 114 | momentum=args.momentum) 115 | else: 116 | raise Exception('Unsupported optimizer', args.optimizer) 117 | 118 | # Minimize SSL task loss, maximize linear classification accuracy 119 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 0, -1) 120 | scheduler_linear = lr_scheduler.CosineAnnealingLR(optimizer_linear, args.epochs, 0, -1) 121 | if args.warmup_epochs > 0: 122 | scheduler = GradualWarmupScheduler(optimizer, multiplier=total_batch_size / 256., 123 | total_epoch=args.warmup_epochs, after_scheduler=scheduler) 124 | scheduler_linear = GradualWarmupScheduler(optimizer, multiplier=total_batch_size / 256., 125 | total_epoch=args.warmup_epochs, 126 | after_scheduler=scheduler_linear) 127 | 128 | for epoch in range(args.epochs): 129 | train_sampler.set_epoch(epoch) 130 | 131 | train(train_loader, model, linear_classifier, 132 | optimizer, optimizer_linear, epoch, args) 133 | 134 | val_loss, val_acc = validate(val_loader, model, linear_classifier, args, dist) 135 | 136 | scheduler.step() 137 | scheduler_linear.step() 138 | 139 | if dist.get_rank() == 0: 140 | is_best = val_loss < best_loss 141 | best_loss = min(val_loss, best_loss) 142 | save_checkpoint({ 143 | 'epoch': epoch + 1, 144 | 'state_dict': model.state_dict(), 145 | 'optimizer': optimizer.state_dict(), 146 | 'scheduler': scheduler.state_dict(), 147 | 'state_dict_linear': linear_classifier.state_dict(), 148 | 'optimizer_linear': optimizer_linear.state_dict(), 149 | 'schedular_linear': scheduler_linear.state_dict(), 150 | 'best_loss': best_loss, 151 | 'best_acc': val_acc 152 | }, is_best, args) 153 | 154 | 155 | def train(train_loader, model, linear_classifier, optimizer, 156 | optimizer_linear, epoch, args): 157 | batch_time = AverageMeter('Time', ':6.3f') 158 | data_time = AverageMeter('Data', ':6.3f') 159 | top1 = AverageMeter('LinearAcc@1', ':6.2f') 160 | top5 = AverageMeter('LinearAcc@5', ':6.2f') 161 | avg_meters = {k: AverageMeter(k, fmt) 162 | for k, fmt in zip(args.metrics, args.metrics_fmt)} 163 | progress = ProgressMeter( 164 | len(train_loader), 165 | [batch_time, data_time, top1, top5] + list(avg_meters.values()), 166 | prefix="Epoch: [{}]".format(epoch) 167 | ) 168 | 169 | # switch to train mode 170 | model.train() 171 | linear_classifier.train() 172 | 173 | end = time.time() 174 | for i, (images, target) in enumerate(train_loader): 175 | # measure data loading time 176 | data_time.update(time.time() - end) 177 | 178 | # compute loss 179 | if isinstance(images, (tuple, list)): 180 | # Special case for SimCLR which returns a tuple of 2 image batches 181 | bs = images[0].shape[0] 182 | images = [x.cuda(args.gpu, non_blocking=True) 183 | for x in images] 184 | else: 185 | bs = images.shape[0] 186 | images = images.cuda(args.gpu, non_blocking=True) 187 | target = target.cuda(args.gpu, non_blocking=True) 188 | out, zs = model(images) 189 | zs = zs.detach() 190 | for k, v in out.items(): 191 | avg_meters[k].update(v.item(), bs) 192 | 193 | # compute gradient and optimizer step for ssl task 194 | optimizer.zero_grad() 195 | out['Loss'].backward() 196 | optimizer.step() 197 | 198 | # compute gradient and optimizer step for classifier 199 | logits = linear_classifier(zs) 200 | loss = F.cross_entropy(logits, target) 201 | 202 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 203 | top1.update(acc1[0], bs) 204 | top5.update(acc5[0], bs) 205 | 206 | optimizer_linear.zero_grad() 207 | loss.backward() 208 | optimizer_linear.step() 209 | 210 | # measure elapsed time 211 | batch_time.update(time.time() - end) 212 | end = time.time() 213 | 214 | if i % args.log_interval == 0: 215 | progress.display(i) 216 | 217 | 218 | def validate(val_loader, model, linear_classifier, args, dist): 219 | batch_time = AverageMeter('Time', ':6.3f') 220 | data_time = AverageMeter('Data', ':6.3f') 221 | top1 = AverageMeter('LinearAcc@1', ':6.2f') 222 | top5 = AverageMeter('LinearAcc@5', ':6.2f') 223 | avg_meters = {k: AverageMeter(k, fmt) 224 | for k, fmt in zip(args.metrics, args.metrics_fmt)} 225 | progress = ProgressMeter( 226 | len(val_loader), 227 | [batch_time, data_time, top1, top5] + list(avg_meters.values()), 228 | prefix="Test: " 229 | ) 230 | 231 | # switch to evaluate mode 232 | model.eval() 233 | linear_classifier.eval() 234 | 235 | with torch.no_grad(): 236 | end = time.time() 237 | for i, (images, target) in enumerate(val_loader): 238 | # compute and measure loss 239 | if isinstance(images, (tuple, list)): 240 | # Special case for SimCLR which returns a tuple of 2 image batches 241 | bs = images[0].shape[0] 242 | images = [x.cuda(args.gpu, non_blocking=True) 243 | for x in images] 244 | else: 245 | bs = images.shape[0] 246 | images = images.cuda(args.gpu, non_blocking=True) 247 | target = target.cuda(args.gpu, non_blocking=True) 248 | out, zs = model(images) 249 | for k, v in out.items(): 250 | avg_meters[k].update(v.item(), bs) 251 | 252 | logits = linear_classifier(zs) 253 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 254 | top1.update(acc1[0], bs) 255 | top5.update(acc5[0], bs) 256 | 257 | # measure elapsed time 258 | batch_time.update(time.time() - end) 259 | end = time.time() 260 | 261 | if i % args.log_interval == 0: 262 | progress.display(i) 263 | 264 | data = torch.FloatTensor([avg_meters['Loss'].avg, top1.avg, top5.avg] + [v.avg for v in avg_meters.values()]) 265 | data = data.cuda(args.gpu) 266 | gather_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 267 | dist.all_gather(gather_list, data) 268 | data = torch.stack(gather_list, dim=0).mean(0).cpu().numpy() 269 | 270 | if dist.get_rank() == 0: 271 | print_str = f' * LinearAcc@1 {data[1]:.3f} LinearAcc@5 {data[2]:.3f}' 272 | for i, (k, v) in enumerate(avg_meters.items()): 273 | print_str += f' {k} {data[i+3]:.3f}' 274 | print(print_str) 275 | 276 | dist.barrier() 277 | return data[0], data[1] 278 | 279 | 280 | def save_checkpoint(state, is_best, args, filename='checkpoint.pth.tar'): 281 | filename = osp.join(args.output_dir, filename) 282 | torch.save(state, filename) 283 | if is_best: 284 | shutil.copyfile(filename, osp.join(args.output_dir, 'model_best.pth.tar')) 285 | 286 | 287 | if __name__ == '__main__': 288 | main() 289 | --------------------------------------------------------------------------------