├── LICENSE.txt
├── README.md
├── data_aug
├── contrastive_learning_dataset.py
├── gaussian_blur.py
└── view_generator.py
├── env.yml
├── exceptions
└── exceptions.py
├── feature_eval
└── mini_batch_logistic_regression_evaluator.ipynb
├── models
└── resnet_simclr.py
├── requirements.txt
├── run.py
├── simclr.py
└── utils.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Thalles Silva
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
2 | [](https://zenodo.org/badge/latestdoi/241184407)
3 |
4 |
5 | ### Blog post with full documentation: [Exploring SimCLR: A Simple Framework for Contrastive Learning of Visual Representations](https://sthalles.github.io/simple-self-supervised-learning/)
6 |
7 | 
8 |
9 | ### See also [PyTorch Implementation for BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning](https://github.com/sthalles/PyTorch-BYOL).
10 |
11 | ## Installation
12 |
13 | ```
14 | $ conda env create --name simclr --file env.yml
15 | $ conda activate simclr
16 | $ python run.py
17 | ```
18 |
19 | ## Config file
20 |
21 | Before running SimCLR, make sure you choose the correct running configurations. You can change the running configurations by passing keyword arguments to the ```run.py``` file.
22 |
23 | ```python
24 |
25 | $ python run.py -data ./datasets --dataset-name stl10 --log-every-n-steps 100 --epochs 100
26 |
27 | ```
28 |
29 | If you want to run it on CPU (for debugging purposes) use the ```--disable-cuda``` option.
30 |
31 | For 16-bit precision GPU training, there **NO** need to to install [NVIDIA apex](https://github.com/NVIDIA/apex). Just use the ```--fp16_precision``` flag and this implementation will use [Pytorch built in AMP training](https://pytorch.org/docs/stable/notes/amp_examples.html).
32 |
33 | ## Feature Evaluation
34 |
35 | Feature evaluation is done using a linear model protocol.
36 |
37 | First, we learned features using SimCLR on the ```STL10 unsupervised``` set. Then, we train a linear classifier on top of the frozen features from SimCLR. The linear model is trained on features extracted from the ```STL10 train``` set and evaluated on the ```STL10 test``` set.
38 |
39 | Check the [](https://github.com/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb) notebook for reproducibility.
40 |
41 | Note that SimCLR benefits from **longer training**.
42 |
43 | | Linear Classification | Dataset | Feature Extractor | Architecture | Feature dimensionality | Projection Head dimensionality | Epochs | Top1 % |
44 | |----------------------------|---------|-------------------|---------------------------------------------------------------------------------|------------------------|--------------------------------|--------|--------|
45 | | Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF) | 512 | 128 | 100 | 74.45 |
46 | | Logistic Regression (Adam) | CIFAR10 | SimCLR | [ResNet-18](https://drive.google.com/open?id=1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C) | 512 | 128 | 100 | 69.82 |
47 | | Logistic Regression (Adam) | STL10 | SimCLR | [ResNet-50](https://drive.google.com/open?id=1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu) | 2048 | 128 | 50 | 70.075 |
48 |
--------------------------------------------------------------------------------
/data_aug/contrastive_learning_dataset.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import transforms
2 | from data_aug.gaussian_blur import GaussianBlur
3 | from torchvision import transforms, datasets
4 | from data_aug.view_generator import ContrastiveLearningViewGenerator
5 | from exceptions.exceptions import InvalidDatasetSelection
6 |
7 |
8 | class ContrastiveLearningDataset:
9 | def __init__(self, root_folder):
10 | self.root_folder = root_folder
11 |
12 | @staticmethod
13 | def get_simclr_pipeline_transform(size, s=1):
14 | """Return a set of data augmentation transformations as described in the SimCLR paper."""
15 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
16 | data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
17 | transforms.RandomHorizontalFlip(),
18 | transforms.RandomApply([color_jitter], p=0.8),
19 | transforms.RandomGrayscale(p=0.2),
20 | GaussianBlur(kernel_size=int(0.1 * size)),
21 | transforms.ToTensor()])
22 | return data_transforms
23 |
24 | def get_dataset(self, name, n_views):
25 | valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
26 | transform=ContrastiveLearningViewGenerator(
27 | self.get_simclr_pipeline_transform(32),
28 | n_views),
29 | download=True),
30 |
31 | 'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
32 | transform=ContrastiveLearningViewGenerator(
33 | self.get_simclr_pipeline_transform(96),
34 | n_views),
35 | download=True)}
36 |
37 | try:
38 | dataset_fn = valid_datasets[name]
39 | except KeyError:
40 | raise InvalidDatasetSelection()
41 | else:
42 | return dataset_fn()
43 |
--------------------------------------------------------------------------------
/data_aug/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torchvision.transforms import transforms
5 |
6 | np.random.seed(0)
7 |
8 |
9 | class GaussianBlur(object):
10 | """blur a single image on CPU"""
11 | def __init__(self, kernel_size):
12 | radias = kernel_size // 2
13 | kernel_size = radias * 2 + 1
14 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
15 | stride=1, padding=0, bias=False, groups=3)
16 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
17 | stride=1, padding=0, bias=False, groups=3)
18 | self.k = kernel_size
19 | self.r = radias
20 |
21 | self.blur = nn.Sequential(
22 | nn.ReflectionPad2d(radias),
23 | self.blur_h,
24 | self.blur_v
25 | )
26 |
27 | self.pil_to_tensor = transforms.ToTensor()
28 | self.tensor_to_pil = transforms.ToPILImage()
29 |
30 | def __call__(self, img):
31 | img = self.pil_to_tensor(img).unsqueeze(0)
32 |
33 | sigma = np.random.uniform(0.1, 2.0)
34 | x = np.arange(-self.r, self.r + 1)
35 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
36 | x = x / x.sum()
37 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
38 |
39 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
40 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
41 |
42 | with torch.no_grad():
43 | img = self.blur(img)
44 | img = img.squeeze()
45 |
46 | img = self.tensor_to_pil(img)
47 |
48 | return img
--------------------------------------------------------------------------------
/data_aug/view_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | np.random.seed(0)
4 |
5 |
6 | class ContrastiveLearningViewGenerator(object):
7 | """Take two random crops of one image as the query and key."""
8 |
9 | def __init__(self, base_transform, n_views=2):
10 | self.base_transform = base_transform
11 | self.n_views = n_views
12 |
13 | def __call__(self, x):
14 | return [self.base_transform(x) for i in range(self.n_views)]
15 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: simclr
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - cudatoolkit=10.1
9 | - numpy=1.18.1
10 | - opencv=3.4.2
11 | - pillow=7.0
12 | - pip=20.0
13 | - python=3.7.6
14 | - pytorch=1.4.0
15 | - torchvision=0.5
16 | - tensorboard=2.1
17 | - matplotlib=3.1.3
18 | - scikit-learn=0.22.1
19 | - pyyaml=5.3.1
20 | - nvidia-apex=0.1
21 |
22 |
--------------------------------------------------------------------------------
/exceptions/exceptions.py:
--------------------------------------------------------------------------------
1 | class BaseSimCLRException(Exception):
2 | """Base exception"""
3 |
4 |
5 | class InvalidBackboneError(BaseSimCLRException):
6 | """Raised when the choice of backbone Convnet is invalid."""
7 |
8 |
9 | class InvalidDatasetSelection(BaseSimCLRException):
10 | """Raised when the choice of dataset is invalid."""
11 |
--------------------------------------------------------------------------------
/feature_eval/mini_batch_logistic_regression_evaluator.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "pytorch",
7 | "language": "python",
8 | "name": "pytorch"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.6"
21 | },
22 | "colab": {
23 | "name": "Copy of mini-batch-logistic-regression-evaluator.ipynb",
24 | "provenance": [],
25 | "include_colab_link": true
26 | },
27 | "accelerator": "GPU",
28 | "widgets": {
29 | "application/vnd.jupyter.widget-state+json": {
30 | "149b9ce8fb68473a837a77431c12281a": {
31 | "model_module": "@jupyter-widgets/controls",
32 | "model_name": "HBoxModel",
33 | "state": {
34 | "_view_name": "HBoxView",
35 | "_dom_classes": [],
36 | "_model_name": "HBoxModel",
37 | "_view_module": "@jupyter-widgets/controls",
38 | "_model_module_version": "1.5.0",
39 | "_view_count": null,
40 | "_view_module_version": "1.5.0",
41 | "box_style": "",
42 | "layout": "IPY_MODEL_88cd3db2831e4c13a4a634709700d6b2",
43 | "_model_module": "@jupyter-widgets/controls",
44 | "children": [
45 | "IPY_MODEL_a88c31d74f5c40a2b24bcff5a35d216c",
46 | "IPY_MODEL_60c6150177694717a622936b830427b5"
47 | ]
48 | }
49 | },
50 | "88cd3db2831e4c13a4a634709700d6b2": {
51 | "model_module": "@jupyter-widgets/base",
52 | "model_name": "LayoutModel",
53 | "state": {
54 | "_view_name": "LayoutView",
55 | "grid_template_rows": null,
56 | "right": null,
57 | "justify_content": null,
58 | "_view_module": "@jupyter-widgets/base",
59 | "overflow": null,
60 | "_model_module_version": "1.2.0",
61 | "_view_count": null,
62 | "flex_flow": null,
63 | "width": null,
64 | "min_width": null,
65 | "border": null,
66 | "align_items": null,
67 | "bottom": null,
68 | "_model_module": "@jupyter-widgets/base",
69 | "top": null,
70 | "grid_column": null,
71 | "overflow_y": null,
72 | "overflow_x": null,
73 | "grid_auto_flow": null,
74 | "grid_area": null,
75 | "grid_template_columns": null,
76 | "flex": null,
77 | "_model_name": "LayoutModel",
78 | "justify_items": null,
79 | "grid_row": null,
80 | "max_height": null,
81 | "align_content": null,
82 | "visibility": null,
83 | "align_self": null,
84 | "height": null,
85 | "min_height": null,
86 | "padding": null,
87 | "grid_auto_rows": null,
88 | "grid_gap": null,
89 | "max_width": null,
90 | "order": null,
91 | "_view_module_version": "1.2.0",
92 | "grid_template_areas": null,
93 | "object_position": null,
94 | "object_fit": null,
95 | "grid_auto_columns": null,
96 | "margin": null,
97 | "display": null,
98 | "left": null
99 | }
100 | },
101 | "a88c31d74f5c40a2b24bcff5a35d216c": {
102 | "model_module": "@jupyter-widgets/controls",
103 | "model_name": "FloatProgressModel",
104 | "state": {
105 | "_view_name": "ProgressView",
106 | "style": "IPY_MODEL_dba019efadee4fdc8c799f309b9a7e70",
107 | "_dom_classes": [],
108 | "description": "",
109 | "_model_name": "FloatProgressModel",
110 | "bar_style": "info",
111 | "max": 1,
112 | "_view_module": "@jupyter-widgets/controls",
113 | "_model_module_version": "1.5.0",
114 | "value": 1,
115 | "_view_count": null,
116 | "_view_module_version": "1.5.0",
117 | "orientation": "horizontal",
118 | "min": 0,
119 | "description_tooltip": null,
120 | "_model_module": "@jupyter-widgets/controls",
121 | "layout": "IPY_MODEL_5901c2829a554c8ebbd5926610088041"
122 | }
123 | },
124 | "60c6150177694717a622936b830427b5": {
125 | "model_module": "@jupyter-widgets/controls",
126 | "model_name": "HTMLModel",
127 | "state": {
128 | "_view_name": "HTMLView",
129 | "style": "IPY_MODEL_957362a11d174407979cf17012bf9208",
130 | "_dom_classes": [],
131 | "description": "",
132 | "_model_name": "HTMLModel",
133 | "placeholder": "",
134 | "_view_module": "@jupyter-widgets/controls",
135 | "_model_module_version": "1.5.0",
136 | "value": " 2640404480/? [00:51<00:00, 32685718.58it/s]",
137 | "_view_count": null,
138 | "_view_module_version": "1.5.0",
139 | "description_tooltip": null,
140 | "_model_module": "@jupyter-widgets/controls",
141 | "layout": "IPY_MODEL_a4f82234388e4701a02a9f68a177193a"
142 | }
143 | },
144 | "dba019efadee4fdc8c799f309b9a7e70": {
145 | "model_module": "@jupyter-widgets/controls",
146 | "model_name": "ProgressStyleModel",
147 | "state": {
148 | "_view_name": "StyleView",
149 | "_model_name": "ProgressStyleModel",
150 | "description_width": "initial",
151 | "_view_module": "@jupyter-widgets/base",
152 | "_model_module_version": "1.5.0",
153 | "_view_count": null,
154 | "_view_module_version": "1.2.0",
155 | "bar_color": null,
156 | "_model_module": "@jupyter-widgets/controls"
157 | }
158 | },
159 | "5901c2829a554c8ebbd5926610088041": {
160 | "model_module": "@jupyter-widgets/base",
161 | "model_name": "LayoutModel",
162 | "state": {
163 | "_view_name": "LayoutView",
164 | "grid_template_rows": null,
165 | "right": null,
166 | "justify_content": null,
167 | "_view_module": "@jupyter-widgets/base",
168 | "overflow": null,
169 | "_model_module_version": "1.2.0",
170 | "_view_count": null,
171 | "flex_flow": null,
172 | "width": null,
173 | "min_width": null,
174 | "border": null,
175 | "align_items": null,
176 | "bottom": null,
177 | "_model_module": "@jupyter-widgets/base",
178 | "top": null,
179 | "grid_column": null,
180 | "overflow_y": null,
181 | "overflow_x": null,
182 | "grid_auto_flow": null,
183 | "grid_area": null,
184 | "grid_template_columns": null,
185 | "flex": null,
186 | "_model_name": "LayoutModel",
187 | "justify_items": null,
188 | "grid_row": null,
189 | "max_height": null,
190 | "align_content": null,
191 | "visibility": null,
192 | "align_self": null,
193 | "height": null,
194 | "min_height": null,
195 | "padding": null,
196 | "grid_auto_rows": null,
197 | "grid_gap": null,
198 | "max_width": null,
199 | "order": null,
200 | "_view_module_version": "1.2.0",
201 | "grid_template_areas": null,
202 | "object_position": null,
203 | "object_fit": null,
204 | "grid_auto_columns": null,
205 | "margin": null,
206 | "display": null,
207 | "left": null
208 | }
209 | },
210 | "957362a11d174407979cf17012bf9208": {
211 | "model_module": "@jupyter-widgets/controls",
212 | "model_name": "DescriptionStyleModel",
213 | "state": {
214 | "_view_name": "StyleView",
215 | "_model_name": "DescriptionStyleModel",
216 | "description_width": "",
217 | "_view_module": "@jupyter-widgets/base",
218 | "_model_module_version": "1.5.0",
219 | "_view_count": null,
220 | "_view_module_version": "1.2.0",
221 | "_model_module": "@jupyter-widgets/controls"
222 | }
223 | },
224 | "a4f82234388e4701a02a9f68a177193a": {
225 | "model_module": "@jupyter-widgets/base",
226 | "model_name": "LayoutModel",
227 | "state": {
228 | "_view_name": "LayoutView",
229 | "grid_template_rows": null,
230 | "right": null,
231 | "justify_content": null,
232 | "_view_module": "@jupyter-widgets/base",
233 | "overflow": null,
234 | "_model_module_version": "1.2.0",
235 | "_view_count": null,
236 | "flex_flow": null,
237 | "width": null,
238 | "min_width": null,
239 | "border": null,
240 | "align_items": null,
241 | "bottom": null,
242 | "_model_module": "@jupyter-widgets/base",
243 | "top": null,
244 | "grid_column": null,
245 | "overflow_y": null,
246 | "overflow_x": null,
247 | "grid_auto_flow": null,
248 | "grid_area": null,
249 | "grid_template_columns": null,
250 | "flex": null,
251 | "_model_name": "LayoutModel",
252 | "justify_items": null,
253 | "grid_row": null,
254 | "max_height": null,
255 | "align_content": null,
256 | "visibility": null,
257 | "align_self": null,
258 | "height": null,
259 | "min_height": null,
260 | "padding": null,
261 | "grid_auto_rows": null,
262 | "grid_gap": null,
263 | "max_width": null,
264 | "order": null,
265 | "_view_module_version": "1.2.0",
266 | "grid_template_areas": null,
267 | "object_position": null,
268 | "object_fit": null,
269 | "grid_auto_columns": null,
270 | "margin": null,
271 | "display": null,
272 | "left": null
273 | }
274 | }
275 | }
276 | }
277 | },
278 | "cells": [
279 | {
280 | "cell_type": "markdown",
281 | "metadata": {
282 | "id": "view-in-github",
283 | "colab_type": "text"
284 | },
285 | "source": [
286 | "
"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "metadata": {
292 | "id": "YUemQib7ZE4D"
293 | },
294 | "source": [
295 | "import torch\n",
296 | "import sys\n",
297 | "import numpy as np\n",
298 | "import os\n",
299 | "import yaml\n",
300 | "import matplotlib.pyplot as plt\n",
301 | "import torchvision"
302 | ],
303 | "execution_count": 10,
304 | "outputs": []
305 | },
306 | {
307 | "cell_type": "code",
308 | "metadata": {
309 | "id": "WSgRE1CcLqdS",
310 | "colab": {
311 | "base_uri": "https://localhost:8080/"
312 | },
313 | "outputId": "48a2ae15-f672-495b-8d43-9a23b85fa3b8"
314 | },
315 | "source": [
316 | "!pip install gdown"
317 | ],
318 | "execution_count": 11,
319 | "outputs": [
320 | {
321 | "output_type": "stream",
322 | "text": [
323 | "Requirement already satisfied: gdown in /usr/local/lib/python3.6/dist-packages (3.6.4)\n",
324 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from gdown) (1.15.0)\n",
325 | "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from gdown) (2.23.0)\n",
326 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from gdown) (4.41.1)\n",
327 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2020.12.5)\n",
328 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (1.24.3)\n",
329 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (3.0.4)\n",
330 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->gdown) (2.10)\n"
331 | ],
332 | "name": "stdout"
333 | }
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "metadata": {
339 | "id": "NOIJEui1ZziV"
340 | },
341 | "source": [
342 | "def get_file_id_by_model(folder_name):\n",
343 | " file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',\n",
344 | " 'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',\n",
345 | " 'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}\n",
346 | " return file_id.get(folder_name, \"Model not found.\")"
347 | ],
348 | "execution_count": 12,
349 | "outputs": []
350 | },
351 | {
352 | "cell_type": "code",
353 | "metadata": {
354 | "id": "G7YMxsvEZMrX",
355 | "colab": {
356 | "base_uri": "https://localhost:8080/"
357 | },
358 | "outputId": "59475430-69d2-45a2-b61b-ae755d5d6e88"
359 | },
360 | "source": [
361 | "folder_name = 'resnet50_50-epochs_stl10'\n",
362 | "file_id = get_file_id_by_model(folder_name)\n",
363 | "print(folder_name, file_id)"
364 | ],
365 | "execution_count": 13,
366 | "outputs": [
367 | {
368 | "output_type": "stream",
369 | "text": [
370 | "resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu\n"
371 | ],
372 | "name": "stdout"
373 | }
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "metadata": {
379 | "id": "PWZ8fet_YoJm",
380 | "colab": {
381 | "base_uri": "https://localhost:8080/"
382 | },
383 | "outputId": "fbaeb858-221b-4d1b-dd90-001a6e713b75"
384 | },
385 | "source": [
386 | "# download and extract model files\n",
387 | "os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))\n",
388 | "os.system('unzip {}'.format(folder_name))\n",
389 | "!ls"
390 | ],
391 | "execution_count": 14,
392 | "outputs": [
393 | {
394 | "output_type": "stream",
395 | "text": [
396 | "checkpoint_0040.pth.tar\n",
397 | "config.yml\n",
398 | "events.out.tfevents.1610927742.4cb2c837708d.2694093.0\n",
399 | "resnet50_50-epochs_stl10.zip\n",
400 | "sample_data\n",
401 | "training.log\n"
402 | ],
403 | "name": "stdout"
404 | }
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "metadata": {
410 | "id": "3_nypQVEv-hn"
411 | },
412 | "source": [
413 | "from torch.utils.data import DataLoader\n",
414 | "import torchvision.transforms as transforms\n",
415 | "from torchvision import datasets"
416 | ],
417 | "execution_count": 15,
418 | "outputs": []
419 | },
420 | {
421 | "cell_type": "code",
422 | "metadata": {
423 | "id": "lDfbL3w_Z0Od",
424 | "colab": {
425 | "base_uri": "https://localhost:8080/"
426 | },
427 | "outputId": "7532966e-1c4a-4641-c928-4cda14c53389"
428 | },
429 | "source": [
430 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
431 | "print(\"Using device:\", device)"
432 | ],
433 | "execution_count": 16,
434 | "outputs": [
435 | {
436 | "output_type": "stream",
437 | "text": [
438 | "Using device: cuda\n"
439 | ],
440 | "name": "stdout"
441 | }
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "metadata": {
447 | "id": "BfIPl0G6_RrT"
448 | },
449 | "source": [
450 | "def get_stl10_data_loaders(download, shuffle=False, batch_size=256):\n",
451 | " train_dataset = datasets.STL10('./data', split='train', download=download,\n",
452 | " transform=transforms.ToTensor())\n",
453 | "\n",
454 | " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
455 | " num_workers=0, drop_last=False, shuffle=shuffle)\n",
456 | " \n",
457 | " test_dataset = datasets.STL10('./data', split='test', download=download,\n",
458 | " transform=transforms.ToTensor())\n",
459 | "\n",
460 | " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
461 | " num_workers=10, drop_last=False, shuffle=shuffle)\n",
462 | " return train_loader, test_loader\n",
463 | "\n",
464 | "def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):\n",
465 | " train_dataset = datasets.CIFAR10('./data', train=True, download=download,\n",
466 | " transform=transforms.ToTensor())\n",
467 | "\n",
468 | " train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
469 | " num_workers=0, drop_last=False, shuffle=shuffle)\n",
470 | " \n",
471 | " test_dataset = datasets.CIFAR10('./data', train=False, download=download,\n",
472 | " transform=transforms.ToTensor())\n",
473 | "\n",
474 | " test_loader = DataLoader(test_dataset, batch_size=2*batch_size,\n",
475 | " num_workers=10, drop_last=False, shuffle=shuffle)\n",
476 | " return train_loader, test_loader"
477 | ],
478 | "execution_count": 17,
479 | "outputs": []
480 | },
481 | {
482 | "cell_type": "code",
483 | "metadata": {
484 | "id": "6N8lYkbmDTaK"
485 | },
486 | "source": [
487 | "with open(os.path.join('./config.yml')) as file:\n",
488 | " config = yaml.load(file)"
489 | ],
490 | "execution_count": 18,
491 | "outputs": []
492 | },
493 | {
494 | "cell_type": "code",
495 | "metadata": {
496 | "id": "a18lPD-tIle6"
497 | },
498 | "source": [
499 | "if config.arch == 'resnet18':\n",
500 | " model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)\n",
501 | "elif config.arch == 'resnet50':\n",
502 | " model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)"
503 | ],
504 | "execution_count": 19,
505 | "outputs": []
506 | },
507 | {
508 | "cell_type": "code",
509 | "metadata": {
510 | "id": "4AIfgq41GuTT"
511 | },
512 | "source": [
513 | "checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)\n",
514 | "state_dict = checkpoint['state_dict']\n",
515 | "\n",
516 | "for k in list(state_dict.keys()):\n",
517 | "\n",
518 | " if k.startswith('backbone.'):\n",
519 | " if k.startswith('backbone') and not k.startswith('backbone.fc'):\n",
520 | " # remove prefix\n",
521 | " state_dict[k[len(\"backbone.\"):]] = state_dict[k]\n",
522 | " del state_dict[k]"
523 | ],
524 | "execution_count": 21,
525 | "outputs": []
526 | },
527 | {
528 | "cell_type": "code",
529 | "metadata": {
530 | "id": "VVjA83PPJYWl"
531 | },
532 | "source": [
533 | "log = model.load_state_dict(state_dict, strict=False)\n",
534 | "assert log.missing_keys == ['fc.weight', 'fc.bias']"
535 | ],
536 | "execution_count": 22,
537 | "outputs": []
538 | },
539 | {
540 | "cell_type": "code",
541 | "metadata": {
542 | "id": "_GC0a14uWRr6",
543 | "colab": {
544 | "base_uri": "https://localhost:8080/",
545 | "height": 117,
546 | "referenced_widgets": [
547 | "149b9ce8fb68473a837a77431c12281a",
548 | "88cd3db2831e4c13a4a634709700d6b2",
549 | "a88c31d74f5c40a2b24bcff5a35d216c",
550 | "60c6150177694717a622936b830427b5",
551 | "dba019efadee4fdc8c799f309b9a7e70",
552 | "5901c2829a554c8ebbd5926610088041",
553 | "957362a11d174407979cf17012bf9208",
554 | "a4f82234388e4701a02a9f68a177193a"
555 | ]
556 | },
557 | "outputId": "4c2558db-921c-425e-f947-6cc746d8c749"
558 | },
559 | "source": [
560 | "if config.dataset_name == 'cifar10':\n",
561 | " train_loader, test_loader = get_cifar10_data_loaders(download=True)\n",
562 | "elif config.dataset_name == 'stl10':\n",
563 | " train_loader, test_loader = get_stl10_data_loaders(download=True)\n",
564 | "print(\"Dataset:\", config.dataset_name)"
565 | ],
566 | "execution_count": 23,
567 | "outputs": [
568 | {
569 | "output_type": "stream",
570 | "text": [
571 | "Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz\n"
572 | ],
573 | "name": "stdout"
574 | },
575 | {
576 | "output_type": "display_data",
577 | "data": {
578 | "application/vnd.jupyter.widget-view+json": {
579 | "model_id": "149b9ce8fb68473a837a77431c12281a",
580 | "version_minor": 0,
581 | "version_major": 2
582 | },
583 | "text/plain": [
584 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
585 | ]
586 | },
587 | "metadata": {
588 | "tags": []
589 | }
590 | },
591 | {
592 | "output_type": "stream",
593 | "text": [
594 | "Extracting ./data/stl10_binary.tar.gz to ./data\n",
595 | "Files already downloaded and verified\n",
596 | "Dataset: stl10\n"
597 | ],
598 | "name": "stdout"
599 | }
600 | ]
601 | },
602 | {
603 | "cell_type": "code",
604 | "metadata": {
605 | "id": "pYT_KsM0Mnnr"
606 | },
607 | "source": [
608 | "# freeze all layers but the last fc\n",
609 | "for name, param in model.named_parameters():\n",
610 | " if name not in ['fc.weight', 'fc.bias']:\n",
611 | " param.requires_grad = False\n",
612 | "\n",
613 | "parameters = list(filter(lambda p: p.requires_grad, model.parameters()))\n",
614 | "assert len(parameters) == 2 # fc.weight, fc.bias"
615 | ],
616 | "execution_count": 24,
617 | "outputs": []
618 | },
619 | {
620 | "cell_type": "code",
621 | "metadata": {
622 | "id": "aPVh1S_eMRDU"
623 | },
624 | "source": [
625 | "optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)\n",
626 | "criterion = torch.nn.CrossEntropyLoss().to(device)"
627 | ],
628 | "execution_count": 25,
629 | "outputs": []
630 | },
631 | {
632 | "cell_type": "code",
633 | "metadata": {
634 | "id": "edr6RhP2PdVq"
635 | },
636 | "source": [
637 | "def accuracy(output, target, topk=(1,)):\n",
638 | " \"\"\"Computes the accuracy over the k top predictions for the specified values of k\"\"\"\n",
639 | " with torch.no_grad():\n",
640 | " maxk = max(topk)\n",
641 | " batch_size = target.size(0)\n",
642 | "\n",
643 | " _, pred = output.topk(maxk, 1, True, True)\n",
644 | " pred = pred.t()\n",
645 | " correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
646 | "\n",
647 | " res = []\n",
648 | " for k in topk:\n",
649 | " correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
650 | " res.append(correct_k.mul_(100.0 / batch_size))\n",
651 | " return res"
652 | ],
653 | "execution_count": 26,
654 | "outputs": []
655 | },
656 | {
657 | "cell_type": "code",
658 | "metadata": {
659 | "id": "qOder0dAMI7X",
660 | "colab": {
661 | "base_uri": "https://localhost:8080/"
662 | },
663 | "outputId": "5f723b91-5a5e-43eb-ca01-a9b5ae2f1346"
664 | },
665 | "source": [
666 | "epochs = 100\n",
667 | "for epoch in range(epochs):\n",
668 | " top1_train_accuracy = 0\n",
669 | " for counter, (x_batch, y_batch) in enumerate(train_loader):\n",
670 | " x_batch = x_batch.to(device)\n",
671 | " y_batch = y_batch.to(device)\n",
672 | "\n",
673 | " logits = model(x_batch)\n",
674 | " loss = criterion(logits, y_batch)\n",
675 | " top1 = accuracy(logits, y_batch, topk=(1,))\n",
676 | " top1_train_accuracy += top1[0]\n",
677 | "\n",
678 | " optimizer.zero_grad()\n",
679 | " loss.backward()\n",
680 | " optimizer.step()\n",
681 | "\n",
682 | " top1_train_accuracy /= (counter + 1)\n",
683 | " top1_accuracy = 0\n",
684 | " top5_accuracy = 0\n",
685 | " for counter, (x_batch, y_batch) in enumerate(test_loader):\n",
686 | " x_batch = x_batch.to(device)\n",
687 | " y_batch = y_batch.to(device)\n",
688 | "\n",
689 | " logits = model(x_batch)\n",
690 | " \n",
691 | " top1, top5 = accuracy(logits, y_batch, topk=(1,5))\n",
692 | " top1_accuracy += top1[0]\n",
693 | " top5_accuracy += top5[0]\n",
694 | " \n",
695 | " top1_accuracy /= (counter + 1)\n",
696 | " top5_accuracy /= (counter + 1)\n",
697 | " print(f\"Epoch {epoch}\\tTop1 Train accuracy {top1_train_accuracy.item()}\\tTop1 Test accuracy: {top1_accuracy.item()}\\tTop5 test acc: {top5_accuracy.item()}\")"
698 | ],
699 | "execution_count": 27,
700 | "outputs": [
701 | {
702 | "output_type": "stream",
703 | "text": [
704 | "Epoch 0\tTop1 Train accuracy 28.7109375\tTop1 Test accuracy: 43.75\tTop5 test acc: 93.837890625\n",
705 | "Epoch 1\tTop1 Train accuracy 49.37959671020508\tTop1 Test accuracy: 52.8662109375\tTop5 test acc: 95.439453125\n",
706 | "Epoch 2\tTop1 Train accuracy 55.257354736328125\tTop1 Test accuracy: 56.45263671875\tTop5 test acc: 95.91796875\n",
707 | "Epoch 3\tTop1 Train accuracy 57.51838302612305\tTop1 Test accuracy: 57.39013671875\tTop5 test acc: 96.19384765625\n",
708 | "Epoch 4\tTop1 Train accuracy 58.727020263671875\tTop1 Test accuracy: 58.2568359375\tTop5 test acc: 96.435546875\n",
709 | "Epoch 5\tTop1 Train accuracy 59.677162170410156\tTop1 Test accuracy: 58.7353515625\tTop5 test acc: 96.50390625\n",
710 | "Epoch 6\tTop1 Train accuracy 60.065486907958984\tTop1 Test accuracy: 59.17724609375\tTop5 test acc: 96.708984375\n",
711 | "Epoch 7\tTop1 Train accuracy 60.612361907958984\tTop1 Test accuracy: 59.482421875\tTop5 test acc: 96.74560546875\n",
712 | "Epoch 8\tTop1 Train accuracy 60.827205657958984\tTop1 Test accuracy: 59.66064453125\tTop5 test acc: 96.77490234375\n",
713 | "Epoch 9\tTop1 Train accuracy 61.100643157958984\tTop1 Test accuracy: 60.09521484375\tTop5 test acc: 96.82373046875\n",
714 | "Epoch 10\tTop1 Train accuracy 61.52803421020508\tTop1 Test accuracy: 60.3466796875\tTop5 test acc: 96.82861328125\n",
715 | "Epoch 11\tTop1 Train accuracy 61.80147171020508\tTop1 Test accuracy: 60.6640625\tTop5 test acc: 96.8896484375\n",
716 | "Epoch 12\tTop1 Train accuracy 62.09444046020508\tTop1 Test accuracy: 60.96435546875\tTop5 test acc: 96.99462890625\n",
717 | "Epoch 13\tTop1 Train accuracy 62.541358947753906\tTop1 Test accuracy: 61.13037109375\tTop5 test acc: 97.0068359375\n",
718 | "Epoch 14\tTop1 Train accuracy 62.853858947753906\tTop1 Test accuracy: 61.34033203125\tTop5 test acc: 97.01904296875\n",
719 | "Epoch 15\tTop1 Train accuracy 62.951515197753906\tTop1 Test accuracy: 61.5673828125\tTop5 test acc: 96.99951171875\n",
720 | "Epoch 16\tTop1 Train accuracy 63.400733947753906\tTop1 Test accuracy: 61.806640625\tTop5 test acc: 97.0361328125\n",
721 | "Epoch 17\tTop1 Train accuracy 63.66958236694336\tTop1 Test accuracy: 61.98974609375\tTop5 test acc: 97.0849609375\n",
722 | "Epoch 18\tTop1 Train accuracy 63.82583236694336\tTop1 Test accuracy: 62.265625\tTop5 test acc: 97.07275390625\n",
723 | "Epoch 19\tTop1 Train accuracy 64.1187973022461\tTop1 Test accuracy: 62.412109375\tTop5 test acc: 97.09716796875\n",
724 | "Epoch 20\tTop1 Train accuracy 64.2750473022461\tTop1 Test accuracy: 62.56591796875\tTop5 test acc: 97.12158203125\n",
725 | "Epoch 21\tTop1 Train accuracy 64.4140625\tTop1 Test accuracy: 62.724609375\tTop5 test acc: 97.20703125\n",
726 | "Epoch 22\tTop1 Train accuracy 64.53125\tTop1 Test accuracy: 62.90771484375\tTop5 test acc: 97.255859375\n",
727 | "Epoch 23\tTop1 Train accuracy 64.6484375\tTop1 Test accuracy: 62.95654296875\tTop5 test acc: 97.29248046875\n",
728 | "Epoch 24\tTop1 Train accuracy 64.86328125\tTop1 Test accuracy: 63.12255859375\tTop5 test acc: 97.35595703125\n",
729 | "Epoch 25\tTop1 Train accuracy 65.1344223022461\tTop1 Test accuracy: 63.330078125\tTop5 test acc: 97.40478515625\n",
730 | "Epoch 26\tTop1 Train accuracy 65.3297348022461\tTop1 Test accuracy: 63.3984375\tTop5 test acc: 97.44873046875\n",
731 | "Epoch 27\tTop1 Train accuracy 65.4469223022461\tTop1 Test accuracy: 63.34228515625\tTop5 test acc: 97.412109375\n",
732 | "Epoch 28\tTop1 Train accuracy 65.6227035522461\tTop1 Test accuracy: 63.48876953125\tTop5 test acc: 97.412109375\n",
733 | "Epoch 29\tTop1 Train accuracy 65.85478210449219\tTop1 Test accuracy: 63.56201171875\tTop5 test acc: 97.42431640625\n",
734 | "Epoch 30\tTop1 Train accuracy 66.06732940673828\tTop1 Test accuracy: 63.67431640625\tTop5 test acc: 97.4560546875\n",
735 | "Epoch 31\tTop1 Train accuracy 66.20404815673828\tTop1 Test accuracy: 63.80859375\tTop5 test acc: 97.48046875\n",
736 | "Epoch 32\tTop1 Train accuracy 66.24080657958984\tTop1 Test accuracy: 63.92578125\tTop5 test acc: 97.5048828125\n",
737 | "Epoch 33\tTop1 Train accuracy 66.58777618408203\tTop1 Test accuracy: 63.9990234375\tTop5 test acc: 97.529296875\n",
738 | "Epoch 34\tTop1 Train accuracy 66.70496368408203\tTop1 Test accuracy: 64.1455078125\tTop5 test acc: 97.51708984375\n",
739 | "Epoch 35\tTop1 Train accuracy 66.80261993408203\tTop1 Test accuracy: 64.20654296875\tTop5 test acc: 97.529296875\n",
740 | "Epoch 36\tTop1 Train accuracy 66.91980743408203\tTop1 Test accuracy: 64.32861328125\tTop5 test acc: 97.51708984375\n",
741 | "Epoch 37\tTop1 Train accuracy 66.93933868408203\tTop1 Test accuracy: 64.3896484375\tTop5 test acc: 97.51708984375\n",
742 | "Epoch 38\tTop1 Train accuracy 66.97840118408203\tTop1 Test accuracy: 64.47021484375\tTop5 test acc: 97.529296875\n",
743 | "Epoch 39\tTop1 Train accuracy 67.11282348632812\tTop1 Test accuracy: 64.53125\tTop5 test acc: 97.56591796875\n",
744 | "Epoch 40\tTop1 Train accuracy 67.24954223632812\tTop1 Test accuracy: 64.6044921875\tTop5 test acc: 97.6025390625\n",
745 | "Epoch 41\tTop1 Train accuracy 67.34949493408203\tTop1 Test accuracy: 64.62890625\tTop5 test acc: 97.59033203125\n",
746 | "Epoch 42\tTop1 Train accuracy 67.42761993408203\tTop1 Test accuracy: 64.7265625\tTop5 test acc: 97.6025390625\n",
747 | "Epoch 43\tTop1 Train accuracy 67.52527618408203\tTop1 Test accuracy: 64.84375\tTop5 test acc: 97.61474609375\n",
748 | "Epoch 44\tTop1 Train accuracy 67.58386993408203\tTop1 Test accuracy: 64.87548828125\tTop5 test acc: 97.61474609375\n",
749 | "Epoch 45\tTop1 Train accuracy 67.64246368408203\tTop1 Test accuracy: 64.9365234375\tTop5 test acc: 97.626953125\n",
750 | "Epoch 46\tTop1 Train accuracy 67.75735473632812\tTop1 Test accuracy: 65.0341796875\tTop5 test acc: 97.66357421875\n",
751 | "Epoch 47\tTop1 Train accuracy 67.85501098632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.7001953125\n",
752 | "Epoch 48\tTop1 Train accuracy 67.89407348632812\tTop1 Test accuracy: 65.1318359375\tTop5 test acc: 97.73681640625\n",
753 | "Epoch 49\tTop1 Train accuracy 67.95266723632812\tTop1 Test accuracy: 65.15625\tTop5 test acc: 97.73681640625\n",
754 | "Epoch 50\tTop1 Train accuracy 68.01126098632812\tTop1 Test accuracy: 65.21728515625\tTop5 test acc: 97.76123046875\n",
755 | "Epoch 51\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.29052734375\tTop5 test acc: 97.7490234375\n",
756 | "Epoch 52\tTop1 Train accuracy 68.05032348632812\tTop1 Test accuracy: 65.3564453125\tTop5 test acc: 97.78564453125\n",
757 | "Epoch 53\tTop1 Train accuracy 68.20657348632812\tTop1 Test accuracy: 65.3759765625\tTop5 test acc: 97.7978515625\n",
758 | "Epoch 54\tTop1 Train accuracy 68.28469848632812\tTop1 Test accuracy: 65.45654296875\tTop5 test acc: 97.822265625\n",
759 | "Epoch 55\tTop1 Train accuracy 68.41912078857422\tTop1 Test accuracy: 65.46875\tTop5 test acc: 97.8466796875\n",
760 | "Epoch 56\tTop1 Train accuracy 68.45818328857422\tTop1 Test accuracy: 65.5615234375\tTop5 test acc: 97.85888671875\n",
761 | "Epoch 57\tTop1 Train accuracy 68.61443328857422\tTop1 Test accuracy: 65.56640625\tTop5 test acc: 97.87109375\n",
762 | "Epoch 58\tTop1 Train accuracy 68.71208953857422\tTop1 Test accuracy: 65.5859375\tTop5 test acc: 97.90771484375\n",
763 | "Epoch 59\tTop1 Train accuracy 68.69255828857422\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.919921875\n",
764 | "Epoch 60\tTop1 Train accuracy 68.80744934082031\tTop1 Test accuracy: 65.64697265625\tTop5 test acc: 97.93212890625\n",
765 | "Epoch 61\tTop1 Train accuracy 68.94416809082031\tTop1 Test accuracy: 65.72021484375\tTop5 test acc: 97.93212890625\n",
766 | "Epoch 62\tTop1 Train accuracy 69.04182434082031\tTop1 Test accuracy: 65.76904296875\tTop5 test acc: 97.919921875\n",
767 | "Epoch 63\tTop1 Train accuracy 69.06135559082031\tTop1 Test accuracy: 65.84228515625\tTop5 test acc: 97.90771484375\n",
768 | "Epoch 64\tTop1 Train accuracy 69.19807434082031\tTop1 Test accuracy: 65.93505859375\tTop5 test acc: 97.90771484375\n",
769 | "Epoch 65\tTop1 Train accuracy 69.23713684082031\tTop1 Test accuracy: 65.95947265625\tTop5 test acc: 97.9150390625\n",
770 | "Epoch 66\tTop1 Train accuracy 69.25666809082031\tTop1 Test accuracy: 66.0888671875\tTop5 test acc: 97.939453125\n",
771 | "Epoch 67\tTop1 Train accuracy 69.31526184082031\tTop1 Test accuracy: 66.02783203125\tTop5 test acc: 97.939453125\n",
772 | "Epoch 68\tTop1 Train accuracy 69.43014526367188\tTop1 Test accuracy: 66.07666015625\tTop5 test acc: 97.9638671875\n",
773 | "Epoch 69\tTop1 Train accuracy 69.48873901367188\tTop1 Test accuracy: 66.12060546875\tTop5 test acc: 97.9638671875\n",
774 | "Epoch 70\tTop1 Train accuracy 69.50827026367188\tTop1 Test accuracy: 66.083984375\tTop5 test acc: 97.95166015625\n",
775 | "Epoch 71\tTop1 Train accuracy 69.60592651367188\tTop1 Test accuracy: 66.1572265625\tTop5 test acc: 97.9638671875\n",
776 | "Epoch 72\tTop1 Train accuracy 69.68635559082031\tTop1 Test accuracy: 66.2060546875\tTop5 test acc: 97.95166015625\n",
777 | "Epoch 73\tTop1 Train accuracy 69.78170776367188\tTop1 Test accuracy: 66.2744140625\tTop5 test acc: 97.92724609375\n",
778 | "Epoch 74\tTop1 Train accuracy 69.84030151367188\tTop1 Test accuracy: 66.31591796875\tTop5 test acc: 97.92724609375\n",
779 | "Epoch 75\tTop1 Train accuracy 69.89889526367188\tTop1 Test accuracy: 66.328125\tTop5 test acc: 97.9150390625\n",
780 | "Epoch 76\tTop1 Train accuracy 69.93795776367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.92724609375\n",
781 | "Epoch 77\tTop1 Train accuracy 69.95748901367188\tTop1 Test accuracy: 66.41357421875\tTop5 test acc: 97.9150390625\n",
782 | "Epoch 78\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.474609375\tTop5 test acc: 97.9150390625\n",
783 | "Epoch 79\tTop1 Train accuracy 69.99655151367188\tTop1 Test accuracy: 66.53564453125\tTop5 test acc: 97.939453125\n",
784 | "Epoch 80\tTop1 Train accuracy 70.01608276367188\tTop1 Test accuracy: 66.56005859375\tTop5 test acc: 97.939453125\n",
785 | "Epoch 81\tTop1 Train accuracy 70.09420776367188\tTop1 Test accuracy: 66.56494140625\tTop5 test acc: 97.939453125\n",
786 | "Epoch 82\tTop1 Train accuracy 70.11373901367188\tTop1 Test accuracy: 66.650390625\tTop5 test acc: 97.939453125\n",
787 | "Epoch 83\tTop1 Train accuracy 70.19186401367188\tTop1 Test accuracy: 66.71142578125\tTop5 test acc: 97.92724609375\n",
788 | "Epoch 84\tTop1 Train accuracy 70.26998901367188\tTop1 Test accuracy: 66.7236328125\tTop5 test acc: 97.90283203125\n",
789 | "Epoch 85\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.73583984375\tTop5 test acc: 97.90283203125\n",
790 | "Epoch 86\tTop1 Train accuracy 70.32858276367188\tTop1 Test accuracy: 66.748046875\tTop5 test acc: 97.890625\n",
791 | "Epoch 87\tTop1 Train accuracy 70.46530151367188\tTop1 Test accuracy: 66.7724609375\tTop5 test acc: 97.890625\n",
792 | "Epoch 88\tTop1 Train accuracy 70.52389526367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.90283203125\n",
793 | "Epoch 89\tTop1 Train accuracy 70.56295776367188\tTop1 Test accuracy: 66.78466796875\tTop5 test acc: 97.890625\n",
794 | "Epoch 90\tTop1 Train accuracy 70.68014526367188\tTop1 Test accuracy: 66.83349609375\tTop5 test acc: 97.87841796875\n",
795 | "Epoch 91\tTop1 Train accuracy 70.77780151367188\tTop1 Test accuracy: 66.826171875\tTop5 test acc: 97.87841796875\n",
796 | "Epoch 92\tTop1 Train accuracy 70.81686401367188\tTop1 Test accuracy: 66.88720703125\tTop5 test acc: 97.87841796875\n",
797 | "Epoch 93\tTop1 Train accuracy 70.85592651367188\tTop1 Test accuracy: 66.8994140625\tTop5 test acc: 97.87841796875\n",
798 | "Epoch 94\tTop1 Train accuracy 70.91452026367188\tTop1 Test accuracy: 66.9482421875\tTop5 test acc: 97.890625\n",
799 | "Epoch 95\tTop1 Train accuracy 71.03170776367188\tTop1 Test accuracy: 66.98486328125\tTop5 test acc: 97.890625\n",
800 | "Epoch 96\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.001953125\tTop5 test acc: 97.91015625\n",
801 | "Epoch 97\tTop1 Train accuracy 71.09030151367188\tTop1 Test accuracy: 67.0263671875\tTop5 test acc: 97.91015625\n",
802 | "Epoch 98\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.06298828125\tTop5 test acc: 97.89794921875\n",
803 | "Epoch 99\tTop1 Train accuracy 71.12936401367188\tTop1 Test accuracy: 67.0751953125\tTop5 test acc: 97.8857421875\n"
804 | ],
805 | "name": "stdout"
806 | }
807 | ]
808 | },
809 | {
810 | "cell_type": "code",
811 | "metadata": {
812 | "id": "dtYqHZirMNZk"
813 | },
814 | "source": [
815 | ""
816 | ],
817 | "execution_count": 27,
818 | "outputs": []
819 | }
820 | ]
821 | }
--------------------------------------------------------------------------------
/models/resnet_simclr.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torchvision.models as models
3 |
4 | from exceptions.exceptions import InvalidBackboneError
5 |
6 |
7 | class ResNetSimCLR(nn.Module):
8 |
9 | def __init__(self, base_model, out_dim):
10 | super(ResNetSimCLR, self).__init__()
11 | self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
12 | "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}
13 |
14 | self.backbone = self._get_basemodel(base_model)
15 | dim_mlp = self.backbone.fc.in_features
16 |
17 | # add mlp projection head
18 | self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)
19 |
20 | def _get_basemodel(self, model_name):
21 | try:
22 | model = self.resnet_dict[model_name]
23 | except KeyError:
24 | raise InvalidBackboneError(
25 | "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
26 | else:
27 | return model
28 |
29 | def forward(self, x):
30 | return self.backbone(x)
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | absl-py=0.9.0=pypi_0
6 | blas=1.0=mkl
7 | bzip2=1.0.8=h516909a_2
8 | ca-certificates=2019.11.28=hecc5488_0
9 | cachetools=4.0.0=pypi_0
10 | cairo=1.14.12=h80bd089_1005
11 | certifi=2019.11.28=py37hc8dfbb8_1
12 | chardet=3.0.4=pypi_0
13 | cudatoolkit=10.1.243=h6bb024c_0
14 | ffmpeg=4.0.2=ha0c5888_2
15 | fontconfig=2.13.1=he4413a7_1000
16 | freeglut=3.0.0=hf484d3e_1005
17 | freetype=2.9.1=h8a8886c_1
18 | gettext=0.19.8.1=hc5be6a0_1002
19 | glib=2.56.2=had28632_1001
20 | gmp=6.1.2=hf484d3e_1000
21 | gnutls=3.5.19=h2a4e5f8_1
22 | google-auth=1.11.3=pypi_0
23 | google-auth-oauthlib=0.4.1=pypi_0
24 | graphite2=1.3.13=hf484d3e_1000
25 | grpcio=1.27.2=pypi_0
26 | harfbuzz=1.9.0=he243708_1001
27 | hdf5=1.10.2=hc401514_3
28 | icu=58.2=hf484d3e_1000
29 | idna=2.9=pypi_0
30 | intel-openmp=2020.0=166
31 | jasper=2.0.14=h07fcdf6_1
32 | jpeg=9b=h024ee3a_2
33 | ld_impl_linux-64=2.33.1=h53a641e_7
34 | libedit=3.1.20181209=hc058e9b_0
35 | libffi=3.2.1=hd88cf55_4
36 | libgcc-ng=9.1.0=hdf63c60_0
37 | libgfortran=3.0.0=1
38 | libgfortran-ng=7.3.0=hdf63c60_0
39 | libglu=9.0.0=hf484d3e_1000
40 | libiconv=1.15=h516909a_1005
41 | libopencv=3.4.2=hb342d67_1
42 | libpng=1.6.37=hbc83047_0
43 | libstdcxx-ng=9.1.0=hdf63c60_0
44 | libtiff=4.1.0=h2733197_0
45 | libuuid=2.32.1=h14c3975_1000
46 | libxcb=1.13=h14c3975_1002
47 | libxml2=2.9.9=h13577e0_2
48 | markdown=3.2.1=pypi_0
49 | mkl=2020.0=166
50 | mkl-service=2.3.0=py37he904b0f_0
51 | mkl_fft=1.0.15=py37ha843d7b_0
52 | mkl_random=1.1.0=py37hd6b4f25_0
53 | ncurses=6.2=he6710b0_0
54 | nettle=3.3=0
55 | ninja=1.9.0=py37hfd86e86_0
56 | numpy=1.18.1=py37h4f9e942_0
57 | numpy-base=1.18.1=py37hde5b4d6_1
58 | oauthlib=3.1.0=pypi_0
59 | olefile=0.46=py37_0
60 | opencv=3.4.2=py37h6fd60c2_1
61 | openh264=1.8.0=hdbcaa40_1000
62 | openssl=1.1.1d=h516909a_0
63 | pcre=8.44=he1b5a44_0
64 | pillow=7.0.0=py37hb39fc2d_0
65 | pip=20.0.2=py37_1
66 | pixman=0.34.0=h14c3975_1003
67 | protobuf=3.11.3=pypi_0
68 | pthread-stubs=0.4=h14c3975_1001
69 | py-opencv=3.4.2=py37hb342d67_1
70 | pyasn1=0.4.8=pypi_0
71 | pyasn1-modules=0.2.8=pypi_0
72 | python=3.7.6=h0371630_2
73 | python_abi=3.7=1_cp37m
74 | pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0
75 | pyyaml=5.3=pypi_0
76 | readline=7.0=h7b6447c_5
77 | requests=2.23.0=pypi_0
78 | requests-oauthlib=1.3.0=pypi_0
79 | rsa=4.0=pypi_0
80 | setuptools=46.0.0=py37_0
81 | six=1.14.0=py37_0
82 | sqlite=3.31.1=h7b6447c_0
83 | tensorboard=2.1.1=pypi_0
84 | tk=8.6.8=hbc83047_0
85 | torchvision=0.5.0=py37_cu101
86 | urllib3=1.25.8=pypi_0
87 | werkzeug=1.0.0=pypi_0
88 | wheel=0.34.2=py37_0
89 | x264=1!152.20180806=h14c3975_0
90 | xorg-fixesproto=5.0=h14c3975_1002
91 | xorg-inputproto=2.3.2=h14c3975_1002
92 | xorg-kbproto=1.0.7=h14c3975_1002
93 | xorg-libice=1.0.10=h516909a_0
94 | xorg-libsm=1.2.3=h84519dc_1000
95 | xorg-libx11=1.6.9=h516909a_0
96 | xorg-libxau=1.0.9=h14c3975_0
97 | xorg-libxdmcp=1.1.3=h516909a_0
98 | xorg-libxext=1.3.4=h516909a_0
99 | xorg-libxfixes=5.0.3=h516909a_1004
100 | xorg-libxi=1.7.10=h516909a_0
101 | xorg-libxrender=0.9.10=h516909a_1002
102 | xorg-renderproto=0.11.1=h14c3975_1002
103 | xorg-xextproto=7.3.0=h14c3975_1002
104 | xorg-xproto=7.0.31=h14c3975_1007
105 | xz=5.2.4=h14c3975_4
106 | zlib=1.2.11=h7b6447c_3
107 | zstd=1.3.7=h0b5b093_0
108 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.backends.cudnn as cudnn
4 | from torchvision import models
5 | from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
6 | from models.resnet_simclr import ResNetSimCLR
7 | from simclr import SimCLR
8 |
9 | model_names = sorted(name for name in models.__dict__
10 | if name.islower() and not name.startswith("__")
11 | and callable(models.__dict__[name]))
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch SimCLR')
14 | parser.add_argument('-data', metavar='DIR', default='./datasets',
15 | help='path to dataset')
16 | parser.add_argument('-dataset-name', default='stl10',
17 | help='dataset name', choices=['stl10', 'cifar10'])
18 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
19 | choices=model_names,
20 | help='model architecture: ' +
21 | ' | '.join(model_names) +
22 | ' (default: resnet50)')
23 | parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
24 | help='number of data loading workers (default: 32)')
25 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
26 | help='number of total epochs to run')
27 | parser.add_argument('-b', '--batch-size', default=256, type=int,
28 | metavar='N',
29 | help='mini-batch size (default: 256), this is the total '
30 | 'batch size of all GPUs on the current node when '
31 | 'using Data Parallel or Distributed Data Parallel')
32 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
33 | metavar='LR', help='initial learning rate', dest='lr')
34 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
35 | metavar='W', help='weight decay (default: 1e-4)',
36 | dest='weight_decay')
37 | parser.add_argument('--seed', default=None, type=int,
38 | help='seed for initializing training. ')
39 | parser.add_argument('--disable-cuda', action='store_true',
40 | help='Disable CUDA')
41 | parser.add_argument('--fp16-precision', action='store_true',
42 | help='Whether or not to use 16-bit precision GPU training.')
43 |
44 | parser.add_argument('--out_dim', default=128, type=int,
45 | help='feature dimension (default: 128)')
46 | parser.add_argument('--log-every-n-steps', default=100, type=int,
47 | help='Log every n steps')
48 | parser.add_argument('--temperature', default=0.07, type=float,
49 | help='softmax temperature (default: 0.07)')
50 | parser.add_argument('--n-views', default=2, type=int, metavar='N',
51 | help='Number of views for contrastive learning training.')
52 | parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
53 |
54 |
55 | def main():
56 | args = parser.parse_args()
57 | assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
58 | # check if gpu training is available
59 | if not args.disable_cuda and torch.cuda.is_available():
60 | args.device = torch.device('cuda')
61 | cudnn.deterministic = True
62 | cudnn.benchmark = True
63 | else:
64 | args.device = torch.device('cpu')
65 | args.gpu_index = -1
66 |
67 | dataset = ContrastiveLearningDataset(args.data)
68 |
69 | train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)
70 |
71 | train_loader = torch.utils.data.DataLoader(
72 | train_dataset, batch_size=args.batch_size, shuffle=True,
73 | num_workers=args.workers, pin_memory=True, drop_last=True)
74 |
75 | model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
76 |
77 | optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
78 |
79 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
80 | last_epoch=-1)
81 |
82 | # It’s a no-op if the 'gpu_index' argument is a negative integer or None.
83 | with torch.cuda.device(args.gpu_index):
84 | simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
85 | simclr.train(train_loader)
86 |
87 |
88 | if __name__ == "__main__":
89 | main()
90 |
--------------------------------------------------------------------------------
/simclr.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.cuda.amp import GradScaler, autocast
8 | from torch.utils.tensorboard import SummaryWriter
9 | from tqdm import tqdm
10 | from utils import save_config_file, accuracy, save_checkpoint
11 |
12 | torch.manual_seed(0)
13 |
14 |
15 | class SimCLR(object):
16 |
17 | def __init__(self, *args, **kwargs):
18 | self.args = kwargs['args']
19 | self.model = kwargs['model'].to(self.args.device)
20 | self.optimizer = kwargs['optimizer']
21 | self.scheduler = kwargs['scheduler']
22 | self.writer = SummaryWriter()
23 | logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
24 | self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
25 |
26 | def info_nce_loss(self, features):
27 |
28 | labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
29 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
30 | labels = labels.to(self.args.device)
31 |
32 | features = F.normalize(features, dim=1)
33 |
34 | similarity_matrix = torch.matmul(features, features.T)
35 | # assert similarity_matrix.shape == (
36 | # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
37 | # assert similarity_matrix.shape == labels.shape
38 |
39 | # discard the main diagonal from both: labels and similarities matrix
40 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
41 | labels = labels[~mask].view(labels.shape[0], -1)
42 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
43 | # assert similarity_matrix.shape == labels.shape
44 |
45 | # select and combine multiple positives
46 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
47 |
48 | # select only the negatives the negatives
49 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
50 |
51 | logits = torch.cat([positives, negatives], dim=1)
52 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)
53 |
54 | logits = logits / self.args.temperature
55 | return logits, labels
56 |
57 | def train(self, train_loader):
58 |
59 | scaler = GradScaler(enabled=self.args.fp16_precision)
60 |
61 | # save config file
62 | save_config_file(self.writer.log_dir, self.args)
63 |
64 | n_iter = 0
65 | logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
66 | logging.info(f"Training with gpu: {self.args.disable_cuda}.")
67 |
68 | for epoch_counter in range(self.args.epochs):
69 | for images, _ in tqdm(train_loader):
70 | images = torch.cat(images, dim=0)
71 |
72 | images = images.to(self.args.device)
73 |
74 | with autocast(enabled=self.args.fp16_precision):
75 | features = self.model(images)
76 | logits, labels = self.info_nce_loss(features)
77 | loss = self.criterion(logits, labels)
78 |
79 | self.optimizer.zero_grad()
80 |
81 | scaler.scale(loss).backward()
82 |
83 | scaler.step(self.optimizer)
84 | scaler.update()
85 |
86 | if n_iter % self.args.log_every_n_steps == 0:
87 | top1, top5 = accuracy(logits, labels, topk=(1, 5))
88 | self.writer.add_scalar('loss', loss, global_step=n_iter)
89 | self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
90 | self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
91 | self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)
92 |
93 | n_iter += 1
94 |
95 | # warmup for the first 10 epochs
96 | if epoch_counter >= 10:
97 | self.scheduler.step()
98 | logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")
99 |
100 | logging.info("Training has finished.")
101 | # save model checkpoints
102 | checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
103 | save_checkpoint({
104 | 'epoch': self.args.epochs,
105 | 'arch': self.args.arch,
106 | 'state_dict': self.model.state_dict(),
107 | 'optimizer': self.optimizer.state_dict(),
108 | }, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
109 | logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")
110 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import torch
5 | import yaml
6 |
7 |
8 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
9 | torch.save(state, filename)
10 | if is_best:
11 | shutil.copyfile(filename, 'model_best.pth.tar')
12 |
13 |
14 | def save_config_file(model_checkpoints_folder, args):
15 | if not os.path.exists(model_checkpoints_folder):
16 | os.makedirs(model_checkpoints_folder)
17 | with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
18 | yaml.dump(args, outfile, default_flow_style=False)
19 |
20 |
21 | def accuracy(output, target, topk=(1,)):
22 | """Computes the accuracy over the k top predictions for the specified values of k"""
23 | with torch.no_grad():
24 | maxk = max(topk)
25 | batch_size = target.size(0)
26 |
27 | _, pred = output.topk(maxk, 1, True, True)
28 | pred = pred.t()
29 | correct = pred.eq(target.view(1, -1).expand_as(pred))
30 |
31 | res = []
32 | for k in topk:
33 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
34 | res.append(correct_k.mul_(100.0 / batch_size))
35 | return res
36 |
--------------------------------------------------------------------------------