├── .gitignore ├── LICENSE.md ├── README.md ├── colabs ├── model_apply.ipynb └── read_logs.ipynb ├── configs ├── cifar_eval.yaml ├── cifar_train_epochs1000_bs1024.yaml ├── imagenet_eval.yaml ├── imagenet_train_epochs100_bs512.yaml ├── imagenet_train_epochs200_bs2k.yaml └── imagenet_train_epochs600_bs2k.yaml ├── environment.yml ├── models ├── __init__.py ├── encoder.py ├── losses.py ├── resnet.py └── ssl.py ├── myexman ├── __init__.py ├── index.py └── parser.py ├── train.py └── utils ├── datautils.py ├── lars_optimizer.py ├── logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *__pycache__* 3 | *pretrained_models/* 4 | *logs 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Andrei Atanov*, Arsenii Ashukha; Bayesian Methods Research Group, Samsung AI Center Moscow, Samsung-HSE Laboratory, EPFL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimCLR PyTorch 2 | 3 | This is an unofficial repository reproducing results of the paper [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709). The implementation supports multi-GPU distributed training on several nodes with PyTorch `DistributedDataParallel`. 4 | 5 | ## How close are we to the original SimCLR? 6 | 7 | The implementation closely reproduces the original ResNet50 results on ImageNet and CIFAR-10. 8 | 9 |

10 | 11 |

12 | 13 | | Dataset | Batch Size | \# Epochs | Training GPUs | Training time | Top\-1 accuracy of Linear evaluation (100% labels)| Reference | 14 | |----------|------------|-----------|---------------|---------------|-----------------------------------|------------| 15 | | CIFAR-10 | 1024 | 1000 | 2v100 | 13h | 93\.44 | 93.95 | 16 | | ImageNet | 512 | 100 | 4v100 | 85h | 60\.14 | 60.62 | 17 | | ImageNet | 2048 | 200 | 16v100 | 55h | 65\.58 | 65.83 | 18 | | ImageNet | 2048 | 600 | 16v100 | 170h | 67\.84 | 68.71 | 19 | 20 | ## Pre-trained weights 21 | 22 | Try out a pre-trained models [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AndrewAtanov/simclr-pytorch/blob/master/colabs/model_apply.ipynb) 23 | 24 | You can download pre-trained weights from [here](https://drive.google.com/file/d/13tjpWYTzV8qLB5yY5raBn5cwtIyFtt6-/view?usp=sharing). 25 | 26 | To eval the preatrained CIFAR-10 linear model and encoder use the following command: 27 | ```(bash) 28 | python train.py --problem eval --eval_only true --iters 1 --arch linear \ 29 | --ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000_linear.pth.tar \ 30 | --encoder_ckpt pretrained_models/resnet50_cifar10_bs1024_epochs1000.pth.tar 31 | ``` 32 | 33 | To eval the preatrained ImageNet linear model and encoder use the following command: 34 | ```(bash) 35 | export IMAGENET_PATH=.../raw-data 36 | python train.py --problem eval --eval_only true --iters 1 --arch linear --data imagenet \ 37 | --ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600_linear.pth.tar \ 38 | --encoder_ckpt pretrained_models/resnet50_imagenet_bs2k_epochs600.pth.tar 39 | ``` 40 | 41 | ## Enviroment Setup 42 | 43 | 44 | Create a python enviroment with the provided config file and [miniconda](https://docs.conda.io/en/latest/miniconda.html): 45 | 46 | ```(bash) 47 | conda env create -f environment.yml 48 | conda activate simclr_pytorch 49 | 50 | export IMAGENET_PATH=... # If you have enough RAM using /dev/shm usually accelerates data loading time 51 | export EXMAN_PATH=... # A path to logs 52 | ``` 53 | 54 | ## Training 55 | Model training consists of two steps: (1) self-supervised encoder pretraining and (2) classifier learning with the encoder representations. Both steps are done with the `train.py` script. To see the help for `sim-clr/eval` problem call the following command: `python source/train.py --help --problem sim-clr/eval`. 56 | 57 | ### Self-supervised pretraining 58 | 59 | #### CIFAR-10 60 | The config `cifar_train_epochs1000_bs1024.yaml` contains the parameters to reproduce results for CIFAR-10 dataset. It requires 2 V100 GPUs. The pretraining command is: 61 | 62 | ```(bash) 63 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml 64 | ``` 65 | 66 | #### ImageNet 67 | The configs `imagenet_params_epochs*_bs*.yaml` contain the parameters to reproduce results for ImageNet dataset. It requires at 4v100-16v100 GPUs depending on a batch size. The single-node (4 v100 GPUs) pretraining command is: 68 | 69 | ```(bash) 70 | python train.py --config configs/imagenet_train_epochs100_bs512.yaml 71 | ``` 72 | 73 | #### Logs 74 | The logs and the model will be stored at `./logs/exman-train.py/runs//`. You can access all the experiments from python with `exman.Index('./logs/exman-train.py').info()`. 75 | 76 | See how to work with logs [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AndrewAtanov/simclr-pytorch/blob/master/colabs/read_logs.ipynb) 77 | 78 | ### Linear Evaluation 79 | To train a linear classifier on top of the pretrained encoder, run the following command: 80 | 81 | ```(bash) 82 | python train.py --config configs/cifar_eval.yaml --encoder_ckpt 83 | ``` 84 | 85 | The above model with batch size 1024 gives `93.5` linear eval test accuracy. 86 | 87 | ### Pretraining with `DistributedDataParallel` 88 | To train a model with larger batch size on several nodes you need to set `--dist ddp` flag and specify the following parameters: 89 | - `--dist_address`: the address and a port of the main node in the `
:` format 90 | - `--node_rank`: 0 for the main node and 1,... for the others. 91 | - `--world_size`: the number of nodes. 92 | 93 | For example, to train with two nodes you need to run the following command on the main node: 94 | ```(bash) 95 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address
: --node_rank 0 --world_size 2 96 | ``` 97 | and on the second node: 98 | ```(bash) 99 | python train.py --config configs/cifar_train_epochs1000_bs1024.yaml --dist ddp --dist_address
: --node_rank 1 --world_size 2 100 | ``` 101 | 102 | The ImageNet the pretaining on 4 nodes all with 4 GPUs looks as follows: 103 | ``` 104 | node1: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 0 105 | node2: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 1 106 | node3: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 2 107 | node4: python train.py --config configs/imagenet_train_epochs200_bs2k.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 3 108 | ``` 109 | 110 | ## Attribution 111 | Parts of this code are based on the following repositories:v 112 | - [PyTorch](https://github.com/pytorch/pytorch), [PyTorch Examples](https://github.com/pytorch/examples/tree/ee964a2/imagenet), [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) for standard backbones, training loops, etc. 113 | - [SimCLR - A Simple Framework for Contrastive Learning of Visual Representations](https://github.com/google-research/simclr) for more details on the original implementation 114 | - [diffdist](https://github.com/ag14774/diffdist) for multi-gpu contrastive loss implementation, allows backpropagation through `all_gather` operation (see [models/losses.py#L58](https://github.com/AndrewAtanov/simclr-pytorch/blob/master/models/losses.py#L62)) 115 | - [Experiment Manager (exman)](https://github.com/ferrine/exman) a tool that distributes logs, checkpoints, and parameters-dicts via folders, and allows to load them in a pandas DataFrame, that is handly for processing in ipython notebooks. 116 | - [NVIDIA APEX](https://github.com/NVIDIA/apex) for LARS optimizer. We modeified LARC to make it consistent with SimCLR repo. 117 | 118 | ## Acknowledgements 119 | - This work was supported in part through computational resources of HPC facilities at NRU HSE 120 | -------------------------------------------------------------------------------- /colabs/read_logs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "PjSN8gOUIQ1t" 7 | }, 8 | "source": [ 9 | "\n", 10 | "# Experiment Manager" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": { 17 | "colab": { 18 | "base_uri": "https://localhost:8080/" 19 | }, 20 | "id": "CkMiXmImIhUN", 21 | "outputId": "c6604c8a-ecd1-4170-d7b6-42bc06ca4977" 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 29 | " Dload Upload Total Spent Left Speed\n", 30 | " 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0\n", 31 | "100 869M 0 869M 0 0 9.8M 0 --:--:-- 0:01:28 --:--:-- 10.2M\n", 32 | "Archive: logs.zip\n", 33 | " inflating: logs/exman-train.py/index/000002.yaml \n", 34 | " inflating: logs/exman-train.py/index/000004.yaml \n", 35 | " inflating: logs/exman-train.py/index/000010.yaml \n", 36 | " inflating: logs/exman-train.py/index/000012.yaml \n", 37 | " inflating: logs/exman-train.py/index/000023.yaml \n", 38 | " inflating: logs/exman-train.py/index/000027.yaml \n", 39 | " inflating: logs/exman-train.py/index/000030.yaml \n", 40 | " inflating: logs/exman-train.py/index/000031.yaml \n", 41 | " inflating: logs/exman-train.py/index/000033.yaml \n", 42 | "replace logs/exman-train.py/runs/000002/checkpoint.pth.tar? [y]es, [n]o, [A]ll, [N]one, [r]ename: " 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "!!pip install diffdist wldhx.yadisk-direct configargparse strconv\n", 48 | "!curl -L $(yadisk-direct https://yadi.sk/d/GYMBGjXGQr9oFw?w=1) -o logs.zip\n", 49 | "!unzip logs.zip > unzip.out" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "id": "bamF1nUS80W0" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "!git clone https://github.com/AndrewAtanov/simclr-pytorch.git" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import sys\n", 70 | "sys.path.append('./simclr-pytorch')" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 32, 76 | "metadata": { 77 | "id": "90vfPvSRIQ1u" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "import myexman\n", 82 | "import pandas as pd\n", 83 | "\n", 84 | "index = myexman.Index('./logs/exman-train.py').info().set_index('id')\n", 85 | "index.root = index.root.apply(lambda x: str(x).replace('/home/aashukha/simclr-pytorch/', ''))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 33, 91 | "metadata": { 92 | "colab": { 93 | "base_uri": "https://localhost:8080/" 94 | }, 95 | "id": "xniSdGb0IQ1u", 96 | "outputId": "821e111c-a660-4c20-9c96-930251e13e34" 97 | }, 98 | "outputs": [ 99 | { 100 | "data": { 101 | "text/plain": [ 102 | "Index(['arch', 'aug', 'augmentation', 'batch_size', 'ckpt', 'config_file',\n", 103 | " 'data', 'dist', 'dist_address', 'encoder_ckpt', 'eval_freq', 'finetune',\n", 104 | " 'gpu', 'iters', 'log_freq', 'lr', 'lr_schedule', 'name', 'node_rank',\n", 105 | " 'number_of_processes', 'opt', 'precompute_emb_bs', 'problem', 'root',\n", 106 | " 'save_freq', 'scale_lower', 'seed', 'test_bs', 'tmp', 'verbose',\n", 107 | " 'warmup', 'weight_decay', 'workers', 'world_size', 'time',\n", 108 | " 'base_lr_linear_scale', 'color_dist_s', 'cooldown', 'cooldown_after',\n", 109 | " 'momentum', 'multiplier', 'norm_multiplier', 'projection', 'status',\n", 110 | " 'sync_bn', 'temperature', 'ckpt_iter', 'encode_layer', 'model_id',\n", 111 | " 'use_all_classes'],\n", 112 | " dtype='object')" 113 | ] 114 | }, 115 | "execution_count": 33, 116 | "metadata": { 117 | "tags": [] 118 | }, 119 | "output_type": "execute_result" 120 | } 121 | ], 122 | "source": [ 123 | "index.columns" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 35, 129 | "metadata": { 130 | "colab": { 131 | "base_uri": "https://localhost:8080/", 132 | "height": 414 133 | }, 134 | "id": "Ps-fforc-r_N", 135 | "outputId": "59109585-78d3-45fa-f599-e468b8416b7f" 136 | }, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/html": [ 141 | "
\n", 142 | "\n", 155 | "\n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | "
archaugaugmentationbatch_sizeckptconfig_filedatadistdist_addressencoder_ckpteval_freqfinetunegpuiterslog_freqlrlr_schedulenamenode_ranknumber_of_processesoptprecompute_emb_bsproblemrootsave_freqscale_lowerseedtest_bstmpverbosewarmupweight_decayworkersworld_sizetimebase_lr_linear_scalecolor_dist_scooldowncooldown_aftermomentummultipliernorm_multiplierprojectionstatussync_bntemperatureckpt_iterencode_layermodel_iduse_all_classes
id
2ResNet50TrueNaN512cifar_params.yamlcifarddpcn-012:8881NaN4800NaN048000484.0warmup-anneal02larsNaNsim-clrlogs/exman-train.py/runs/00000248000.08-1NaNFalseTrue0.010.000001222020-11-18 21:32:53False0.5linear-1.00.92.0FalseMLPv2failTrue0.5NaNNaNNaNNaN
4ResNet50TrueNaN128imagenet_params_epochs200_bs2k.yamlimagenetddpcn-010:8881NaN12510NaN01251001002.4warmup-annealimagenet-reproduce016larsNaNsim-clrlogs/exman-train.py/runs/000004125100.08-1NaNFalseTrue0.100.000001842020-11-23 16:44:02False1.0linear-1.00.92.0FalseMLPv2failTrue0.1NaNNaNNaNNaN
10linearTrueRandomResizedCrop4096configs/imagenet_eval_params.yamlimagenetdp/home/aashukha/simclr-pytorch/logs/exman-train...100False02808010001.6lineareval_imagenet_newmodels01sgd-1.0evallogs/exman-train.py/runs/000010100000000000000000.08-14096.0FalseFalse0.000.0000002012020-11-26 00:44:34FalseNaNlinear-1.00.9NaNNaNNaNfailNaNNaN-1.0h-1.0False
12ResNet50TrueNaN128configs/imagenet_params_epochs600_bs2k.yamlimagenetddpcn-010:8881NaN12510NaN03753001002.4warmup-annealimagenet-reproduce016larsNaNsim-clrlogs/exman-train.py/runs/000012125100.08-1NaNFalseTrue0.100.000001842020-11-26 01:17:43False1.0linear-1.00.92.0FalseMLPv2NaNTrue0.1NaNNaNNaNNaN
23linearTrueRandomCrop1024configs/cifar_eval.yamlcifardp127.0.0.1:1234logs/exman-train.py/runs/000002/checkpoint.pth...1000False0800001000.1linear01sgd-1.0evallogs/exman-train.py/runs/0000231000000000.08-11024.0FalseFalse0.000.000100212020-11-26 16:20:18NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", 532 | "
" 533 | ], 534 | "text/plain": [ 535 | " arch aug augmentation ... encode_layer model_id use_all_classes\n", 536 | "id ... \n", 537 | "2 ResNet50 True NaN ... NaN NaN NaN\n", 538 | "4 ResNet50 True NaN ... NaN NaN NaN\n", 539 | "10 linear True RandomResizedCrop ... h -1.0 False\n", 540 | "12 ResNet50 True NaN ... NaN NaN NaN\n", 541 | "23 linear True RandomCrop ... NaN NaN NaN\n", 542 | "\n", 543 | "[5 rows x 50 columns]" 544 | ] 545 | }, 546 | "execution_count": 35, 547 | "metadata": { 548 | "tags": [] 549 | }, 550 | "output_type": "execute_result" 551 | } 552 | ], 553 | "source": [ 554 | "index.head()" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 39, 560 | "metadata": { 561 | "colab": { 562 | "base_uri": "https://localhost:8080/" 563 | }, 564 | "id": "Wy8Mwu_z_EBF", 565 | "outputId": "a190b8d4-5d30-4c35-ee34-0e8a7da354e2" 566 | }, 567 | "outputs": [ 568 | { 569 | "data": { 570 | "text/plain": [ 571 | "{'arch': 'linear',\n", 572 | " 'aug': True,\n", 573 | " 'augmentation': 'RandomResizedCrop',\n", 574 | " 'base_lr_linear_scale': nan,\n", 575 | " 'batch_size': 4096,\n", 576 | " 'ckpt': '',\n", 577 | " 'ckpt_iter': nan,\n", 578 | " 'color_dist_s': nan,\n", 579 | " 'config_file': 'configs/imagenet_eval_params.yaml',\n", 580 | " 'cooldown': nan,\n", 581 | " 'cooldown_after': nan,\n", 582 | " 'data': 'imagenet',\n", 583 | " 'dist': 'dp',\n", 584 | " 'dist_address': '',\n", 585 | " 'encode_layer': nan,\n", 586 | " 'encoder_ckpt': '/home/aashukha/simclr-pytorch/logs/exman-train.py/runs/000012/checkpoint.pth.tar',\n", 587 | " 'eval_freq': 100,\n", 588 | " 'finetune': False,\n", 589 | " 'gpu': 0,\n", 590 | " 'iters': 28080,\n", 591 | " 'log_freq': 1000,\n", 592 | " 'lr': 1.6,\n", 593 | " 'lr_schedule': 'linear',\n", 594 | " 'model_id': nan,\n", 595 | " 'momentum': nan,\n", 596 | " 'multiplier': nan,\n", 597 | " 'name': 'eval_imagenet_newmodels',\n", 598 | " 'node_rank': 0,\n", 599 | " 'norm_multiplier': nan,\n", 600 | " 'number_of_processes': 1,\n", 601 | " 'opt': 'sgd',\n", 602 | " 'precompute_emb_bs': -1.0,\n", 603 | " 'problem': 'eval',\n", 604 | " 'projection': nan,\n", 605 | " 'root': 'logs/exman-train.py/runs/000033',\n", 606 | " 'save_freq': 10000000000000000,\n", 607 | " 'scale_lower': 0.08,\n", 608 | " 'seed': -1,\n", 609 | " 'status': nan,\n", 610 | " 'sync_bn': nan,\n", 611 | " 'temperature': nan,\n", 612 | " 'test_bs': 4096.0,\n", 613 | " 'time': Timestamp('2020-12-05 14:49:17'),\n", 614 | " 'tmp': False,\n", 615 | " 'use_all_classes': nan,\n", 616 | " 'verbose': False,\n", 617 | " 'warmup': 0.0,\n", 618 | " 'weight_decay': 0.0,\n", 619 | " 'workers': 20,\n", 620 | " 'world_size': 1}" 621 | ] 622 | }, 623 | "execution_count": 39, 624 | "metadata": { 625 | "tags": [] 626 | }, 627 | "output_type": "execute_result" 628 | } 629 | ], 630 | "source": [ 631 | "dict(index.loc[33])" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 34, 637 | "metadata": { 638 | "colab": { 639 | "base_uri": "https://localhost:8080/", 640 | "height": 107 641 | }, 642 | "id": "Mc15MLjhIQ1w", 643 | "outputId": "82065f2b-76c9-4b85-fef4-b42621fa8760" 644 | }, 645 | "outputs": [ 646 | { 647 | "data": { 648 | "text/html": [ 649 | "
\n", 650 | "\n", 663 | "\n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | "
ttest_losstest_acctrain_losstrain_acctrain_epochlrdata_timeit_time
279280001.2799210.678401.1585020.71676089.4568690.004615316.1050782025.544294
280280801.2799010.678421.1611200.71624889.7124600.00005712.917784141.360860
\n", 705 | "
" 706 | ], 707 | "text/plain": [ 708 | " t test_loss test_acc ... lr data_time it_time\n", 709 | "279 28000 1.279921 0.67840 ... 0.004615 316.105078 2025.544294\n", 710 | "280 28080 1.279901 0.67842 ... 0.000057 12.917784 141.360860\n", 711 | "\n", 712 | "[2 rows x 9 columns]" 713 | ] 714 | }, 715 | "execution_count": 34, 716 | "metadata": { 717 | "tags": [] 718 | }, 719 | "output_type": "execute_result" 720 | } 721 | ], 722 | "source": [ 723 | "logs = pd.read_csv(index.loc[33].root + '/logs.csv')\n", 724 | "logs.tail(2)" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "metadata": { 731 | "id": "w-QviLaJ9P-u" 732 | }, 733 | "outputs": [], 734 | "source": [] 735 | } 736 | ], 737 | "metadata": { 738 | "colab": { 739 | "collapsed_sections": [], 740 | "name": "read_logs.ipynb", 741 | "provenance": [] 742 | }, 743 | "kernelspec": { 744 | "display_name": "Python 3", 745 | "language": "python", 746 | "name": "python3" 747 | }, 748 | "language_info": { 749 | "codemirror_mode": { 750 | "name": "ipython", 751 | "version": 3 752 | }, 753 | "file_extension": ".py", 754 | "mimetype": "text/x-python", 755 | "name": "python", 756 | "nbconvert_exporter": "python", 757 | "pygments_lexer": "ipython3", 758 | "version": "3.7.6" 759 | } 760 | }, 761 | "nbformat": 4, 762 | "nbformat_minor": 1 763 | } 764 | -------------------------------------------------------------------------------- /configs/cifar_eval.yaml: -------------------------------------------------------------------------------- 1 | arch: linear 2 | aug: true 3 | augmentation: RandomCrop 4 | batch_size: 1024 5 | ckpt: '' 6 | ckpt_iter: -1 7 | config_file: null 8 | data: cifar 9 | dist: dp 10 | dist_address: 127.0.0.1:1234 11 | encode_layer: h 12 | encoder_ckpt: '' 13 | eval_freq: 1000 14 | finetune: false 15 | iters: 80000 16 | log_freq: 100 17 | lr: 0.1 18 | lr_schedule: linear 19 | model_id: -1 20 | n_augs_test: 50 21 | n_augs_train: 10 22 | name: '' 23 | node_rank: 0 24 | opt: sgd 25 | precompute_emb_bs: -1 26 | problem: eval 27 | root: '' 28 | save_freq: 100000000 29 | scale_lower: 0.08 30 | seed: -1 31 | status: fail 32 | test_bs: 1024 33 | tmp: false 34 | warmup: 0.0 35 | weight_decay: 0.0001 36 | workers: 2 37 | world_size: 1 38 | 39 | time: '2020-07-18T00:55:23' 40 | id: 3549 -------------------------------------------------------------------------------- /configs/cifar_train_epochs1000_bs1024.yaml: -------------------------------------------------------------------------------- 1 | arch: ResNet50 2 | aug: true 3 | batch_size: 1024 4 | ckpt: '' 5 | color_dist_s: 0.5 6 | config_file: null 7 | data: cifar 8 | dist: ddp 9 | dist_address: '127.0.0.1:1234' 10 | eval_freq: 4800 11 | iters: 48000 12 | log_freq: 48 13 | lr: 4.0 14 | lr_schedule: warmup-anneal 15 | multiplier: 2 16 | name: 'reproduce-cifar10' 17 | node_rank: 0 18 | opt: lars 19 | problem: sim-clr 20 | root: 'none' 21 | save_freq: 4800 22 | scale_lower: 0.08 23 | seed: -1 24 | sync_bn: true 25 | temperature: 0.5 26 | tmp: false 27 | verbose: true 28 | warmup: 0.01 29 | weight_decay: 1.0e-06 30 | workers: 2 31 | world_size: 1 -------------------------------------------------------------------------------- /configs/imagenet_eval.yaml: -------------------------------------------------------------------------------- 1 | arch: linear 2 | aug: true 3 | augmentation: RandomResizedCrop 4 | batch_size: 4096 5 | ckpt: '' 6 | ckpt_iter: -1 7 | config_file: null 8 | data: imagenet 9 | dist: dp 10 | dist_address: '' 11 | encode_layer: h 12 | encoder_ckpt: '' 13 | eval_freq: 100 14 | finetune: false 15 | iters: 28080 16 | log_freq: 1000 17 | lr: 1.6 18 | lr_schedule: linear 19 | model_id: -1 20 | name: eval_imagenet_newmodels 21 | node_rank: 0 22 | opt: sgd 23 | precompute_emb_bs: -1 24 | problem: eval 25 | save_freq: 10000000000000000 26 | scale_lower: 0.08 27 | seed: -1 28 | test_bs: 4096 29 | tmp: false 30 | warmup: 0.0 31 | weight_decay: 0.0 32 | workers: 20 33 | world_size: 1 -------------------------------------------------------------------------------- /configs/imagenet_train_epochs100_bs512.yaml: -------------------------------------------------------------------------------- 1 | arch: ResNet50 2 | aug: true 3 | batch_size: 512 4 | ckpt: '' 5 | color_dist_s: 1.0 6 | config_file: '' 7 | data: imagenet 8 | dist: ddp 9 | dist_address: '127.0.0.1:1234' 10 | eval_freq: 50040 11 | gpu: 0 12 | iters: 250200 13 | log_freq: 100 14 | lr: 0.6 15 | lr_schedule: warmup-anneal 16 | multiplier: 2 17 | name: imagenet-reproduce 18 | node_rank: 0 19 | opt: lars 20 | problem: sim-clr 21 | root: '' 22 | save_freq: 12510 23 | scale_lower: 0.08 24 | seed: -1 25 | sync_bn: true 26 | temperature: 0.1 27 | tmp: false 28 | verbose: true 29 | warmup: 0.1 30 | weight_decay: 1.0e-06 31 | workers: 8 32 | world_size: 1 33 | -------------------------------------------------------------------------------- /configs/imagenet_train_epochs200_bs2k.yaml: -------------------------------------------------------------------------------- 1 | arch: ResNet50 2 | aug: true 3 | batch_size: 2048 4 | ckpt: '' 5 | color_dist_s: 1.0 6 | config_file: '' 7 | data: imagenet 8 | dist: ddp 9 | dist_address: '' 10 | eval_freq: 12510 11 | gpu: 0 12 | iters: 125100 13 | log_freq: 100 14 | lr: 2.4 15 | lr_schedule: warmup-anneal 16 | momentum: 0.9 17 | multiplier: 2 18 | name: imagenet-reproduce 19 | node_rank: 0 20 | number_of_processes: 16 21 | opt: lars 22 | problem: sim-clr 23 | root: '' 24 | save_freq: 12510 25 | scale_lower: 0.08 26 | seed: -1 27 | sync_bn: true 28 | temperature: 0.1 29 | tmp: false 30 | verbose: true 31 | warmup: 0.1 32 | weight_decay: 1.0e-06 33 | workers: 8 34 | world_size: 4 -------------------------------------------------------------------------------- /configs/imagenet_train_epochs600_bs2k.yaml: -------------------------------------------------------------------------------- 1 | arch: ResNet50 2 | aug: true 3 | batch_size: 2048 4 | ckpt: '' 5 | color_dist_s: 1.0 6 | config_file: '' 7 | data: imagenet 8 | dist: ddp 9 | dist_address: '' 10 | eval_freq: 12510 11 | gpu: 0 12 | iters: 375300 13 | log_freq: 100 14 | lr: 2.4 15 | lr_schedule: warmup-anneal 16 | multiplier: 2 17 | name: imagenet-reproduce 18 | node_rank: 0 19 | opt: lars 20 | problem: sim-clr 21 | root: '' 22 | save_freq: 12510 23 | scale_lower: 0.08 24 | seed: -1 25 | sync_bn: true 26 | temperature: 0.1 27 | tmp: false 28 | verbose: true 29 | warmup: 0.1 30 | weight_decay: 1.0e-06 31 | workers: 8 32 | world_size: 4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: simclr_pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - backcall=0.2.0=py_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.10.14=0 10 | - certifi=2020.12.5=py36h06a4308_0 11 | - configargparse=1.2.3=py_0 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - dataclasses=0.7=py36_0 14 | - decorator=4.4.2=py_0 15 | - filelock=3.0.12=py_0 16 | - freetype=2.10.4=h5ab3b9f_0 17 | - intel-openmp=2020.2=254 18 | - ipython=7.16.1=py36h5ca1d4c_0 19 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 20 | - jedi=0.17.2=py36h06a4308_1 21 | - joblib=0.17.0=py_0 22 | - jpeg=9b=h024ee3a_2 23 | - lcms2=2.11=h396b838_0 24 | - ld_impl_linux-64=2.33.1=h53a641e_7 25 | - libedit=3.1.20191231=h14c3975_1 26 | - libffi=3.3=he6710b0_2 27 | - libgcc-ng=9.1.0=hdf63c60_0 28 | - libgfortran-ng=7.3.0=hdf63c60_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.1.0=hdf63c60_0 31 | - libtiff=4.1.0=h2733197_1 32 | - libuv=1.40.0=h7b6447c_0 33 | - lz4-c=1.9.2=heb0550a_3 34 | - mkl=2020.2=256 35 | - mkl-service=2.3.0=py36he8ac12f_0 36 | - mkl_fft=1.2.0=py36h23d657b_0 37 | - mkl_random=1.1.1=py36h0573a6f_0 38 | - ncurses=6.2=he6710b0_1 39 | - ninja=1.10.2=py36hff7bd54_0 40 | - numpy=1.19.2=py36h54aff64_0 41 | - numpy-base=1.19.2=py36hfa32c7d_0 42 | - olefile=0.46=py36_0 43 | - openssl=1.1.1h=h7b6447c_0 44 | - pandas=1.1.3=py36he6710b0_0 45 | - parso=0.7.0=py_0 46 | - pexpect=4.8.0=pyhd3eb1b0_3 47 | - pickleshare=0.7.5=pyhd3eb1b0_1003 48 | - pillow=8.0.1=py36he98fc37_0 49 | - pip=20.3.1=py36h06a4308_0 50 | - prompt-toolkit=3.0.8=py_0 51 | - ptyprocess=0.6.0=pyhd3eb1b0_2 52 | - pygments=2.7.3=pyhd3eb1b0_0 53 | - python=3.6.12=hcff3b4d_2 54 | - python-dateutil=2.8.1=py_0 55 | - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0 56 | - pytz=2020.4=pyhd3eb1b0_0 57 | - pyyaml=5.3.1=py36h7b6447c_1 58 | - readline=8.0=h7b6447c_0 59 | - scikit-learn=0.23.2=py36h0573a6f_0 60 | - scipy=1.5.2=py36h0b6359f_0 61 | - setuptools=51.0.0=py36h06a4308_2 62 | - six=1.15.0=py36h06a4308_0 63 | - sqlite=3.33.0=h62c20be_0 64 | - tabulate=0.8.7=py36_0 65 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 66 | - tk=8.6.10=hbc83047_0 67 | - torchaudio=0.7.0=py36 68 | - torchvision=0.8.1=py36_cu101 69 | - tqdm=4.54.1=pyhd3eb1b0_0 70 | - traitlets=4.3.3=py36_0 71 | - typing_extensions=3.7.4.3=py_0 72 | - wcwidth=0.2.5=py_0 73 | - wheel=0.36.1=pyhd3eb1b0_0 74 | - xz=5.2.5=h7b6447c_0 75 | - yaml=0.2.5=h7b6447c_0 76 | - zlib=1.2.11=h7b6447c_3 77 | - zstd=1.4.5=h9ceee32_0 78 | - pip: 79 | - diffdist==0.1 80 | - strconv==0.4.2 81 | prefix: /home/aashukha/miniconda3/envs/simclr_pytorch 82 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import encoder 2 | from models import losses 3 | from models import resnet 4 | from models import ssl 5 | 6 | REGISTERED_MODELS = { 7 | 'sim-clr': ssl.SimCLR, 8 | 'eval': ssl.SSLEval, 9 | 'semi-supervised-eval': ssl.SemiSupervisedEval, 10 | } 11 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import models 4 | from collections import OrderedDict 5 | from argparse import Namespace 6 | import yaml 7 | import os 8 | 9 | 10 | class BatchNorm1dNoBias(nn.BatchNorm1d): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.bias.requires_grad = False 14 | 15 | 16 | class EncodeProject(nn.Module): 17 | def __init__(self, hparams): 18 | super().__init__() 19 | 20 | if hparams.arch == 'ResNet50': 21 | cifar_head = (hparams.data == 'cifar') 22 | self.convnet = models.resnet.ResNet50(cifar_head=cifar_head, hparams=hparams) 23 | self.encoder_dim = 2048 24 | elif hparams.arch == 'resnet18': 25 | self.convnet = models.resnet.ResNet18(cifar_head=(hparams.data == 'cifar')) 26 | self.encoder_dim = 512 27 | else: 28 | raise NotImplementedError 29 | 30 | num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad) 31 | 32 | print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters') 33 | 34 | self.proj_dim = 128 35 | projection_layers = [ 36 | ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)), 37 | ('bn1', nn.BatchNorm1d(self.encoder_dim)), 38 | ('relu1', nn.ReLU()), 39 | ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)), 40 | ('bn2', BatchNorm1dNoBias(128)), 41 | ] 42 | 43 | self.projection = nn.Sequential(OrderedDict(projection_layers)) 44 | 45 | def forward(self, x, out='z'): 46 | h = self.convnet(x) 47 | if out == 'h': 48 | return h 49 | return self.projection(h) 50 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import diffdist 6 | import torch.distributed as dist 7 | 8 | 9 | def gather(z): 10 | gather_z = [torch.zeros_like(z) for _ in range(torch.distributed.get_world_size())] 11 | gather_z = diffdist.functional.all_gather(gather_z, z) 12 | gather_z = torch.cat(gather_z) 13 | 14 | return gather_z 15 | 16 | 17 | def accuracy(logits, labels, k): 18 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] 19 | labels = torch.sort(labels, 1)[0] 20 | acc = (topk == labels).all(1).float() 21 | return acc 22 | 23 | 24 | def mean_cumulative_gain(logits, labels, k): 25 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] 26 | labels = torch.sort(labels, 1)[0] 27 | mcg = (topk == labels).float().mean(1) 28 | return mcg 29 | 30 | 31 | def mean_average_precision(logits, labels, k): 32 | # TODO: not the fastest solution but looks fine 33 | argsort = torch.argsort(logits, dim=1, descending=True) 34 | labels_to_sorted_idx = torch.sort(torch.gather(torch.argsort(argsort, dim=1), 1, labels), dim=1)[0] + 1 35 | precision = (1 + torch.arange(k, device=logits.device).float()) / labels_to_sorted_idx 36 | return precision.sum(1) / k 37 | 38 | 39 | class NTXent(nn.Module): 40 | """ 41 | Contrastive loss with distributed data parallel support 42 | """ 43 | LARGE_NUMBER = 1e9 44 | 45 | def __init__(self, tau=1., gpu=None, multiplier=2, distributed=False): 46 | super().__init__() 47 | self.tau = tau 48 | self.multiplier = multiplier 49 | self.distributed = distributed 50 | self.norm = 1. 51 | 52 | def forward(self, z, get_map=False): 53 | n = z.shape[0] 54 | assert n % self.multiplier == 0 55 | 56 | z = F.normalize(z, p=2, dim=1) / np.sqrt(self.tau) 57 | 58 | if self.distributed: 59 | z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] 60 | # all_gather fills the list as [, , ...] 61 | # TODO: try to rewrite it with pytorch official tools 62 | z_list = diffdist.functional.all_gather(z_list, z) 63 | # split it into [, , ..., , , ...] 64 | z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] 65 | # sort it to [, , ...] that simply means [, , ...] as expected below 66 | z_sorted = [] 67 | for m in range(self.multiplier): 68 | for i in range(dist.get_world_size()): 69 | z_sorted.append(z_list[i * self.multiplier + m]) 70 | z = torch.cat(z_sorted, dim=0) 71 | n = z.shape[0] 72 | 73 | logits = z @ z.t() 74 | logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER 75 | 76 | logprob = F.log_softmax(logits, dim=1) 77 | 78 | # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) 79 | m = self.multiplier 80 | labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n//m, n)) % n 81 | # remove labels pointet to itself, i.e. (i, i) 82 | labels = labels.reshape(n, m)[:, 1:].reshape(-1) 83 | 84 | # TODO: maybe different terms for each process should only be computed here... 85 | loss = -logprob[np.repeat(np.arange(n), m-1), labels].sum() / n / (m-1) / self.norm 86 | 87 | # zero the probability of identical pairs 88 | pred = logprob.data.clone() 89 | pred[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER 90 | acc = accuracy(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1) 91 | 92 | if get_map: 93 | _map = mean_average_precision(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1) 94 | return loss, acc, _map 95 | 96 | return loss, acc 97 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch.nn as nn 10 | import torchvision.models as models 11 | import torch 12 | 13 | 14 | class Flatten(nn.Module): 15 | def __init__(self, dim=-1): 16 | super(Flatten, self).__init__() 17 | self.dim = dim 18 | 19 | def forward(self, feat): 20 | return torch.flatten(feat, start_dim=self.dim) 21 | 22 | 23 | class ResNetEncoder(models.resnet.ResNet): 24 | """Wrapper for TorchVison ResNet Model 25 | This was needed to remove the final FC Layer from the ResNet Model""" 26 | def __init__(self, block, layers, cifar_head=False, hparams=None): 27 | super().__init__(block, layers) 28 | self.cifar_head = cifar_head 29 | if cifar_head: 30 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 31 | self.bn1 = self._norm_layer(64) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.hparams = hparams 34 | 35 | print('** Using avgpool **') 36 | 37 | def forward(self, x): 38 | x = self.conv1(x) 39 | x = self.bn1(x) 40 | x = self.relu(x) 41 | if not self.cifar_head: 42 | x = self.maxpool(x) 43 | 44 | x = self.layer1(x) 45 | x = self.layer2(x) 46 | x = self.layer3(x) 47 | x = self.layer4(x) 48 | 49 | x = self.avgpool(x) 50 | x = torch.flatten(x, 1) 51 | 52 | return x 53 | 54 | class ResNet18(ResNetEncoder): 55 | def __init__(self, cifar_head=True): 56 | super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], cifar_head=cifar_head) 57 | 58 | 59 | class ResNet50(ResNetEncoder): 60 | def __init__(self, cifar_head=True, hparams=None): 61 | super().__init__(models.resnet.Bottleneck, [3, 4, 6, 3], cifar_head=cifar_head, hparams=hparams) 62 | -------------------------------------------------------------------------------- /models/ssl.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace, ArgumentParser 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torchvision import datasets 8 | import torchvision.transforms as transforms 9 | from utils import datautils 10 | import models 11 | from utils import utils 12 | import numpy as np 13 | import PIL 14 | from tqdm import tqdm 15 | import sklearn 16 | from utils.lars_optimizer import LARS 17 | import scipy 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | import torch.distributed as dist 20 | 21 | import copy 22 | 23 | class BaseSSL(nn.Module): 24 | """ 25 | Inspired by the PYTORCH LIGHTNING https://pytorch-lightning.readthedocs.io/en/latest/ 26 | Similar but lighter and customized version. 27 | """ 28 | DATA_ROOT = os.environ.get('DATA_ROOT', os.path.dirname(os.path.abspath(__file__)) + '/data') 29 | IMAGENET_PATH = os.environ.get('IMAGENET_PATH', '/home/aashukha/imagenet/raw-data/') 30 | 31 | def __init__(self, hparams): 32 | super().__init__() 33 | self.hparams = hparams 34 | if hparams.data == 'imagenet': 35 | print(f"IMAGENET_PATH = {self.IMAGENET_PATH}") 36 | 37 | def get_ckpt(self): 38 | return { 39 | 'state_dict': self.state_dict(), 40 | 'hparams': self.hparams, 41 | } 42 | 43 | @classmethod 44 | def load(cls, ckpt, device=None): 45 | parser = ArgumentParser() 46 | cls.add_model_hparams(parser) 47 | hparams = parser.parse_args([], namespace=ckpt['hparams']) 48 | 49 | res = cls(hparams, device=device) 50 | res.load_state_dict(ckpt['state_dict']) 51 | return res 52 | 53 | @classmethod 54 | def default(cls, device=None, **kwargs): 55 | parser = ArgumentParser() 56 | cls.add_model_hparams(parser) 57 | hparams = parser.parse_args([], namespace=Namespace(**kwargs)) 58 | res = cls(hparams, device=device) 59 | return res 60 | 61 | def forward(self, x): 62 | pass 63 | 64 | def transforms(self): 65 | pass 66 | 67 | def samplers(self): 68 | return None, None 69 | 70 | def prepare_data(self): 71 | train_transform, test_transform = self.transforms() 72 | # print('The following train transform is used:\n', train_transform) 73 | # print('The following test transform is used:\n', test_transform) 74 | if self.hparams.data == 'cifar': 75 | self.trainset = datasets.CIFAR10(root=self.DATA_ROOT, train=True, download=True, transform=train_transform) 76 | self.testset = datasets.CIFAR10(root=self.DATA_ROOT, train=False, download=True, transform=test_transform) 77 | elif self.hparams.data == 'imagenet': 78 | traindir = os.path.join(self.IMAGENET_PATH, 'train') 79 | valdir = os.path.join(self.IMAGENET_PATH, 'val') 80 | self.trainset = datasets.ImageFolder(traindir, transform=train_transform) 81 | self.testset = datasets.ImageFolder(valdir, transform=test_transform) 82 | else: 83 | raise NotImplementedError 84 | 85 | def dataloaders(self, iters=None): 86 | train_batch_sampler, test_batch_sampler = self.samplers() 87 | if iters is not None: 88 | train_batch_sampler = datautils.ContinousSampler( 89 | train_batch_sampler, 90 | iters 91 | ) 92 | 93 | train_loader = torch.utils.data.DataLoader( 94 | self.trainset, 95 | num_workers=self.hparams.workers, 96 | pin_memory=True, 97 | batch_sampler=train_batch_sampler, 98 | ) 99 | test_loader = torch.utils.data.DataLoader( 100 | self.testset, 101 | num_workers=self.hparams.workers, 102 | pin_memory=True, 103 | batch_sampler=test_batch_sampler, 104 | ) 105 | 106 | return train_loader, test_loader 107 | 108 | @staticmethod 109 | def add_parent_hparams(add_model_hparams): 110 | def foo(cls, parser): 111 | for base in cls.__bases__: 112 | base.add_model_hparams(parser) 113 | add_model_hparams(cls, parser) 114 | return foo 115 | 116 | @classmethod 117 | def add_model_hparams(cls, parser): 118 | parser.add_argument('--data', help='Dataset to use', default='cifar') 119 | parser.add_argument('--arch', default='ResNet50', help='Encoder architecture') 120 | parser.add_argument('--batch_size', default=256, type=int, help='The number of unique images in the batch') 121 | parser.add_argument('--aug', default=True, type=bool, help='Applies random augmentations if True') 122 | 123 | 124 | class SimCLR(BaseSSL): 125 | @classmethod 126 | @BaseSSL.add_parent_hparams 127 | def add_model_hparams(cls, parser): 128 | # loss params 129 | parser.add_argument('--temperature', default=0.1, type=float, help='Temperature in the NTXent loss') 130 | # data params 131 | parser.add_argument('--multiplier', default=2, type=int) 132 | parser.add_argument('--color_dist_s', default=1., type=float, help='Color distortion strength') 133 | parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop') 134 | # ddp 135 | parser.add_argument('--sync_bn', default=True, type=bool, 136 | help='Syncronises BatchNorm layers between all processes if True' 137 | ) 138 | 139 | def __init__(self, hparams, device=None): 140 | super().__init__(hparams) 141 | 142 | self.hparams.dist = getattr(self.hparams, 'dist', 'dp') 143 | 144 | model = models.encoder.EncodeProject(hparams) 145 | self.reset_parameters() 146 | if device is not None: 147 | model = model.to(device) 148 | if self.hparams.dist == 'ddp': 149 | if self.hparams.sync_bn: 150 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 151 | dist.barrier() 152 | if device is not None: 153 | model = model.to(device) 154 | self.model = DDP(model, [hparams.gpu], find_unused_parameters=True) 155 | elif self.hparams.dist == 'dp': 156 | self.model = nn.DataParallel(model) 157 | else: 158 | raise NotImplementedError 159 | 160 | self.criterion = models.losses.NTXent( 161 | tau=hparams.temperature, 162 | multiplier=hparams.multiplier, 163 | distributed=(hparams.dist == 'ddp'), 164 | ) 165 | 166 | def reset_parameters(self): 167 | def conv2d_weight_truncated_normal_init(p): 168 | fan_in = p.shape[1] 169 | stddev = np.sqrt(1. / fan_in) / .87962566103423978 170 | r = scipy.stats.truncnorm.rvs(-2, 2, loc=0, scale=1., size=p.shape) 171 | r = stddev * r 172 | with torch.no_grad(): 173 | p.copy_(torch.FloatTensor(r)) 174 | 175 | def linear_normal_init(p): 176 | with torch.no_grad(): 177 | p.normal_(std=0.01) 178 | 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | conv2d_weight_truncated_normal_init(m.weight) 182 | elif isinstance(m, nn.Linear): 183 | linear_normal_init(m.weight) 184 | 185 | def step(self, batch): 186 | x, _ = batch 187 | z = self.model(x) 188 | loss, acc = self.criterion(z) 189 | return { 190 | 'loss': loss, 191 | 'contrast_acc': acc, 192 | } 193 | 194 | def encode(self, x): 195 | return self.model(x, out='h') 196 | 197 | def forward(self, *args, **kwargs): 198 | return self.model(*args, **kwargs) 199 | 200 | def train_step(self, batch, it=None): 201 | logs = self.step(batch) 202 | 203 | if self.hparams.dist == 'ddp': 204 | self.trainsampler.set_epoch(it) 205 | if it is not None: 206 | logs['epoch'] = it / len(self.batch_trainsampler) 207 | 208 | return logs 209 | 210 | def test_step(self, batch): 211 | return self.step(batch) 212 | 213 | def samplers(self): 214 | if self.hparams.dist == 'ddp': 215 | # trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset, num_replicas=1, rank=0) 216 | trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset) 217 | print(f'Process {dist.get_rank()}: {len(trainsampler)} training samples per epoch') 218 | testsampler = torch.utils.data.distributed.DistributedSampler(self.testset) 219 | print(f'Process {dist.get_rank()}: {len(testsampler)} test samples') 220 | else: 221 | trainsampler = torch.utils.data.sampler.RandomSampler(self.trainset) 222 | testsampler = torch.utils.data.sampler.RandomSampler(self.testset) 223 | 224 | batch_sampler = datautils.MultiplyBatchSampler 225 | # batch_sampler.MULTILPLIER = self.hparams.multiplier if self.hparams.dist == 'dp' else 1 226 | batch_sampler.MULTILPLIER = self.hparams.multiplier 227 | 228 | # need for DDP to sync samplers between processes 229 | self.trainsampler = trainsampler 230 | self.batch_trainsampler = batch_sampler(trainsampler, self.hparams.batch_size, drop_last=True) 231 | 232 | return ( 233 | self.batch_trainsampler, 234 | batch_sampler(testsampler, self.hparams.batch_size, drop_last=True) 235 | ) 236 | 237 | def transforms(self): 238 | if self.hparams.data == 'cifar': 239 | train_transform = transforms.Compose([ 240 | transforms.RandomResizedCrop( 241 | 32, 242 | scale=(self.hparams.scale_lower, 1.0), 243 | interpolation=PIL.Image.BICUBIC, 244 | ), 245 | transforms.RandomHorizontalFlip(), 246 | datautils.get_color_distortion(s=self.hparams.color_dist_s), 247 | transforms.ToTensor(), 248 | datautils.Clip(), 249 | ]) 250 | test_transform = train_transform 251 | 252 | elif self.hparams.data == 'imagenet': 253 | from utils.datautils import GaussianBlur 254 | 255 | im_size = 224 256 | train_transform = transforms.Compose([ 257 | transforms.RandomResizedCrop( 258 | im_size, 259 | scale=(self.hparams.scale_lower, 1.0), 260 | interpolation=PIL.Image.BICUBIC, 261 | ), 262 | transforms.RandomHorizontalFlip(0.5), 263 | datautils.get_color_distortion(s=self.hparams.color_dist_s), 264 | transforms.ToTensor(), 265 | GaussianBlur(im_size // 10, 0.5), 266 | datautils.Clip(), 267 | ]) 268 | test_transform = train_transform 269 | return train_transform, test_transform 270 | 271 | def get_ckpt(self): 272 | return { 273 | 'state_dict': self.model.module.state_dict(), 274 | 'hparams': self.hparams, 275 | } 276 | 277 | def load_state_dict(self, state): 278 | k = next(iter(state.keys())) 279 | if k.startswith('model.module'): 280 | super().load_state_dict(state) 281 | else: 282 | self.model.module.load_state_dict(state) 283 | 284 | 285 | class SSLEval(BaseSSL): 286 | @classmethod 287 | @BaseSSL.add_parent_hparams 288 | def add_model_hparams(cls, parser): 289 | parser.add_argument('--test_bs', default=256, type=int) 290 | parser.add_argument('--encoder_ckpt', default='', help='Path to the encoder checkpoint') 291 | parser.add_argument('--precompute_emb_bs', default=-1, type=int, 292 | help='If it\'s not equal to -1 embeddings are precomputed and fixed before training with batch size equal to this.' 293 | ) 294 | parser.add_argument('--finetune', default=False, type=bool, help='Finetunes the encoder if True') 295 | parser.add_argument('--augmentation', default='RandomResizedCrop', help='') 296 | parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop') 297 | 298 | def __init__(self, hparams, device=None): 299 | super().__init__(hparams) 300 | 301 | self.hparams.dist = getattr(self.hparams, 'dist', 'dp') 302 | 303 | if hparams.encoder_ckpt != '': 304 | ckpt = torch.load(hparams.encoder_ckpt, map_location=device) 305 | if getattr(ckpt['hparams'], 'dist', 'dp') == 'ddp': 306 | ckpt['hparams'].dist = 'dp' 307 | if self.hparams.dist == 'ddp': 308 | ckpt['hparams'].dist = 'gpu:%d' % hparams.gpu 309 | 310 | self.encoder = models.REGISTERED_MODELS[ckpt['hparams'].problem].load(ckpt, device=device) 311 | else: 312 | print('===> Random encoder is used!!!') 313 | self.encoder = SimCLR.default(device=device) 314 | self.encoder.to(device) 315 | 316 | if not hparams.finetune: 317 | for p in self.encoder.parameters(): 318 | p.requires_grad = False 319 | elif hparams.dist == 'ddp': 320 | raise NotImplementedError 321 | 322 | self.encoder.eval() 323 | if hparams.data == 'cifar': 324 | hdim = self.encode(torch.ones(10, 3, 32, 32).to(device)).shape[1] 325 | n_classes = 10 326 | elif hparams.data == 'imagenet': 327 | hdim = self.encode(torch.ones(10, 3, 224, 224).to(device)).shape[1] 328 | n_classes = 1000 329 | 330 | if hparams.arch == 'linear': 331 | model = nn.Linear(hdim, n_classes).to(device) 332 | model.weight.data.zero_() 333 | model.bias.data.zero_() 334 | self.model = model 335 | else: 336 | raise NotImplementedError 337 | 338 | if hparams.dist == 'ddp': 339 | self.model = DDP(model, [hparams.gpu]) 340 | 341 | def encode(self, x): 342 | return self.encoder.model(x, out='h') 343 | 344 | def step(self, batch): 345 | if self.hparams.problem == 'eval' and self.hparams.data == 'imagenet': 346 | batch[0] = batch[0] / 255. 347 | h, y = batch 348 | if self.hparams.precompute_emb_bs == -1: 349 | h = self.encode(h) 350 | p = self.model(h) 351 | loss = F.cross_entropy(p, y) 352 | acc = (p.argmax(1) == y).float() 353 | return { 354 | 'loss': loss, 355 | 'acc': acc, 356 | } 357 | 358 | def forward(self, *args, **kwargs): 359 | return self.model(*args, **kwargs) 360 | 361 | def train_step(self, batch, it=None): 362 | logs = self.step(batch) 363 | if it is not None: 364 | iters_per_epoch = len(self.trainset) / self.hparams.batch_size 365 | iters_per_epoch = max(1, int(np.around(iters_per_epoch))) 366 | logs['epoch'] = it / iters_per_epoch 367 | if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1: 368 | self.object_trainsampler.set_epoch(it) 369 | 370 | return logs 371 | 372 | def test_step(self, batch): 373 | logs = self.step(batch) 374 | if self.hparams.dist == 'ddp': 375 | utils.gather_metrics(logs) 376 | return logs 377 | 378 | def prepare_data(self): 379 | super().prepare_data() 380 | 381 | def create_emb_dataset(dataset): 382 | embs, labels = [], [] 383 | loader = torch.utils.data.DataLoader( 384 | dataset, 385 | num_workers=self.hparams.workers, 386 | pin_memory=True, 387 | batch_size=self.hparams.precompute_emb_bs, 388 | shuffle=False, 389 | ) 390 | for x, y in tqdm(loader): 391 | if self.hparams.data == 'imagenet': 392 | x = x.to(torch.device('cuda')) 393 | x = x / 255. 394 | e = self.encode(x) 395 | embs.append(utils.tonp(e)) 396 | labels.append(utils.tonp(y)) 397 | embs, labels = np.concatenate(embs), np.concatenate(labels) 398 | dataset = torch.utils.data.TensorDataset(torch.FloatTensor(embs), torch.LongTensor(labels)) 399 | return dataset 400 | 401 | if self.hparams.precompute_emb_bs != -1: 402 | print('===> Precompute embeddings:') 403 | assert not self.hparams.aug 404 | with torch.no_grad(): 405 | self.encoder.eval() 406 | self.testset = create_emb_dataset(self.testset) 407 | self.trainset = create_emb_dataset(self.trainset) 408 | 409 | print(f'Train size: {len(self.trainset)}') 410 | print(f'Test size: {len(self.testset)}') 411 | 412 | def dataloaders(self, iters=None): 413 | if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1: 414 | trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset) 415 | testsampler = torch.utils.data.distributed.DistributedSampler(self.testset, shuffle=False) 416 | else: 417 | trainsampler = torch.utils.data.RandomSampler(self.trainset) 418 | testsampler = torch.utils.data.SequentialSampler(self.testset) 419 | 420 | self.object_trainsampler = trainsampler 421 | trainsampler = torch.utils.data.BatchSampler( 422 | self.object_trainsampler, 423 | batch_size=self.hparams.batch_size, drop_last=False, 424 | ) 425 | if iters is not None: 426 | trainsampler = datautils.ContinousSampler(trainsampler, iters) 427 | 428 | train_loader = torch.utils.data.DataLoader( 429 | self.trainset, 430 | num_workers=self.hparams.workers, 431 | pin_memory=True, 432 | batch_sampler=trainsampler, 433 | ) 434 | test_loader = torch.utils.data.DataLoader( 435 | self.testset, 436 | num_workers=self.hparams.workers, 437 | pin_memory=True, 438 | sampler=testsampler, 439 | batch_size=self.hparams.test_bs, 440 | ) 441 | return train_loader, test_loader 442 | 443 | def transforms(self): 444 | if self.hparams.data == 'cifar': 445 | trs = [] 446 | if 'RandomResizedCrop' in self.hparams.augmentation: 447 | trs.append( 448 | transforms.RandomResizedCrop( 449 | 32, 450 | scale=(self.hparams.scale_lower, 1.0), 451 | interpolation=PIL.Image.BICUBIC, 452 | ) 453 | ) 454 | if 'RandomCrop' in self.hparams.augmentation: 455 | trs.append(transforms.RandomCrop(32, padding=4, padding_mode='reflect')) 456 | if 'color_distortion' in self.hparams.augmentation: 457 | trs.append(datautils.get_color_distortion(self.encoder.hparams.color_dist_s)) 458 | 459 | train_transform = transforms.Compose(trs + [ 460 | transforms.RandomHorizontalFlip(), 461 | transforms.ToTensor(), 462 | datautils.Clip(), 463 | ]) 464 | test_transform = transforms.Compose([ 465 | transforms.ToTensor(), 466 | ]) 467 | elif self.hparams.data == 'imagenet': 468 | train_transform = transforms.Compose([ 469 | transforms.RandomResizedCrop( 470 | 224, 471 | scale=(self.hparams.scale_lower, 1.0), 472 | interpolation=PIL.Image.BICUBIC, 473 | ), 474 | transforms.RandomHorizontalFlip(), 475 | transforms.ToTensor(), 476 | lambda x: (255*x).byte(), 477 | ]) 478 | test_transform = transforms.Compose([ 479 | datautils.CenterCropAndResize(proportion=0.875, size=224), 480 | transforms.ToTensor(), 481 | lambda x: (255 * x).byte(), 482 | ]) 483 | return train_transform if self.hparams.aug else test_transform, test_transform 484 | 485 | def train(self, mode=True): 486 | if self.hparams.finetune: 487 | super().train(mode) 488 | else: 489 | self.model.train(mode) 490 | 491 | def get_ckpt(self): 492 | return { 493 | 'state_dict': self.state_dict() if self.hparams.finetune else self.model.state_dict(), 494 | 'hparams': self.hparams, 495 | } 496 | 497 | def load_state_dict(self, state): 498 | if self.hparams.finetune: 499 | super().load_state_dict(state) 500 | else: 501 | if hasattr(self.model, 'module'): 502 | self.model.module.load_state_dict(state) 503 | else: 504 | self.model.load_state_dict(state) 505 | 506 | class SemiSupervisedEval(SSLEval): 507 | @classmethod 508 | @BaseSSL.add_parent_hparams 509 | def add_model_hparams(cls, parser): 510 | parser.add_argument('--train_size', default=-1, type=int) 511 | parser.add_argument('--data_split_seed', default=42, type=int) 512 | parser.add_argument('--n_augs_train', default=-1, type=int) 513 | parser.add_argument('--n_augs_test', default=-1, type=int) 514 | parser.add_argument('--acc_on_unlabeled', default=False, type=bool) 515 | 516 | def prepare_data(self): 517 | super(SSLEval, self).prepare_data() 518 | 519 | if len(self.trainset) != self.hparams.train_size: 520 | idxs, unlabeled_idxs = sklearn.model_selection.train_test_split( 521 | np.arange(len(self.trainset)), 522 | train_size=self.hparams.train_size, 523 | random_state=self.hparams.data_split_seed, 524 | ) 525 | if self.hparams.data == 'cifar' or self.hparams.data == 'cifar100': 526 | if self.hparams.acc_on_unlabeled: 527 | self.trainset_unlabeled = copy.deepcopy(self.trainset) 528 | self.trainset_unlabeled.data = self.trainset.data[unlabeled_idxs] 529 | self.trainset_unlabeled.targets = np.array(self.trainset.targets)[unlabeled_idxs] 530 | print(f'Test size (0): {len(self.testset)}') 531 | print(f'Unlabeled train size (1): {len(self.trainset_unlabeled)}') 532 | 533 | self.trainset.data = self.trainset.data[idxs] 534 | self.trainset.targets = np.array(self.trainset.targets)[idxs] 535 | 536 | print('Training dataset size:', len(self.trainset)) 537 | else: 538 | assert not self.hparams.acc_on_unlabeled 539 | if isinstance(self.trainset, torch.utils.data.TensorDataset): 540 | self.trainset.tensors = [t[idxs] for t in self.trainset.tensors] 541 | else: 542 | self.trainset.samples = [self.trainset.samples[i] for i in idxs] 543 | 544 | print('Training dataset size:', len(self.trainset)) 545 | 546 | self.encoder.eval() 547 | with torch.no_grad(): 548 | if self.hparams.n_augs_train != -1: 549 | self.trainset = EmbEnsEval.create_emb_dataset(self, self.trainset, n_augs=self.hparams.n_augs_train) 550 | if self.hparams.n_augs_test != -1: 551 | self.testset = EmbEnsEval.create_emb_dataset(self, self.testset, n_augs=self.hparams.n_augs_test) 552 | if self.hparams.acc_on_unlabeled: 553 | self.trainset_unlabeled = EmbEnsEval.create_emb_dataset( 554 | self, 555 | self.trainset_unlabeled, 556 | n_augs=self.hparams.n_augs_test 557 | ) 558 | if self.hparams.acc_on_unlabeled: 559 | self.testset = torch.utils.data.ConcatDataset([ 560 | datautils.DummyOutputWrapper(self.testset, 0), 561 | datautils.DummyOutputWrapper(self.trainset_unlabeled, 1) 562 | ]) 563 | 564 | def transforms(self): 565 | ens_train_transfom, ens_test_transform = EmbEnsEval.transforms(self) 566 | train_transform, test_transform = SSLEval.transforms(self) 567 | return ( 568 | train_transform if self.hparams.n_augs_train == -1 else ens_train_transfom, 569 | test_transform if self.hparams.n_augs_test == -1 else ens_test_transform 570 | ) 571 | 572 | def step(self, batch, it=None): 573 | if self.hparams.problem == 'eval' and self.hparams.data == 'imagenet': 574 | batch[0] = batch[0] / 255. 575 | h, y = batch 576 | if len(h.shape) == 4: 577 | h = self.encode(h) 578 | p = self.model(h) 579 | loss = F.cross_entropy(p, y) 580 | acc = (p.argmax(1) == y).float() 581 | return { 582 | 'loss': loss, 583 | 'acc': acc, 584 | } 585 | 586 | def test_step(self, batch): 587 | if not self.hparams.acc_on_unlabeled: 588 | return super().test_step(batch) 589 | # TODO: refactor 590 | x, y, d = batch 591 | logs = {} 592 | keys = set() 593 | for didx in [0, 1]: 594 | if torch.any(d == didx): 595 | t = super().test_step([x[d == didx], y[d == didx]]) 596 | for k, v in t.items(): 597 | keys.add(k) 598 | logs[k + f'_{didx}'] = v 599 | for didx in [0, 1]: 600 | for k in keys: 601 | logs[k + f'_{didx}'] = logs.get(k + f'_{didx}', torch.tensor([])) 602 | return logs 603 | 604 | 605 | def configure_optimizers(args, model, cur_iter=-1): 606 | iters = args.iters 607 | 608 | def exclude_from_wd_and_adaptation(name): 609 | if 'bn' in name: 610 | return True 611 | if args.opt == 'lars' and 'bias' in name: 612 | return True 613 | 614 | param_groups = [ 615 | { 616 | 'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)], 617 | 'weight_decay': args.weight_decay, 618 | 'layer_adaptation': True, 619 | }, 620 | { 621 | 'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)], 622 | 'weight_decay': 0., 623 | 'layer_adaptation': False, 624 | }, 625 | ] 626 | 627 | LR = args.lr 628 | 629 | if args.opt == 'sgd': 630 | optimizer = torch.optim.SGD( 631 | param_groups, 632 | lr=LR, 633 | momentum=0.9, 634 | ) 635 | elif args.opt == 'adam': 636 | optimizer = torch.optim.Adam( 637 | param_groups, 638 | lr=LR, 639 | ) 640 | elif args.opt == 'lars': 641 | optimizer = torch.optim.SGD( 642 | param_groups, 643 | lr=LR, 644 | momentum=0.9, 645 | ) 646 | larc_optimizer = LARS(optimizer) 647 | else: 648 | raise NotImplementedError 649 | 650 | if args.lr_schedule == 'warmup-anneal': 651 | scheduler = utils.LinearWarmupAndCosineAnneal( 652 | optimizer, 653 | args.warmup, 654 | iters, 655 | last_epoch=cur_iter, 656 | ) 657 | elif args.lr_schedule == 'linear': 658 | scheduler = utils.LinearLR(optimizer, iters, last_epoch=cur_iter) 659 | elif args.lr_schedule == 'const': 660 | scheduler = None 661 | else: 662 | raise NotImplementedError 663 | 664 | if args.opt == 'lars': 665 | optimizer = larc_optimizer 666 | 667 | # if args.verbose: 668 | # print('Optimizer : ', optimizer) 669 | # print('Scheduler : ', scheduler) 670 | 671 | return optimizer, scheduler 672 | -------------------------------------------------------------------------------- /myexman/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import ( 2 | ExParser, 3 | simpleroot 4 | ) 5 | from .index import ( 6 | Index 7 | ) 8 | from . import index 9 | from . import parser 10 | __version__ = '0.0.2' 11 | -------------------------------------------------------------------------------- /myexman/index.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import pandas as pd 3 | import pathlib 4 | import strconv 5 | import json 6 | import functools 7 | import datetime 8 | from . import parser 9 | import yaml 10 | from argparse import Namespace 11 | __all__ = [ 12 | 'Index' 13 | ] 14 | 15 | 16 | def only_value_error(conv): 17 | @functools.wraps(conv) 18 | def new_conv(value): 19 | try: 20 | return conv(value) 21 | except Exception as e: 22 | raise ValueError from e 23 | return new_conv 24 | 25 | 26 | def none2none(none): 27 | if none is None: 28 | return None 29 | else: 30 | raise ValueError 31 | 32 | 33 | converter = strconv.Strconv(converters=[ 34 | ('int', strconv.convert_int), 35 | ('float', strconv.convert_float), 36 | ('bool', only_value_error(parser.str2bool)), 37 | ('time', strconv.convert_time), 38 | ('datetime', strconv.convert_datetime), 39 | ('datetime1', lambda time: datetime.datetime.strptime(time, parser.TIME_FORMAT)), 40 | ('date', strconv.convert_date), 41 | ('json', only_value_error(json.loads)), 42 | ]) 43 | 44 | 45 | def get_args(path): 46 | with open(path, 'rb') as f: 47 | return Namespace(**yaml.load(f)) 48 | 49 | 50 | class Index(object): 51 | def __init__(self, root): 52 | self.root = pathlib.Path(root) 53 | 54 | @property 55 | def index(self): 56 | return self.root / 'index' 57 | 58 | @property 59 | def marked(self): 60 | return self.root / 'marked' 61 | 62 | def info(self, source=None, nlast=None): 63 | if source is None: 64 | source = self.index 65 | files = source.iterdir() 66 | if nlast is not None: 67 | files = sorted(list(files))[-nlast:] 68 | else: 69 | source = self.marked / source 70 | files = source.glob('**/*/'+parser.PARAMS_FILE) 71 | 72 | def get_dict(cfg): 73 | return configargparse.YAMLConfigFileParser().parse(cfg.open('r')) 74 | 75 | def convert_column(col): 76 | if any(isinstance(v, str) for v in converter.convert_series(col)): 77 | return col 78 | else: 79 | return pd.Series(converter.convert_series(col), name=col.name, index=col.index) 80 | try: 81 | df = (pd.DataFrame 82 | .from_records((get_dict(c) for c in files)) 83 | .apply(lambda s: convert_column(s)) 84 | .sort_values('id') 85 | .assign(root=lambda _: _.root.apply(self.root.__truediv__)) 86 | .reset_index(drop=True)) 87 | cols = df.columns.tolist() 88 | cols.insert(0, cols.pop(cols.index('id'))) 89 | return df.reindex(columns=cols) 90 | except FileNotFoundError as e: 91 | raise KeyError(source.name) from e 92 | -------------------------------------------------------------------------------- /myexman/parser.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import argparse 3 | import pathlib 4 | import datetime 5 | import yaml 6 | import yaml.representer 7 | import os 8 | import functools 9 | import itertools 10 | from filelock import FileLock 11 | __all__ = [ 12 | 'ExParser', 13 | 'simpleroot', 14 | ] 15 | 16 | 17 | TIME_FORMAT_DIR = '%Y-%m-%d-%H-%M-%S' 18 | TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' 19 | DIR_FORMAT = '{num}' 20 | EXT = 'yaml' 21 | PARAMS_FILE = 'params.'+EXT 22 | FOLDER_DEFAULT = 'exman' 23 | RESERVED_DIRECTORIES = { 24 | 'runs', 'index', 25 | 'tmp', 'marked' 26 | } 27 | 28 | 29 | def yaml_file(name): 30 | return name + '.' + EXT 31 | 32 | 33 | def simpleroot(__file__): 34 | return pathlib.Path(os.path.dirname(os.path.abspath(__file__)))/FOLDER_DEFAULT 35 | 36 | 37 | def represent_as_str(self, data, tostr=str): 38 | return yaml.representer.Representer.represent_str(self, tostr(data)) 39 | 40 | 41 | def register_str_converter(*types, tostr=str): 42 | for T in types: 43 | yaml.add_representer(T, functools.partial(represent_as_str, tostr=tostr)) 44 | 45 | 46 | register_str_converter(pathlib.PosixPath, pathlib.WindowsPath) 47 | 48 | 49 | def str2bool(s): 50 | true = ('true', 't', 'yes', 'y', 'on', '1') 51 | false = ('false', 'f', 'no', 'n', 'off', '0') 52 | 53 | if s.lower() in true: 54 | return True 55 | elif s.lower() in false: 56 | return False 57 | else: 58 | raise argparse.ArgumentTypeError(s, 'bool argument should be one of {}'.format(str(true + false))) 59 | 60 | 61 | class ParserWithRoot(configargparse.ArgumentParser): 62 | def __init__(self, *args, root=None, zfill=6, 63 | **kwargs): 64 | super().__init__(*args, **kwargs) 65 | if root is None: 66 | raise ValueError('Root directory is not specified') 67 | root = pathlib.Path(root) 68 | if not root.is_absolute(): 69 | raise ValueError(root, 'Root directory is not absolute path') 70 | if not root.exists(): 71 | raise ValueError(root, 'Root directory does not exist') 72 | self.root = pathlib.Path(root) 73 | self.zfill = zfill 74 | self.register('type', bool, str2bool) 75 | for directory in RESERVED_DIRECTORIES: 76 | getattr(self, directory).mkdir(exist_ok=True) 77 | self.lock = FileLock(str(self.root/'lock')) 78 | 79 | @property 80 | def runs(self): 81 | return self.root / 'runs' 82 | 83 | @property 84 | def marked(self): 85 | return self.root / 'marked' 86 | 87 | @property 88 | def index(self): 89 | return self.root / 'index' 90 | 91 | @property 92 | def tmp(self): 93 | return self.root / 'tmp' 94 | 95 | def max_ex(self): 96 | max_num = 0 97 | for directory in itertools.chain(self.runs.iterdir(), self.tmp.iterdir()): 98 | num = int(directory.name.split('-', 1)[0]) 99 | if num > max_num: 100 | max_num = num 101 | return max_num 102 | 103 | def num_ex(self): 104 | return len(list(self.runs.iterdir())) 105 | 106 | def next_ex(self): 107 | return self.max_ex() + 1 108 | 109 | def next_ex_str(self): 110 | return str(self.next_ex()).zfill(self.zfill) 111 | 112 | 113 | class ExParser(ParserWithRoot): 114 | """ 115 | Parser responsible for creating the following structure of experiments 116 | ``` 117 | root 118 | |-- runs 119 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 120 | | |-- params.yaml 121 | | `-- ... 122 | |-- index 123 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS.yaml (symlink) 124 | |-- marked 125 | | `-- 126 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS (symlink) 127 | | |-- params.yaml 128 | | `-- ... 129 | `-- tmp 130 | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 131 | |-- params.yaml 132 | `-- ... 133 | ``` 134 | """ 135 | def __init__(self, *args, zfill=6, file=None, 136 | args_for_setting_config_path=('--config', ), 137 | automark=(), 138 | parents=[], 139 | **kwargs): 140 | 141 | root = os.path.join(os.path.abspath(os.environ.get('EXMAN_PATH', './logs')), ('exman-' + str(file))) 142 | if not os.path.exists(root): 143 | os.makedirs(root) 144 | 145 | if len(parents) == 1: 146 | self.yaml_params_path = parents[0].yaml_params_path 147 | root = parents[0].root 148 | 149 | super().__init__(*args, root=root, zfill=zfill, 150 | args_for_setting_config_path=args_for_setting_config_path, 151 | config_file_parser_class=configargparse.YAMLConfigFileParser, 152 | ignore_unknown_config_file_keys=True, 153 | parents=parents, 154 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 155 | **kwargs) 156 | self.automark = automark 157 | if len(parents) == 0: 158 | self.add_argument('--tmp', action='store_true') 159 | 160 | def _initialize_dir(self, tmp): 161 | try: 162 | # with self.lock: # different processes can make it same time, this is needed to avoid collision 163 | time = datetime.datetime.now() 164 | num = self.next_ex_str() 165 | name = DIR_FORMAT.format(num=num, time=time.strftime(TIME_FORMAT_DIR)) 166 | if tmp: 167 | absroot = self.tmp / name 168 | relroot = pathlib.Path('tmp') / name 169 | else: 170 | absroot = self.runs / name 171 | relroot = pathlib.Path('runs') / name 172 | # this process now safely owns root directory 173 | # raises FileExistsError on fail 174 | absroot.mkdir() 175 | except FileExistsError: # shit still happens 176 | return self._initialize_dir(tmp) 177 | return absroot, relroot, name, time, num 178 | 179 | def parse_known_args(self, *args, log_params=True, **kwargs): 180 | args, argv = super().parse_known_args(*args, **kwargs) 181 | if not log_params: 182 | return args, argv 183 | 184 | if hasattr(self, 'yaml_params_path'): 185 | with self.yaml_params_path.open('w') as f: 186 | self.dumpd = args.__dict__.copy() 187 | yaml.dump(self.dumpd, f, default_flow_style=False) 188 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f) 189 | print("id:", int(self.num), file=f) 190 | print(self.yaml_params_path.read_text()) 191 | return args, argv 192 | 193 | absroot, relroot, name, time, num = self._initialize_dir(args.tmp) 194 | self.time = time 195 | self.num = num 196 | args.root = absroot 197 | self.yaml_params_path = args.root / PARAMS_FILE 198 | rel_yaml_params_path = pathlib.Path('..', 'runs', name, PARAMS_FILE) 199 | with self.yaml_params_path.open('a') as f: 200 | self.dumpd = args.__dict__.copy() 201 | # dumpd['root'] = relroot 202 | yaml.dump(self.dumpd, f, default_flow_style=False) 203 | print("\ntime: '{}'".format(time.strftime(TIME_FORMAT)), file=f) 204 | print("id:", int(num), file=f) 205 | print(self.yaml_params_path.read_text()) 206 | symlink = self.index / yaml_file(name) 207 | if not args.tmp: 208 | symlink.symlink_to(rel_yaml_params_path) 209 | print('Created symlink from', symlink, '->', rel_yaml_params_path) 210 | if self.automark and not args.tmp: 211 | automark_path_part = pathlib.Path(*itertools.chain.from_iterable( 212 | (mark, str(getattr(args, mark, ''))) 213 | for mark in self.automark)) 214 | markpath = pathlib.Path(self.marked, automark_path_part) 215 | markpath.mkdir(exist_ok=True, parents=True) 216 | relpathmark = pathlib.Path('..', *(['..']*len(automark_path_part.parts))) / 'runs' / name 217 | (markpath / name).symlink_to(relpathmark, target_is_directory=True) 218 | print('Created symlink from', markpath / name, '->', relpathmark) 219 | return args, argv 220 | 221 | def done(self): 222 | print('Success.') 223 | self.dumpd['status'] = 'done' 224 | with self.yaml_params_path.open('a') as f: 225 | yaml.dump(self.dumpd, f, default_flow_style=False) 226 | 227 | def update_params_file(self, args): 228 | dumpd = args.__dict__.copy() 229 | with self.yaml_params_path.open('w') as f: 230 | yaml.dump(dumpd, f, default_flow_style=False) 231 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f) 232 | print("id:", int(self.num), file=f) 233 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import models 8 | from utils.logger import Logger 9 | import myexman 10 | from utils import utils 11 | import sys 12 | import torch.multiprocessing as mp 13 | import torch.distributed as dist 14 | import socket 15 | 16 | 17 | def add_learner_params(parser): 18 | parser.add_argument('--problem', default='sim-clr', 19 | help='The problem to train', 20 | choices=models.REGISTERED_MODELS, 21 | ) 22 | parser.add_argument('--name', default='', 23 | help='Name for the experiment', 24 | ) 25 | parser.add_argument('--ckpt', default='', 26 | help='Optional checkpoint to init the model.' 27 | ) 28 | parser.add_argument('--verbose', default=False, type=bool) 29 | # optimizer params 30 | parser.add_argument('--lr_schedule', default='warmup-anneal') 31 | parser.add_argument('--opt', default='lars', help='Optimizer to use', choices=['sgd', 'adam', 'lars']) 32 | parser.add_argument('--iters', default=-1, type=int, help='The number of optimizer updates') 33 | parser.add_argument('--warmup', default=0, type=float, help='The number of warmup iterations in proportion to \'iters\'') 34 | parser.add_argument('--lr', default=0.1, type=float, help='Base learning rate') 35 | parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, dest='weight_decay') 36 | # trainer params 37 | parser.add_argument('--save_freq', default=10000000000000000, type=int, help='Frequency to save the model') 38 | parser.add_argument('--log_freq', default=100, type=int, help='Logging frequency') 39 | parser.add_argument('--eval_freq', default=10000000000000000, type=int, help='Evaluation frequency') 40 | parser.add_argument('-j', '--workers', default=4, type=int, help='The number of data loader workers') 41 | parser.add_argument('--eval_only', default=False, type=bool, help='Skips the training step if True') 42 | parser.add_argument('--seed', default=-1, type=int, help='Random seed') 43 | # parallelizm params: 44 | parser.add_argument('--dist', default='dp', type=str, 45 | help='dp: DataParallel, ddp: DistributedDataParallel', 46 | choices=['dp', 'ddp'], 47 | ) 48 | parser.add_argument('--dist_address', default='127.0.0.1:1234', type=str, 49 | help='the address and a port of the main node in the
: format' 50 | ) 51 | parser.add_argument('--node_rank', default=0, type=int, 52 | help='Rank of the node (script launched): 0 for the main node and 1,... for the others', 53 | ) 54 | parser.add_argument('--world_size', default=1, type=int, 55 | help='the number of nodes (scripts launched)', 56 | ) 57 | 58 | 59 | def main(): 60 | parser = myexman.ExParser(file=os.path.basename(__file__)) 61 | add_learner_params(parser) 62 | 63 | is_help = False 64 | if '--help' in sys.argv or '-h' in sys.argv: 65 | sys.argv.pop(sys.argv.index('--help' if '--help' in sys.argv else '-h')) 66 | is_help = True 67 | 68 | args, _ = parser.parse_known_args(log_params=False) 69 | 70 | models.REGISTERED_MODELS[args.problem].add_model_hparams(parser) 71 | 72 | if is_help: 73 | sys.argv.append('--help') 74 | 75 | args = parser.parse_args(namespace=args) 76 | 77 | if args.data == 'imagenet' and args.aug == False: 78 | raise Exception('ImageNet models should be eval with aug=True!') 79 | 80 | if args.seed != -1: 81 | random.seed(args.seed) 82 | torch.manual_seed(args.seed) 83 | cudnn.deterministic = True 84 | 85 | args.gpu = 0 86 | ngpus = torch.cuda.device_count() 87 | args.number_of_processes = 1 88 | if args.dist == 'ddp': 89 | # add additional argument to be able to retrieve # of processes from logs 90 | # and don't change initial arguments to reproduce the experiment 91 | args.number_of_processes = args.world_size * ngpus 92 | parser.update_params_file(args) 93 | 94 | args.world_size *= ngpus 95 | mp.spawn( 96 | main_worker, 97 | nprocs=ngpus, 98 | args=(ngpus, args), 99 | ) 100 | else: 101 | parser.update_params_file(args) 102 | main_worker(args.gpu, -1, args) 103 | 104 | 105 | def main_worker(gpu, ngpus, args): 106 | fmt = { 107 | 'train_time': '.3f', 108 | 'val_time': '.3f', 109 | 'lr': '.1e', 110 | } 111 | logger = Logger('logs', base=args.root, fmt=fmt) 112 | 113 | args.gpu = gpu 114 | torch.cuda.set_device(gpu) 115 | args.rank = args.node_rank * ngpus + gpu 116 | 117 | device = torch.device('cuda:%d' % args.gpu) 118 | 119 | if args.dist == 'ddp': 120 | dist.init_process_group( 121 | backend='nccl', 122 | init_method='tcp://%s' % args.dist_address, 123 | world_size=args.world_size, 124 | rank=args.rank, 125 | ) 126 | 127 | n_gpus_total = dist.get_world_size() 128 | assert args.batch_size % n_gpus_total == 0 129 | args.batch_size //= n_gpus_total 130 | if args.rank == 0: 131 | print(f'===> {n_gpus_total} GPUs total; batch_size={args.batch_size} per GPU') 132 | 133 | print(f'===> Proc {dist.get_rank()}/{dist.get_world_size()}@{socket.gethostname()}', flush=True) 134 | 135 | # create model 136 | model = models.REGISTERED_MODELS[args.problem](args, device=device) 137 | 138 | if args.ckpt != '': 139 | ckpt = torch.load(args.ckpt, map_location=device) 140 | model.load_state_dict(ckpt['state_dict']) 141 | 142 | # Data loading code 143 | model.prepare_data() 144 | train_loader, val_loader = model.dataloaders(iters=args.iters) 145 | 146 | # define optimizer 147 | cur_iter = 0 148 | optimizer, scheduler = models.ssl.configure_optimizers(args, model, cur_iter - 1) 149 | 150 | # optionally resume from a checkpoint 151 | if args.ckpt and not args.eval_only: 152 | optimizer.load_state_dict(ckpt['opt_state_dict']) 153 | 154 | cudnn.benchmark = True 155 | 156 | continue_training = args.iters != 0 157 | data_time, it_time = 0, 0 158 | 159 | while continue_training: 160 | train_logs = [] 161 | model.train() 162 | 163 | start_time = time.time() 164 | for _, batch in enumerate(train_loader): 165 | cur_iter += 1 166 | 167 | batch = [x.to(device) for x in batch] 168 | data_time += time.time() - start_time 169 | 170 | logs = {} 171 | if not args.eval_only: 172 | # forward pass and compute loss 173 | logs = model.train_step(batch, cur_iter) 174 | loss = logs['loss'] 175 | 176 | # gradient step 177 | optimizer.zero_grad() 178 | loss.backward() 179 | optimizer.step() 180 | 181 | # save logs for the batch 182 | train_logs.append({k: utils.tonp(v) for k, v in logs.items()}) 183 | 184 | if cur_iter % args.save_freq == 0 and args.rank == 0: 185 | save_checkpoint(args.root, model, optimizer, cur_iter) 186 | 187 | if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters: 188 | # TODO: aggregate metrics over all processes 189 | test_logs = [] 190 | model.eval() 191 | with torch.no_grad(): 192 | for batch in val_loader: 193 | batch = [x.to(device) for x in batch] 194 | # forward pass 195 | logs = model.test_step(batch) 196 | # save logs for the batch 197 | test_logs.append(logs) 198 | model.train() 199 | 200 | test_logs = utils.agg_all_metrics(test_logs) 201 | logger.add_logs(cur_iter, test_logs, pref='test_') 202 | 203 | it_time += time.time() - start_time 204 | 205 | if (cur_iter % args.log_freq == 0 or cur_iter >= args.iters) and args.rank == 0: 206 | save_checkpoint(args.root, model, optimizer) 207 | train_logs = utils.agg_all_metrics(train_logs) 208 | 209 | logger.add_logs(cur_iter, train_logs, pref='train_') 210 | logger.add_scalar(cur_iter, 'lr', optimizer.param_groups[0]['lr']) 211 | logger.add_scalar(cur_iter, 'data_time', data_time) 212 | logger.add_scalar(cur_iter, 'it_time', it_time) 213 | logger.iter_info() 214 | logger.save() 215 | 216 | data_time, it_time = 0, 0 217 | train_logs = [] 218 | 219 | if scheduler is not None: 220 | scheduler.step() 221 | 222 | if cur_iter >= args.iters: 223 | continue_training = False 224 | break 225 | 226 | start_time = time.time() 227 | 228 | save_checkpoint(args.root, model, optimizer) 229 | 230 | if args.dist == 'ddp': 231 | dist.destroy_process_group() 232 | 233 | 234 | def save_checkpoint(path, model, optimizer, cur_iter=None): 235 | if cur_iter is None: 236 | fname = os.path.join(path, 'checkpoint.pth.tar') 237 | else: 238 | fname = os.path.join(path, 'checkpoint-%d.pth.tar' % cur_iter) 239 | 240 | ckpt = model.get_ckpt() 241 | ckpt.update( 242 | { 243 | 'opt_state_dict': optimizer.state_dict(), 244 | 'iter': cur_iter, 245 | } 246 | ) 247 | 248 | torch.save(ckpt, fname) 249 | 250 | 251 | if __name__ == '__main__': 252 | main() 253 | -------------------------------------------------------------------------------- /utils/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torchvision import transforms 5 | import torch.utils.data 6 | import PIL 7 | import torchvision.transforms.functional as FT 8 | from PIL import Image 9 | 10 | 11 | if 'DATA_ROOT' in os.environ: 12 | DATA_ROOT = os.environ['DATA_ROOT'] 13 | else: 14 | DATA_ROOT = './data' 15 | 16 | IMAGENET_PATH = './data/imagenet/raw-data' 17 | 18 | 19 | def pad(img, size, mode): 20 | if isinstance(img, PIL.Image.Image): 21 | img = np.array(img) 22 | return np.pad(img, [(size, size), (size, size), (0, 0)], mode) 23 | 24 | 25 | mean = { 26 | 'mnist': (0.1307,), 27 | 'cifar10': (0.4914, 0.4822, 0.4465) 28 | } 29 | 30 | std = { 31 | 'mnist': (0.3081,), 32 | 'cifar10': (0.2470, 0.2435, 0.2616) 33 | } 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | PyTorch version of 39 | https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 40 | """ 41 | def gaussian_blur(self, image, sigma): 42 | image = image.reshape(1, 3, 224, 224) 43 | radius = np.int(self.kernel_size/2) 44 | kernel_size = radius * 2 + 1 45 | x = np.arange(-radius, radius + 1) 46 | 47 | blur_filter = np.exp( 48 | -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) 49 | blur_filter /= np.sum(blur_filter) 50 | 51 | conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, padding=[kernel_size//2, 0], bias=False) 52 | conv1.weight = torch.nn.Parameter( 53 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 0, 1]))) 54 | 55 | conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, padding=[0, kernel_size//2], bias=False) 56 | conv2.weight = torch.nn.Parameter( 57 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 1, 0]))) 58 | 59 | res = conv2(conv1(image)) 60 | assert res.shape == image.shape 61 | return res[0] 62 | 63 | def __init__(self, kernel_size, p=0.5): 64 | self.kernel_size = kernel_size 65 | self.p = p 66 | 67 | def __call__(self, img): 68 | with torch.no_grad(): 69 | assert isinstance(img, torch.Tensor) 70 | if np.random.uniform() < self.p: 71 | return self.gaussian_blur(img, sigma=np.random.uniform(0.2, 2)) 72 | return img 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + '(kernel_size={0}, p={1})'.format(self.kernel_size, self.p) 76 | 77 | class CenterCropAndResize(object): 78 | """Crops the given PIL Image at the center. 79 | 80 | Args: 81 | size (sequence or int): Desired output size of the crop. If size is an 82 | int instead of sequence like (h, w), a square crop (size, size) is 83 | made. 84 | """ 85 | 86 | def __init__(self, proportion, size): 87 | self.proportion = proportion 88 | self.size = size 89 | 90 | def __call__(self, img): 91 | """ 92 | Args: 93 | img (PIL Image): Image to be cropped. 94 | 95 | Returns: 96 | PIL Image: Cropped and image. 97 | """ 98 | w, h = (np.array(img.size) * self.proportion).astype(int) 99 | img = FT.resize( 100 | FT.center_crop(img, (h, w)), 101 | (self.size, self.size), 102 | interpolation=PIL.Image.BICUBIC 103 | ) 104 | return img 105 | 106 | def __repr__(self): 107 | return self.__class__.__name__ + '(proportion={0}, size={1})'.format(self.proportion, self.size) 108 | 109 | 110 | class Clip(object): 111 | def __call__(self, x): 112 | return torch.clamp(x, 0, 1) 113 | 114 | 115 | class MultiplyBatchSampler(torch.utils.data.sampler.BatchSampler): 116 | MULTILPLIER = 2 117 | 118 | def __iter__(self): 119 | for batch in super().__iter__(): 120 | yield batch * self.MULTILPLIER 121 | 122 | 123 | class ContinousSampler(torch.utils.data.sampler.Sampler): 124 | def __init__(self, sampler, n_iterations): 125 | self.base_sampler = sampler 126 | self.n_iterations = n_iterations 127 | 128 | def __iter__(self): 129 | cur_iter = 0 130 | while cur_iter < self.n_iterations: 131 | for batch in self.base_sampler: 132 | yield batch 133 | cur_iter += 1 134 | if cur_iter >= self.n_iterations: return 135 | 136 | def __len__(self): 137 | return self.n_iterations 138 | 139 | def set_epoch(self, epoch): 140 | self.base_sampler.set_epoch(epoch) 141 | 142 | 143 | def get_color_distortion(s=1.0): 144 | # s is the strength of color distortion. 145 | # given from https://arxiv.org/pdf/2002.05709.pdf 146 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 147 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 148 | rnd_gray = transforms.RandomGrayscale(p=0.2) 149 | color_distort = transforms.Compose([ 150 | rnd_color_jitter, 151 | rnd_gray]) 152 | return color_distort 153 | 154 | 155 | class DummyOutputWrapper(torch.utils.data.dataset.Dataset): 156 | def __init__(self, dataset, dummy): 157 | self.dummy = dummy 158 | self.dataset = dataset 159 | 160 | def __getitem__(self, index): 161 | return (*self.dataset[index], self.dummy) 162 | 163 | def __len__(self): 164 | return len(self.dataset) 165 | -------------------------------------------------------------------------------- /utils/lars_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class LARS(object): 8 | """ 9 | Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py 10 | Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py 11 | 12 | Args: 13 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 14 | trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 15 | """ 16 | 17 | def __init__(self, 18 | optimizer, 19 | trust_coefficient=0.001, 20 | ): 21 | self.param_groups = optimizer.param_groups 22 | self.optim = optimizer 23 | self.trust_coefficient = trust_coefficient 24 | 25 | def __getstate__(self): 26 | return self.optim.__getstate__() 27 | 28 | def __setstate__(self, state): 29 | self.optim.__setstate__(state) 30 | 31 | def __repr__(self): 32 | return self.optim.__repr__() 33 | 34 | def state_dict(self): 35 | return self.optim.state_dict() 36 | 37 | def load_state_dict(self, state_dict): 38 | self.optim.load_state_dict(state_dict) 39 | 40 | def zero_grad(self): 41 | self.optim.zero_grad() 42 | 43 | def add_param_group(self, param_group): 44 | self.optim.add_param_group(param_group) 45 | 46 | def step(self): 47 | with torch.no_grad(): 48 | weight_decays = [] 49 | for group in self.optim.param_groups: 50 | # absorb weight decay control from optimizer 51 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 52 | weight_decays.append(weight_decay) 53 | group['weight_decay'] = 0 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | 58 | if weight_decay != 0: 59 | p.grad.data += weight_decay * p.data 60 | 61 | param_norm = torch.norm(p.data) 62 | grad_norm = torch.norm(p.grad.data) 63 | adaptive_lr = 1. 64 | 65 | if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: 66 | adaptive_lr = self.trust_coefficient * param_norm / grad_norm 67 | 68 | p.grad.data *= adaptive_lr 69 | 70 | self.optim.step() 71 | # return weight decay control to optimizer 72 | for i, group in enumerate(self.optim.param_groups): 73 | group['weight_decay'] = weight_decays[i] 74 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | 6 | from collections import OrderedDict 7 | from tabulate import tabulate 8 | from pandas import DataFrame 9 | from time import gmtime, strftime 10 | import time 11 | 12 | 13 | class Logger: 14 | def __init__(self, name='name', fmt=None, base='./logs'): 15 | self.handler = True 16 | self.scalar_metrics = OrderedDict() 17 | self.fmt = fmt if fmt else dict() 18 | 19 | if not os.path.exists(base): 20 | os.makedirs(base) 21 | 22 | time = gmtime() 23 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)]) 24 | fname = '-'.join(sys.argv[0].split('/')[-3:]) 25 | # self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time)) 26 | # self.path = '%s/%s-%s' % (base, fname, name) 27 | self.path = os.path.join(base, name) 28 | 29 | self.logs = self.path + '.csv' 30 | self.output = self.path + '.out' 31 | self.iters_since_last_header = 0 32 | 33 | def prin(*args): 34 | str_to_write = ' '.join(map(str, args)) 35 | with open(self.output, 'a') as f: 36 | f.write(str_to_write + '\n') 37 | f.flush() 38 | 39 | print(str_to_write) 40 | sys.stdout.flush() 41 | 42 | self.print = prin 43 | 44 | def add_scalar(self, t, key, value): 45 | if key not in self.scalar_metrics: 46 | self.scalar_metrics[key] = [] 47 | self.scalar_metrics[key] += [(t, value)] 48 | 49 | def add_logs(self, t, logs, pref=''): 50 | for k, v in logs.items(): 51 | self.add_scalar(t, pref + k, v) 52 | 53 | def iter_info(self, order=None): 54 | self.iters_since_last_header += 1 55 | if self.iters_since_last_header > 40: 56 | self.handler = True 57 | 58 | names = list(self.scalar_metrics.keys()) 59 | if order: 60 | names = order 61 | values = [self.scalar_metrics[name][-1][1] for name in names] 62 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names])) 63 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.3f' for name in names] 64 | 65 | if self.handler: 66 | self.handler = False 67 | self.iters_since_last_header = 0 68 | self.print(tabulate([[t] + values], ['t'] + names, floatfmt=fmt)) 69 | else: 70 | self.print(tabulate([[t] + values], ['t'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1]) 71 | 72 | def save(self): 73 | result = None 74 | for key in self.scalar_metrics.keys(): 75 | if result is None: 76 | result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 77 | else: 78 | df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 79 | result = result.join(df, how='outer') 80 | result.to_csv(self.logs) 81 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import warnings 4 | import time 5 | import torch.distributed as dist 6 | 7 | 8 | def timing(f): 9 | def wrap(*args, **kwargs): 10 | time1 = time.time() 11 | ret = f(*args, **kwargs) 12 | time2 = time.time() 13 | print('{:s} function took {:.3f} ms'.format(f.__name__, (time2-time1)*1000.0)) 14 | 15 | return ret 16 | return wrap 17 | 18 | 19 | def agg_all_metrics(outputs): 20 | if len(outputs) == 0: 21 | return outputs 22 | res = {} 23 | keys = [k for k in outputs[0].keys() if not isinstance(outputs[0][k], dict)] 24 | for k in keys: 25 | all_logs = np.concatenate([tonp(x[k]).reshape(-1) for x in outputs]) 26 | if k != 'epoch': 27 | res[k] = np.mean(all_logs) 28 | else: 29 | res[k] = all_logs[-1] 30 | return res 31 | 32 | 33 | def gather_metrics(metrics): 34 | for k, v in metrics.items(): 35 | if v.dim() == 0: 36 | v = v[None] 37 | v_all = [torch.zeros_like(v) for _ in range(dist.get_world_size())] 38 | dist.all_gather(v_all, v) 39 | v_all = torch.cat(v_all) 40 | metrics[k] = v_all 41 | 42 | 43 | def viz_array_grid(array, rows, cols, padding=0, channels_last=False, normalize=False, **kwargs): 44 | # normalization 45 | ''' 46 | Args: 47 | array: (N_images, N_channels, H, W) or (N_images, H, W, N_channels) 48 | rows, cols: rows and columns of the plot. rows * cols == array.shape[0] 49 | padding: padding between cells of plot 50 | channels_last: for Tensorflow = True, for PyTorch = False 51 | normalize: `False`, `mean_std`, or `min_max` 52 | Kwargs: 53 | if normalize == 'mean_std': 54 | mean: mean of the distribution. Default 0.5 55 | std: std of the distribution. Default 0.5 56 | if normalize == 'min_max': 57 | min: min of the distribution. Default array.min() 58 | max: max if the distribution. Default array.max() 59 | ''' 60 | array = tonp(array) 61 | if not channels_last: 62 | array = np.transpose(array, (0, 2, 3, 1)) 63 | 64 | array = array.astype('float32') 65 | 66 | if normalize: 67 | if normalize == 'mean_std': 68 | mean = kwargs.get('mean', 0.5) 69 | mean = np.array(mean).reshape((1, 1, 1, -1)) 70 | std = kwargs.get('std', 0.5) 71 | std = np.array(std).reshape((1, 1, 1, -1)) 72 | array = array * std + mean 73 | elif normalize == 'min_max': 74 | min_ = kwargs.get('min', array.min()) 75 | min_ = np.array(min_).reshape((1, 1, 1, -1)) 76 | max_ = kwargs.get('max', array.max()) 77 | max_ = np.array(max_).reshape((1, 1, 1, -1)) 78 | array -= min_ 79 | array /= max_ + 1e-9 80 | 81 | batch_size, H, W, channels = array.shape 82 | assert rows * cols == batch_size 83 | 84 | if channels == 1: 85 | canvas = np.ones((H * rows + padding * (rows - 1), 86 | W * cols + padding * (cols - 1))) 87 | array = array[:, :, :, 0] 88 | elif channels == 3: 89 | canvas = np.ones((H * rows + padding * (rows - 1), 90 | W * cols + padding * (cols - 1), 91 | 3)) 92 | else: 93 | raise TypeError('number of channels is either 1 of 3') 94 | 95 | for i in range(rows): 96 | for j in range(cols): 97 | img = array[i * cols + j] 98 | start_h = i * padding + i * H 99 | start_w = j * padding + j * W 100 | canvas[start_h: start_h + H, start_w: start_w + W] = img 101 | 102 | canvas = np.clip(canvas, 0, 1) 103 | canvas *= 255.0 104 | canvas = canvas.astype('uint8') 105 | return canvas 106 | 107 | 108 | def tonp(x): 109 | if isinstance(x, (np.ndarray, float, int)): 110 | return np.array(x) 111 | return x.detach().cpu().numpy() 112 | 113 | 114 | class LinearLR(torch.optim.lr_scheduler._LRScheduler): 115 | def __init__(self, optimizer, num_epochs, last_epoch=-1): 116 | self.num_epochs = max(num_epochs, 1) 117 | super().__init__(optimizer, last_epoch) 118 | 119 | def get_lr(self): 120 | res = [] 121 | for lr in self.base_lrs: 122 | res.append(np.maximum(lr * np.minimum(-self.last_epoch * 1. / self.num_epochs + 1., 1.), 0.)) 123 | return res 124 | 125 | 126 | class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): 127 | def __init__(self, optimizer, warm_up, T_max, last_epoch=-1, smooth=1e-9): 128 | self.warm_up = int(warm_up * T_max) 129 | self.T_max = T_max - self.warm_up 130 | self.smooth = smooth 131 | super().__init__(optimizer, last_epoch=last_epoch) 132 | 133 | def get_lr(self): 134 | if not self._get_lr_called_within_step: 135 | warnings.warn("To get the last learning rate computed by the scheduler, " 136 | "please use `get_last_lr()`.") 137 | 138 | if self.last_epoch == 0: 139 | return [lr / (self.warm_up + 1) for lr in self.base_lrs] 140 | elif self.last_epoch <= self.warm_up: 141 | c = (self.last_epoch + 1) / self.last_epoch 142 | return [group['lr'] * c for group in self.optimizer.param_groups] 143 | else: 144 | # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 145 | le = self.last_epoch - self.warm_up 146 | 147 | if le > self.T_max: 148 | warnings.warn(f"Epoch {self.last_epoch}: reached maximum number of iterations {self.T_max + self.warm_up}. This is unexpected behavior, and this SimCLR implementation was not tested in this regime!") 149 | 150 | return [(1 + np.cos(np.pi * le / self.T_max)) / 151 | (1 + np.cos(np.pi * (le - 1) / self.T_max) + self.smooth) * 152 | group['lr'] 153 | for group in self.optimizer.param_groups] 154 | 155 | 156 | class BaseLR(torch.optim.lr_scheduler._LRScheduler): 157 | def get_lr(self): 158 | return [group['lr'] for group in self.optimizer.param_groups] 159 | --------------------------------------------------------------------------------