├── .gitignore
├── LICENSE.md
├── README.md
├── colabs
├── model_apply.ipynb
└── read_logs.ipynb
├── configs
├── cifar_eval.yaml
├── cifar_train_epochs1000_bs1024.yaml
├── imagenet_eval.yaml
├── imagenet_train_epochs100_bs512.yaml
├── imagenet_train_epochs200_bs2k.yaml
└── imagenet_train_epochs600_bs2k.yaml
├── environment.yml
├── models
├── __init__.py
├── encoder.py
├── losses.py
├── resnet.py
└── ssl.py
├── myexman
├── __init__.py
├── index.py
└── parser.py
├── train.py
└── utils
├── datautils.py
├── lars_optimizer.py
├── logger.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *__pycache__*
3 | *pretrained_models/*
4 | *logs
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # Cython debug symbols
143 | cython_debug/
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Andrei Atanov*, Arsenii Ashukha; Bayesian Methods Research Group, Samsung AI Center Moscow, Samsung-HSE Laboratory, EPFL
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimCLR PyTorch
2 |
3 | This is an unofficial repository reproducing results of the paper [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709). The implementation supports multi-GPU distributed training on several nodes with PyTorch `DistributedDataParallel`.
4 |
5 | ## How close are we to the original SimCLR?
6 |
7 | The implementation closely reproduces the original ResNet50 results on ImageNet and CIFAR-10.
8 |
9 |
10 |
11 |
12 |
13 | | Dataset | Batch Size | \# Epochs | Training GPUs | Training time | Top\-1 accuracy of Linear evaluation (100% labels)| Reference |
14 | |----------|------------|-----------|---------------|---------------|-----------------------------------|------------|
15 | | CIFAR-10 | 1024 | 1000 | 2v100 | 13h | 93\.44 | 93.95 |
16 | | ImageNet | 512 | 100 | 4v100 | 85h | 60\.14 | 60.62 |
17 | | ImageNet | 2048 | 200 | 16v100 | 55h | 65\.58 | 65.83 |
18 | | ImageNet | 2048 | 600 | 16v100 | 170h | 67\.84 | 68.71 |
19 |
20 | ## Pre-trained weights
21 |
22 | Try out a pre-trained models [](https://colab.research.google.com/github/AndrewAtanov/simclr-pytorch/blob/master/colabs/model_apply.ipynb)
23 |
24 | You can download pre-trained weights from [here](https://drive.google.com/file/d/13tjpWYTzV8qLB5yY5raBn5cwtIyFtt6-/view?usp=sharing).
25 |
26 | To eval the preatrained CIFAR-10 linear model and encoder use the following command:
27 | ```(bash)
28 | python train.py --problem eval --eval_only true --iters 1 --arch linear \
29 | --ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000_linear.pth.tar \
30 | --encoder_ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000.pth.tar
31 | ```
32 |
33 | To eval the preatrained ImageNet linear model and encoder use the following command:
34 | ```(bash)
35 | export IMAGENET_PATH=.../raw-data
36 | python train.py --problem eval --eval_only true --iters 1 --arch linear --data imagenet \
37 | --ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600_linear.pth.tar \
38 | --encoder_ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600.pth.tar
39 | ```
40 |
41 | ## Enviroment Setup
42 |
43 |
44 | Create a python enviroment with the provided config file and [miniconda](https://docs.conda.io/en/latest/miniconda.html):
45 |
46 | ```(bash)
47 | conda env create -f environment.yml
48 | conda activate simclr_pytorch
49 |
50 | export IMAGENET_PATH=... # If you have enough RAM using /dev/shm usually accelerates data loading time
51 | export EXMAN_PATH=... # A path to logs
52 | ```
53 |
54 | ## Training
55 | Model training consists of two steps: (1) self-supervised encoder pretraining and (2) classifier learning with the encoder representations. Both steps are done with the `train.py` script. To see the help for `sim-clr/eval` problem call the following command: `python source/train.py --help --problem sim-clr/eval`.
56 |
57 | ### Self-supervised pretraining
58 |
59 | #### CIFAR-10
60 | The config `cifar_train_epochs1000_bs1024.yaml` contains the parameters to reproduce results for CIFAR-10 dataset. It requires 2 V100 GPUs. The pretraining command is:
61 |
62 | ```(bash)
63 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml
64 | ```
65 |
66 | #### ImageNet
67 | The configs `imagenet_params_epochs*_bs*.yaml` contain the parameters to reproduce results for ImageNet dataset. It requires at 4v100-16v100 GPUs depending on a batch size. The single-node (4 v100 GPUs) pretraining command is:
68 |
69 | ```(bash)
70 | python train.py --config configs/imagenet_train_epochs100_bs512.yaml
71 | ```
72 |
73 | #### Logs
74 | The logs and the model will be stored at `./logs/exman-train.py/runs//`. You can access all the experiments from python with `exman.Index('./logs/exman-train.py').info()`.
75 |
76 | See how to work with logs [](https://colab.research.google.com/github/AndrewAtanov/simclr-pytorch/blob/master/colabs/read_logs.ipynb)
77 |
78 | ### Linear Evaluation
79 | To train a linear classifier on top of the pretrained encoder, run the following command:
80 |
81 | ```(bash)
82 | python train.py --config configs/cifar_eval.yaml --encoder_ckpt
83 | ```
84 |
85 | The above model with batch size 1024 gives `93.5` linear eval test accuracy.
86 |
87 | ### Pretraining with `DistributedDataParallel`
88 | To train a model with larger batch size on several nodes you need to set `--dist ddp` flag and specify the following parameters:
89 | - `--dist_address`: the address and a port of the main node in the `:` format
90 | - `--node_rank`: 0 for the main node and 1,... for the others.
91 | - `--world_size`: the number of nodes.
92 |
93 | For example, to train with two nodes you need to run the following command on the main node:
94 | ```(bash)
95 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address : --node_rank 0 --world_size 2
96 | ```
97 | and on the second node:
98 | ```(bash)
99 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address : --node_rank 1 --world_size 2
100 | ```
101 |
102 | The ImageNet the pretaining on 4 nodes all with 4 GPUs looks as follows:
103 | ```
104 | node1: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address : --node_rank 0
105 | node2: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address : --node_rank 1
106 | node3: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address : --node_rank 2
107 | node4: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address : --node_rank 3
108 | ```
109 |
110 | ## Attribution
111 | Parts of this code are based on the following repositories:v
112 | - [PyTorch](https://github.com/pytorch/pytorch), [PyTorch Examples](https://github.com/pytorch/examples/tree/ee964a2/imagenet), [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) for standard backbones, training loops, etc.
113 | - [SimCLR - A Simple Framework for Contrastive Learning of Visual Representations](https://github.com/google-research/simclr) for more details on the original implementation
114 | - [diffdist](https://github.com/ag14774/diffdist) for multi-gpu contrastive loss implementation, allows backpropagation through `all_gather` operation (see [models/losses.py#L58](https://github.com/AndrewAtanov/simclr-pytorch/blob/master/models/losses.py#L62))
115 | - [Experiment Manager (exman)](https://github.com/ferrine/exman) a tool that distributes logs, checkpoints, and parameters-dicts via folders, and allows to load them in a pandas DataFrame, that is handly for processing in ipython notebooks.
116 | - [NVIDIA APEX](https://github.com/NVIDIA/apex) for LARS optimizer. We modeified LARC to make it consistent with SimCLR repo.
117 |
118 | ## Acknowledgements
119 | - This work was supported in part through computational resources of HPC facilities at NRU HSE
120 |
--------------------------------------------------------------------------------
/colabs/read_logs.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "PjSN8gOUIQ1t"
7 | },
8 | "source": [
9 | "\n",
10 | "# Experiment Manager"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {
17 | "colab": {
18 | "base_uri": "https://localhost:8080/"
19 | },
20 | "id": "CkMiXmImIhUN",
21 | "outputId": "c6604c8a-ecd1-4170-d7b6-42bc06ca4977"
22 | },
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | " % Total % Received % Xferd Average Speed Time Time Time Current\n",
29 | " Dload Upload Total Spent Left Speed\n",
30 | " 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0\n",
31 | "100 869M 0 869M 0 0 9.8M 0 --:--:-- 0:01:28 --:--:-- 10.2M\n",
32 | "Archive: logs.zip\n",
33 | " inflating: logs/exman-train.py/index/000002.yaml \n",
34 | " inflating: logs/exman-train.py/index/000004.yaml \n",
35 | " inflating: logs/exman-train.py/index/000010.yaml \n",
36 | " inflating: logs/exman-train.py/index/000012.yaml \n",
37 | " inflating: logs/exman-train.py/index/000023.yaml \n",
38 | " inflating: logs/exman-train.py/index/000027.yaml \n",
39 | " inflating: logs/exman-train.py/index/000030.yaml \n",
40 | " inflating: logs/exman-train.py/index/000031.yaml \n",
41 | " inflating: logs/exman-train.py/index/000033.yaml \n",
42 | "replace logs/exman-train.py/runs/000002/checkpoint.pth.tar? [y]es, [n]o, [A]ll, [N]one, [r]ename: "
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "!!pip install diffdist wldhx.yadisk-direct configargparse strconv\n",
48 | "!curl -L $(yadisk-direct https://yadi.sk/d/GYMBGjXGQr9oFw?w=1) -o logs.zip\n",
49 | "!unzip logs.zip > unzip.out"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {
56 | "id": "bamF1nUS80W0"
57 | },
58 | "outputs": [],
59 | "source": [
60 | "!git clone https://github.com/AndrewAtanov/simclr-pytorch.git"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "import sys\n",
70 | "sys.path.append('./simclr-pytorch')"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 32,
76 | "metadata": {
77 | "id": "90vfPvSRIQ1u"
78 | },
79 | "outputs": [],
80 | "source": [
81 | "import myexman\n",
82 | "import pandas as pd\n",
83 | "\n",
84 | "index = myexman.Index('./logs/exman-train.py').info().set_index('id')\n",
85 | "index.root = index.root.apply(lambda x: str(x).replace('/home/aashukha/simclr-pytorch/', ''))"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 33,
91 | "metadata": {
92 | "colab": {
93 | "base_uri": "https://localhost:8080/"
94 | },
95 | "id": "xniSdGb0IQ1u",
96 | "outputId": "821e111c-a660-4c20-9c96-930251e13e34"
97 | },
98 | "outputs": [
99 | {
100 | "data": {
101 | "text/plain": [
102 | "Index(['arch', 'aug', 'augmentation', 'batch_size', 'ckpt', 'config_file',\n",
103 | " 'data', 'dist', 'dist_address', 'encoder_ckpt', 'eval_freq', 'finetune',\n",
104 | " 'gpu', 'iters', 'log_freq', 'lr', 'lr_schedule', 'name', 'node_rank',\n",
105 | " 'number_of_processes', 'opt', 'precompute_emb_bs', 'problem', 'root',\n",
106 | " 'save_freq', 'scale_lower', 'seed', 'test_bs', 'tmp', 'verbose',\n",
107 | " 'warmup', 'weight_decay', 'workers', 'world_size', 'time',\n",
108 | " 'base_lr_linear_scale', 'color_dist_s', 'cooldown', 'cooldown_after',\n",
109 | " 'momentum', 'multiplier', 'norm_multiplier', 'projection', 'status',\n",
110 | " 'sync_bn', 'temperature', 'ckpt_iter', 'encode_layer', 'model_id',\n",
111 | " 'use_all_classes'],\n",
112 | " dtype='object')"
113 | ]
114 | },
115 | "execution_count": 33,
116 | "metadata": {
117 | "tags": []
118 | },
119 | "output_type": "execute_result"
120 | }
121 | ],
122 | "source": [
123 | "index.columns"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 35,
129 | "metadata": {
130 | "colab": {
131 | "base_uri": "https://localhost:8080/",
132 | "height": 414
133 | },
134 | "id": "Ps-fforc-r_N",
135 | "outputId": "59109585-78d3-45fa-f599-e468b8416b7f"
136 | },
137 | "outputs": [
138 | {
139 | "data": {
140 | "text/html": [
141 | "\n",
142 | "\n",
155 | "
\n",
156 | " \n",
157 | " \n",
158 | " | \n",
159 | " arch | \n",
160 | " aug | \n",
161 | " augmentation | \n",
162 | " batch_size | \n",
163 | " ckpt | \n",
164 | " config_file | \n",
165 | " data | \n",
166 | " dist | \n",
167 | " dist_address | \n",
168 | " encoder_ckpt | \n",
169 | " eval_freq | \n",
170 | " finetune | \n",
171 | " gpu | \n",
172 | " iters | \n",
173 | " log_freq | \n",
174 | " lr | \n",
175 | " lr_schedule | \n",
176 | " name | \n",
177 | " node_rank | \n",
178 | " number_of_processes | \n",
179 | " opt | \n",
180 | " precompute_emb_bs | \n",
181 | " problem | \n",
182 | " root | \n",
183 | " save_freq | \n",
184 | " scale_lower | \n",
185 | " seed | \n",
186 | " test_bs | \n",
187 | " tmp | \n",
188 | " verbose | \n",
189 | " warmup | \n",
190 | " weight_decay | \n",
191 | " workers | \n",
192 | " world_size | \n",
193 | " time | \n",
194 | " base_lr_linear_scale | \n",
195 | " color_dist_s | \n",
196 | " cooldown | \n",
197 | " cooldown_after | \n",
198 | " momentum | \n",
199 | " multiplier | \n",
200 | " norm_multiplier | \n",
201 | " projection | \n",
202 | " status | \n",
203 | " sync_bn | \n",
204 | " temperature | \n",
205 | " ckpt_iter | \n",
206 | " encode_layer | \n",
207 | " model_id | \n",
208 | " use_all_classes | \n",
209 | "
\n",
210 | " \n",
211 | " id | \n",
212 | " | \n",
213 | " | \n",
214 | " | \n",
215 | " | \n",
216 | " | \n",
217 | " | \n",
218 | " | \n",
219 | " | \n",
220 | " | \n",
221 | " | \n",
222 | " | \n",
223 | " | \n",
224 | " | \n",
225 | " | \n",
226 | " | \n",
227 | " | \n",
228 | " | \n",
229 | " | \n",
230 | " | \n",
231 | " | \n",
232 | " | \n",
233 | " | \n",
234 | " | \n",
235 | " | \n",
236 | " | \n",
237 | " | \n",
238 | " | \n",
239 | " | \n",
240 | " | \n",
241 | " | \n",
242 | " | \n",
243 | " | \n",
244 | " | \n",
245 | " | \n",
246 | " | \n",
247 | " | \n",
248 | " | \n",
249 | " | \n",
250 | " | \n",
251 | " | \n",
252 | " | \n",
253 | " | \n",
254 | " | \n",
255 | " | \n",
256 | " | \n",
257 | " | \n",
258 | " | \n",
259 | " | \n",
260 | " | \n",
261 | " | \n",
262 | "
\n",
263 | " \n",
264 | " \n",
265 | " \n",
266 | " 2 | \n",
267 | " ResNet50 | \n",
268 | " True | \n",
269 | " NaN | \n",
270 | " 512 | \n",
271 | " | \n",
272 | " cifar_params.yaml | \n",
273 | " cifar | \n",
274 | " ddp | \n",
275 | " cn-012:8881 | \n",
276 | " NaN | \n",
277 | " 4800 | \n",
278 | " NaN | \n",
279 | " 0 | \n",
280 | " 48000 | \n",
281 | " 48 | \n",
282 | " 4.0 | \n",
283 | " warmup-anneal | \n",
284 | " | \n",
285 | " 0 | \n",
286 | " 2 | \n",
287 | " lars | \n",
288 | " NaN | \n",
289 | " sim-clr | \n",
290 | " logs/exman-train.py/runs/000002 | \n",
291 | " 4800 | \n",
292 | " 0.08 | \n",
293 | " -1 | \n",
294 | " NaN | \n",
295 | " False | \n",
296 | " True | \n",
297 | " 0.01 | \n",
298 | " 0.000001 | \n",
299 | " 2 | \n",
300 | " 2 | \n",
301 | " 2020-11-18 21:32:53 | \n",
302 | " False | \n",
303 | " 0.5 | \n",
304 | " linear | \n",
305 | " -1.0 | \n",
306 | " 0.9 | \n",
307 | " 2.0 | \n",
308 | " False | \n",
309 | " MLPv2 | \n",
310 | " fail | \n",
311 | " True | \n",
312 | " 0.5 | \n",
313 | " NaN | \n",
314 | " NaN | \n",
315 | " NaN | \n",
316 | " NaN | \n",
317 | "
\n",
318 | " \n",
319 | " 4 | \n",
320 | " ResNet50 | \n",
321 | " True | \n",
322 | " NaN | \n",
323 | " 128 | \n",
324 | " | \n",
325 | " imagenet_params_epochs200_bs2k.yaml | \n",
326 | " imagenet | \n",
327 | " ddp | \n",
328 | " cn-010:8881 | \n",
329 | " NaN | \n",
330 | " 12510 | \n",
331 | " NaN | \n",
332 | " 0 | \n",
333 | " 125100 | \n",
334 | " 100 | \n",
335 | " 2.4 | \n",
336 | " warmup-anneal | \n",
337 | " imagenet-reproduce | \n",
338 | " 0 | \n",
339 | " 16 | \n",
340 | " lars | \n",
341 | " NaN | \n",
342 | " sim-clr | \n",
343 | " logs/exman-train.py/runs/000004 | \n",
344 | " 12510 | \n",
345 | " 0.08 | \n",
346 | " -1 | \n",
347 | " NaN | \n",
348 | " False | \n",
349 | " True | \n",
350 | " 0.10 | \n",
351 | " 0.000001 | \n",
352 | " 8 | \n",
353 | " 4 | \n",
354 | " 2020-11-23 16:44:02 | \n",
355 | " False | \n",
356 | " 1.0 | \n",
357 | " linear | \n",
358 | " -1.0 | \n",
359 | " 0.9 | \n",
360 | " 2.0 | \n",
361 | " False | \n",
362 | " MLPv2 | \n",
363 | " fail | \n",
364 | " True | \n",
365 | " 0.1 | \n",
366 | " NaN | \n",
367 | " NaN | \n",
368 | " NaN | \n",
369 | " NaN | \n",
370 | "
\n",
371 | " \n",
372 | " 10 | \n",
373 | " linear | \n",
374 | " True | \n",
375 | " RandomResizedCrop | \n",
376 | " 4096 | \n",
377 | " | \n",
378 | " configs/imagenet_eval_params.yaml | \n",
379 | " imagenet | \n",
380 | " dp | \n",
381 | " | \n",
382 | " /home/aashukha/simclr-pytorch/logs/exman-train... | \n",
383 | " 100 | \n",
384 | " False | \n",
385 | " 0 | \n",
386 | " 28080 | \n",
387 | " 1000 | \n",
388 | " 1.6 | \n",
389 | " linear | \n",
390 | " eval_imagenet_newmodels | \n",
391 | " 0 | \n",
392 | " 1 | \n",
393 | " sgd | \n",
394 | " -1.0 | \n",
395 | " eval | \n",
396 | " logs/exman-train.py/runs/000010 | \n",
397 | " 10000000000000000 | \n",
398 | " 0.08 | \n",
399 | " -1 | \n",
400 | " 4096.0 | \n",
401 | " False | \n",
402 | " False | \n",
403 | " 0.00 | \n",
404 | " 0.000000 | \n",
405 | " 20 | \n",
406 | " 1 | \n",
407 | " 2020-11-26 00:44:34 | \n",
408 | " False | \n",
409 | " NaN | \n",
410 | " linear | \n",
411 | " -1.0 | \n",
412 | " 0.9 | \n",
413 | " NaN | \n",
414 | " NaN | \n",
415 | " NaN | \n",
416 | " fail | \n",
417 | " NaN | \n",
418 | " NaN | \n",
419 | " -1.0 | \n",
420 | " h | \n",
421 | " -1.0 | \n",
422 | " False | \n",
423 | "
\n",
424 | " \n",
425 | " 12 | \n",
426 | " ResNet50 | \n",
427 | " True | \n",
428 | " NaN | \n",
429 | " 128 | \n",
430 | " | \n",
431 | " configs/imagenet_params_epochs600_bs2k.yaml | \n",
432 | " imagenet | \n",
433 | " ddp | \n",
434 | " cn-010:8881 | \n",
435 | " NaN | \n",
436 | " 12510 | \n",
437 | " NaN | \n",
438 | " 0 | \n",
439 | " 375300 | \n",
440 | " 100 | \n",
441 | " 2.4 | \n",
442 | " warmup-anneal | \n",
443 | " imagenet-reproduce | \n",
444 | " 0 | \n",
445 | " 16 | \n",
446 | " lars | \n",
447 | " NaN | \n",
448 | " sim-clr | \n",
449 | " logs/exman-train.py/runs/000012 | \n",
450 | " 12510 | \n",
451 | " 0.08 | \n",
452 | " -1 | \n",
453 | " NaN | \n",
454 | " False | \n",
455 | " True | \n",
456 | " 0.10 | \n",
457 | " 0.000001 | \n",
458 | " 8 | \n",
459 | " 4 | \n",
460 | " 2020-11-26 01:17:43 | \n",
461 | " False | \n",
462 | " 1.0 | \n",
463 | " linear | \n",
464 | " -1.0 | \n",
465 | " 0.9 | \n",
466 | " 2.0 | \n",
467 | " False | \n",
468 | " MLPv2 | \n",
469 | " NaN | \n",
470 | " True | \n",
471 | " 0.1 | \n",
472 | " NaN | \n",
473 | " NaN | \n",
474 | " NaN | \n",
475 | " NaN | \n",
476 | "
\n",
477 | " \n",
478 | " 23 | \n",
479 | " linear | \n",
480 | " True | \n",
481 | " RandomCrop | \n",
482 | " 1024 | \n",
483 | " | \n",
484 | " configs/cifar_eval.yaml | \n",
485 | " cifar | \n",
486 | " dp | \n",
487 | " 127.0.0.1:1234 | \n",
488 | " logs/exman-train.py/runs/000002/checkpoint.pth... | \n",
489 | " 1000 | \n",
490 | " False | \n",
491 | " 0 | \n",
492 | " 80000 | \n",
493 | " 100 | \n",
494 | " 0.1 | \n",
495 | " linear | \n",
496 | " | \n",
497 | " 0 | \n",
498 | " 1 | \n",
499 | " sgd | \n",
500 | " -1.0 | \n",
501 | " eval | \n",
502 | " logs/exman-train.py/runs/000023 | \n",
503 | " 100000000 | \n",
504 | " 0.08 | \n",
505 | " -1 | \n",
506 | " 1024.0 | \n",
507 | " False | \n",
508 | " False | \n",
509 | " 0.00 | \n",
510 | " 0.000100 | \n",
511 | " 2 | \n",
512 | " 1 | \n",
513 | " 2020-11-26 16:20:18 | \n",
514 | " NaN | \n",
515 | " NaN | \n",
516 | " NaN | \n",
517 | " NaN | \n",
518 | " NaN | \n",
519 | " NaN | \n",
520 | " NaN | \n",
521 | " NaN | \n",
522 | " NaN | \n",
523 | " NaN | \n",
524 | " NaN | \n",
525 | " NaN | \n",
526 | " NaN | \n",
527 | " NaN | \n",
528 | " NaN | \n",
529 | "
\n",
530 | " \n",
531 | "
\n",
532 | "
"
533 | ],
534 | "text/plain": [
535 | " arch aug augmentation ... encode_layer model_id use_all_classes\n",
536 | "id ... \n",
537 | "2 ResNet50 True NaN ... NaN NaN NaN\n",
538 | "4 ResNet50 True NaN ... NaN NaN NaN\n",
539 | "10 linear True RandomResizedCrop ... h -1.0 False\n",
540 | "12 ResNet50 True NaN ... NaN NaN NaN\n",
541 | "23 linear True RandomCrop ... NaN NaN NaN\n",
542 | "\n",
543 | "[5 rows x 50 columns]"
544 | ]
545 | },
546 | "execution_count": 35,
547 | "metadata": {
548 | "tags": []
549 | },
550 | "output_type": "execute_result"
551 | }
552 | ],
553 | "source": [
554 | "index.head()"
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": 39,
560 | "metadata": {
561 | "colab": {
562 | "base_uri": "https://localhost:8080/"
563 | },
564 | "id": "Wy8Mwu_z_EBF",
565 | "outputId": "a190b8d4-5d30-4c35-ee34-0e8a7da354e2"
566 | },
567 | "outputs": [
568 | {
569 | "data": {
570 | "text/plain": [
571 | "{'arch': 'linear',\n",
572 | " 'aug': True,\n",
573 | " 'augmentation': 'RandomResizedCrop',\n",
574 | " 'base_lr_linear_scale': nan,\n",
575 | " 'batch_size': 4096,\n",
576 | " 'ckpt': '',\n",
577 | " 'ckpt_iter': nan,\n",
578 | " 'color_dist_s': nan,\n",
579 | " 'config_file': 'configs/imagenet_eval_params.yaml',\n",
580 | " 'cooldown': nan,\n",
581 | " 'cooldown_after': nan,\n",
582 | " 'data': 'imagenet',\n",
583 | " 'dist': 'dp',\n",
584 | " 'dist_address': '',\n",
585 | " 'encode_layer': nan,\n",
586 | " 'encoder_ckpt': '/home/aashukha/simclr-pytorch/logs/exman-train.py/runs/000012/checkpoint.pth.tar',\n",
587 | " 'eval_freq': 100,\n",
588 | " 'finetune': False,\n",
589 | " 'gpu': 0,\n",
590 | " 'iters': 28080,\n",
591 | " 'log_freq': 1000,\n",
592 | " 'lr': 1.6,\n",
593 | " 'lr_schedule': 'linear',\n",
594 | " 'model_id': nan,\n",
595 | " 'momentum': nan,\n",
596 | " 'multiplier': nan,\n",
597 | " 'name': 'eval_imagenet_newmodels',\n",
598 | " 'node_rank': 0,\n",
599 | " 'norm_multiplier': nan,\n",
600 | " 'number_of_processes': 1,\n",
601 | " 'opt': 'sgd',\n",
602 | " 'precompute_emb_bs': -1.0,\n",
603 | " 'problem': 'eval',\n",
604 | " 'projection': nan,\n",
605 | " 'root': 'logs/exman-train.py/runs/000033',\n",
606 | " 'save_freq': 10000000000000000,\n",
607 | " 'scale_lower': 0.08,\n",
608 | " 'seed': -1,\n",
609 | " 'status': nan,\n",
610 | " 'sync_bn': nan,\n",
611 | " 'temperature': nan,\n",
612 | " 'test_bs': 4096.0,\n",
613 | " 'time': Timestamp('2020-12-05 14:49:17'),\n",
614 | " 'tmp': False,\n",
615 | " 'use_all_classes': nan,\n",
616 | " 'verbose': False,\n",
617 | " 'warmup': 0.0,\n",
618 | " 'weight_decay': 0.0,\n",
619 | " 'workers': 20,\n",
620 | " 'world_size': 1}"
621 | ]
622 | },
623 | "execution_count": 39,
624 | "metadata": {
625 | "tags": []
626 | },
627 | "output_type": "execute_result"
628 | }
629 | ],
630 | "source": [
631 | "dict(index.loc[33])"
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": 34,
637 | "metadata": {
638 | "colab": {
639 | "base_uri": "https://localhost:8080/",
640 | "height": 107
641 | },
642 | "id": "Mc15MLjhIQ1w",
643 | "outputId": "82065f2b-76c9-4b85-fef4-b42621fa8760"
644 | },
645 | "outputs": [
646 | {
647 | "data": {
648 | "text/html": [
649 | "\n",
650 | "\n",
663 | "
\n",
664 | " \n",
665 | " \n",
666 | " | \n",
667 | " t | \n",
668 | " test_loss | \n",
669 | " test_acc | \n",
670 | " train_loss | \n",
671 | " train_acc | \n",
672 | " train_epoch | \n",
673 | " lr | \n",
674 | " data_time | \n",
675 | " it_time | \n",
676 | "
\n",
677 | " \n",
678 | " \n",
679 | " \n",
680 | " 279 | \n",
681 | " 28000 | \n",
682 | " 1.279921 | \n",
683 | " 0.67840 | \n",
684 | " 1.158502 | \n",
685 | " 0.716760 | \n",
686 | " 89.456869 | \n",
687 | " 0.004615 | \n",
688 | " 316.105078 | \n",
689 | " 2025.544294 | \n",
690 | "
\n",
691 | " \n",
692 | " 280 | \n",
693 | " 28080 | \n",
694 | " 1.279901 | \n",
695 | " 0.67842 | \n",
696 | " 1.161120 | \n",
697 | " 0.716248 | \n",
698 | " 89.712460 | \n",
699 | " 0.000057 | \n",
700 | " 12.917784 | \n",
701 | " 141.360860 | \n",
702 | "
\n",
703 | " \n",
704 | "
\n",
705 | "
"
706 | ],
707 | "text/plain": [
708 | " t test_loss test_acc ... lr data_time it_time\n",
709 | "279 28000 1.279921 0.67840 ... 0.004615 316.105078 2025.544294\n",
710 | "280 28080 1.279901 0.67842 ... 0.000057 12.917784 141.360860\n",
711 | "\n",
712 | "[2 rows x 9 columns]"
713 | ]
714 | },
715 | "execution_count": 34,
716 | "metadata": {
717 | "tags": []
718 | },
719 | "output_type": "execute_result"
720 | }
721 | ],
722 | "source": [
723 | "logs = pd.read_csv(index.loc[33].root + '/logs.csv')\n",
724 | "logs.tail(2)"
725 | ]
726 | },
727 | {
728 | "cell_type": "code",
729 | "execution_count": null,
730 | "metadata": {
731 | "id": "w-QviLaJ9P-u"
732 | },
733 | "outputs": [],
734 | "source": []
735 | }
736 | ],
737 | "metadata": {
738 | "colab": {
739 | "collapsed_sections": [],
740 | "name": "read_logs.ipynb",
741 | "provenance": []
742 | },
743 | "kernelspec": {
744 | "display_name": "Python 3",
745 | "language": "python",
746 | "name": "python3"
747 | },
748 | "language_info": {
749 | "codemirror_mode": {
750 | "name": "ipython",
751 | "version": 3
752 | },
753 | "file_extension": ".py",
754 | "mimetype": "text/x-python",
755 | "name": "python",
756 | "nbconvert_exporter": "python",
757 | "pygments_lexer": "ipython3",
758 | "version": "3.7.6"
759 | }
760 | },
761 | "nbformat": 4,
762 | "nbformat_minor": 1
763 | }
764 |
--------------------------------------------------------------------------------
/configs/cifar_eval.yaml:
--------------------------------------------------------------------------------
1 | arch: linear
2 | aug: true
3 | augmentation: RandomCrop
4 | batch_size: 1024
5 | ckpt: ''
6 | ckpt_iter: -1
7 | config_file: null
8 | data: cifar
9 | dist: dp
10 | dist_address: 127.0.0.1:1234
11 | encode_layer: h
12 | encoder_ckpt: ''
13 | eval_freq: 1000
14 | finetune: false
15 | iters: 80000
16 | log_freq: 100
17 | lr: 0.1
18 | lr_schedule: linear
19 | model_id: -1
20 | n_augs_test: 50
21 | n_augs_train: 10
22 | name: ''
23 | node_rank: 0
24 | opt: sgd
25 | precompute_emb_bs: -1
26 | problem: eval
27 | root: ''
28 | save_freq: 100000000
29 | scale_lower: 0.08
30 | seed: -1
31 | status: fail
32 | test_bs: 1024
33 | tmp: false
34 | warmup: 0.0
35 | weight_decay: 0.0001
36 | workers: 2
37 | world_size: 1
38 |
39 | time: '2020-07-18T00:55:23'
40 | id: 3549
--------------------------------------------------------------------------------
/configs/cifar_train_epochs1000_bs1024.yaml:
--------------------------------------------------------------------------------
1 | arch: ResNet50
2 | aug: true
3 | batch_size: 1024
4 | ckpt: ''
5 | color_dist_s: 0.5
6 | config_file: null
7 | data: cifar
8 | dist: ddp
9 | dist_address: '127.0.0.1:1234'
10 | eval_freq: 4800
11 | iters: 48000
12 | log_freq: 48
13 | lr: 4.0
14 | lr_schedule: warmup-anneal
15 | multiplier: 2
16 | name: 'reproduce-cifar10'
17 | node_rank: 0
18 | opt: lars
19 | problem: sim-clr
20 | root: 'none'
21 | save_freq: 4800
22 | scale_lower: 0.08
23 | seed: -1
24 | sync_bn: true
25 | temperature: 0.5
26 | tmp: false
27 | verbose: true
28 | warmup: 0.01
29 | weight_decay: 1.0e-06
30 | workers: 2
31 | world_size: 1
--------------------------------------------------------------------------------
/configs/imagenet_eval.yaml:
--------------------------------------------------------------------------------
1 | arch: linear
2 | aug: true
3 | augmentation: RandomResizedCrop
4 | batch_size: 4096
5 | ckpt: ''
6 | ckpt_iter: -1
7 | config_file: null
8 | data: imagenet
9 | dist: dp
10 | dist_address: ''
11 | encode_layer: h
12 | encoder_ckpt: ''
13 | eval_freq: 100
14 | finetune: false
15 | iters: 28080
16 | log_freq: 1000
17 | lr: 1.6
18 | lr_schedule: linear
19 | model_id: -1
20 | name: eval_imagenet_newmodels
21 | node_rank: 0
22 | opt: sgd
23 | precompute_emb_bs: -1
24 | problem: eval
25 | save_freq: 10000000000000000
26 | scale_lower: 0.08
27 | seed: -1
28 | test_bs: 4096
29 | tmp: false
30 | warmup: 0.0
31 | weight_decay: 0.0
32 | workers: 20
33 | world_size: 1
--------------------------------------------------------------------------------
/configs/imagenet_train_epochs100_bs512.yaml:
--------------------------------------------------------------------------------
1 | arch: ResNet50
2 | aug: true
3 | batch_size: 512
4 | ckpt: ''
5 | color_dist_s: 1.0
6 | config_file: ''
7 | data: imagenet
8 | dist: ddp
9 | dist_address: '127.0.0.1:1234'
10 | eval_freq: 50040
11 | gpu: 0
12 | iters: 250200
13 | log_freq: 100
14 | lr: 0.6
15 | lr_schedule: warmup-anneal
16 | multiplier: 2
17 | name: imagenet-reproduce
18 | node_rank: 0
19 | opt: lars
20 | problem: sim-clr
21 | root: ''
22 | save_freq: 12510
23 | scale_lower: 0.08
24 | seed: -1
25 | sync_bn: true
26 | temperature: 0.1
27 | tmp: false
28 | verbose: true
29 | warmup: 0.1
30 | weight_decay: 1.0e-06
31 | workers: 8
32 | world_size: 1
33 |
--------------------------------------------------------------------------------
/configs/imagenet_train_epochs200_bs2k.yaml:
--------------------------------------------------------------------------------
1 | arch: ResNet50
2 | aug: true
3 | batch_size: 2048
4 | ckpt: ''
5 | color_dist_s: 1.0
6 | config_file: ''
7 | data: imagenet
8 | dist: ddp
9 | dist_address: ''
10 | eval_freq: 12510
11 | gpu: 0
12 | iters: 125100
13 | log_freq: 100
14 | lr: 2.4
15 | lr_schedule: warmup-anneal
16 | momentum: 0.9
17 | multiplier: 2
18 | name: imagenet-reproduce
19 | node_rank: 0
20 | number_of_processes: 16
21 | opt: lars
22 | problem: sim-clr
23 | root: ''
24 | save_freq: 12510
25 | scale_lower: 0.08
26 | seed: -1
27 | sync_bn: true
28 | temperature: 0.1
29 | tmp: false
30 | verbose: true
31 | warmup: 0.1
32 | weight_decay: 1.0e-06
33 | workers: 8
34 | world_size: 4
--------------------------------------------------------------------------------
/configs/imagenet_train_epochs600_bs2k.yaml:
--------------------------------------------------------------------------------
1 | arch: ResNet50
2 | aug: true
3 | batch_size: 2048
4 | ckpt: ''
5 | color_dist_s: 1.0
6 | config_file: ''
7 | data: imagenet
8 | dist: ddp
9 | dist_address: ''
10 | eval_freq: 12510
11 | gpu: 0
12 | iters: 375300
13 | log_freq: 100
14 | lr: 2.4
15 | lr_schedule: warmup-anneal
16 | multiplier: 2
17 | name: imagenet-reproduce
18 | node_rank: 0
19 | opt: lars
20 | problem: sim-clr
21 | root: ''
22 | save_freq: 12510
23 | scale_lower: 0.08
24 | seed: -1
25 | sync_bn: true
26 | temperature: 0.1
27 | tmp: false
28 | verbose: true
29 | warmup: 0.1
30 | weight_decay: 1.0e-06
31 | workers: 8
32 | world_size: 4
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: simclr_pytorch
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - backcall=0.2.0=py_0
8 | - blas=1.0=mkl
9 | - ca-certificates=2020.10.14=0
10 | - certifi=2020.12.5=py36h06a4308_0
11 | - configargparse=1.2.3=py_0
12 | - cudatoolkit=10.1.243=h6bb024c_0
13 | - dataclasses=0.7=py36_0
14 | - decorator=4.4.2=py_0
15 | - filelock=3.0.12=py_0
16 | - freetype=2.10.4=h5ab3b9f_0
17 | - intel-openmp=2020.2=254
18 | - ipython=7.16.1=py36h5ca1d4c_0
19 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
20 | - jedi=0.17.2=py36h06a4308_1
21 | - joblib=0.17.0=py_0
22 | - jpeg=9b=h024ee3a_2
23 | - lcms2=2.11=h396b838_0
24 | - ld_impl_linux-64=2.33.1=h53a641e_7
25 | - libedit=3.1.20191231=h14c3975_1
26 | - libffi=3.3=he6710b0_2
27 | - libgcc-ng=9.1.0=hdf63c60_0
28 | - libgfortran-ng=7.3.0=hdf63c60_0
29 | - libpng=1.6.37=hbc83047_0
30 | - libstdcxx-ng=9.1.0=hdf63c60_0
31 | - libtiff=4.1.0=h2733197_1
32 | - libuv=1.40.0=h7b6447c_0
33 | - lz4-c=1.9.2=heb0550a_3
34 | - mkl=2020.2=256
35 | - mkl-service=2.3.0=py36he8ac12f_0
36 | - mkl_fft=1.2.0=py36h23d657b_0
37 | - mkl_random=1.1.1=py36h0573a6f_0
38 | - ncurses=6.2=he6710b0_1
39 | - ninja=1.10.2=py36hff7bd54_0
40 | - numpy=1.19.2=py36h54aff64_0
41 | - numpy-base=1.19.2=py36hfa32c7d_0
42 | - olefile=0.46=py36_0
43 | - openssl=1.1.1h=h7b6447c_0
44 | - pandas=1.1.3=py36he6710b0_0
45 | - parso=0.7.0=py_0
46 | - pexpect=4.8.0=pyhd3eb1b0_3
47 | - pickleshare=0.7.5=pyhd3eb1b0_1003
48 | - pillow=8.0.1=py36he98fc37_0
49 | - pip=20.3.1=py36h06a4308_0
50 | - prompt-toolkit=3.0.8=py_0
51 | - ptyprocess=0.6.0=pyhd3eb1b0_2
52 | - pygments=2.7.3=pyhd3eb1b0_0
53 | - python=3.6.12=hcff3b4d_2
54 | - python-dateutil=2.8.1=py_0
55 | - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0
56 | - pytz=2020.4=pyhd3eb1b0_0
57 | - pyyaml=5.3.1=py36h7b6447c_1
58 | - readline=8.0=h7b6447c_0
59 | - scikit-learn=0.23.2=py36h0573a6f_0
60 | - scipy=1.5.2=py36h0b6359f_0
61 | - setuptools=51.0.0=py36h06a4308_2
62 | - six=1.15.0=py36h06a4308_0
63 | - sqlite=3.33.0=h62c20be_0
64 | - tabulate=0.8.7=py36_0
65 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
66 | - tk=8.6.10=hbc83047_0
67 | - torchaudio=0.7.0=py36
68 | - torchvision=0.8.1=py36_cu101
69 | - tqdm=4.54.1=pyhd3eb1b0_0
70 | - traitlets=4.3.3=py36_0
71 | - typing_extensions=3.7.4.3=py_0
72 | - wcwidth=0.2.5=py_0
73 | - wheel=0.36.1=pyhd3eb1b0_0
74 | - xz=5.2.5=h7b6447c_0
75 | - yaml=0.2.5=h7b6447c_0
76 | - zlib=1.2.11=h7b6447c_3
77 | - zstd=1.4.5=h9ceee32_0
78 | - pip:
79 | - diffdist==0.1
80 | - strconv==0.4.2
81 | prefix: /home/aashukha/miniconda3/envs/simclr_pytorch
82 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models import encoder
2 | from models import losses
3 | from models import resnet
4 | from models import ssl
5 |
6 | REGISTERED_MODELS = {
7 | 'sim-clr': ssl.SimCLR,
8 | 'eval': ssl.SSLEval,
9 | 'semi-supervised-eval': ssl.SemiSupervisedEval,
10 | }
11 |
--------------------------------------------------------------------------------
/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import models
4 | from collections import OrderedDict
5 | from argparse import Namespace
6 | import yaml
7 | import os
8 |
9 |
10 | class BatchNorm1dNoBias(nn.BatchNorm1d):
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.bias.requires_grad = False
14 |
15 |
16 | class EncodeProject(nn.Module):
17 | def __init__(self, hparams):
18 | super().__init__()
19 |
20 | if hparams.arch == 'ResNet50':
21 | cifar_head = (hparams.data == 'cifar')
22 | self.convnet = models.resnet.ResNet50(cifar_head=cifar_head, hparams=hparams)
23 | self.encoder_dim = 2048
24 | elif hparams.arch == 'resnet18':
25 | self.convnet = models.resnet.ResNet18(cifar_head=(hparams.data == 'cifar'))
26 | self.encoder_dim = 512
27 | else:
28 | raise NotImplementedError
29 |
30 | num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad)
31 |
32 | print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters')
33 |
34 | self.proj_dim = 128
35 | projection_layers = [
36 | ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)),
37 | ('bn1', nn.BatchNorm1d(self.encoder_dim)),
38 | ('relu1', nn.ReLU()),
39 | ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)),
40 | ('bn2', BatchNorm1dNoBias(128)),
41 | ]
42 |
43 | self.projection = nn.Sequential(OrderedDict(projection_layers))
44 |
45 | def forward(self, x, out='z'):
46 | h = self.convnet(x)
47 | if out == 'h':
48 | return h
49 | return self.projection(h)
50 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import diffdist
6 | import torch.distributed as dist
7 |
8 |
9 | def gather(z):
10 | gather_z = [torch.zeros_like(z) for _ in range(torch.distributed.get_world_size())]
11 | gather_z = diffdist.functional.all_gather(gather_z, z)
12 | gather_z = torch.cat(gather_z)
13 |
14 | return gather_z
15 |
16 |
17 | def accuracy(logits, labels, k):
18 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0]
19 | labels = torch.sort(labels, 1)[0]
20 | acc = (topk == labels).all(1).float()
21 | return acc
22 |
23 |
24 | def mean_cumulative_gain(logits, labels, k):
25 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0]
26 | labels = torch.sort(labels, 1)[0]
27 | mcg = (topk == labels).float().mean(1)
28 | return mcg
29 |
30 |
31 | def mean_average_precision(logits, labels, k):
32 | # TODO: not the fastest solution but looks fine
33 | argsort = torch.argsort(logits, dim=1, descending=True)
34 | labels_to_sorted_idx = torch.sort(torch.gather(torch.argsort(argsort, dim=1), 1, labels), dim=1)[0] + 1
35 | precision = (1 + torch.arange(k, device=logits.device).float()) / labels_to_sorted_idx
36 | return precision.sum(1) / k
37 |
38 |
39 | class NTXent(nn.Module):
40 | """
41 | Contrastive loss with distributed data parallel support
42 | """
43 | LARGE_NUMBER = 1e9
44 |
45 | def __init__(self, tau=1., gpu=None, multiplier=2, distributed=False):
46 | super().__init__()
47 | self.tau = tau
48 | self.multiplier = multiplier
49 | self.distributed = distributed
50 | self.norm = 1.
51 |
52 | def forward(self, z, get_map=False):
53 | n = z.shape[0]
54 | assert n % self.multiplier == 0
55 |
56 | z = F.normalize(z, p=2, dim=1) / np.sqrt(self.tau)
57 |
58 | if self.distributed:
59 | z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())]
60 | # all_gather fills the list as [, , ...]
61 | # TODO: try to rewrite it with pytorch official tools
62 | z_list = diffdist.functional.all_gather(z_list, z)
63 | # split it into [, , ..., , , ...]
64 | z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)]
65 | # sort it to [, , ...] that simply means [, , ...] as expected below
66 | z_sorted = []
67 | for m in range(self.multiplier):
68 | for i in range(dist.get_world_size()):
69 | z_sorted.append(z_list[i * self.multiplier + m])
70 | z = torch.cat(z_sorted, dim=0)
71 | n = z.shape[0]
72 |
73 | logits = z @ z.t()
74 | logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER
75 |
76 | logprob = F.log_softmax(logits, dim=1)
77 |
78 | # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1)
79 | m = self.multiplier
80 | labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n//m, n)) % n
81 | # remove labels pointet to itself, i.e. (i, i)
82 | labels = labels.reshape(n, m)[:, 1:].reshape(-1)
83 |
84 | # TODO: maybe different terms for each process should only be computed here...
85 | loss = -logprob[np.repeat(np.arange(n), m-1), labels].sum() / n / (m-1) / self.norm
86 |
87 | # zero the probability of identical pairs
88 | pred = logprob.data.clone()
89 | pred[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER
90 | acc = accuracy(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1)
91 |
92 | if get_map:
93 | _map = mean_average_precision(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1)
94 | return loss, acc, _map
95 |
96 | return loss, acc
97 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import torch.nn as nn
10 | import torchvision.models as models
11 | import torch
12 |
13 |
14 | class Flatten(nn.Module):
15 | def __init__(self, dim=-1):
16 | super(Flatten, self).__init__()
17 | self.dim = dim
18 |
19 | def forward(self, feat):
20 | return torch.flatten(feat, start_dim=self.dim)
21 |
22 |
23 | class ResNetEncoder(models.resnet.ResNet):
24 | """Wrapper for TorchVison ResNet Model
25 | This was needed to remove the final FC Layer from the ResNet Model"""
26 | def __init__(self, block, layers, cifar_head=False, hparams=None):
27 | super().__init__(block, layers)
28 | self.cifar_head = cifar_head
29 | if cifar_head:
30 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
31 | self.bn1 = self._norm_layer(64)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.hparams = hparams
34 |
35 | print('** Using avgpool **')
36 |
37 | def forward(self, x):
38 | x = self.conv1(x)
39 | x = self.bn1(x)
40 | x = self.relu(x)
41 | if not self.cifar_head:
42 | x = self.maxpool(x)
43 |
44 | x = self.layer1(x)
45 | x = self.layer2(x)
46 | x = self.layer3(x)
47 | x = self.layer4(x)
48 |
49 | x = self.avgpool(x)
50 | x = torch.flatten(x, 1)
51 |
52 | return x
53 |
54 | class ResNet18(ResNetEncoder):
55 | def __init__(self, cifar_head=True):
56 | super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], cifar_head=cifar_head)
57 |
58 |
59 | class ResNet50(ResNetEncoder):
60 | def __init__(self, cifar_head=True, hparams=None):
61 | super().__init__(models.resnet.Bottleneck, [3, 4, 6, 3], cifar_head=cifar_head, hparams=hparams)
62 |
--------------------------------------------------------------------------------
/models/ssl.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace, ArgumentParser
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torchvision import datasets
8 | import torchvision.transforms as transforms
9 | from utils import datautils
10 | import models
11 | from utils import utils
12 | import numpy as np
13 | import PIL
14 | from tqdm import tqdm
15 | import sklearn
16 | from utils.lars_optimizer import LARS
17 | import scipy
18 | from torch.nn.parallel import DistributedDataParallel as DDP
19 | import torch.distributed as dist
20 |
21 | import copy
22 |
23 | class BaseSSL(nn.Module):
24 | """
25 | Inspired by the PYTORCH LIGHTNING https://pytorch-lightning.readthedocs.io/en/latest/
26 | Similar but lighter and customized version.
27 | """
28 | DATA_ROOT = os.environ.get('DATA_ROOT', os.path.dirname(os.path.abspath(__file__)) + '/data')
29 | IMAGENET_PATH = os.environ.get('IMAGENET_PATH', '/home/aashukha/imagenet/raw-data/')
30 |
31 | def __init__(self, hparams):
32 | super().__init__()
33 | self.hparams = hparams
34 | if hparams.data == 'imagenet':
35 | print(f"IMAGENET_PATH = {self.IMAGENET_PATH}")
36 |
37 | def get_ckpt(self):
38 | return {
39 | 'state_dict': self.state_dict(),
40 | 'hparams': self.hparams,
41 | }
42 |
43 | @classmethod
44 | def load(cls, ckpt, device=None):
45 | parser = ArgumentParser()
46 | cls.add_model_hparams(parser)
47 | hparams = parser.parse_args([], namespace=ckpt['hparams'])
48 |
49 | res = cls(hparams, device=device)
50 | res.load_state_dict(ckpt['state_dict'])
51 | return res
52 |
53 | @classmethod
54 | def default(cls, device=None, **kwargs):
55 | parser = ArgumentParser()
56 | cls.add_model_hparams(parser)
57 | hparams = parser.parse_args([], namespace=Namespace(**kwargs))
58 | res = cls(hparams, device=device)
59 | return res
60 |
61 | def forward(self, x):
62 | pass
63 |
64 | def transforms(self):
65 | pass
66 |
67 | def samplers(self):
68 | return None, None
69 |
70 | def prepare_data(self):
71 | train_transform, test_transform = self.transforms()
72 | # print('The following train transform is used:\n', train_transform)
73 | # print('The following test transform is used:\n', test_transform)
74 | if self.hparams.data == 'cifar':
75 | self.trainset = datasets.CIFAR10(root=self.DATA_ROOT, train=True, download=True, transform=train_transform)
76 | self.testset = datasets.CIFAR10(root=self.DATA_ROOT, train=False, download=True, transform=test_transform)
77 | elif self.hparams.data == 'imagenet':
78 | traindir = os.path.join(self.IMAGENET_PATH, 'train')
79 | valdir = os.path.join(self.IMAGENET_PATH, 'val')
80 | self.trainset = datasets.ImageFolder(traindir, transform=train_transform)
81 | self.testset = datasets.ImageFolder(valdir, transform=test_transform)
82 | else:
83 | raise NotImplementedError
84 |
85 | def dataloaders(self, iters=None):
86 | train_batch_sampler, test_batch_sampler = self.samplers()
87 | if iters is not None:
88 | train_batch_sampler = datautils.ContinousSampler(
89 | train_batch_sampler,
90 | iters
91 | )
92 |
93 | train_loader = torch.utils.data.DataLoader(
94 | self.trainset,
95 | num_workers=self.hparams.workers,
96 | pin_memory=True,
97 | batch_sampler=train_batch_sampler,
98 | )
99 | test_loader = torch.utils.data.DataLoader(
100 | self.testset,
101 | num_workers=self.hparams.workers,
102 | pin_memory=True,
103 | batch_sampler=test_batch_sampler,
104 | )
105 |
106 | return train_loader, test_loader
107 |
108 | @staticmethod
109 | def add_parent_hparams(add_model_hparams):
110 | def foo(cls, parser):
111 | for base in cls.__bases__:
112 | base.add_model_hparams(parser)
113 | add_model_hparams(cls, parser)
114 | return foo
115 |
116 | @classmethod
117 | def add_model_hparams(cls, parser):
118 | parser.add_argument('--data', help='Dataset to use', default='cifar')
119 | parser.add_argument('--arch', default='ResNet50', help='Encoder architecture')
120 | parser.add_argument('--batch_size', default=256, type=int, help='The number of unique images in the batch')
121 | parser.add_argument('--aug', default=True, type=bool, help='Applies random augmentations if True')
122 |
123 |
124 | class SimCLR(BaseSSL):
125 | @classmethod
126 | @BaseSSL.add_parent_hparams
127 | def add_model_hparams(cls, parser):
128 | # loss params
129 | parser.add_argument('--temperature', default=0.1, type=float, help='Temperature in the NTXent loss')
130 | # data params
131 | parser.add_argument('--multiplier', default=2, type=int)
132 | parser.add_argument('--color_dist_s', default=1., type=float, help='Color distortion strength')
133 | parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop')
134 | # ddp
135 | parser.add_argument('--sync_bn', default=True, type=bool,
136 | help='Syncronises BatchNorm layers between all processes if True'
137 | )
138 |
139 | def __init__(self, hparams, device=None):
140 | super().__init__(hparams)
141 |
142 | self.hparams.dist = getattr(self.hparams, 'dist', 'dp')
143 |
144 | model = models.encoder.EncodeProject(hparams)
145 | self.reset_parameters()
146 | if device is not None:
147 | model = model.to(device)
148 | if self.hparams.dist == 'ddp':
149 | if self.hparams.sync_bn:
150 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
151 | dist.barrier()
152 | if device is not None:
153 | model = model.to(device)
154 | self.model = DDP(model, [hparams.gpu], find_unused_parameters=True)
155 | elif self.hparams.dist == 'dp':
156 | self.model = nn.DataParallel(model)
157 | else:
158 | raise NotImplementedError
159 |
160 | self.criterion = models.losses.NTXent(
161 | tau=hparams.temperature,
162 | multiplier=hparams.multiplier,
163 | distributed=(hparams.dist == 'ddp'),
164 | )
165 |
166 | def reset_parameters(self):
167 | def conv2d_weight_truncated_normal_init(p):
168 | fan_in = p.shape[1]
169 | stddev = np.sqrt(1. / fan_in) / .87962566103423978
170 | r = scipy.stats.truncnorm.rvs(-2, 2, loc=0, scale=1., size=p.shape)
171 | r = stddev * r
172 | with torch.no_grad():
173 | p.copy_(torch.FloatTensor(r))
174 |
175 | def linear_normal_init(p):
176 | with torch.no_grad():
177 | p.normal_(std=0.01)
178 |
179 | for m in self.modules():
180 | if isinstance(m, nn.Conv2d):
181 | conv2d_weight_truncated_normal_init(m.weight)
182 | elif isinstance(m, nn.Linear):
183 | linear_normal_init(m.weight)
184 |
185 | def step(self, batch):
186 | x, _ = batch
187 | z = self.model(x)
188 | loss, acc = self.criterion(z)
189 | return {
190 | 'loss': loss,
191 | 'contrast_acc': acc,
192 | }
193 |
194 | def encode(self, x):
195 | return self.model(x, out='h')
196 |
197 | def forward(self, *args, **kwargs):
198 | return self.model(*args, **kwargs)
199 |
200 | def train_step(self, batch, it=None):
201 | logs = self.step(batch)
202 |
203 | if self.hparams.dist == 'ddp':
204 | self.trainsampler.set_epoch(it)
205 | if it is not None:
206 | logs['epoch'] = it / len(self.batch_trainsampler)
207 |
208 | return logs
209 |
210 | def test_step(self, batch):
211 | return self.step(batch)
212 |
213 | def samplers(self):
214 | if self.hparams.dist == 'ddp':
215 | # trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset, num_replicas=1, rank=0)
216 | trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset)
217 | print(f'Process {dist.get_rank()}: {len(trainsampler)} training samples per epoch')
218 | testsampler = torch.utils.data.distributed.DistributedSampler(self.testset)
219 | print(f'Process {dist.get_rank()}: {len(testsampler)} test samples')
220 | else:
221 | trainsampler = torch.utils.data.sampler.RandomSampler(self.trainset)
222 | testsampler = torch.utils.data.sampler.RandomSampler(self.testset)
223 |
224 | batch_sampler = datautils.MultiplyBatchSampler
225 | # batch_sampler.MULTILPLIER = self.hparams.multiplier if self.hparams.dist == 'dp' else 1
226 | batch_sampler.MULTILPLIER = self.hparams.multiplier
227 |
228 | # need for DDP to sync samplers between processes
229 | self.trainsampler = trainsampler
230 | self.batch_trainsampler = batch_sampler(trainsampler, self.hparams.batch_size, drop_last=True)
231 |
232 | return (
233 | self.batch_trainsampler,
234 | batch_sampler(testsampler, self.hparams.batch_size, drop_last=True)
235 | )
236 |
237 | def transforms(self):
238 | if self.hparams.data == 'cifar':
239 | train_transform = transforms.Compose([
240 | transforms.RandomResizedCrop(
241 | 32,
242 | scale=(self.hparams.scale_lower, 1.0),
243 | interpolation=PIL.Image.BICUBIC,
244 | ),
245 | transforms.RandomHorizontalFlip(),
246 | datautils.get_color_distortion(s=self.hparams.color_dist_s),
247 | transforms.ToTensor(),
248 | datautils.Clip(),
249 | ])
250 | test_transform = train_transform
251 |
252 | elif self.hparams.data == 'imagenet':
253 | from utils.datautils import GaussianBlur
254 |
255 | im_size = 224
256 | train_transform = transforms.Compose([
257 | transforms.RandomResizedCrop(
258 | im_size,
259 | scale=(self.hparams.scale_lower, 1.0),
260 | interpolation=PIL.Image.BICUBIC,
261 | ),
262 | transforms.RandomHorizontalFlip(0.5),
263 | datautils.get_color_distortion(s=self.hparams.color_dist_s),
264 | transforms.ToTensor(),
265 | GaussianBlur(im_size // 10, 0.5),
266 | datautils.Clip(),
267 | ])
268 | test_transform = train_transform
269 | return train_transform, test_transform
270 |
271 | def get_ckpt(self):
272 | return {
273 | 'state_dict': self.model.module.state_dict(),
274 | 'hparams': self.hparams,
275 | }
276 |
277 | def load_state_dict(self, state):
278 | k = next(iter(state.keys()))
279 | if k.startswith('model.module'):
280 | super().load_state_dict(state)
281 | else:
282 | self.model.module.load_state_dict(state)
283 |
284 |
285 | class SSLEval(BaseSSL):
286 | @classmethod
287 | @BaseSSL.add_parent_hparams
288 | def add_model_hparams(cls, parser):
289 | parser.add_argument('--test_bs', default=256, type=int)
290 | parser.add_argument('--encoder_ckpt', default='', help='Path to the encoder checkpoint')
291 | parser.add_argument('--precompute_emb_bs', default=-1, type=int,
292 | help='If it\'s not equal to -1 embeddings are precomputed and fixed before training with batch size equal to this.'
293 | )
294 | parser.add_argument('--finetune', default=False, type=bool, help='Finetunes the encoder if True')
295 | parser.add_argument('--augmentation', default='RandomResizedCrop', help='')
296 | parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop')
297 |
298 | def __init__(self, hparams, device=None):
299 | super().__init__(hparams)
300 |
301 | self.hparams.dist = getattr(self.hparams, 'dist', 'dp')
302 |
303 | if hparams.encoder_ckpt != '':
304 | ckpt = torch.load(hparams.encoder_ckpt, map_location=device)
305 | if getattr(ckpt['hparams'], 'dist', 'dp') == 'ddp':
306 | ckpt['hparams'].dist = 'dp'
307 | if self.hparams.dist == 'ddp':
308 | ckpt['hparams'].dist = 'gpu:%d' % hparams.gpu
309 |
310 | self.encoder = models.REGISTERED_MODELS[ckpt['hparams'].problem].load(ckpt, device=device)
311 | else:
312 | print('===> Random encoder is used!!!')
313 | self.encoder = SimCLR.default(device=device)
314 | self.encoder.to(device)
315 |
316 | if not hparams.finetune:
317 | for p in self.encoder.parameters():
318 | p.requires_grad = False
319 | elif hparams.dist == 'ddp':
320 | raise NotImplementedError
321 |
322 | self.encoder.eval()
323 | if hparams.data == 'cifar':
324 | hdim = self.encode(torch.ones(10, 3, 32, 32).to(device)).shape[1]
325 | n_classes = 10
326 | elif hparams.data == 'imagenet':
327 | hdim = self.encode(torch.ones(10, 3, 224, 224).to(device)).shape[1]
328 | n_classes = 1000
329 |
330 | if hparams.arch == 'linear':
331 | model = nn.Linear(hdim, n_classes).to(device)
332 | model.weight.data.zero_()
333 | model.bias.data.zero_()
334 | self.model = model
335 | else:
336 | raise NotImplementedError
337 |
338 | if hparams.dist == 'ddp':
339 | self.model = DDP(model, [hparams.gpu])
340 |
341 | def encode(self, x):
342 | return self.encoder.model(x, out='h')
343 |
344 | def step(self, batch):
345 | if self.hparams.problem == 'eval' and self.hparams.data == 'imagenet':
346 | batch[0] = batch[0] / 255.
347 | h, y = batch
348 | if self.hparams.precompute_emb_bs == -1:
349 | h = self.encode(h)
350 | p = self.model(h)
351 | loss = F.cross_entropy(p, y)
352 | acc = (p.argmax(1) == y).float()
353 | return {
354 | 'loss': loss,
355 | 'acc': acc,
356 | }
357 |
358 | def forward(self, *args, **kwargs):
359 | return self.model(*args, **kwargs)
360 |
361 | def train_step(self, batch, it=None):
362 | logs = self.step(batch)
363 | if it is not None:
364 | iters_per_epoch = len(self.trainset) / self.hparams.batch_size
365 | iters_per_epoch = max(1, int(np.around(iters_per_epoch)))
366 | logs['epoch'] = it / iters_per_epoch
367 | if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1:
368 | self.object_trainsampler.set_epoch(it)
369 |
370 | return logs
371 |
372 | def test_step(self, batch):
373 | logs = self.step(batch)
374 | if self.hparams.dist == 'ddp':
375 | utils.gather_metrics(logs)
376 | return logs
377 |
378 | def prepare_data(self):
379 | super().prepare_data()
380 |
381 | def create_emb_dataset(dataset):
382 | embs, labels = [], []
383 | loader = torch.utils.data.DataLoader(
384 | dataset,
385 | num_workers=self.hparams.workers,
386 | pin_memory=True,
387 | batch_size=self.hparams.precompute_emb_bs,
388 | shuffle=False,
389 | )
390 | for x, y in tqdm(loader):
391 | if self.hparams.data == 'imagenet':
392 | x = x.to(torch.device('cuda'))
393 | x = x / 255.
394 | e = self.encode(x)
395 | embs.append(utils.tonp(e))
396 | labels.append(utils.tonp(y))
397 | embs, labels = np.concatenate(embs), np.concatenate(labels)
398 | dataset = torch.utils.data.TensorDataset(torch.FloatTensor(embs), torch.LongTensor(labels))
399 | return dataset
400 |
401 | if self.hparams.precompute_emb_bs != -1:
402 | print('===> Precompute embeddings:')
403 | assert not self.hparams.aug
404 | with torch.no_grad():
405 | self.encoder.eval()
406 | self.testset = create_emb_dataset(self.testset)
407 | self.trainset = create_emb_dataset(self.trainset)
408 |
409 | print(f'Train size: {len(self.trainset)}')
410 | print(f'Test size: {len(self.testset)}')
411 |
412 | def dataloaders(self, iters=None):
413 | if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1:
414 | trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset)
415 | testsampler = torch.utils.data.distributed.DistributedSampler(self.testset, shuffle=False)
416 | else:
417 | trainsampler = torch.utils.data.RandomSampler(self.trainset)
418 | testsampler = torch.utils.data.SequentialSampler(self.testset)
419 |
420 | self.object_trainsampler = trainsampler
421 | trainsampler = torch.utils.data.BatchSampler(
422 | self.object_trainsampler,
423 | batch_size=self.hparams.batch_size, drop_last=False,
424 | )
425 | if iters is not None:
426 | trainsampler = datautils.ContinousSampler(trainsampler, iters)
427 |
428 | train_loader = torch.utils.data.DataLoader(
429 | self.trainset,
430 | num_workers=self.hparams.workers,
431 | pin_memory=True,
432 | batch_sampler=trainsampler,
433 | )
434 | test_loader = torch.utils.data.DataLoader(
435 | self.testset,
436 | num_workers=self.hparams.workers,
437 | pin_memory=True,
438 | sampler=testsampler,
439 | batch_size=self.hparams.test_bs,
440 | )
441 | return train_loader, test_loader
442 |
443 | def transforms(self):
444 | if self.hparams.data == 'cifar':
445 | trs = []
446 | if 'RandomResizedCrop' in self.hparams.augmentation:
447 | trs.append(
448 | transforms.RandomResizedCrop(
449 | 32,
450 | scale=(self.hparams.scale_lower, 1.0),
451 | interpolation=PIL.Image.BICUBIC,
452 | )
453 | )
454 | if 'RandomCrop' in self.hparams.augmentation:
455 | trs.append(transforms.RandomCrop(32, padding=4, padding_mode='reflect'))
456 | if 'color_distortion' in self.hparams.augmentation:
457 | trs.append(datautils.get_color_distortion(self.encoder.hparams.color_dist_s))
458 |
459 | train_transform = transforms.Compose(trs + [
460 | transforms.RandomHorizontalFlip(),
461 | transforms.ToTensor(),
462 | datautils.Clip(),
463 | ])
464 | test_transform = transforms.Compose([
465 | transforms.ToTensor(),
466 | ])
467 | elif self.hparams.data == 'imagenet':
468 | train_transform = transforms.Compose([
469 | transforms.RandomResizedCrop(
470 | 224,
471 | scale=(self.hparams.scale_lower, 1.0),
472 | interpolation=PIL.Image.BICUBIC,
473 | ),
474 | transforms.RandomHorizontalFlip(),
475 | transforms.ToTensor(),
476 | lambda x: (255*x).byte(),
477 | ])
478 | test_transform = transforms.Compose([
479 | datautils.CenterCropAndResize(proportion=0.875, size=224),
480 | transforms.ToTensor(),
481 | lambda x: (255 * x).byte(),
482 | ])
483 | return train_transform if self.hparams.aug else test_transform, test_transform
484 |
485 | def train(self, mode=True):
486 | if self.hparams.finetune:
487 | super().train(mode)
488 | else:
489 | self.model.train(mode)
490 |
491 | def get_ckpt(self):
492 | return {
493 | 'state_dict': self.state_dict() if self.hparams.finetune else self.model.state_dict(),
494 | 'hparams': self.hparams,
495 | }
496 |
497 | def load_state_dict(self, state):
498 | if self.hparams.finetune:
499 | super().load_state_dict(state)
500 | else:
501 | if hasattr(self.model, 'module'):
502 | self.model.module.load_state_dict(state)
503 | else:
504 | self.model.load_state_dict(state)
505 |
506 | class SemiSupervisedEval(SSLEval):
507 | @classmethod
508 | @BaseSSL.add_parent_hparams
509 | def add_model_hparams(cls, parser):
510 | parser.add_argument('--train_size', default=-1, type=int)
511 | parser.add_argument('--data_split_seed', default=42, type=int)
512 | parser.add_argument('--n_augs_train', default=-1, type=int)
513 | parser.add_argument('--n_augs_test', default=-1, type=int)
514 | parser.add_argument('--acc_on_unlabeled', default=False, type=bool)
515 |
516 | def prepare_data(self):
517 | super(SSLEval, self).prepare_data()
518 |
519 | if len(self.trainset) != self.hparams.train_size:
520 | idxs, unlabeled_idxs = sklearn.model_selection.train_test_split(
521 | np.arange(len(self.trainset)),
522 | train_size=self.hparams.train_size,
523 | random_state=self.hparams.data_split_seed,
524 | )
525 | if self.hparams.data == 'cifar' or self.hparams.data == 'cifar100':
526 | if self.hparams.acc_on_unlabeled:
527 | self.trainset_unlabeled = copy.deepcopy(self.trainset)
528 | self.trainset_unlabeled.data = self.trainset.data[unlabeled_idxs]
529 | self.trainset_unlabeled.targets = np.array(self.trainset.targets)[unlabeled_idxs]
530 | print(f'Test size (0): {len(self.testset)}')
531 | print(f'Unlabeled train size (1): {len(self.trainset_unlabeled)}')
532 |
533 | self.trainset.data = self.trainset.data[idxs]
534 | self.trainset.targets = np.array(self.trainset.targets)[idxs]
535 |
536 | print('Training dataset size:', len(self.trainset))
537 | else:
538 | assert not self.hparams.acc_on_unlabeled
539 | if isinstance(self.trainset, torch.utils.data.TensorDataset):
540 | self.trainset.tensors = [t[idxs] for t in self.trainset.tensors]
541 | else:
542 | self.trainset.samples = [self.trainset.samples[i] for i in idxs]
543 |
544 | print('Training dataset size:', len(self.trainset))
545 |
546 | self.encoder.eval()
547 | with torch.no_grad():
548 | if self.hparams.n_augs_train != -1:
549 | self.trainset = EmbEnsEval.create_emb_dataset(self, self.trainset, n_augs=self.hparams.n_augs_train)
550 | if self.hparams.n_augs_test != -1:
551 | self.testset = EmbEnsEval.create_emb_dataset(self, self.testset, n_augs=self.hparams.n_augs_test)
552 | if self.hparams.acc_on_unlabeled:
553 | self.trainset_unlabeled = EmbEnsEval.create_emb_dataset(
554 | self,
555 | self.trainset_unlabeled,
556 | n_augs=self.hparams.n_augs_test
557 | )
558 | if self.hparams.acc_on_unlabeled:
559 | self.testset = torch.utils.data.ConcatDataset([
560 | datautils.DummyOutputWrapper(self.testset, 0),
561 | datautils.DummyOutputWrapper(self.trainset_unlabeled, 1)
562 | ])
563 |
564 | def transforms(self):
565 | ens_train_transfom, ens_test_transform = EmbEnsEval.transforms(self)
566 | train_transform, test_transform = SSLEval.transforms(self)
567 | return (
568 | train_transform if self.hparams.n_augs_train == -1 else ens_train_transfom,
569 | test_transform if self.hparams.n_augs_test == -1 else ens_test_transform
570 | )
571 |
572 | def step(self, batch, it=None):
573 | if self.hparams.problem == 'eval' and self.hparams.data == 'imagenet':
574 | batch[0] = batch[0] / 255.
575 | h, y = batch
576 | if len(h.shape) == 4:
577 | h = self.encode(h)
578 | p = self.model(h)
579 | loss = F.cross_entropy(p, y)
580 | acc = (p.argmax(1) == y).float()
581 | return {
582 | 'loss': loss,
583 | 'acc': acc,
584 | }
585 |
586 | def test_step(self, batch):
587 | if not self.hparams.acc_on_unlabeled:
588 | return super().test_step(batch)
589 | # TODO: refactor
590 | x, y, d = batch
591 | logs = {}
592 | keys = set()
593 | for didx in [0, 1]:
594 | if torch.any(d == didx):
595 | t = super().test_step([x[d == didx], y[d == didx]])
596 | for k, v in t.items():
597 | keys.add(k)
598 | logs[k + f'_{didx}'] = v
599 | for didx in [0, 1]:
600 | for k in keys:
601 | logs[k + f'_{didx}'] = logs.get(k + f'_{didx}', torch.tensor([]))
602 | return logs
603 |
604 |
605 | def configure_optimizers(args, model, cur_iter=-1):
606 | iters = args.iters
607 |
608 | def exclude_from_wd_and_adaptation(name):
609 | if 'bn' in name:
610 | return True
611 | if args.opt == 'lars' and 'bias' in name:
612 | return True
613 |
614 | param_groups = [
615 | {
616 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)],
617 | 'weight_decay': args.weight_decay,
618 | 'layer_adaptation': True,
619 | },
620 | {
621 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)],
622 | 'weight_decay': 0.,
623 | 'layer_adaptation': False,
624 | },
625 | ]
626 |
627 | LR = args.lr
628 |
629 | if args.opt == 'sgd':
630 | optimizer = torch.optim.SGD(
631 | param_groups,
632 | lr=LR,
633 | momentum=0.9,
634 | )
635 | elif args.opt == 'adam':
636 | optimizer = torch.optim.Adam(
637 | param_groups,
638 | lr=LR,
639 | )
640 | elif args.opt == 'lars':
641 | optimizer = torch.optim.SGD(
642 | param_groups,
643 | lr=LR,
644 | momentum=0.9,
645 | )
646 | larc_optimizer = LARS(optimizer)
647 | else:
648 | raise NotImplementedError
649 |
650 | if args.lr_schedule == 'warmup-anneal':
651 | scheduler = utils.LinearWarmupAndCosineAnneal(
652 | optimizer,
653 | args.warmup,
654 | iters,
655 | last_epoch=cur_iter,
656 | )
657 | elif args.lr_schedule == 'linear':
658 | scheduler = utils.LinearLR(optimizer, iters, last_epoch=cur_iter)
659 | elif args.lr_schedule == 'const':
660 | scheduler = None
661 | else:
662 | raise NotImplementedError
663 |
664 | if args.opt == 'lars':
665 | optimizer = larc_optimizer
666 |
667 | # if args.verbose:
668 | # print('Optimizer : ', optimizer)
669 | # print('Scheduler : ', scheduler)
670 |
671 | return optimizer, scheduler
672 |
--------------------------------------------------------------------------------
/myexman/__init__.py:
--------------------------------------------------------------------------------
1 | from .parser import (
2 | ExParser,
3 | simpleroot
4 | )
5 | from .index import (
6 | Index
7 | )
8 | from . import index
9 | from . import parser
10 | __version__ = '0.0.2'
11 |
--------------------------------------------------------------------------------
/myexman/index.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import pandas as pd
3 | import pathlib
4 | import strconv
5 | import json
6 | import functools
7 | import datetime
8 | from . import parser
9 | import yaml
10 | from argparse import Namespace
11 | __all__ = [
12 | 'Index'
13 | ]
14 |
15 |
16 | def only_value_error(conv):
17 | @functools.wraps(conv)
18 | def new_conv(value):
19 | try:
20 | return conv(value)
21 | except Exception as e:
22 | raise ValueError from e
23 | return new_conv
24 |
25 |
26 | def none2none(none):
27 | if none is None:
28 | return None
29 | else:
30 | raise ValueError
31 |
32 |
33 | converter = strconv.Strconv(converters=[
34 | ('int', strconv.convert_int),
35 | ('float', strconv.convert_float),
36 | ('bool', only_value_error(parser.str2bool)),
37 | ('time', strconv.convert_time),
38 | ('datetime', strconv.convert_datetime),
39 | ('datetime1', lambda time: datetime.datetime.strptime(time, parser.TIME_FORMAT)),
40 | ('date', strconv.convert_date),
41 | ('json', only_value_error(json.loads)),
42 | ])
43 |
44 |
45 | def get_args(path):
46 | with open(path, 'rb') as f:
47 | return Namespace(**yaml.load(f))
48 |
49 |
50 | class Index(object):
51 | def __init__(self, root):
52 | self.root = pathlib.Path(root)
53 |
54 | @property
55 | def index(self):
56 | return self.root / 'index'
57 |
58 | @property
59 | def marked(self):
60 | return self.root / 'marked'
61 |
62 | def info(self, source=None, nlast=None):
63 | if source is None:
64 | source = self.index
65 | files = source.iterdir()
66 | if nlast is not None:
67 | files = sorted(list(files))[-nlast:]
68 | else:
69 | source = self.marked / source
70 | files = source.glob('**/*/'+parser.PARAMS_FILE)
71 |
72 | def get_dict(cfg):
73 | return configargparse.YAMLConfigFileParser().parse(cfg.open('r'))
74 |
75 | def convert_column(col):
76 | if any(isinstance(v, str) for v in converter.convert_series(col)):
77 | return col
78 | else:
79 | return pd.Series(converter.convert_series(col), name=col.name, index=col.index)
80 | try:
81 | df = (pd.DataFrame
82 | .from_records((get_dict(c) for c in files))
83 | .apply(lambda s: convert_column(s))
84 | .sort_values('id')
85 | .assign(root=lambda _: _.root.apply(self.root.__truediv__))
86 | .reset_index(drop=True))
87 | cols = df.columns.tolist()
88 | cols.insert(0, cols.pop(cols.index('id')))
89 | return df.reindex(columns=cols)
90 | except FileNotFoundError as e:
91 | raise KeyError(source.name) from e
92 |
--------------------------------------------------------------------------------
/myexman/parser.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import argparse
3 | import pathlib
4 | import datetime
5 | import yaml
6 | import yaml.representer
7 | import os
8 | import functools
9 | import itertools
10 | from filelock import FileLock
11 | __all__ = [
12 | 'ExParser',
13 | 'simpleroot',
14 | ]
15 |
16 |
17 | TIME_FORMAT_DIR = '%Y-%m-%d-%H-%M-%S'
18 | TIME_FORMAT = '%Y-%m-%dT%H:%M:%S'
19 | DIR_FORMAT = '{num}'
20 | EXT = 'yaml'
21 | PARAMS_FILE = 'params.'+EXT
22 | FOLDER_DEFAULT = 'exman'
23 | RESERVED_DIRECTORIES = {
24 | 'runs', 'index',
25 | 'tmp', 'marked'
26 | }
27 |
28 |
29 | def yaml_file(name):
30 | return name + '.' + EXT
31 |
32 |
33 | def simpleroot(__file__):
34 | return pathlib.Path(os.path.dirname(os.path.abspath(__file__)))/FOLDER_DEFAULT
35 |
36 |
37 | def represent_as_str(self, data, tostr=str):
38 | return yaml.representer.Representer.represent_str(self, tostr(data))
39 |
40 |
41 | def register_str_converter(*types, tostr=str):
42 | for T in types:
43 | yaml.add_representer(T, functools.partial(represent_as_str, tostr=tostr))
44 |
45 |
46 | register_str_converter(pathlib.PosixPath, pathlib.WindowsPath)
47 |
48 |
49 | def str2bool(s):
50 | true = ('true', 't', 'yes', 'y', 'on', '1')
51 | false = ('false', 'f', 'no', 'n', 'off', '0')
52 |
53 | if s.lower() in true:
54 | return True
55 | elif s.lower() in false:
56 | return False
57 | else:
58 | raise argparse.ArgumentTypeError(s, 'bool argument should be one of {}'.format(str(true + false)))
59 |
60 |
61 | class ParserWithRoot(configargparse.ArgumentParser):
62 | def __init__(self, *args, root=None, zfill=6,
63 | **kwargs):
64 | super().__init__(*args, **kwargs)
65 | if root is None:
66 | raise ValueError('Root directory is not specified')
67 | root = pathlib.Path(root)
68 | if not root.is_absolute():
69 | raise ValueError(root, 'Root directory is not absolute path')
70 | if not root.exists():
71 | raise ValueError(root, 'Root directory does not exist')
72 | self.root = pathlib.Path(root)
73 | self.zfill = zfill
74 | self.register('type', bool, str2bool)
75 | for directory in RESERVED_DIRECTORIES:
76 | getattr(self, directory).mkdir(exist_ok=True)
77 | self.lock = FileLock(str(self.root/'lock'))
78 |
79 | @property
80 | def runs(self):
81 | return self.root / 'runs'
82 |
83 | @property
84 | def marked(self):
85 | return self.root / 'marked'
86 |
87 | @property
88 | def index(self):
89 | return self.root / 'index'
90 |
91 | @property
92 | def tmp(self):
93 | return self.root / 'tmp'
94 |
95 | def max_ex(self):
96 | max_num = 0
97 | for directory in itertools.chain(self.runs.iterdir(), self.tmp.iterdir()):
98 | num = int(directory.name.split('-', 1)[0])
99 | if num > max_num:
100 | max_num = num
101 | return max_num
102 |
103 | def num_ex(self):
104 | return len(list(self.runs.iterdir()))
105 |
106 | def next_ex(self):
107 | return self.max_ex() + 1
108 |
109 | def next_ex_str(self):
110 | return str(self.next_ex()).zfill(self.zfill)
111 |
112 |
113 | class ExParser(ParserWithRoot):
114 | """
115 | Parser responsible for creating the following structure of experiments
116 | ```
117 | root
118 | |-- runs
119 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS
120 | | |-- params.yaml
121 | | `-- ...
122 | |-- index
123 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS.yaml (symlink)
124 | |-- marked
125 | | `--
126 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS (symlink)
127 | | |-- params.yaml
128 | | `-- ...
129 | `-- tmp
130 | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS
131 | |-- params.yaml
132 | `-- ...
133 | ```
134 | """
135 | def __init__(self, *args, zfill=6, file=None,
136 | args_for_setting_config_path=('--config', ),
137 | automark=(),
138 | parents=[],
139 | **kwargs):
140 |
141 | root = os.path.join(os.path.abspath(os.environ.get('EXMAN_PATH', './logs')), ('exman-' + str(file)))
142 | if not os.path.exists(root):
143 | os.makedirs(root)
144 |
145 | if len(parents) == 1:
146 | self.yaml_params_path = parents[0].yaml_params_path
147 | root = parents[0].root
148 |
149 | super().__init__(*args, root=root, zfill=zfill,
150 | args_for_setting_config_path=args_for_setting_config_path,
151 | config_file_parser_class=configargparse.YAMLConfigFileParser,
152 | ignore_unknown_config_file_keys=True,
153 | parents=parents,
154 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
155 | **kwargs)
156 | self.automark = automark
157 | if len(parents) == 0:
158 | self.add_argument('--tmp', action='store_true')
159 |
160 | def _initialize_dir(self, tmp):
161 | try:
162 | # with self.lock: # different processes can make it same time, this is needed to avoid collision
163 | time = datetime.datetime.now()
164 | num = self.next_ex_str()
165 | name = DIR_FORMAT.format(num=num, time=time.strftime(TIME_FORMAT_DIR))
166 | if tmp:
167 | absroot = self.tmp / name
168 | relroot = pathlib.Path('tmp') / name
169 | else:
170 | absroot = self.runs / name
171 | relroot = pathlib.Path('runs') / name
172 | # this process now safely owns root directory
173 | # raises FileExistsError on fail
174 | absroot.mkdir()
175 | except FileExistsError: # shit still happens
176 | return self._initialize_dir(tmp)
177 | return absroot, relroot, name, time, num
178 |
179 | def parse_known_args(self, *args, log_params=True, **kwargs):
180 | args, argv = super().parse_known_args(*args, **kwargs)
181 | if not log_params:
182 | return args, argv
183 |
184 | if hasattr(self, 'yaml_params_path'):
185 | with self.yaml_params_path.open('w') as f:
186 | self.dumpd = args.__dict__.copy()
187 | yaml.dump(self.dumpd, f, default_flow_style=False)
188 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f)
189 | print("id:", int(self.num), file=f)
190 | print(self.yaml_params_path.read_text())
191 | return args, argv
192 |
193 | absroot, relroot, name, time, num = self._initialize_dir(args.tmp)
194 | self.time = time
195 | self.num = num
196 | args.root = absroot
197 | self.yaml_params_path = args.root / PARAMS_FILE
198 | rel_yaml_params_path = pathlib.Path('..', 'runs', name, PARAMS_FILE)
199 | with self.yaml_params_path.open('a') as f:
200 | self.dumpd = args.__dict__.copy()
201 | # dumpd['root'] = relroot
202 | yaml.dump(self.dumpd, f, default_flow_style=False)
203 | print("\ntime: '{}'".format(time.strftime(TIME_FORMAT)), file=f)
204 | print("id:", int(num), file=f)
205 | print(self.yaml_params_path.read_text())
206 | symlink = self.index / yaml_file(name)
207 | if not args.tmp:
208 | symlink.symlink_to(rel_yaml_params_path)
209 | print('Created symlink from', symlink, '->', rel_yaml_params_path)
210 | if self.automark and not args.tmp:
211 | automark_path_part = pathlib.Path(*itertools.chain.from_iterable(
212 | (mark, str(getattr(args, mark, '')))
213 | for mark in self.automark))
214 | markpath = pathlib.Path(self.marked, automark_path_part)
215 | markpath.mkdir(exist_ok=True, parents=True)
216 | relpathmark = pathlib.Path('..', *(['..']*len(automark_path_part.parts))) / 'runs' / name
217 | (markpath / name).symlink_to(relpathmark, target_is_directory=True)
218 | print('Created symlink from', markpath / name, '->', relpathmark)
219 | return args, argv
220 |
221 | def done(self):
222 | print('Success.')
223 | self.dumpd['status'] = 'done'
224 | with self.yaml_params_path.open('a') as f:
225 | yaml.dump(self.dumpd, f, default_flow_style=False)
226 |
227 | def update_params_file(self, args):
228 | dumpd = args.__dict__.copy()
229 | with self.yaml_params_path.open('w') as f:
230 | yaml.dump(dumpd, f, default_flow_style=False)
231 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f)
232 | print("id:", int(self.num), file=f)
233 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 |
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import models
8 | from utils.logger import Logger
9 | import myexman
10 | from utils import utils
11 | import sys
12 | import torch.multiprocessing as mp
13 | import torch.distributed as dist
14 | import socket
15 |
16 |
17 | def add_learner_params(parser):
18 | parser.add_argument('--problem', default='sim-clr',
19 | help='The problem to train',
20 | choices=models.REGISTERED_MODELS,
21 | )
22 | parser.add_argument('--name', default='',
23 | help='Name for the experiment',
24 | )
25 | parser.add_argument('--ckpt', default='',
26 | help='Optional checkpoint to init the model.'
27 | )
28 | parser.add_argument('--verbose', default=False, type=bool)
29 | # optimizer params
30 | parser.add_argument('--lr_schedule', default='warmup-anneal')
31 | parser.add_argument('--opt', default='lars', help='Optimizer to use', choices=['sgd', 'adam', 'lars'])
32 | parser.add_argument('--iters', default=-1, type=int, help='The number of optimizer updates')
33 | parser.add_argument('--warmup', default=0, type=float, help='The number of warmup iterations in proportion to \'iters\'')
34 | parser.add_argument('--lr', default=0.1, type=float, help='Base learning rate')
35 | parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, dest='weight_decay')
36 | # trainer params
37 | parser.add_argument('--save_freq', default=10000000000000000, type=int, help='Frequency to save the model')
38 | parser.add_argument('--log_freq', default=100, type=int, help='Logging frequency')
39 | parser.add_argument('--eval_freq', default=10000000000000000, type=int, help='Evaluation frequency')
40 | parser.add_argument('-j', '--workers', default=4, type=int, help='The number of data loader workers')
41 | parser.add_argument('--eval_only', default=False, type=bool, help='Skips the training step if True')
42 | parser.add_argument('--seed', default=-1, type=int, help='Random seed')
43 | # parallelizm params:
44 | parser.add_argument('--dist', default='dp', type=str,
45 | help='dp: DataParallel, ddp: DistributedDataParallel',
46 | choices=['dp', 'ddp'],
47 | )
48 | parser.add_argument('--dist_address', default='127.0.0.1:1234', type=str,
49 | help='the address and a port of the main node in the : format'
50 | )
51 | parser.add_argument('--node_rank', default=0, type=int,
52 | help='Rank of the node (script launched): 0 for the main node and 1,... for the others',
53 | )
54 | parser.add_argument('--world_size', default=1, type=int,
55 | help='the number of nodes (scripts launched)',
56 | )
57 |
58 |
59 | def main():
60 | parser = myexman.ExParser(file=os.path.basename(__file__))
61 | add_learner_params(parser)
62 |
63 | is_help = False
64 | if '--help' in sys.argv or '-h' in sys.argv:
65 | sys.argv.pop(sys.argv.index('--help' if '--help' in sys.argv else '-h'))
66 | is_help = True
67 |
68 | args, _ = parser.parse_known_args(log_params=False)
69 |
70 | models.REGISTERED_MODELS[args.problem].add_model_hparams(parser)
71 |
72 | if is_help:
73 | sys.argv.append('--help')
74 |
75 | args = parser.parse_args(namespace=args)
76 |
77 | if args.data == 'imagenet' and args.aug == False:
78 | raise Exception('ImageNet models should be eval with aug=True!')
79 |
80 | if args.seed != -1:
81 | random.seed(args.seed)
82 | torch.manual_seed(args.seed)
83 | cudnn.deterministic = True
84 |
85 | args.gpu = 0
86 | ngpus = torch.cuda.device_count()
87 | args.number_of_processes = 1
88 | if args.dist == 'ddp':
89 | # add additional argument to be able to retrieve # of processes from logs
90 | # and don't change initial arguments to reproduce the experiment
91 | args.number_of_processes = args.world_size * ngpus
92 | parser.update_params_file(args)
93 |
94 | args.world_size *= ngpus
95 | mp.spawn(
96 | main_worker,
97 | nprocs=ngpus,
98 | args=(ngpus, args),
99 | )
100 | else:
101 | parser.update_params_file(args)
102 | main_worker(args.gpu, -1, args)
103 |
104 |
105 | def main_worker(gpu, ngpus, args):
106 | fmt = {
107 | 'train_time': '.3f',
108 | 'val_time': '.3f',
109 | 'lr': '.1e',
110 | }
111 | logger = Logger('logs', base=args.root, fmt=fmt)
112 |
113 | args.gpu = gpu
114 | torch.cuda.set_device(gpu)
115 | args.rank = args.node_rank * ngpus + gpu
116 |
117 | device = torch.device('cuda:%d' % args.gpu)
118 |
119 | if args.dist == 'ddp':
120 | dist.init_process_group(
121 | backend='nccl',
122 | init_method='tcp://%s' % args.dist_address,
123 | world_size=args.world_size,
124 | rank=args.rank,
125 | )
126 |
127 | n_gpus_total = dist.get_world_size()
128 | assert args.batch_size % n_gpus_total == 0
129 | args.batch_size //= n_gpus_total
130 | if args.rank == 0:
131 | print(f'===> {n_gpus_total} GPUs total; batch_size={args.batch_size} per GPU')
132 |
133 | print(f'===> Proc {dist.get_rank()}/{dist.get_world_size()}@{socket.gethostname()}', flush=True)
134 |
135 | # create model
136 | model = models.REGISTERED_MODELS[args.problem](args, device=device)
137 |
138 | if args.ckpt != '':
139 | ckpt = torch.load(args.ckpt, map_location=device)
140 | model.load_state_dict(ckpt['state_dict'])
141 |
142 | # Data loading code
143 | model.prepare_data()
144 | train_loader, val_loader = model.dataloaders(iters=args.iters)
145 |
146 | # define optimizer
147 | cur_iter = 0
148 | optimizer, scheduler = models.ssl.configure_optimizers(args, model, cur_iter - 1)
149 |
150 | # optionally resume from a checkpoint
151 | if args.ckpt and not args.eval_only:
152 | optimizer.load_state_dict(ckpt['opt_state_dict'])
153 |
154 | cudnn.benchmark = True
155 |
156 | continue_training = args.iters != 0
157 | data_time, it_time = 0, 0
158 |
159 | while continue_training:
160 | train_logs = []
161 | model.train()
162 |
163 | start_time = time.time()
164 | for _, batch in enumerate(train_loader):
165 | cur_iter += 1
166 |
167 | batch = [x.to(device) for x in batch]
168 | data_time += time.time() - start_time
169 |
170 | logs = {}
171 | if not args.eval_only:
172 | # forward pass and compute loss
173 | logs = model.train_step(batch, cur_iter)
174 | loss = logs['loss']
175 |
176 | # gradient step
177 | optimizer.zero_grad()
178 | loss.backward()
179 | optimizer.step()
180 |
181 | # save logs for the batch
182 | train_logs.append({k: utils.tonp(v) for k, v in logs.items()})
183 |
184 | if cur_iter % args.save_freq == 0 and args.rank == 0:
185 | save_checkpoint(args.root, model, optimizer, cur_iter)
186 |
187 | if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters:
188 | # TODO: aggregate metrics over all processes
189 | test_logs = []
190 | model.eval()
191 | with torch.no_grad():
192 | for batch in val_loader:
193 | batch = [x.to(device) for x in batch]
194 | # forward pass
195 | logs = model.test_step(batch)
196 | # save logs for the batch
197 | test_logs.append(logs)
198 | model.train()
199 |
200 | test_logs = utils.agg_all_metrics(test_logs)
201 | logger.add_logs(cur_iter, test_logs, pref='test_')
202 |
203 | it_time += time.time() - start_time
204 |
205 | if (cur_iter % args.log_freq == 0 or cur_iter >= args.iters) and args.rank == 0:
206 | save_checkpoint(args.root, model, optimizer)
207 | train_logs = utils.agg_all_metrics(train_logs)
208 |
209 | logger.add_logs(cur_iter, train_logs, pref='train_')
210 | logger.add_scalar(cur_iter, 'lr', optimizer.param_groups[0]['lr'])
211 | logger.add_scalar(cur_iter, 'data_time', data_time)
212 | logger.add_scalar(cur_iter, 'it_time', it_time)
213 | logger.iter_info()
214 | logger.save()
215 |
216 | data_time, it_time = 0, 0
217 | train_logs = []
218 |
219 | if scheduler is not None:
220 | scheduler.step()
221 |
222 | if cur_iter >= args.iters:
223 | continue_training = False
224 | break
225 |
226 | start_time = time.time()
227 |
228 | save_checkpoint(args.root, model, optimizer)
229 |
230 | if args.dist == 'ddp':
231 | dist.destroy_process_group()
232 |
233 |
234 | def save_checkpoint(path, model, optimizer, cur_iter=None):
235 | if cur_iter is None:
236 | fname = os.path.join(path, 'checkpoint.pth.tar')
237 | else:
238 | fname = os.path.join(path, 'checkpoint-%d.pth.tar' % cur_iter)
239 |
240 | ckpt = model.get_ckpt()
241 | ckpt.update(
242 | {
243 | 'opt_state_dict': optimizer.state_dict(),
244 | 'iter': cur_iter,
245 | }
246 | )
247 |
248 | torch.save(ckpt, fname)
249 |
250 |
251 | if __name__ == '__main__':
252 | main()
253 |
--------------------------------------------------------------------------------
/utils/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | from torchvision import transforms
5 | import torch.utils.data
6 | import PIL
7 | import torchvision.transforms.functional as FT
8 | from PIL import Image
9 |
10 |
11 | if 'DATA_ROOT' in os.environ:
12 | DATA_ROOT = os.environ['DATA_ROOT']
13 | else:
14 | DATA_ROOT = './data'
15 |
16 | IMAGENET_PATH = './data/imagenet/raw-data'
17 |
18 |
19 | def pad(img, size, mode):
20 | if isinstance(img, PIL.Image.Image):
21 | img = np.array(img)
22 | return np.pad(img, [(size, size), (size, size), (0, 0)], mode)
23 |
24 |
25 | mean = {
26 | 'mnist': (0.1307,),
27 | 'cifar10': (0.4914, 0.4822, 0.4465)
28 | }
29 |
30 | std = {
31 | 'mnist': (0.3081,),
32 | 'cifar10': (0.2470, 0.2435, 0.2616)
33 | }
34 |
35 |
36 | class GaussianBlur(object):
37 | """
38 | PyTorch version of
39 | https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311
40 | """
41 | def gaussian_blur(self, image, sigma):
42 | image = image.reshape(1, 3, 224, 224)
43 | radius = np.int(self.kernel_size/2)
44 | kernel_size = radius * 2 + 1
45 | x = np.arange(-radius, radius + 1)
46 |
47 | blur_filter = np.exp(
48 | -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0)))
49 | blur_filter /= np.sum(blur_filter)
50 |
51 | conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, padding=[kernel_size//2, 0], bias=False)
52 | conv1.weight = torch.nn.Parameter(
53 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 0, 1])))
54 |
55 | conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, padding=[0, kernel_size//2], bias=False)
56 | conv2.weight = torch.nn.Parameter(
57 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 1, 0])))
58 |
59 | res = conv2(conv1(image))
60 | assert res.shape == image.shape
61 | return res[0]
62 |
63 | def __init__(self, kernel_size, p=0.5):
64 | self.kernel_size = kernel_size
65 | self.p = p
66 |
67 | def __call__(self, img):
68 | with torch.no_grad():
69 | assert isinstance(img, torch.Tensor)
70 | if np.random.uniform() < self.p:
71 | return self.gaussian_blur(img, sigma=np.random.uniform(0.2, 2))
72 | return img
73 |
74 | def __repr__(self):
75 | return self.__class__.__name__ + '(kernel_size={0}, p={1})'.format(self.kernel_size, self.p)
76 |
77 | class CenterCropAndResize(object):
78 | """Crops the given PIL Image at the center.
79 |
80 | Args:
81 | size (sequence or int): Desired output size of the crop. If size is an
82 | int instead of sequence like (h, w), a square crop (size, size) is
83 | made.
84 | """
85 |
86 | def __init__(self, proportion, size):
87 | self.proportion = proportion
88 | self.size = size
89 |
90 | def __call__(self, img):
91 | """
92 | Args:
93 | img (PIL Image): Image to be cropped.
94 |
95 | Returns:
96 | PIL Image: Cropped and image.
97 | """
98 | w, h = (np.array(img.size) * self.proportion).astype(int)
99 | img = FT.resize(
100 | FT.center_crop(img, (h, w)),
101 | (self.size, self.size),
102 | interpolation=PIL.Image.BICUBIC
103 | )
104 | return img
105 |
106 | def __repr__(self):
107 | return self.__class__.__name__ + '(proportion={0}, size={1})'.format(self.proportion, self.size)
108 |
109 |
110 | class Clip(object):
111 | def __call__(self, x):
112 | return torch.clamp(x, 0, 1)
113 |
114 |
115 | class MultiplyBatchSampler(torch.utils.data.sampler.BatchSampler):
116 | MULTILPLIER = 2
117 |
118 | def __iter__(self):
119 | for batch in super().__iter__():
120 | yield batch * self.MULTILPLIER
121 |
122 |
123 | class ContinousSampler(torch.utils.data.sampler.Sampler):
124 | def __init__(self, sampler, n_iterations):
125 | self.base_sampler = sampler
126 | self.n_iterations = n_iterations
127 |
128 | def __iter__(self):
129 | cur_iter = 0
130 | while cur_iter < self.n_iterations:
131 | for batch in self.base_sampler:
132 | yield batch
133 | cur_iter += 1
134 | if cur_iter >= self.n_iterations: return
135 |
136 | def __len__(self):
137 | return self.n_iterations
138 |
139 | def set_epoch(self, epoch):
140 | self.base_sampler.set_epoch(epoch)
141 |
142 |
143 | def get_color_distortion(s=1.0):
144 | # s is the strength of color distortion.
145 | # given from https://arxiv.org/pdf/2002.05709.pdf
146 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
147 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
148 | rnd_gray = transforms.RandomGrayscale(p=0.2)
149 | color_distort = transforms.Compose([
150 | rnd_color_jitter,
151 | rnd_gray])
152 | return color_distort
153 |
154 |
155 | class DummyOutputWrapper(torch.utils.data.dataset.Dataset):
156 | def __init__(self, dataset, dummy):
157 | self.dummy = dummy
158 | self.dataset = dataset
159 |
160 | def __getitem__(self, index):
161 | return (*self.dataset[index], self.dummy)
162 |
163 | def __len__(self):
164 | return len(self.dataset)
165 |
--------------------------------------------------------------------------------
/utils/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from torch.nn.parameter import Parameter
5 |
6 |
7 | class LARS(object):
8 | """
9 | Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
10 | Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py
11 |
12 | Args:
13 | optimizer: Pytorch optimizer to wrap and modify learning rate for.
14 | trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888
15 | """
16 |
17 | def __init__(self,
18 | optimizer,
19 | trust_coefficient=0.001,
20 | ):
21 | self.param_groups = optimizer.param_groups
22 | self.optim = optimizer
23 | self.trust_coefficient = trust_coefficient
24 |
25 | def __getstate__(self):
26 | return self.optim.__getstate__()
27 |
28 | def __setstate__(self, state):
29 | self.optim.__setstate__(state)
30 |
31 | def __repr__(self):
32 | return self.optim.__repr__()
33 |
34 | def state_dict(self):
35 | return self.optim.state_dict()
36 |
37 | def load_state_dict(self, state_dict):
38 | self.optim.load_state_dict(state_dict)
39 |
40 | def zero_grad(self):
41 | self.optim.zero_grad()
42 |
43 | def add_param_group(self, param_group):
44 | self.optim.add_param_group(param_group)
45 |
46 | def step(self):
47 | with torch.no_grad():
48 | weight_decays = []
49 | for group in self.optim.param_groups:
50 | # absorb weight decay control from optimizer
51 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
52 | weight_decays.append(weight_decay)
53 | group['weight_decay'] = 0
54 | for p in group['params']:
55 | if p.grad is None:
56 | continue
57 |
58 | if weight_decay != 0:
59 | p.grad.data += weight_decay * p.data
60 |
61 | param_norm = torch.norm(p.data)
62 | grad_norm = torch.norm(p.grad.data)
63 | adaptive_lr = 1.
64 |
65 | if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']:
66 | adaptive_lr = self.trust_coefficient * param_norm / grad_norm
67 |
68 | p.grad.data *= adaptive_lr
69 |
70 | self.optim.step()
71 | # return weight decay control to optimizer
72 | for i, group in enumerate(self.optim.param_groups):
73 | group['weight_decay'] = weight_decays[i]
74 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import numpy as np
5 |
6 | from collections import OrderedDict
7 | from tabulate import tabulate
8 | from pandas import DataFrame
9 | from time import gmtime, strftime
10 | import time
11 |
12 |
13 | class Logger:
14 | def __init__(self, name='name', fmt=None, base='./logs'):
15 | self.handler = True
16 | self.scalar_metrics = OrderedDict()
17 | self.fmt = fmt if fmt else dict()
18 |
19 | if not os.path.exists(base):
20 | os.makedirs(base)
21 |
22 | time = gmtime()
23 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)])
24 | fname = '-'.join(sys.argv[0].split('/')[-3:])
25 | # self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time))
26 | # self.path = '%s/%s-%s' % (base, fname, name)
27 | self.path = os.path.join(base, name)
28 |
29 | self.logs = self.path + '.csv'
30 | self.output = self.path + '.out'
31 | self.iters_since_last_header = 0
32 |
33 | def prin(*args):
34 | str_to_write = ' '.join(map(str, args))
35 | with open(self.output, 'a') as f:
36 | f.write(str_to_write + '\n')
37 | f.flush()
38 |
39 | print(str_to_write)
40 | sys.stdout.flush()
41 |
42 | self.print = prin
43 |
44 | def add_scalar(self, t, key, value):
45 | if key not in self.scalar_metrics:
46 | self.scalar_metrics[key] = []
47 | self.scalar_metrics[key] += [(t, value)]
48 |
49 | def add_logs(self, t, logs, pref=''):
50 | for k, v in logs.items():
51 | self.add_scalar(t, pref + k, v)
52 |
53 | def iter_info(self, order=None):
54 | self.iters_since_last_header += 1
55 | if self.iters_since_last_header > 40:
56 | self.handler = True
57 |
58 | names = list(self.scalar_metrics.keys())
59 | if order:
60 | names = order
61 | values = [self.scalar_metrics[name][-1][1] for name in names]
62 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names]))
63 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.3f' for name in names]
64 |
65 | if self.handler:
66 | self.handler = False
67 | self.iters_since_last_header = 0
68 | self.print(tabulate([[t] + values], ['t'] + names, floatfmt=fmt))
69 | else:
70 | self.print(tabulate([[t] + values], ['t'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1])
71 |
72 | def save(self):
73 | result = None
74 | for key in self.scalar_metrics.keys():
75 | if result is None:
76 | result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
77 | else:
78 | df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t')
79 | result = result.join(df, how='outer')
80 | result.to_csv(self.logs)
81 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import warnings
4 | import time
5 | import torch.distributed as dist
6 |
7 |
8 | def timing(f):
9 | def wrap(*args, **kwargs):
10 | time1 = time.time()
11 | ret = f(*args, **kwargs)
12 | time2 = time.time()
13 | print('{:s} function took {:.3f} ms'.format(f.__name__, (time2-time1)*1000.0))
14 |
15 | return ret
16 | return wrap
17 |
18 |
19 | def agg_all_metrics(outputs):
20 | if len(outputs) == 0:
21 | return outputs
22 | res = {}
23 | keys = [k for k in outputs[0].keys() if not isinstance(outputs[0][k], dict)]
24 | for k in keys:
25 | all_logs = np.concatenate([tonp(x[k]).reshape(-1) for x in outputs])
26 | if k != 'epoch':
27 | res[k] = np.mean(all_logs)
28 | else:
29 | res[k] = all_logs[-1]
30 | return res
31 |
32 |
33 | def gather_metrics(metrics):
34 | for k, v in metrics.items():
35 | if v.dim() == 0:
36 | v = v[None]
37 | v_all = [torch.zeros_like(v) for _ in range(dist.get_world_size())]
38 | dist.all_gather(v_all, v)
39 | v_all = torch.cat(v_all)
40 | metrics[k] = v_all
41 |
42 |
43 | def viz_array_grid(array, rows, cols, padding=0, channels_last=False, normalize=False, **kwargs):
44 | # normalization
45 | '''
46 | Args:
47 | array: (N_images, N_channels, H, W) or (N_images, H, W, N_channels)
48 | rows, cols: rows and columns of the plot. rows * cols == array.shape[0]
49 | padding: padding between cells of plot
50 | channels_last: for Tensorflow = True, for PyTorch = False
51 | normalize: `False`, `mean_std`, or `min_max`
52 | Kwargs:
53 | if normalize == 'mean_std':
54 | mean: mean of the distribution. Default 0.5
55 | std: std of the distribution. Default 0.5
56 | if normalize == 'min_max':
57 | min: min of the distribution. Default array.min()
58 | max: max if the distribution. Default array.max()
59 | '''
60 | array = tonp(array)
61 | if not channels_last:
62 | array = np.transpose(array, (0, 2, 3, 1))
63 |
64 | array = array.astype('float32')
65 |
66 | if normalize:
67 | if normalize == 'mean_std':
68 | mean = kwargs.get('mean', 0.5)
69 | mean = np.array(mean).reshape((1, 1, 1, -1))
70 | std = kwargs.get('std', 0.5)
71 | std = np.array(std).reshape((1, 1, 1, -1))
72 | array = array * std + mean
73 | elif normalize == 'min_max':
74 | min_ = kwargs.get('min', array.min())
75 | min_ = np.array(min_).reshape((1, 1, 1, -1))
76 | max_ = kwargs.get('max', array.max())
77 | max_ = np.array(max_).reshape((1, 1, 1, -1))
78 | array -= min_
79 | array /= max_ + 1e-9
80 |
81 | batch_size, H, W, channels = array.shape
82 | assert rows * cols == batch_size
83 |
84 | if channels == 1:
85 | canvas = np.ones((H * rows + padding * (rows - 1),
86 | W * cols + padding * (cols - 1)))
87 | array = array[:, :, :, 0]
88 | elif channels == 3:
89 | canvas = np.ones((H * rows + padding * (rows - 1),
90 | W * cols + padding * (cols - 1),
91 | 3))
92 | else:
93 | raise TypeError('number of channels is either 1 of 3')
94 |
95 | for i in range(rows):
96 | for j in range(cols):
97 | img = array[i * cols + j]
98 | start_h = i * padding + i * H
99 | start_w = j * padding + j * W
100 | canvas[start_h: start_h + H, start_w: start_w + W] = img
101 |
102 | canvas = np.clip(canvas, 0, 1)
103 | canvas *= 255.0
104 | canvas = canvas.astype('uint8')
105 | return canvas
106 |
107 |
108 | def tonp(x):
109 | if isinstance(x, (np.ndarray, float, int)):
110 | return np.array(x)
111 | return x.detach().cpu().numpy()
112 |
113 |
114 | class LinearLR(torch.optim.lr_scheduler._LRScheduler):
115 | def __init__(self, optimizer, num_epochs, last_epoch=-1):
116 | self.num_epochs = max(num_epochs, 1)
117 | super().__init__(optimizer, last_epoch)
118 |
119 | def get_lr(self):
120 | res = []
121 | for lr in self.base_lrs:
122 | res.append(np.maximum(lr * np.minimum(-self.last_epoch * 1. / self.num_epochs + 1., 1.), 0.))
123 | return res
124 |
125 |
126 | class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler):
127 | def __init__(self, optimizer, warm_up, T_max, last_epoch=-1, smooth=1e-9):
128 | self.warm_up = int(warm_up * T_max)
129 | self.T_max = T_max - self.warm_up
130 | self.smooth = smooth
131 | super().__init__(optimizer, last_epoch=last_epoch)
132 |
133 | def get_lr(self):
134 | if not self._get_lr_called_within_step:
135 | warnings.warn("To get the last learning rate computed by the scheduler, "
136 | "please use `get_last_lr()`.")
137 |
138 | if self.last_epoch == 0:
139 | return [lr / (self.warm_up + 1) for lr in self.base_lrs]
140 | elif self.last_epoch <= self.warm_up:
141 | c = (self.last_epoch + 1) / self.last_epoch
142 | return [group['lr'] * c for group in self.optimizer.param_groups]
143 | else:
144 | # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493
145 | le = self.last_epoch - self.warm_up
146 |
147 | if le > self.T_max:
148 | warnings.warn(f"Epoch {self.last_epoch}: reached maximum number of iterations {self.T_max + self.warm_up}. This is unexpected behavior, and this SimCLR implementation was not tested in this regime!")
149 |
150 | return [(1 + np.cos(np.pi * le / self.T_max)) /
151 | (1 + np.cos(np.pi * (le - 1) / self.T_max) + self.smooth) *
152 | group['lr']
153 | for group in self.optimizer.param_groups]
154 |
155 |
156 | class BaseLR(torch.optim.lr_scheduler._LRScheduler):
157 | def get_lr(self):
158 | return [group['lr'] for group in self.optimizer.param_groups]
159 |
--------------------------------------------------------------------------------