├── .gitignore ├── LICENSE ├── README.md ├── data └── mihi-chewie │ └── processed │ └── full-chewie-10032013-0.10.pkl ├── docs ├── cifar.md ├── custom_image_dataset.md ├── custom_sequence_dataset.md └── monkey_reach.md ├── requirements.txt ├── scripts ├── cifar-eval.py ├── cifar-train.py └── monkey-train.py └── self_supervised ├── __init__.py ├── data ├── __init__.py ├── generators │ ├── __init__.py │ └── local_global_generator.py ├── io.py ├── monkey_reach_dataset.py └── utils.py ├── loss ├── __init__.py └── cosine_loss.py ├── model ├── __init__.py ├── byol.py ├── double_byol.py ├── mlp3.py └── myow_factory.py ├── nets ├── __init__.py ├── mlp.py └── resnets.py ├── optimizer ├── __init__.py └── lars.py ├── tasks ├── __init__.py ├── classification.py ├── fast_classification.py └── neural_tasks.py ├── tensorboard ├── __init__.py └── embedding_projector.py ├── trainer ├── __init__.py ├── byol_trainer.py └── myow_trainer.py ├── transforms ├── __init__.py └── neural.py └── utils ├── __init__.py ├── metric_logger.py └── random_seeders.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | 81 | # Mac OS-specific storage files 82 | .DS_Store 83 | 84 | # vim 85 | *.swp 86 | *.swo 87 | 88 | # Mypy cache 89 | .mypy_cache/ 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MYOW 2 | 3 | PyTorch implementation of 4 | [Mine Your Own vieW: Self-Supervised Learning Through Across-Sample Prediction](https://arxiv.org/abs/2102.10106). 5 | 6 | ## Installation 7 | To install requirements run: 8 | ```bash 9 | python3 -m pip install -r requirements.txt 10 | export PYTHONPATH=$PYTHONPATH:$PWD 11 | ``` 12 | 13 | 14 | ## Training 15 | * Running MYOW on CIFAR-10 dataset.
16 | * Running MYOW on neural recordings from primates.
17 | 18 | Setting up your own datasets: 19 | 20 | * Image dataset.
21 | * Temporal dataset.
22 | 23 | ## Contributors 24 | 25 | * Mehdi Azabou (Maintainer), github: [mazabou](https://github.com/mazabou) 26 | * Ran Liu, github: [ranliu98](https://github.com/ranliu98) 27 | * Kiran Bhaskaran-Nair, github: [kbn-gh](https://github.com/kbn-gh) 28 | * Erik C. Johnson, github: [erikjohnson24](https://github.com/erikjohnson24) 29 | 30 | ## Citation 31 | If you find the code useful for your research, please consider citing our work: 32 | 33 | ``` 34 | @misc{azabou2021view, 35 | title={Mine Your Own vieW: Self-Supervised Learning Through Across-Sample Prediction}, 36 | author={Mehdi Azabou and Mohammad Gheshlaghi Azar and Ran Liu and Chi-Heng Lin and Erik C. Johnson 37 | and Kiran Bhaskaran-Nair and Max Dabagia and Keith B. Hengen and William Gray-Roncal 38 | and Michal Valko and Eva L. Dyer}, 39 | year={2021}, 40 | eprint={2102.10106}, 41 | archivePrefix={arXiv}, 42 | primaryClass={cs.LG} 43 | } 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /data/mihi-chewie/processed/full-chewie-10032013-0.10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nerdslab/myow/73f2b430d284bcd6bb0c6faaac30576e84393ff1/data/mihi-chewie/processed/full-chewie-10032013-0.10.pkl -------------------------------------------------------------------------------- /docs/cifar.md: -------------------------------------------------------------------------------- 1 | # Running MYOW on CIFAR10 dataset 2 | 3 | This page walks through the steps required to run MYOW on the CIFAR10 dataset. 4 | 5 | ## Running training 6 | Training is parallalised using `DistributedDataParallel`. The pool of candidate views during mining is shared across 7 | all instances. 8 | 9 | To start training run: 10 | 11 | ```bash 12 | CUDA_VISIBLE_DEVICES=0,1 python3 scripts/cifar-train.py \ 13 | --lr 2.0 \ 14 | --mm 0.98 \ 15 | --weight_decay 5e-5 \ 16 | --optimizer sgd \ 17 | --lr_warmup_epochs 30 \ 18 | --batch_size 256 \ 19 | --port 12354 \ 20 | --logdir myow-cifar \ 21 | --ckptpath myow-cifar 22 | ``` 23 | 24 | ## Running evaluation 25 | 26 | Evaluation can be done simultaneously or after training on a separate GPU instance. The eval script will automatically 27 | run evaluation each time a new checkpoint is saved to `ckptpath`. It is also possible to start evaluation only after 28 | a certain number of epoch using the `resume_eval` argument. 29 | 30 | ```bash 31 | CUDA_VISIBLE_DEVICES=2 python3 scripts/cifar-eval.py \ 32 | --lr 0.04 \ 33 | --resume_eval 0 \ 34 | --logdir runs-cifar \ 35 | --ckptpath myow-cifar 36 | ``` 37 | 38 | ## Running tensorboard 39 | 40 | ```bash 41 | tensorboard --logdir=runs-cifar 42 | ``` 43 | -------------------------------------------------------------------------------- /docs/custom_image_dataset.md: -------------------------------------------------------------------------------- 1 | # Using your own image datasets 2 | 3 | Setting up your own dataset is straightforward. This code already contains an example of MYOW applied to the CIFAR-10 4 | dataset. We have made it easy to train MYOW on any dataset and here, we give a brief step-by-step guide to setting 5 | up your own. 6 | 7 | Augmentations can either be done on the CPU or on the GPU. 8 | 9 | ### On-GPU augmentations 10 | When it is possible and desirable to perform augmentations on the GPU in batch mode, use the `transform` argument 11 | to specify the augmentations used during training to generate views. In this case, augmented view generation is 12 | fully handled by the trainer which only need to receive a batch of the original images. 13 | 14 | Any standard PyTorch dataset can be used out of the box. A function called `prepare_views` needs to be defined, this 15 | function handles the definition of a dictionary that contains the two views needed by the trainer. In this case, 16 | the image is given twice. 17 | 18 | ```python 19 | def prepare_views(inputs): 20 | x, labels = inputs 21 | outputs = {'view1': x, 'view2': x} 22 | return outputs 23 | ``` 24 | 25 | When constructing the trainer, simply specify: 26 | 27 | ```python 28 | MYOWTrainer(..., prepare_views=prepare_views, transform=transform) 29 | ``` 30 | 31 | In some situations, it might be desired to define two different classes of transformations, this can be done by 32 | separately defining `transform_1` and `transform_2` with each being applied to their respective view. 33 | 34 | ```python 35 | MYOWTrainer(..., prepare_views=prepare_views, transform_1=transform_1, transform_2=transform_2) 36 | ``` 37 | 38 | **View mining** 39 | 40 | A set of candidates needs to be generated from the dataset during mining. When working with an image dataset, this can 41 | easily be done through a second dataloader passed to the trainer through `view_pool_dataloader`. The batch size of this 42 | dataloader can be different from the batch size of the main dataloader. The transformation used during mining is 43 | specified through `transform_m` which will be applied to the key sample as well as all candidate samples. 44 | 45 | ```python 46 | MYOWTrainer(..., view_pool_dataloader=view_pool_dataloader, transform_m=transform_m) 47 | ``` 48 | 49 | ### On-CPU augmentations 50 | 51 | To perform augmentation on the CPU in single image mode, the augmented views need to be generated before being passed 52 | to the trainer. In this case, a custom Dataset object can be implemented along with `prepare_views` which 53 | is used to assign the views. 54 | 55 | ```python 56 | def prepare_views(inputs): 57 | x1, x2 = inputs 58 | outputs = {'view1': x1, 'view2': x2} 59 | return outputs 60 | ``` 61 | 62 | When constructing the trainer, there is no need to specify the `transform`: 63 | 64 | ```python 65 | MYOWTrainer(..., prepare_views=prepare_views, transform=None) 66 | ``` 67 | 68 | **View mining** 69 | 70 | Similarly, the batch of candidate views needs to be generated pre-`trainer`. The custom Dataset needs to generate 4 71 | tensors, given an input `x`: two augmented views of `x` using the main class of transformations, one augmented view 72 | using `transform_m` and a batch of candidate views. 73 | 74 | ```python 75 | def prepare_views(inputs): 76 | x1, x2, x3, xpool = inputs 77 | outputs = {'view1': x1, 'view2': x2, 'view3':x3, 'view_pool':xpool} 78 | return outputs 79 | ``` 80 | -------------------------------------------------------------------------------- /docs/custom_sequence_dataset.md: -------------------------------------------------------------------------------- 1 | # Using your own sequence datasets 2 | 3 | This code already contains an example of MYOW applied to a dataset that contains a number of sequences. 4 | Here, we give a brief step-by-step guide to setting up your own sequence dataset. 5 | 6 | The main augmentation is temporal shift, for that we use `data.generator.LocalGlobalGenerator` that takes as input the 7 | matrix of features as well as the list of possible pairs. This list is pre-computed to allow for faster data loading. 8 | `data.utils.onlywithin_indices` can be used to generate such list. 9 | 10 | ```python 11 | pair_sets = utils.onlywithin_indices(sequence_lengths, k_min=-2, k_max=2) 12 | ``` 13 | 14 | Then there is the generation of the pool of candidates for mining. Because we are working with time-varying data, 15 | we need to restrict the candidate views used for mining to be separated by a minimum distance in time from the key 16 | views. We do this for our monkey datasets, by sampling two different sets of sequences from which we then we sample the 17 | key views and the pool of candidate views separately. 18 | 19 | The additional transforms can be added through the `transform` argument. 20 | 21 | ```python 22 | generator = generators.LocalGlobalGenerator(firing_rates, pair_sets, sequence_lengths, 23 | num_examples=firing_rates.shape[0], 24 | batch_size=batch_size, 25 | pool_batch_size=pool_batch_size, 26 | transform=transform, num_workers=1, 27 | structured_transform=True) 28 | ``` 29 | 30 | Similarly to images, the data needs to be specified for the trainer through the `prepare_views` function, which in this 31 | case is defined as a static method of the generator. 32 | 33 | ```python 34 | @staticmethod 35 | def prepare_views(inputs): 36 | x1, x2, x3, x4 = inputs 37 | outputs = {'view1': x1.squeeze(), 'view2': x2.squeeze(), 38 | 'view3': x3.squeeze(), 'view_pool': x4.squeeze()} 39 | return outputs 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/monkey_reach.md: -------------------------------------------------------------------------------- 1 | # Running MYOW on Reach dataset 2 | 3 | This page walks through the steps required to run MYOW on the monkey datasets. 4 | 5 | ## Dataset 6 | Processed data can be found in `data/` 7 | 8 | ## Running training and evaluation 9 | 10 | Training can be run using the following command: 11 | 12 | ```bash 13 | python3 scripts/monkey-train.py \ 14 | --data_path=./data/mihi-chewie \ 15 | --primate="chewie" \ 16 | --day=1 \ 17 | --max_lookahead=4 \ 18 | --noise_sigma=0.1 \ 19 | --dropout_p=0.8 \ 20 | --dropout_apply_p=0.9 \ 21 | --structured_transform=True \ 22 | --batch_size=256 \ 23 | --pool_batch_size=512 \ 24 | --miner_k=3 \ 25 | --myow_warmup_epochs=10 \ 26 | --myow_rampup_epochs=110 27 | ``` 28 | where `primate` and `day` specify the animal. 29 | 30 | 31 | ## Running tensorboard 32 | 33 | ```bash 34 | tensorboard --logdir=runs-chewie1 35 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | tensorboard==2.3.0 4 | tensorflow==2.3.1 5 | tqdm~=4.48.2 6 | scipy~=1.4.1 7 | numpy~=1.18.5 8 | matplotlib~=3.3.1 9 | h5py~=2.10.0 10 | kornia~=0.4.0 11 | sklearn~=0.0 12 | scikit-learn~=0.23.2 13 | opencv-python~=4.4.0.44 14 | Pillow~=7.2.0 15 | pandas~=1.1.3 16 | -------------------------------------------------------------------------------- /scripts/cifar-eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | 5 | from absl import app 6 | from absl import flags 7 | from glob import glob 8 | import torch 9 | import tensorflow as tf 10 | from torch.utils.data import DataLoader 11 | from torchvision.datasets import CIFAR10, CIFAR100 12 | from torchvision.transforms import ToTensor 13 | from torchvision import transforms 14 | 15 | from self_supervised.trainer import MYOWTrainer 16 | from self_supervised.nets import resnet_cifar 17 | from self_supervised.tasks import fast_classification 18 | 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | flags.DEFINE_float('lr', 0.02, 'Base learning rate.') 23 | flags.DEFINE_integer('resume_eval', 0, 'Epoch at which evaluation starts.') 24 | flags.DEFINE_string('logdir', 'myow-run', 'Tensorboard dir name.') 25 | flags.DEFINE_string('ckptpath', 'myow-model', 'Checkkpoint folder dir name.') 26 | 27 | 28 | def eval(gpu, ckpt_epoch, args, dataset_class=CIFAR10, num_classes=10): 29 | # load dataset 30 | image_size = 32 31 | transform = transforms.Compose([transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)), 32 | transforms.RandomHorizontalFlip(p=0.5), 33 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 34 | 35 | pre_transform = transforms.Compose([ToTensor(), 36 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 37 | 38 | dataset = dataset_class("../datasets/cifar10", train=True, download=True, transform=pre_transform, 39 | target_transform=torch.tensor) 40 | 41 | dataset_val = dataset_class("../datasets/cifar10", train=False, download=True, transform=pre_transform, 42 | target_transform=torch.tensor) 43 | 44 | # Class of transformation for BYOL 45 | batch_size = 1024 46 | num_workers = 2 47 | 48 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, 49 | drop_last=False, pin_memory=True) 50 | dataloader_val = DataLoader(dataset_val, batch_size=batch_size, num_workers=num_workers, shuffle=False, 51 | drop_last=False, pin_memory=True) 52 | 53 | def prepare_views(inputs): 54 | x, labels = inputs 55 | outputs = {'view1': x, 'view2': x} 56 | return outputs 57 | 58 | # build ResNET network 59 | encoder = resnet_cifar() 60 | representation_size = 512 61 | projection_hidden_size = 4096 62 | projection_size = 256 63 | 64 | # build byol trainer 65 | projection_size_2 = 64 66 | projection_hidden_size_2 = 1024 67 | lr = 0.08 68 | momentum = 0.99 69 | 70 | train_epochs = 800 71 | byol_warmup_epochs = 10 72 | n_decay = 1.5 73 | myow_warmup_epochs = 100 74 | myow_rampup_epochs = 110 75 | myow_max_weight = 1.0 76 | 77 | trainer = MYOWTrainer(encoder=encoder, representation_size=representation_size, projection_size=projection_size, 78 | projection_hidden_size=projection_hidden_size, projection_size_2=projection_size_2, 79 | projection_hidden_size_2=projection_hidden_size_2, prepare_views=prepare_views, 80 | train_dataloader=dataloader, view_pool_dataloader=dataloader, transform=transform, 81 | total_epochs=train_epochs, warmup_epochs=byol_warmup_epochs, 82 | myow_warmup_epochs=myow_warmup_epochs, myow_rampup_epochs=myow_rampup_epochs, 83 | myow_max_weight=myow_max_weight, batch_size=batch_size, 84 | base_lr=lr, base_momentum=momentum, momentum=0.9, weight_decay=args.wd, 85 | exclude_bias_and_bn=True, optimizer_type='sgd', symmetric_loss=True, view_miner_k=1, 86 | gpu=gpu, distributed=True, decay='cosine', m_decay='cosine', 87 | log_step=10, log_dir='runs-cifar/{}-elr{}'.format(args.logdir, args.lr), 88 | ckpt_path='ckpt/{}/ckpt-%d.pt'.format(args.ckptpath)) 89 | 90 | trainer.load_checkpoint(ckpt_epoch) 91 | print('checkpoint loaded') 92 | 93 | trainer.model.eval() 94 | # representation 95 | print('computing representations') 96 | data_train = fast_classification.compute_representations(trainer.model.online_encoder.eval(), dataloader, 97 | device=trainer.device) 98 | data_val = fast_classification.compute_representations(trainer.model.online_encoder.eval(), dataloader_val, 99 | device=trainer.device) 100 | 101 | clr = args.lr 102 | classifier = resnet_cifar.get_linear_classifier(output_dim=num_classes).to(trainer.device) 103 | class_optimizer = torch.optim.SGD(classifier.parameters(), lr=clr, momentum=0.9) 104 | scheduler = torch.optim.lr_scheduler.MultiStepLR(class_optimizer, milestones=[60, 80], gamma=0.1) 105 | batch_size = 512 106 | acc = fast_classification.train_linear_layer(classifier, data_train, data_val, class_optimizer, scheduler=scheduler, 107 | writer=trainer.writer, tag=ckpt_epoch, batch_size=batch_size, 108 | num_epochs=100, device=trainer.device, tqdm_progress=True) 109 | 110 | print('Train', acc.train_last, ', Test', acc.val_smooth) 111 | trainer.writer.add_scalar('eval-train-%d' % num_classes, acc.train_last, ckpt_epoch) 112 | trainer.writer.add_scalar('eval-test-%d' % num_classes, acc.val_smooth, ckpt_epoch) 113 | 114 | 115 | def find_checkpoints(ckpt_path): 116 | def atoi(text): 117 | return int(text) if text.isdigit() else text 118 | 119 | def natural_keys(text): 120 | """alist.sort(key=natural_keys) sorts in human order 121 | http://nedbatchelder.com/blog/200712/human_sorting.html 122 | """ 123 | return [atoi(c) for c in re.split(r'(\d+)', text)] 124 | 125 | ckpt_list = glob(os.path.join(ckpt_path, 'byol-ckpt-*.pt')) 126 | ckpt_list.sort(key=natural_keys) 127 | return ckpt_list 128 | 129 | def main(): 130 | ckptpath = 'ckpt/{}'.format(FLAGS.ckptpath) 131 | already_computed = [] 132 | while True: 133 | ckpt_list = find_checkpoints(ckptpath) 134 | for c in ckpt_list: 135 | ckpt_epoch = int(re.findall(r'(\d{1,3})\.pt$', c)[0]) 136 | if ckpt_epoch