├── LICENSE.txt ├── README.md ├── data_aug ├── contrastive_learning_dataset.py ├── gaussian_blur.py └── view_generator.py ├── env.yml ├── exceptions └── exceptions.py ├── feature_eval └── mini_batch_logistic_regression_evaluator.ipynb ├── models └── resnet_simclr.py ├── requirements.txt ├── run.py ├── simclr.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Thalles Silva 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations 2 | [![DOI](https://zenodo.org/badge/241184407.svg)](https://zenodo.org/badge/latestdoi/241184407) 3 | 4 | 5 | ### Blog post with full documentation: [Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations](https://sthalles.github.io/simple-self-supervised-learning/) 6 | 7 | ![Image of SimCLR Arch](https://sthalles.github.io/assets/contrastive-self-supervised/cover.png) 8 | 9 | ### See also [PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning](https://github.com/sthalles/PyTorch-BYOL). 10 | 11 | ## Installation 12 | 13 | ``` 14 | $ conda env create --name simclr --file env.yml 15 | $ conda activate simclr 16 | $ python run.py 17 | ``` 18 | 19 | ## Config file 20 | 21 | Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the ```run.py``` file. 22 | 23 | ```python 24 | 25 | $ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100 26 | 27 | ``` 28 | 29 | If you want to run it on CPU (for debugging purposes) use the ```--disable-cuda``` option. 30 | 31 | For 16-bit precision GPU training, there **NO** need to to install [NVIDIA apex](https://github.com/NVIDIA/apex). Just use the ```--fp16_precision``` flag and this implementation will use [Pytorch built in AMP training](https://pytorch.org/docs/stable/notes/amp_examples.html). 32 | 33 | ## Feature Evaluation 34 | 35 | Feature evaluation is done using a linear model protocol. 36 | 37 | First, we learned features using SimCLR on the ```STL10 unsupervised``` set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the ```STL10 train``` set and evaluated on the ```STL10 test``` set. 38 | 39 | Check the [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb) notebook for reproducibility. 40 | 41 | Note that SimCLR benefits from **longer training**. 42 | 43 | | Linear Classification | Dataset | Feature Extractor | Architecture | Feature dimensionality | Projection Head dimensionality | Epochs | Top1 % | 44 | |----------------------------|---------|-------------------|---------------------------------------------------------------------------------|------------------------|--------------------------------|--------|--------| 45 | | Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF) | 512 | 128 | 100 | 74.45 | 46 | | Logistic Regression (Adam) | CIFAR10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C) | 512 | 128 | 100 | 69.82 | 47 | | Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-50](https://drive.google.com/open?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu) | 2048 | 128 | 50 | 70.075 | 48 | -------------------------------------------------------------------------------- /data_aug/contrastive_learning_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | from data_aug.gaussian_blur import GaussianBlur 3 | from torchvision import transforms, datasets 4 | from data_aug.view_generator import ContrastiveLearningViewGenerator 5 | from exceptions.exceptions import InvalidDatasetSelection 6 | 7 | 8 | class ContrastiveLearningDataset: 9 | def __init__(self, root_folder): 10 | self.root_folder = root_folder 11 | 12 | @staticmethod 13 | def get_simclr_pipeline_transform(size, s=1): 14 | """Return a set of data augmentation transformations as described in the SimCLR paper.""" 15 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 16 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.RandomApply([color_jitter], p=0.8), 19 | transforms.RandomGrayscale(p=0.2), 20 | GaussianBlur(kernel_size=int(0.1 * size)), 21 | transforms.ToTensor()]) 22 | return data_transforms 23 | 24 | def get_dataset(self, name, n_views): 25 | valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True, 26 | transform=ContrastiveLearningViewGenerator( 27 | self.get_simclr_pipeline_transform(32), 28 | n_views), 29 | download=True), 30 | 31 | 'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled', 32 | transform=ContrastiveLearningViewGenerator( 33 | self.get_simclr_pipeline_transform(96), 34 | n_views), 35 | download=True)} 36 | 37 | try: 38 | dataset_fn = valid_datasets[name] 39 | except KeyError: 40 | raise InvalidDatasetSelection() 41 | else: 42 | return dataset_fn() 43 | -------------------------------------------------------------------------------- /data_aug/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision.transforms import transforms 5 | 6 | np.random.seed(0) 7 | 8 | 9 | class GaussianBlur(object): 10 | """blur a single image on CPU""" 11 | def __init__(self, kernel_size): 12 | radias = kernel_size // 2 13 | kernel_size = radias * 2 + 1 14 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 15 | stride=1, padding=0, bias=False, groups=3) 16 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 17 | stride=1, padding=0, bias=False, groups=3) 18 | self.k = kernel_size 19 | self.r = radias 20 | 21 | self.blur = nn.Sequential( 22 | nn.ReflectionPad2d(radias), 23 | self.blur_h, 24 | self.blur_v 25 | ) 26 | 27 | self.pil_to_tensor = transforms.ToTensor() 28 | self.tensor_to_pil = transforms.ToPILImage() 29 | 30 | def __call__(self, img): 31 | img = self.pil_to_tensor(img).unsqueeze(0) 32 | 33 | sigma = np.random.uniform(0.1, 2.0) 34 | x = np.arange(-self.r, self.r + 1) 35 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 36 | x = x / x.sum() 37 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 38 | 39 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 40 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 41 | 42 | with torch.no_grad(): 43 | img = self.blur(img) 44 | img = img.squeeze() 45 | 46 | img = self.tensor_to_pil(img) 47 | 48 | return img -------------------------------------------------------------------------------- /data_aug/view_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | np.random.seed(0) 4 | 5 | 6 | class ContrastiveLearningViewGenerator(object): 7 | """Take two random crops of one image as the query and key.""" 8 | 9 | def __init__(self, base_transform, n_views=2): 10 | self.base_transform = base_transform 11 | self.n_views = n_views 12 | 13 | def __call__(self, x): 14 | return [self.base_transform(x) for i in range(self.n_views)] 15 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: simclr 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - cudatoolkit=10.1 9 | - numpy=1.18.1 10 | - opencv=3.4.2 11 | - pillow=7.0 12 | - pip=20.0 13 | - python=3.7.6 14 | - pytorch=1.4.0 15 | - torchvision=0.5 16 | - tensorboard=2.1 17 | - matplotlib=3.1.3 18 | - scikit-learn=0.22.1 19 | - pyyaml=5.3.1 20 | - nvidia-apex=0.1 21 | 22 | -------------------------------------------------------------------------------- /exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class BaseSimCLRException(Exception): 2 | """Base exception""" 3 | 4 | 5 | class InvalidBackboneError(BaseSimCLRException): 6 | """Raised when the choice of backbone Convnet is invalid.""" 7 | 8 | 9 | class InvalidDatasetSelection(BaseSimCLRException): 10 | """Raised when the choice of dataset is invalid.""" 11 | -------------------------------------------------------------------------------- /feature_eval/mini_batch_logistic_regression_evaluator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "pytorch", 7 | "language": "python", 8 | "name": "pytorch" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.6" 21 | }, 22 | "colab": { 23 | "name": "Copy of mini-batch-logistic-regression-evaluator.ipynb", 24 | "provenance": [], 25 | "include_colab_link": true 26 | }, 27 | "accelerator": "GPU", 28 | "widgets": { 29 | "application/vnd.jupyter.widget-state+json": { 30 | "149b9ce8fb68473a837a77431c12281a": { 31 | "model_module": "@jupyter-widgets/controls", 32 | "model_name": "HBoxModel", 33 | "state": { 34 | "_view_name": "HBoxView", 35 | "_dom_classes": [], 36 | "_model_name": "HBoxModel", 37 | "_view_module": "@jupyter-widgets/controls", 38 | "_model_module_version": "1.5.0", 39 | "_view_count": null, 40 | "_view_module_version": "1.5.0", 41 | "box_style": "", 42 | "layout": "IPY_MODEL_88cd3db2831e4c13a4a634709700d6b2", 43 | "_model_module": "@jupyter-widgets/controls", 44 | "children": [ 45 | "IPY_MODEL_a88c31d74f5c40a2b24bcff5a35d216c", 46 | "IPY_MODEL_60c6150177694717a622936b830427b5" 47 | ] 48 | } 49 | }, 50 | "88cd3db2831e4c13a4a634709700d6b2": { 51 | "model_module": "@jupyter-widgets/base", 52 | "model_name": "LayoutModel", 53 | "state": { 54 | "_view_name": "LayoutView", 55 | "grid_template_rows": null, 56 | "right": null, 57 | "justify_content": null, 58 | "_view_module": "@jupyter-widgets/base", 59 | "overflow": null, 60 | "_model_module_version": "1.2.0", 61 | "_view_count": null, 62 | "flex_flow": null, 63 | "width": null, 64 | "min_width": null, 65 | "border": null, 66 | "align_items": null, 67 | "bottom": null, 68 | "_model_module": "@jupyter-widgets/base", 69 | "top": null, 70 | "grid_column": null, 71 | "overflow_y": null, 72 | "overflow_x": null, 73 | "grid_auto_flow": null, 74 | "grid_area": null, 75 | "grid_template_columns": null, 76 | "flex": null, 77 | "_model_name": "LayoutModel", 78 | "justify_items": null, 79 | "grid_row": null, 80 | "max_height": null, 81 | "align_content": null, 82 | "visibility": null, 83 | "align_self": null, 84 | "height": null, 85 | "min_height": null, 86 | "padding": null, 87 | "grid_auto_rows": null, 88 | "grid_gap": null, 89 | "max_width": null, 90 | "order": null, 91 | "_view_module_version": "1.2.0", 92 | "grid_template_areas": null, 93 | "object_position": null, 94 | "object_fit": null, 95 | "grid_auto_columns": null, 96 | "margin": null, 97 | "display": null, 98 | "left": null 99 | } 100 | }, 101 | "a88c31d74f5c40a2b24bcff5a35d216c": { 102 | "model_module": "@jupyter-widgets/controls", 103 | "model_name": "FloatProgressModel", 104 | "state": { 105 | "_view_name": "ProgressView", 106 | "style": "IPY_MODEL_dba019efadee4fdc8c799f309b9a7e70", 107 | "_dom_classes": [], 108 | "description": "", 109 | "_model_name": "FloatProgressModel", 110 | "bar_style": "info", 111 | "max": 1, 112 | "_view_module": "@jupyter-widgets/controls", 113 | "_model_module_version": "1.5.0", 114 | "value": 1, 115 | "_view_count": null, 116 | "_view_module_version": "1.5.0", 117 | "orientation": "horizontal", 118 | "min": 0, 119 | "description_tooltip": null, 120 | "_model_module": "@jupyter-widgets/controls", 121 | "layout": "IPY_MODEL_5901c2829a554c8ebbd5926610088041" 122 | } 123 | }, 124 | "60c6150177694717a622936b830427b5": { 125 | "model_module": "@jupyter-widgets/controls", 126 | "model_name": "HTMLModel", 127 | "state": { 128 | "_view_name": "HTMLView", 129 | "style": "IPY_MODEL_957362a11d174407979cf17012bf9208", 130 | "_dom_classes": [], 131 | "description": "", 132 | "_model_name": "HTMLModel", 133 | "placeholder": "​", 134 | "_view_module": "@jupyter-widgets/controls", 135 | "_model_module_version": "1.5.0", 136 | "value": " 2640404480/? [00:51<00:00, 32685718.58it/s]", 137 | "_view_count": null, 138 | "_view_module_version": "1.5.0", 139 | "description_tooltip": null, 140 | "_model_module": "@jupyter-widgets/controls", 141 | "layout": "IPY_MODEL_a4f82234388e4701a02a9f68a177193a" 142 | } 143 | }, 144 | "dba019efadee4fdc8c799f309b9a7e70": { 145 | "model_module": "@jupyter-widgets/controls", 146 | "model_name": "ProgressStyleModel", 147 | "state": { 148 | "_view_name": "StyleView", 149 | "_model_name": "ProgressStyleModel", 150 | "description_width": "initial", 151 | "_view_module": "@jupyter-widgets/base", 152 | "_model_module_version": "1.5.0", 153 | "_view_count": null, 154 | "_view_module_version": "1.2.0", 155 | "bar_color": null, 156 | "_model_module": "@jupyter-widgets/controls" 157 | } 158 | }, 159 | "5901c2829a554c8ebbd5926610088041": { 160 | "model_module": "@jupyter-widgets/base", 161 | "model_name": "LayoutModel", 162 | "state": { 163 | "_view_name": "LayoutView", 164 | "grid_template_rows": null, 165 | "right": null, 166 | "justify_content": null, 167 | "_view_module": "@jupyter-widgets/base", 168 | "overflow": null, 169 | "_model_module_version": "1.2.0", 170 | "_view_count": null, 171 | "flex_flow": null, 172 | "width": null, 173 | "min_width": null, 174 | "border": null, 175 | "align_items": null, 176 | "bottom": null, 177 | "_model_module": "@jupyter-widgets/base", 178 | "top": null, 179 | "grid_column": null, 180 | "overflow_y": null, 181 | "overflow_x": null, 182 | "grid_auto_flow": null, 183 | "grid_area": null, 184 | "grid_template_columns": null, 185 | "flex": null, 186 | "_model_name": "LayoutModel", 187 | "justify_items": null, 188 | "grid_row": null, 189 | "max_height": null, 190 | "align_content": null, 191 | "visibility": null, 192 | "align_self": null, 193 | "height": null, 194 | "min_height": null, 195 | "padding": null, 196 | "grid_auto_rows": null, 197 | "grid_gap": null, 198 | "max_width": null, 199 | "order": null, 200 | "_view_module_version": "1.2.0", 201 | "grid_template_areas": null, 202 | "object_position": null, 203 | "object_fit": null, 204 | "grid_auto_columns": null, 205 | "margin": null, 206 | "display": null, 207 | "left": null 208 | } 209 | }, 210 | "957362a11d174407979cf17012bf9208": { 211 | "model_module": "@jupyter-widgets/controls", 212 | "model_name": "DescriptionStyleModel", 213 | "state": { 214 | "_view_name": "StyleView", 215 | "_model_name": "DescriptionStyleModel", 216 | "description_width": "", 217 | "_view_module": "@jupyter-widgets/base", 218 | "_model_module_version": "1.5.0", 219 | "_view_count": null, 220 | "_view_module_version": "1.2.0", 221 | "_model_module": "@jupyter-widgets/controls" 222 | } 223 | }, 224 | "a4f82234388e4701a02a9f68a177193a": { 225 | "model_module": "@jupyter-widgets/base", 226 | "model_name": "LayoutModel", 227 | "state": { 228 | "_view_name": "LayoutView", 229 | "grid_template_rows": null, 230 | "right": null, 231 | "justify_content": null, 232 | "_view_module": "@jupyter-widgets/base", 233 | "overflow": null, 234 | "_model_module_version": "1.2.0", 235 | "_view_count": null, 236 | "flex_flow": null, 237 | "width": null, 238 | "min_width": null, 239 | "border": null, 240 | "align_items": null, 241 | "bottom": null, 242 | "_model_module": "@jupyter-widgets/base", 243 | "top": null, 244 | "grid_column": null, 245 | "overflow_y": null, 246 | "overflow_x": null, 247 | "grid_auto_flow": null, 248 | "grid_area": null, 249 | "grid_template_columns": null, 250 | "flex": null, 251 | "_model_name": "LayoutModel", 252 | "justify_items": null, 253 | "grid_row": null, 254 | "max_height": null, 255 | "align_content": null, 256 | "visibility": null, 257 | "align_self": null, 258 | "height": null, 259 | "min_height": null, 260 | "padding": null, 261 | "grid_auto_rows": null, 262 | "grid_gap": null, 263 | "max_width": null, 264 | "order": null, 265 | "_view_module_version": "1.2.0", 266 | "grid_template_areas": null, 267 | "object_position": null, 268 | "object_fit": null, 269 | "grid_auto_columns": null, 270 | "margin": null, 271 | "display": null, 272 | "left": null 273 | } 274 | } 275 | } 276 | } 277 | }, 278 | "cells": [ 279 | { 280 | "cell_type": "markdown", 281 | "metadata": { 282 | "id": "view-in-github", 283 | "colab_type": "text" 284 | }, 285 | "source": [ 286 | "\"Open" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "metadata": { 292 | "id": "YUemQib7ZE4D" 293 | }, 294 | "source": [ 295 | "import torch\n", 296 | "import sys\n", 297 | "import numpy as np\n", 298 | "import os\n", 299 | "import yaml\n", 300 | "import matplotlib.pyplot as plt\n", 301 | "import torchvision" 302 | ], 303 | "execution_count": 10, 304 | "outputs": [] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "metadata": { 309 | "id": "WSgRE1CcLqdS", 310 | "colab": { 311 | "base_uri": "https://localhost:8080/" 312 | }, 313 | "outputId": "48a2ae15-f672-495b-8d43-9a23b85fa3b8" 314 | }, 315 | "source": [ 316 | "!pip install gdown" 317 | ], 318 | "execution_count": 11, 319 | "outputs": [ 320 | { 321 | "output_type": "stream", 322 | "text": [ 323 | "Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n", 324 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n", 325 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n", 326 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n", 327 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n", 328 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n", 329 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n", 330 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n" 331 | ], 332 | "name": "stdout" 333 | } 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "metadata": { 339 | "id": "NOIJEui1ZziV" 340 | }, 341 | "source": [ 342 | "def get_file_id_by_model(folder_name):\n", 343 | " file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n", 344 | " 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',\n", 345 | " 'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}\n", 346 | " return file_id.get(folder_name, \"Model not found.\")" 347 | ], 348 | "execution_count": 12, 349 | "outputs": [] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "metadata": { 354 | "id": "G7YMxsvEZMrX", 355 | "colab": { 356 | "base_uri": "https://localhost:8080/" 357 | }, 358 | "outputId": "59475430-69d2-45a2-b61b-ae755d5d6e88" 359 | }, 360 | "source": [ 361 | "folder_name = 'resnet50_50-epochs_stl10'\n", 362 | "file_id = get_file_id_by_model(folder_name)\n", 363 | "print(folder_name, file_id)" 364 | ], 365 | "execution_count": 13, 366 | "outputs": [ 367 | { 368 | "output_type": "stream", 369 | "text": [ 370 | "resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu\n" 371 | ], 372 | "name": "stdout" 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "PWZ8fet_YoJm", 380 | "colab": { 381 | "base_uri": "https://localhost:8080/" 382 | }, 383 | "outputId": "fbaeb858-221b-4d1b-dd90-001a6e713b75" 384 | }, 385 | "source": [ 386 | "# download and extract model files\n", 387 | "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n", 388 | "os.system('unzip {}'.format(folder_name))\n", 389 | "!ls" 390 | ], 391 | "execution_count": 14, 392 | "outputs": [ 393 | { 394 | "output_type": "stream", 395 | "text": [ 396 | "checkpoint_0040.pth.tar\n", 397 | "config.yml\n", 398 | "events.out.tfevents.1610927742.4cb2c837708d.2694093.0\n", 399 | "resnet50_50-epochs_stl10.zip\n", 400 | "sample_data\n", 401 | "training.log\n" 402 | ], 403 | "name": "stdout" 404 | } 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "metadata": { 410 | "id": "3_nypQVEv-hn" 411 | }, 412 | "source": [ 413 | "from torch.utils.data import DataLoader\n", 414 | "import torchvision.transforms as transforms\n", 415 | "from torchvision import datasets" 416 | ], 417 | "execution_count": 15, 418 | "outputs": [] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "metadata": { 423 | "id": "lDfbL3w_Z0Od", 424 | "colab": { 425 | "base_uri": "https://localhost:8080/" 426 | }, 427 | "outputId": "7532966e-1c4a-4641-c928-4cda14c53389" 428 | }, 429 | "source": [ 430 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 431 | "print(\"Using device:\", device)" 432 | ], 433 | "execution_count": 16, 434 | "outputs": [ 435 | { 436 | "output_type": "stream", 437 | "text": [ 438 | "Using device: cuda\n" 439 | ], 440 | "name": "stdout" 441 | } 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "metadata": { 447 | "id": "BfIPl0G6_RrT" 448 | }, 449 | "source": [ 450 | "def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n", 451 | " train_dataset = datasets.STL10('./data', split='train', download=download,\n", 452 | " transform=transforms.ToTensor())\n", 453 | "\n", 454 | " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", 455 | " num_workers=0, drop_last=False, shuffle=shuffle)\n", 456 | " \n", 457 | " test_dataset = datasets.STL10('./data', split='test', download=download,\n", 458 | " transform=transforms.ToTensor())\n", 459 | "\n", 460 | " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n", 461 | " num_workers=10, drop_last=False, shuffle=shuffle)\n", 462 | " return train_loader, test_loader\n", 463 | "\n", 464 | "def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n", 465 | " train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n", 466 | " transform=transforms.ToTensor())\n", 467 | "\n", 468 | " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n", 469 | " num_workers=0, drop_last=False, shuffle=shuffle)\n", 470 | " \n", 471 | " test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n", 472 | " transform=transforms.ToTensor())\n", 473 | "\n", 474 | " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n", 475 | " num_workers=10, drop_last=False, shuffle=shuffle)\n", 476 | " return train_loader, test_loader" 477 | ], 478 | "execution_count": 17, 479 | "outputs": [] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "metadata": { 484 | "id": "6N8lYkbmDTaK" 485 | }, 486 | "source": [ 487 | "with open(os.path.join('./config.yml')) as file:\n", 488 | " config = yaml.load(file)" 489 | ], 490 | "execution_count": 18, 491 | "outputs": [] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "metadata": { 496 | "id": "a18lPD-tIle6" 497 | }, 498 | "source": [ 499 | "if config.arch == 'resnet18':\n", 500 | " model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n", 501 | "elif config.arch == 'resnet50':\n", 502 | " model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)" 503 | ], 504 | "execution_count": 19, 505 | "outputs": [] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "metadata": { 510 | "id": "4AIfgq41GuTT" 511 | }, 512 | "source": [ 513 | "checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)\n", 514 | "state_dict = checkpoint['state_dict']\n", 515 | "\n", 516 | "for k in list(state_dict.keys()):\n", 517 | "\n", 518 | " if k.startswith('backbone.'):\n", 519 | " if k.startswith('backbone') and not k.startswith('backbone.fc'):\n", 520 | " # remove prefix\n", 521 | " state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n", 522 | " del state_dict[k]" 523 | ], 524 | "execution_count": 21, 525 | "outputs": [] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "metadata": { 530 | "id": "VVjA83PPJYWl" 531 | }, 532 | "source": [ 533 | "log = model.load_state_dict(state_dict, strict=False)\n", 534 | "assert log.missing_keys == ['fc.weight', 'fc.bias']" 535 | ], 536 | "execution_count": 22, 537 | "outputs": [] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "metadata": { 542 | "id": "_GC0a14uWRr6", 543 | "colab": { 544 | "base_uri": "https://localhost:8080/", 545 | "height": 117, 546 | "referenced_widgets": [ 547 | "149b9ce8fb68473a837a77431c12281a", 548 | "88cd3db2831e4c13a4a634709700d6b2", 549 | "a88c31d74f5c40a2b24bcff5a35d216c", 550 | "60c6150177694717a622936b830427b5", 551 | "dba019efadee4fdc8c799f309b9a7e70", 552 | "5901c2829a554c8ebbd5926610088041", 553 | "957362a11d174407979cf17012bf9208", 554 | "a4f82234388e4701a02a9f68a177193a" 555 | ] 556 | }, 557 | "outputId": "4c2558db-921c-425e-f947-6cc746d8c749" 558 | }, 559 | "source": [ 560 | "if config.dataset_name == 'cifar10':\n", 561 | " train_loader, test_loader = get_cifar10_data_loaders(download=True)\n", 562 | "elif config.dataset_name == 'stl10':\n", 563 | " train_loader, test_loader = get_stl10_data_loaders(download=True)\n", 564 | "print(\"Dataset:\", config.dataset_name)" 565 | ], 566 | "execution_count": 23, 567 | "outputs": [ 568 | { 569 | "output_type": "stream", 570 | "text": [ 571 | "Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n" 572 | ], 573 | "name": "stdout" 574 | }, 575 | { 576 | "output_type": "display_data", 577 | "data": { 578 | "application/vnd.jupyter.widget-view+json": { 579 | "model_id": "149b9ce8fb68473a837a77431c12281a", 580 | "version_minor": 0, 581 | "version_major": 2 582 | }, 583 | "text/plain": [ 584 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" 585 | ] 586 | }, 587 | "metadata": { 588 | "tags": [] 589 | } 590 | }, 591 | { 592 | "output_type": "stream", 593 | "text": [ 594 | "Extracting ./data/stl10_binary.tar.gz to ./data\n", 595 | "Files already downloaded and verified\n", 596 | "Dataset: stl10\n" 597 | ], 598 | "name": "stdout" 599 | } 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "metadata": { 605 | "id": "pYT_KsM0Mnnr" 606 | }, 607 | "source": [ 608 | "# freeze all layers but the last fc\n", 609 | "for name, param in model.named_parameters():\n", 610 | " if name not in ['fc.weight', 'fc.bias']:\n", 611 | " param.requires_grad = False\n", 612 | "\n", 613 | "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n", 614 | "assert len(parameters) == 2 # fc.weight, fc.bias" 615 | ], 616 | "execution_count": 24, 617 | "outputs": [] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "metadata": { 622 | "id": "aPVh1S_eMRDU" 623 | }, 624 | "source": [ 625 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n", 626 | "criterion = torch.nn.CrossEntropyLoss().to(device)" 627 | ], 628 | "execution_count": 25, 629 | "outputs": [] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "metadata": { 634 | "id": "edr6RhP2PdVq" 635 | }, 636 | "source": [ 637 | "def accuracy(output, target, topk=(1,)):\n", 638 | " \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n", 639 | " with torch.no_grad():\n", 640 | " maxk = max(topk)\n", 641 | " batch_size = target.size(0)\n", 642 | "\n", 643 | " _, pred = output.topk(maxk, 1, True, True)\n", 644 | " pred = pred.t()\n", 645 | " correct = pred.eq(target.view(1, -1).expand_as(pred))\n", 646 | "\n", 647 | " res = []\n", 648 | " for k in topk:\n", 649 | " correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n", 650 | " res.append(correct_k.mul_(100.0 / batch_size))\n", 651 | " return res" 652 | ], 653 | "execution_count": 26, 654 | "outputs": [] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "metadata": { 659 | "id": "qOder0dAMI7X", 660 | "colab": { 661 | "base_uri": "https://localhost:8080/" 662 | }, 663 | "outputId": "5f723b91-5a5e-43eb-ca01-a9b5ae2f1346" 664 | }, 665 | "source": [ 666 | "epochs = 100\n", 667 | "for epoch in range(epochs):\n", 668 | " top1_train_accuracy = 0\n", 669 | " for counter, (x_batch, y_batch) in enumerate(train_loader):\n", 670 | " x_batch = x_batch.to(device)\n", 671 | " y_batch = y_batch.to(device)\n", 672 | "\n", 673 | " logits = model(x_batch)\n", 674 | " loss = criterion(logits, y_batch)\n", 675 | " top1 = accuracy(logits, y_batch, topk=(1,))\n", 676 | " top1_train_accuracy += top1[0]\n", 677 | "\n", 678 | " optimizer.zero_grad()\n", 679 | " loss.backward()\n", 680 | " optimizer.step()\n", 681 | "\n", 682 | " top1_train_accuracy /= (counter + 1)\n", 683 | " top1_accuracy = 0\n", 684 | " top5_accuracy = 0\n", 685 | " for counter, (x_batch, y_batch) in enumerate(test_loader):\n", 686 | " x_batch = x_batch.to(device)\n", 687 | " y_batch = y_batch.to(device)\n", 688 | "\n", 689 | " logits = model(x_batch)\n", 690 | " \n", 691 | " top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n", 692 | " top1_accuracy += top1[0]\n", 693 | " top5_accuracy += top5[0]\n", 694 | " \n", 695 | " top1_accuracy /= (counter + 1)\n", 696 | " top5_accuracy /= (counter + 1)\n", 697 | " print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")" 698 | ], 699 | "execution_count": 27, 700 | "outputs": [ 701 | { 702 | "output_type": "stream", 703 | "text": [ 704 | "Epoch 0\tTop1 Train accuracy 28.7109375\tTop1 Test accuracy: 43.75\tTop5 test acc: 93.837890625\n", 705 | "Epoch 1\tTop1 Train accuracy 49.37959671020508\tTop1 Test accuracy: 52.8662109375\tTop5 test acc: 95.439453125\n", 706 | "Epoch 2\tTop1 Train accuracy 55.257354736328125\tTop1 Test accuracy: 56.45263671875\tTop5 test acc: 95.91796875\n", 707 | "Epoch 3\tTop1 Train accuracy 57.51838302612305\tTop1 Test accuracy: 57.39013671875\tTop5 test acc: 96.19384765625\n", 708 | "Epoch 4\tTop1 Train accuracy 58.727020263671875\tTop1 Test accuracy: 58.2568359375\tTop5 test acc: 96.435546875\n", 709 | "Epoch 5\tTop1 Train accuracy 59.677162170410156\tTop1 Test accuracy: 58.7353515625\tTop5 test acc: 96.50390625\n", 710 | "Epoch 6\tTop1 Train accuracy 60.065486907958984\tTop1 Test accuracy: 59.17724609375\tTop5 test acc: 96.708984375\n", 711 | "Epoch 7\tTop1 Train accuracy 60.612361907958984\tTop1 Test accuracy: 59.482421875\tTop5 test acc: 96.74560546875\n", 712 | "Epoch 8\tTop1 Train accuracy 60.827205657958984\tTop1 Test accuracy: 59.66064453125\tTop5 test acc: 96.77490234375\n", 713 | "Epoch 9\tTop1 Train accuracy 61.100643157958984\tTop1 Test accuracy: 60.09521484375\tTop5 test acc: 96.82373046875\n", 714 | "Epoch 10\tTop1 Train accuracy 61.52803421020508\tTop1 Test accuracy: 60.3466796875\tTop5 test acc: 96.82861328125\n", 715 | "Epoch 11\tTop1 Train accuracy 61.80147171020508\tTop1 Test accuracy: 60.6640625\tTop5 test acc: 96.8896484375\n", 716 | "Epoch 12\tTop1 Train accuracy 62.09444046020508\tTop1 Test accuracy: 60.96435546875\tTop5 test acc: 96.99462890625\n", 717 | "Epoch 13\tTop1 Train accuracy 62.541358947753906\tTop1 Test accuracy: 61.13037109375\tTop5 test acc: 97.0068359375\n", 718 | "Epoch 14\tTop1 Train accuracy 62.853858947753906\tTop1 Test accuracy: 61.34033203125\tTop5 test acc: 97.01904296875\n", 719 | "Epoch 15\tTop1 Train accuracy 62.951515197753906\tTop1 Test accuracy: 61.5673828125\tTop5 test acc: 96.99951171875\n", 720 | "Epoch 16\tTop1 Train accuracy 63.400733947753906\tTop1 Test accuracy: 61.806640625\tTop5 test acc: 97.0361328125\n", 721 | "Epoch 17\tTop1 Train accuracy 63.66958236694336\tTop1 Test accuracy: 61.98974609375\tTop5 test acc: 97.0849609375\n", 722 | "Epoch 18\tTop1 Train accuracy 63.82583236694336\tTop1 Test accuracy: 62.265625\tTop5 test acc: 97.07275390625\n", 723 | "Epoch 19\tTop1 Train accuracy 64.1187973022461\tTop1 Test accuracy: 62.412109375\tTop5 test acc: 97.09716796875\n", 724 | "Epoch 20\tTop1 Train accuracy 64.2750473022461\tTop1 Test accuracy: 62.56591796875\tTop5 test acc: 97.12158203125\n", 725 | "Epoch 21\tTop1 Train accuracy 64.4140625\tTop1 Test accuracy: 62.724609375\tTop5 test acc: 97.20703125\n", 726 | "Epoch 22\tTop1 Train accuracy 64.53125\tTop1 Test accuracy: 62.90771484375\tTop5 test acc: 97.255859375\n", 727 | "Epoch 23\tTop1 Train accuracy 64.6484375\tTop1 Test accuracy: 62.95654296875\tTop5 test acc: 97.29248046875\n", 728 | "Epoch 24\tTop1 Train accuracy 64.86328125\tTop1 Test accuracy: 63.12255859375\tTop5 test acc: 97.35595703125\n", 729 | "Epoch 25\tTop1 Train accuracy 65.1344223022461\tTop1 Test accuracy: 63.330078125\tTop5 test acc: 97.40478515625\n", 730 | "Epoch 26\tTop1 Train accuracy 65.3297348022461\tTop1 Test accuracy: 63.3984375\tTop5 test acc: 97.44873046875\n", 731 | "Epoch 27\tTop1 Train accuracy 65.4469223022461\tTop1 Test accuracy: 63.34228515625\tTop5 test acc: 97.412109375\n", 732 | "Epoch 28\tTop1 Train accuracy 65.6227035522461\tTop1 Test accuracy: 63.48876953125\tTop5 test acc: 97.412109375\n", 733 | "Epoch 29\tTop1 Train accuracy 65.85478210449219\tTop1 Test accuracy: 63.56201171875\tTop5 test acc: 97.42431640625\n", 734 | "Epoch 30\tTop1 Train accuracy 66.06732940673828\tTop1 Test accuracy: 63.67431640625\tTop5 test acc: 97.4560546875\n", 735 | "Epoch 31\tTop1 Train accuracy 66.20404815673828\tTop1 Test accuracy: 63.80859375\tTop5 test acc: 97.48046875\n", 736 | "Epoch 32\tTop1 Train accuracy 66.24080657958984\tTop1 Test accuracy: 63.92578125\tTop5 test acc: 97.5048828125\n", 737 | "Epoch 33\tTop1 Train accuracy 66.58777618408203\tTop1 Test accuracy: 63.9990234375\tTop5 test acc: 97.529296875\n", 738 | "Epoch 34\tTop1 Train accuracy 66.70496368408203\tTop1 Test accuracy: 64.1455078125\tTop5 test acc: 97.51708984375\n", 739 | "Epoch 35\tTop1 Train accuracy 66.80261993408203\tTop1 Test accuracy: 64.20654296875\tTop5 test acc: 97.529296875\n", 740 | "Epoch 36\tTop1 Train accuracy 66.91980743408203\tTop1 Test accuracy: 64.32861328125\tTop5 test acc: 97.51708984375\n", 741 | "Epoch 37\tTop1 Train accuracy 66.93933868408203\tTop1 Test accuracy: 64.3896484375\tTop5 test acc: 97.51708984375\n", 742 | "Epoch 38\tTop1 Train accuracy 66.97840118408203\tTop1 Test accuracy: 64.47021484375\tTop5 test acc: 97.529296875\n", 743 | "Epoch 39\tTop1 Train accuracy 67.11282348632812\tTop1 Test accuracy: 64.53125\tTop5 test acc: 97.56591796875\n", 744 | "Epoch 40\tTop1 Train accuracy 67.24954223632812\tTop1 Test accuracy: 64.6044921875\tTop5 test acc: 97.6025390625\n", 745 | "Epoch 41\tTop1 Train accuracy 67.34949493408203\tTop1 Test accuracy: 64.62890625\tTop5 test acc: 97.59033203125\n", 746 | "Epoch 42\tTop1 Train accuracy 67.42761993408203\tTop1 Test accuracy: 64.7265625\tTop5 test acc: 97.6025390625\n", 747 | "Epoch 43\tTop1 Train accuracy 67.52527618408203\tTop1 Test accuracy: 64.84375\tTop5 test acc: 97.61474609375\n", 748 | "Epoch 44\tTop1 Train accuracy 67.58386993408203\tTop1 Test accuracy: 64.87548828125\tTop5 test acc: 97.61474609375\n", 749 | "Epoch 45\tTop1 Train accuracy 67.64246368408203\tTop1 Test accuracy: 64.9365234375\tTop5 test acc: 97.626953125\n", 750 | "Epoch 46\tTop1 Train accuracy 67.75735473632812\tTop1 Test accuracy: 65.0341796875\tTop5 test acc: 97.66357421875\n", 751 | "Epoch 47\tTop1 Train accuracy 67.85501098632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.7001953125\n", 752 | "Epoch 48\tTop1 Train accuracy 67.89407348632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.73681640625\n", 753 | "Epoch 49\tTop1 Train accuracy 67.95266723632812\tTop1 Test accuracy: 65.15625\tTop5 test acc: 97.73681640625\n", 754 | "Epoch 50\tTop1 Train accuracy 68.01126098632812\tTop1 Test accuracy: 65.21728515625\tTop5 test acc: 97.76123046875\n", 755 | "Epoch 51\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.29052734375\tTop5 test acc: 97.7490234375\n", 756 | "Epoch 52\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.3564453125\tTop5 test acc: 97.78564453125\n", 757 | "Epoch 53\tTop1 Train accuracy 68.20657348632812\tTop1 Test accuracy: 65.3759765625\tTop5 test acc: 97.7978515625\n", 758 | "Epoch 54\tTop1 Train accuracy 68.28469848632812\tTop1 Test accuracy: 65.45654296875\tTop5 test acc: 97.822265625\n", 759 | "Epoch 55\tTop1 Train accuracy 68.41912078857422\tTop1 Test accuracy: 65.46875\tTop5 test acc: 97.8466796875\n", 760 | "Epoch 56\tTop1 Train accuracy 68.45818328857422\tTop1 Test accuracy: 65.5615234375\tTop5 test acc: 97.85888671875\n", 761 | "Epoch 57\tTop1 Train accuracy 68.61443328857422\tTop1 Test accuracy: 65.56640625\tTop5 test acc: 97.87109375\n", 762 | "Epoch 58\tTop1 Train accuracy 68.71208953857422\tTop1 Test accuracy: 65.5859375\tTop5 test acc: 97.90771484375\n", 763 | "Epoch 59\tTop1 Train accuracy 68.69255828857422\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.919921875\n", 764 | "Epoch 60\tTop1 Train accuracy 68.80744934082031\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.93212890625\n", 765 | "Epoch 61\tTop1 Train accuracy 68.94416809082031\tTop1 Test accuracy: 65.72021484375\tTop5 test acc: 97.93212890625\n", 766 | "Epoch 62\tTop1 Train accuracy 69.04182434082031\tTop1 Test accuracy: 65.76904296875\tTop5 test acc: 97.919921875\n", 767 | "Epoch 63\tTop1 Train accuracy 69.06135559082031\tTop1 Test accuracy: 65.84228515625\tTop5 test acc: 97.90771484375\n", 768 | "Epoch 64\tTop1 Train accuracy 69.19807434082031\tTop1 Test accuracy: 65.93505859375\tTop5 test acc: 97.90771484375\n", 769 | "Epoch 65\tTop1 Train accuracy 69.23713684082031\tTop1 Test accuracy: 65.95947265625\tTop5 test acc: 97.9150390625\n", 770 | "Epoch 66\tTop1 Train accuracy 69.25666809082031\tTop1 Test accuracy: 66.0888671875\tTop5 test acc: 97.939453125\n", 771 | "Epoch 67\tTop1 Train accuracy 69.31526184082031\tTop1 Test accuracy: 66.02783203125\tTop5 test acc: 97.939453125\n", 772 | "Epoch 68\tTop1 Train accuracy 69.43014526367188\tTop1 Test accuracy: 66.07666015625\tTop5 test acc: 97.9638671875\n", 773 | "Epoch 69\tTop1 Train accuracy 69.48873901367188\tTop1 Test accuracy: 66.12060546875\tTop5 test acc: 97.9638671875\n", 774 | "Epoch 70\tTop1 Train accuracy 69.50827026367188\tTop1 Test accuracy: 66.083984375\tTop5 test acc: 97.95166015625\n", 775 | "Epoch 71\tTop1 Train accuracy 69.60592651367188\tTop1 Test accuracy: 66.1572265625\tTop5 test acc: 97.9638671875\n", 776 | "Epoch 72\tTop1 Train accuracy 69.68635559082031\tTop1 Test accuracy: 66.2060546875\tTop5 test acc: 97.95166015625\n", 777 | "Epoch 73\tTop1 Train accuracy 69.78170776367188\tTop1 Test accuracy: 66.2744140625\tTop5 test acc: 97.92724609375\n", 778 | "Epoch 74\tTop1 Train accuracy 69.84030151367188\tTop1 Test accuracy: 66.31591796875\tTop5 test acc: 97.92724609375\n", 779 | "Epoch 75\tTop1 Train accuracy 69.89889526367188\tTop1 Test accuracy: 66.328125\tTop5 test acc: 97.9150390625\n", 780 | "Epoch 76\tTop1 Train accuracy 69.93795776367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.92724609375\n", 781 | "Epoch 77\tTop1 Train accuracy 69.95748901367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.9150390625\n", 782 | "Epoch 78\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.474609375\tTop5 test acc: 97.9150390625\n", 783 | "Epoch 79\tTop1 Train accuracy 69.99655151367188\tTop1 Test accuracy: 66.53564453125\tTop5 test acc: 97.939453125\n", 784 | "Epoch 80\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.56005859375\tTop5 test acc: 97.939453125\n", 785 | "Epoch 81\tTop1 Train accuracy 70.09420776367188\tTop1 Test accuracy: 66.56494140625\tTop5 test acc: 97.939453125\n", 786 | "Epoch 82\tTop1 Train accuracy 70.11373901367188\tTop1 Test accuracy: 66.650390625\tTop5 test acc: 97.939453125\n", 787 | "Epoch 83\tTop1 Train accuracy 70.19186401367188\tTop1 Test accuracy: 66.71142578125\tTop5 test acc: 97.92724609375\n", 788 | "Epoch 84\tTop1 Train accuracy 70.26998901367188\tTop1 Test accuracy: 66.7236328125\tTop5 test acc: 97.90283203125\n", 789 | "Epoch 85\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.73583984375\tTop5 test acc: 97.90283203125\n", 790 | "Epoch 86\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.748046875\tTop5 test acc: 97.890625\n", 791 | "Epoch 87\tTop1 Train accuracy 70.46530151367188\tTop1 Test accuracy: 66.7724609375\tTop5 test acc: 97.890625\n", 792 | "Epoch 88\tTop1 Train accuracy 70.52389526367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.90283203125\n", 793 | "Epoch 89\tTop1 Train accuracy 70.56295776367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.890625\n", 794 | "Epoch 90\tTop1 Train accuracy 70.68014526367188\tTop1 Test accuracy: 66.83349609375\tTop5 test acc: 97.87841796875\n", 795 | "Epoch 91\tTop1 Train accuracy 70.77780151367188\tTop1 Test accuracy: 66.826171875\tTop5 test acc: 97.87841796875\n", 796 | "Epoch 92\tTop1 Train accuracy 70.81686401367188\tTop1 Test accuracy: 66.88720703125\tTop5 test acc: 97.87841796875\n", 797 | "Epoch 93\tTop1 Train accuracy 70.85592651367188\tTop1 Test accuracy: 66.8994140625\tTop5 test acc: 97.87841796875\n", 798 | "Epoch 94\tTop1 Train accuracy 70.91452026367188\tTop1 Test accuracy: 66.9482421875\tTop5 test acc: 97.890625\n", 799 | "Epoch 95\tTop1 Train accuracy 71.03170776367188\tTop1 Test accuracy: 66.98486328125\tTop5 test acc: 97.890625\n", 800 | "Epoch 96\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.001953125\tTop5 test acc: 97.91015625\n", 801 | "Epoch 97\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.0263671875\tTop5 test acc: 97.91015625\n", 802 | "Epoch 98\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.06298828125\tTop5 test acc: 97.89794921875\n", 803 | "Epoch 99\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.0751953125\tTop5 test acc: 97.8857421875\n" 804 | ], 805 | "name": "stdout" 806 | } 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "metadata": { 812 | "id": "dtYqHZirMNZk" 813 | }, 814 | "source": [ 815 | "" 816 | ], 817 | "execution_count": 27, 818 | "outputs": [] 819 | } 820 | ] 821 | } -------------------------------------------------------------------------------- /models/resnet_simclr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as models 3 | 4 | from exceptions.exceptions import InvalidBackboneError 5 | 6 | 7 | class ResNetSimCLR(nn.Module): 8 | 9 | def __init__(self, base_model, out_dim): 10 | super(ResNetSimCLR, self).__init__() 11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim), 12 | "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)} 13 | 14 | self.backbone = self._get_basemodel(base_model) 15 | dim_mlp = self.backbone.fc.in_features 16 | 17 | # add mlp projection head 18 | self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc) 19 | 20 | def _get_basemodel(self, model_name): 21 | try: 22 | model = self.resnet_dict[model_name] 23 | except KeyError: 24 | raise InvalidBackboneError( 25 | "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50") 26 | else: 27 | return model 28 | 29 | def forward(self, x): 30 | return self.backbone(x) 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.9.0=pypi_0 6 | blas=1.0=mkl 7 | bzip2=1.0.8=h516909a_2 8 | ca-certificates=2019.11.28=hecc5488_0 9 | cachetools=4.0.0=pypi_0 10 | cairo=1.14.12=h80bd089_1005 11 | certifi=2019.11.28=py37hc8dfbb8_1 12 | chardet=3.0.4=pypi_0 13 | cudatoolkit=10.1.243=h6bb024c_0 14 | ffmpeg=4.0.2=ha0c5888_2 15 | fontconfig=2.13.1=he4413a7_1000 16 | freeglut=3.0.0=hf484d3e_1005 17 | freetype=2.9.1=h8a8886c_1 18 | gettext=0.19.8.1=hc5be6a0_1002 19 | glib=2.56.2=had28632_1001 20 | gmp=6.1.2=hf484d3e_1000 21 | gnutls=3.5.19=h2a4e5f8_1 22 | google-auth=1.11.3=pypi_0 23 | google-auth-oauthlib=0.4.1=pypi_0 24 | graphite2=1.3.13=hf484d3e_1000 25 | grpcio=1.27.2=pypi_0 26 | harfbuzz=1.9.0=he243708_1001 27 | hdf5=1.10.2=hc401514_3 28 | icu=58.2=hf484d3e_1000 29 | idna=2.9=pypi_0 30 | intel-openmp=2020.0=166 31 | jasper=2.0.14=h07fcdf6_1 32 | jpeg=9b=h024ee3a_2 33 | ld_impl_linux-64=2.33.1=h53a641e_7 34 | libedit=3.1.20181209=hc058e9b_0 35 | libffi=3.2.1=hd88cf55_4 36 | libgcc-ng=9.1.0=hdf63c60_0 37 | libgfortran=3.0.0=1 38 | libgfortran-ng=7.3.0=hdf63c60_0 39 | libglu=9.0.0=hf484d3e_1000 40 | libiconv=1.15=h516909a_1005 41 | libopencv=3.4.2=hb342d67_1 42 | libpng=1.6.37=hbc83047_0 43 | libstdcxx-ng=9.1.0=hdf63c60_0 44 | libtiff=4.1.0=h2733197_0 45 | libuuid=2.32.1=h14c3975_1000 46 | libxcb=1.13=h14c3975_1002 47 | libxml2=2.9.9=h13577e0_2 48 | markdown=3.2.1=pypi_0 49 | mkl=2020.0=166 50 | mkl-service=2.3.0=py37he904b0f_0 51 | mkl_fft=1.0.15=py37ha843d7b_0 52 | mkl_random=1.1.0=py37hd6b4f25_0 53 | ncurses=6.2=he6710b0_0 54 | nettle=3.3=0 55 | ninja=1.9.0=py37hfd86e86_0 56 | numpy=1.18.1=py37h4f9e942_0 57 | numpy-base=1.18.1=py37hde5b4d6_1 58 | oauthlib=3.1.0=pypi_0 59 | olefile=0.46=py37_0 60 | opencv=3.4.2=py37h6fd60c2_1 61 | openh264=1.8.0=hdbcaa40_1000 62 | openssl=1.1.1d=h516909a_0 63 | pcre=8.44=he1b5a44_0 64 | pillow=7.0.0=py37hb39fc2d_0 65 | pip=20.0.2=py37_1 66 | pixman=0.34.0=h14c3975_1003 67 | protobuf=3.11.3=pypi_0 68 | pthread-stubs=0.4=h14c3975_1001 69 | py-opencv=3.4.2=py37hb342d67_1 70 | pyasn1=0.4.8=pypi_0 71 | pyasn1-modules=0.2.8=pypi_0 72 | python=3.7.6=h0371630_2 73 | python_abi=3.7=1_cp37m 74 | pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 75 | pyyaml=5.3=pypi_0 76 | readline=7.0=h7b6447c_5 77 | requests=2.23.0=pypi_0 78 | requests-oauthlib=1.3.0=pypi_0 79 | rsa=4.0=pypi_0 80 | setuptools=46.0.0=py37_0 81 | six=1.14.0=py37_0 82 | sqlite=3.31.1=h7b6447c_0 83 | tensorboard=2.1.1=pypi_0 84 | tk=8.6.8=hbc83047_0 85 | torchvision=0.5.0=py37_cu101 86 | urllib3=1.25.8=pypi_0 87 | werkzeug=1.0.0=pypi_0 88 | wheel=0.34.2=py37_0 89 | x264=1!152.20180806=h14c3975_0 90 | xorg-fixesproto=5.0=h14c3975_1002 91 | xorg-inputproto=2.3.2=h14c3975_1002 92 | xorg-kbproto=1.0.7=h14c3975_1002 93 | xorg-libice=1.0.10=h516909a_0 94 | xorg-libsm=1.2.3=h84519dc_1000 95 | xorg-libx11=1.6.9=h516909a_0 96 | xorg-libxau=1.0.9=h14c3975_0 97 | xorg-libxdmcp=1.1.3=h516909a_0 98 | xorg-libxext=1.3.4=h516909a_0 99 | xorg-libxfixes=5.0.3=h516909a_1004 100 | xorg-libxi=1.7.10=h516909a_0 101 | xorg-libxrender=0.9.10=h516909a_1002 102 | xorg-renderproto=0.11.1=h14c3975_1002 103 | xorg-xextproto=7.3.0=h14c3975_1002 104 | xorg-xproto=7.0.31=h14c3975_1007 105 | xz=5.2.4=h14c3975_4 106 | zlib=1.2.11=h7b6447c_3 107 | zstd=1.3.7=h0b5b093_0 108 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | from torchvision import models 5 | from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset 6 | from models.resnet_simclr import ResNetSimCLR 7 | from simclr import SimCLR 8 | 9 | model_names = sorted(name for name in models.__dict__ 10 | if name.islower() and not name.startswith("__") 11 | and callable(models.__dict__[name])) 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch SimCLR') 14 | parser.add_argument('-data', metavar='DIR', default='./datasets', 15 | help='path to dataset') 16 | parser.add_argument('-dataset-name', default='stl10', 17 | help='dataset name', choices=['stl10', 'cifar10']) 18 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 19 | choices=model_names, 20 | help='model architecture: ' + 21 | ' | '.join(model_names) + 22 | ' (default: resnet50)') 23 | parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', 24 | help='number of data loading workers (default: 32)') 25 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 26 | help='number of total epochs to run') 27 | parser.add_argument('-b', '--batch-size', default=256, type=int, 28 | metavar='N', 29 | help='mini-batch size (default: 256), this is the total ' 30 | 'batch size of all GPUs on the current node when ' 31 | 'using Data Parallel or Distributed Data Parallel') 32 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 33 | metavar='LR', help='initial learning rate', dest='lr') 34 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 35 | metavar='W', help='weight decay (default: 1e-4)', 36 | dest='weight_decay') 37 | parser.add_argument('--seed', default=None, type=int, 38 | help='seed for initializing training. ') 39 | parser.add_argument('--disable-cuda', action='store_true', 40 | help='Disable CUDA') 41 | parser.add_argument('--fp16-precision', action='store_true', 42 | help='Whether or not to use 16-bit precision GPU training.') 43 | 44 | parser.add_argument('--out_dim', default=128, type=int, 45 | help='feature dimension (default: 128)') 46 | parser.add_argument('--log-every-n-steps', default=100, type=int, 47 | help='Log every n steps') 48 | parser.add_argument('--temperature', default=0.07, type=float, 49 | help='softmax temperature (default: 0.07)') 50 | parser.add_argument('--n-views', default=2, type=int, metavar='N', 51 | help='Number of views for contrastive learning training.') 52 | parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.') 53 | 54 | 55 | def main(): 56 | args = parser.parse_args() 57 | assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2." 58 | # check if gpu training is available 59 | if not args.disable_cuda and torch.cuda.is_available(): 60 | args.device = torch.device('cuda') 61 | cudnn.deterministic = True 62 | cudnn.benchmark = True 63 | else: 64 | args.device = torch.device('cpu') 65 | args.gpu_index = -1 66 | 67 | dataset = ContrastiveLearningDataset(args.data) 68 | 69 | train_dataset = dataset.get_dataset(args.dataset_name, args.n_views) 70 | 71 | train_loader = torch.utils.data.DataLoader( 72 | train_dataset, batch_size=args.batch_size, shuffle=True, 73 | num_workers=args.workers, pin_memory=True, drop_last=True) 74 | 75 | model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim) 76 | 77 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 78 | 79 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, 80 | last_epoch=-1) 81 | 82 | # It’s a no-op if the 'gpu_index' argument is a negative integer or None. 83 | with torch.cuda.device(args.gpu_index): 84 | simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args) 85 | simclr.train(train_loader) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.cuda.amp import GradScaler, autocast 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | from utils import save_config_file, accuracy, save_checkpoint 11 | 12 | torch.manual_seed(0) 13 | 14 | 15 | class SimCLR(object): 16 | 17 | def __init__(self, *args, **kwargs): 18 | self.args = kwargs['args'] 19 | self.model = kwargs['model'].to(self.args.device) 20 | self.optimizer = kwargs['optimizer'] 21 | self.scheduler = kwargs['scheduler'] 22 | self.writer = SummaryWriter() 23 | logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG) 24 | self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device) 25 | 26 | def info_nce_loss(self, features): 27 | 28 | labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0) 29 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 30 | labels = labels.to(self.args.device) 31 | 32 | features = F.normalize(features, dim=1) 33 | 34 | similarity_matrix = torch.matmul(features, features.T) 35 | # assert similarity_matrix.shape == ( 36 | # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) 37 | # assert similarity_matrix.shape == labels.shape 38 | 39 | # discard the main diagonal from both: labels and similarities matrix 40 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device) 41 | labels = labels[~mask].view(labels.shape[0], -1) 42 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 43 | # assert similarity_matrix.shape == labels.shape 44 | 45 | # select and combine multiple positives 46 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 47 | 48 | # select only the negatives the negatives 49 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 50 | 51 | logits = torch.cat([positives, negatives], dim=1) 52 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device) 53 | 54 | logits = logits / self.args.temperature 55 | return logits, labels 56 | 57 | def train(self, train_loader): 58 | 59 | scaler = GradScaler(enabled=self.args.fp16_precision) 60 | 61 | # save config file 62 | save_config_file(self.writer.log_dir, self.args) 63 | 64 | n_iter = 0 65 | logging.info(f"Start SimCLR training for {self.args.epochs} epochs.") 66 | logging.info(f"Training with gpu: {self.args.disable_cuda}.") 67 | 68 | for epoch_counter in range(self.args.epochs): 69 | for images, _ in tqdm(train_loader): 70 | images = torch.cat(images, dim=0) 71 | 72 | images = images.to(self.args.device) 73 | 74 | with autocast(enabled=self.args.fp16_precision): 75 | features = self.model(images) 76 | logits, labels = self.info_nce_loss(features) 77 | loss = self.criterion(logits, labels) 78 | 79 | self.optimizer.zero_grad() 80 | 81 | scaler.scale(loss).backward() 82 | 83 | scaler.step(self.optimizer) 84 | scaler.update() 85 | 86 | if n_iter % self.args.log_every_n_steps == 0: 87 | top1, top5 = accuracy(logits, labels, topk=(1, 5)) 88 | self.writer.add_scalar('loss', loss, global_step=n_iter) 89 | self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter) 90 | self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter) 91 | self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter) 92 | 93 | n_iter += 1 94 | 95 | # warmup for the first 10 epochs 96 | if epoch_counter >= 10: 97 | self.scheduler.step() 98 | logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}") 99 | 100 | logging.info("Training has finished.") 101 | # save model checkpoints 102 | checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs) 103 | save_checkpoint({ 104 | 'epoch': self.args.epochs, 105 | 'arch': self.args.arch, 106 | 'state_dict': self.model.state_dict(), 107 | 'optimizer': self.optimizer.state_dict(), 108 | }, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name)) 109 | logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.") 110 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 9 | torch.save(state, filename) 10 | if is_best: 11 | shutil.copyfile(filename, 'model_best.pth.tar') 12 | 13 | 14 | def save_config_file(model_checkpoints_folder, args): 15 | if not os.path.exists(model_checkpoints_folder): 16 | os.makedirs(model_checkpoints_folder) 17 | with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile: 18 | yaml.dump(args, outfile, default_flow_style=False) 19 | 20 | 21 | def accuracy(output, target, topk=(1,)): 22 | """Computes the accuracy over the k top predictions for the specified values of k""" 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | --------------------------------------------------------------------------------