├── .gitignore
├── LICENSE
├── README.md
├── data
└── mihi-chewie
│ └── processed
│ └── full-chewie-10032013-0.10.pkl
├── docs
├── cifar.md
├── custom_image_dataset.md
├── custom_sequence_dataset.md
└── monkey_reach.md
├── requirements.txt
├── scripts
├── cifar-eval.py
├── cifar-train.py
└── monkey-train.py
└── self_supervised
├── __init__.py
├── data
├── __init__.py
├── generators
│ ├── __init__.py
│ └── local_global_generator.py
├── io.py
├── monkey_reach_dataset.py
└── utils.py
├── loss
├── __init__.py
└── cosine_loss.py
├── model
├── __init__.py
├── byol.py
├── double_byol.py
├── mlp3.py
└── myow_factory.py
├── nets
├── __init__.py
├── mlp.py
└── resnets.py
├── optimizer
├── __init__.py
└── lars.py
├── tasks
├── __init__.py
├── classification.py
├── fast_classification.py
└── neural_tasks.py
├── tensorboard
├── __init__.py
└── embedding_projector.py
├── trainer
├── __init__.py
├── byol_trainer.py
└── myow_trainer.py
├── transforms
├── __init__.py
└── neural.py
└── utils
├── __init__.py
├── metric_logger.py
└── random_seeders.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | env/
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 | *.cover
45 |
46 | # Translations
47 | *.mo
48 | *.pot
49 |
50 | # Django stuff:
51 | *.log
52 |
53 | # Sphinx documentation
54 | docs/_build/
55 |
56 | # PyBuilder
57 | target/
58 |
59 | # DotEnv configuration
60 | .env
61 |
62 | # Database
63 | *.db
64 | *.rdb
65 |
66 | # Pycharm
67 | .idea
68 |
69 | # VS Code
70 | .vscode/
71 |
72 | # Spyder
73 | .spyproject/
74 |
75 | # Jupyter NB Checkpoints
76 | .ipynb_checkpoints/
77 |
78 | # exclude data from source control by default
79 | /data/
80 |
81 | # Mac OS-specific storage files
82 | .DS_Store
83 |
84 | # vim
85 | *.swp
86 | *.swo
87 |
88 | # Mypy cache
89 | .mypy_cache/
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MYOW
2 |
3 | PyTorch implementation of
4 | [Mine Your Own vieW: Self-Supervised Learning Through Across-Sample Prediction](https://arxiv.org/abs/2102.10106).
5 |
6 | ## Installation
7 | To install requirements run:
8 | ```bash
9 | python3 -m pip install -r requirements.txt
10 | export PYTHONPATH=$PYTHONPATH:$PWD
11 | ```
12 |
13 |
14 | ## Training
15 | * Running MYOW on CIFAR-10 dataset.
16 | * Running MYOW on neural recordings from primates.
17 |
18 | Setting up your own datasets:
19 |
20 | * Image dataset.
21 | * Temporal dataset.
22 |
23 | ## Contributors
24 |
25 | * Mehdi Azabou (Maintainer), github: [mazabou](https://github.com/mazabou)
26 | * Ran Liu, github: [ranliu98](https://github.com/ranliu98)
27 | * Kiran Bhaskaran-Nair, github: [kbn-gh](https://github.com/kbn-gh)
28 | * Erik C. Johnson, github: [erikjohnson24](https://github.com/erikjohnson24)
29 |
30 | ## Citation
31 | If you find the code useful for your research, please consider citing our work:
32 |
33 | ```
34 | @misc{azabou2021view,
35 | title={Mine Your Own vieW: Self-Supervised Learning Through Across-Sample Prediction},
36 | author={Mehdi Azabou and Mohammad Gheshlaghi Azar and Ran Liu and Chi-Heng Lin and Erik C. Johnson
37 | and Kiran Bhaskaran-Nair and Max Dabagia and Keith B. Hengen and William Gray-Roncal
38 | and Michal Valko and Eva L. Dyer},
39 | year={2021},
40 | eprint={2102.10106},
41 | archivePrefix={arXiv},
42 | primaryClass={cs.LG}
43 | }
44 | ```
45 |
46 |
--------------------------------------------------------------------------------
/data/mihi-chewie/processed/full-chewie-10032013-0.10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nerdslab/myow/73f2b430d284bcd6bb0c6faaac30576e84393ff1/data/mihi-chewie/processed/full-chewie-10032013-0.10.pkl
--------------------------------------------------------------------------------
/docs/cifar.md:
--------------------------------------------------------------------------------
1 | # Running MYOW on CIFAR10 dataset
2 |
3 | This page walks through the steps required to run MYOW on the CIFAR10 dataset.
4 |
5 | ## Running training
6 | Training is parallalised using `DistributedDataParallel`. The pool of candidate views during mining is shared across
7 | all instances.
8 |
9 | To start training run:
10 |
11 | ```bash
12 | CUDA_VISIBLE_DEVICES=0,1 python3 scripts/cifar-train.py \
13 | --lr 2.0 \
14 | --mm 0.98 \
15 | --weight_decay 5e-5 \
16 | --optimizer sgd \
17 | --lr_warmup_epochs 30 \
18 | --batch_size 256 \
19 | --port 12354 \
20 | --logdir myow-cifar \
21 | --ckptpath myow-cifar
22 | ```
23 |
24 | ## Running evaluation
25 |
26 | Evaluation can be done simultaneously or after training on a separate GPU instance. The eval script will automatically
27 | run evaluation each time a new checkpoint is saved to `ckptpath`. It is also possible to start evaluation only after
28 | a certain number of epoch using the `resume_eval` argument.
29 |
30 | ```bash
31 | CUDA_VISIBLE_DEVICES=2 python3 scripts/cifar-eval.py \
32 | --lr 0.04 \
33 | --resume_eval 0 \
34 | --logdir runs-cifar \
35 | --ckptpath myow-cifar
36 | ```
37 |
38 | ## Running tensorboard
39 |
40 | ```bash
41 | tensorboard --logdir=runs-cifar
42 | ```
43 |
--------------------------------------------------------------------------------
/docs/custom_image_dataset.md:
--------------------------------------------------------------------------------
1 | # Using your own image datasets
2 |
3 | Setting up your own dataset is straightforward. This code already contains an example of MYOW applied to the CIFAR-10
4 | dataset. We have made it easy to train MYOW on any dataset and here, we give a brief step-by-step guide to setting
5 | up your own.
6 |
7 | Augmentations can either be done on the CPU or on the GPU.
8 |
9 | ### On-GPU augmentations
10 | When it is possible and desirable to perform augmentations on the GPU in batch mode, use the `transform` argument
11 | to specify the augmentations used during training to generate views. In this case, augmented view generation is
12 | fully handled by the trainer which only need to receive a batch of the original images.
13 |
14 | Any standard PyTorch dataset can be used out of the box. A function called `prepare_views` needs to be defined, this
15 | function handles the definition of a dictionary that contains the two views needed by the trainer. In this case,
16 | the image is given twice.
17 |
18 | ```python
19 | def prepare_views(inputs):
20 | x, labels = inputs
21 | outputs = {'view1': x, 'view2': x}
22 | return outputs
23 | ```
24 |
25 | When constructing the trainer, simply specify:
26 |
27 | ```python
28 | MYOWTrainer(..., prepare_views=prepare_views, transform=transform)
29 | ```
30 |
31 | In some situations, it might be desired to define two different classes of transformations, this can be done by
32 | separately defining `transform_1` and `transform_2` with each being applied to their respective view.
33 |
34 | ```python
35 | MYOWTrainer(..., prepare_views=prepare_views, transform_1=transform_1, transform_2=transform_2)
36 | ```
37 |
38 | **View mining**
39 |
40 | A set of candidates needs to be generated from the dataset during mining. When working with an image dataset, this can
41 | easily be done through a second dataloader passed to the trainer through `view_pool_dataloader`. The batch size of this
42 | dataloader can be different from the batch size of the main dataloader. The transformation used during mining is
43 | specified through `transform_m` which will be applied to the key sample as well as all candidate samples.
44 |
45 | ```python
46 | MYOWTrainer(..., view_pool_dataloader=view_pool_dataloader, transform_m=transform_m)
47 | ```
48 |
49 | ### On-CPU augmentations
50 |
51 | To perform augmentation on the CPU in single image mode, the augmented views need to be generated before being passed
52 | to the trainer. In this case, a custom Dataset object can be implemented along with `prepare_views` which
53 | is used to assign the views.
54 |
55 | ```python
56 | def prepare_views(inputs):
57 | x1, x2 = inputs
58 | outputs = {'view1': x1, 'view2': x2}
59 | return outputs
60 | ```
61 |
62 | When constructing the trainer, there is no need to specify the `transform`:
63 |
64 | ```python
65 | MYOWTrainer(..., prepare_views=prepare_views, transform=None)
66 | ```
67 |
68 | **View mining**
69 |
70 | Similarly, the batch of candidate views needs to be generated pre-`trainer`. The custom Dataset needs to generate 4
71 | tensors, given an input `x`: two augmented views of `x` using the main class of transformations, one augmented view
72 | using `transform_m` and a batch of candidate views.
73 |
74 | ```python
75 | def prepare_views(inputs):
76 | x1, x2, x3, xpool = inputs
77 | outputs = {'view1': x1, 'view2': x2, 'view3':x3, 'view_pool':xpool}
78 | return outputs
79 | ```
80 |
--------------------------------------------------------------------------------
/docs/custom_sequence_dataset.md:
--------------------------------------------------------------------------------
1 | # Using your own sequence datasets
2 |
3 | This code already contains an example of MYOW applied to a dataset that contains a number of sequences.
4 | Here, we give a brief step-by-step guide to setting up your own sequence dataset.
5 |
6 | The main augmentation is temporal shift, for that we use `data.generator.LocalGlobalGenerator` that takes as input the
7 | matrix of features as well as the list of possible pairs. This list is pre-computed to allow for faster data loading.
8 | `data.utils.onlywithin_indices` can be used to generate such list.
9 |
10 | ```python
11 | pair_sets = utils.onlywithin_indices(sequence_lengths, k_min=-2, k_max=2)
12 | ```
13 |
14 | Then there is the generation of the pool of candidates for mining. Because we are working with time-varying data,
15 | we need to restrict the candidate views used for mining to be separated by a minimum distance in time from the key
16 | views. We do this for our monkey datasets, by sampling two different sets of sequences from which we then we sample the
17 | key views and the pool of candidate views separately.
18 |
19 | The additional transforms can be added through the `transform` argument.
20 |
21 | ```python
22 | generator = generators.LocalGlobalGenerator(firing_rates, pair_sets, sequence_lengths,
23 | num_examples=firing_rates.shape[0],
24 | batch_size=batch_size,
25 | pool_batch_size=pool_batch_size,
26 | transform=transform, num_workers=1,
27 | structured_transform=True)
28 | ```
29 |
30 | Similarly to images, the data needs to be specified for the trainer through the `prepare_views` function, which in this
31 | case is defined as a static method of the generator.
32 |
33 | ```python
34 | @staticmethod
35 | def prepare_views(inputs):
36 | x1, x2, x3, x4 = inputs
37 | outputs = {'view1': x1.squeeze(), 'view2': x2.squeeze(),
38 | 'view3': x3.squeeze(), 'view_pool': x4.squeeze()}
39 | return outputs
40 | ```
41 |
--------------------------------------------------------------------------------
/docs/monkey_reach.md:
--------------------------------------------------------------------------------
1 | # Running MYOW on Reach dataset
2 |
3 | This page walks through the steps required to run MYOW on the monkey datasets.
4 |
5 | ## Dataset
6 | Processed data can be found in `data/`
7 |
8 | ## Running training and evaluation
9 |
10 | Training can be run using the following command:
11 |
12 | ```bash
13 | python3 scripts/monkey-train.py \
14 | --data_path=./data/mihi-chewie \
15 | --primate="chewie" \
16 | --day=1 \
17 | --max_lookahead=4 \
18 | --noise_sigma=0.1 \
19 | --dropout_p=0.8 \
20 | --dropout_apply_p=0.9 \
21 | --structured_transform=True \
22 | --batch_size=256 \
23 | --pool_batch_size=512 \
24 | --miner_k=3 \
25 | --myow_warmup_epochs=10 \
26 | --myow_rampup_epochs=110
27 | ```
28 | where `primate` and `day` specify the animal.
29 |
30 |
31 | ## Running tensorboard
32 |
33 | ```bash
34 | tensorboard --logdir=runs-chewie1
35 | ```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | tensorboard==2.3.0
4 | tensorflow==2.3.1
5 | tqdm~=4.48.2
6 | scipy~=1.4.1
7 | numpy~=1.18.5
8 | matplotlib~=3.3.1
9 | h5py~=2.10.0
10 | kornia~=0.4.0
11 | sklearn~=0.0
12 | scikit-learn~=0.23.2
13 | opencv-python~=4.4.0.44
14 | Pillow~=7.2.0
15 | pandas~=1.1.3
16 |
--------------------------------------------------------------------------------
/scripts/cifar-eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 |
5 | from absl import app
6 | from absl import flags
7 | from glob import glob
8 | import torch
9 | import tensorflow as tf
10 | from torch.utils.data import DataLoader
11 | from torchvision.datasets import CIFAR10, CIFAR100
12 | from torchvision.transforms import ToTensor
13 | from torchvision import transforms
14 |
15 | from self_supervised.trainer import MYOWTrainer
16 | from self_supervised.nets import resnet_cifar
17 | from self_supervised.tasks import fast_classification
18 |
19 |
20 | FLAGS = flags.FLAGS
21 |
22 | flags.DEFINE_float('lr', 0.02, 'Base learning rate.')
23 | flags.DEFINE_integer('resume_eval', 0, 'Epoch at which evaluation starts.')
24 | flags.DEFINE_string('logdir', 'myow-run', 'Tensorboard dir name.')
25 | flags.DEFINE_string('ckptpath', 'myow-model', 'Checkkpoint folder dir name.')
26 |
27 |
28 | def eval(gpu, ckpt_epoch, args, dataset_class=CIFAR10, num_classes=10):
29 | # load dataset
30 | image_size = 32
31 | transform = transforms.Compose([transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
32 | transforms.RandomHorizontalFlip(p=0.5),
33 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
34 |
35 | pre_transform = transforms.Compose([ToTensor(),
36 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
37 |
38 | dataset = dataset_class("../datasets/cifar10", train=True, download=True, transform=pre_transform,
39 | target_transform=torch.tensor)
40 |
41 | dataset_val = dataset_class("../datasets/cifar10", train=False, download=True, transform=pre_transform,
42 | target_transform=torch.tensor)
43 |
44 | # Class of transformation for BYOL
45 | batch_size = 1024
46 | num_workers = 2
47 |
48 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False,
49 | drop_last=False, pin_memory=True)
50 | dataloader_val = DataLoader(dataset_val, batch_size=batch_size, num_workers=num_workers, shuffle=False,
51 | drop_last=False, pin_memory=True)
52 |
53 | def prepare_views(inputs):
54 | x, labels = inputs
55 | outputs = {'view1': x, 'view2': x}
56 | return outputs
57 |
58 | # build ResNET network
59 | encoder = resnet_cifar()
60 | representation_size = 512
61 | projection_hidden_size = 4096
62 | projection_size = 256
63 |
64 | # build byol trainer
65 | projection_size_2 = 64
66 | projection_hidden_size_2 = 1024
67 | lr = 0.08
68 | momentum = 0.99
69 |
70 | train_epochs = 800
71 | byol_warmup_epochs = 10
72 | n_decay = 1.5
73 | myow_warmup_epochs = 100
74 | myow_rampup_epochs = 110
75 | myow_max_weight = 1.0
76 |
77 | trainer = MYOWTrainer(encoder=encoder, representation_size=representation_size, projection_size=projection_size,
78 | projection_hidden_size=projection_hidden_size, projection_size_2=projection_size_2,
79 | projection_hidden_size_2=projection_hidden_size_2, prepare_views=prepare_views,
80 | train_dataloader=dataloader, view_pool_dataloader=dataloader, transform=transform,
81 | total_epochs=train_epochs, warmup_epochs=byol_warmup_epochs,
82 | myow_warmup_epochs=myow_warmup_epochs, myow_rampup_epochs=myow_rampup_epochs,
83 | myow_max_weight=myow_max_weight, batch_size=batch_size,
84 | base_lr=lr, base_momentum=momentum, momentum=0.9, weight_decay=args.wd,
85 | exclude_bias_and_bn=True, optimizer_type='sgd', symmetric_loss=True, view_miner_k=1,
86 | gpu=gpu, distributed=True, decay='cosine', m_decay='cosine',
87 | log_step=10, log_dir='runs-cifar/{}-elr{}'.format(args.logdir, args.lr),
88 | ckpt_path='ckpt/{}/ckpt-%d.pt'.format(args.ckptpath))
89 |
90 | trainer.load_checkpoint(ckpt_epoch)
91 | print('checkpoint loaded')
92 |
93 | trainer.model.eval()
94 | # representation
95 | print('computing representations')
96 | data_train = fast_classification.compute_representations(trainer.model.online_encoder.eval(), dataloader,
97 | device=trainer.device)
98 | data_val = fast_classification.compute_representations(trainer.model.online_encoder.eval(), dataloader_val,
99 | device=trainer.device)
100 |
101 | clr = args.lr
102 | classifier = resnet_cifar.get_linear_classifier(output_dim=num_classes).to(trainer.device)
103 | class_optimizer = torch.optim.SGD(classifier.parameters(), lr=clr, momentum=0.9)
104 | scheduler = torch.optim.lr_scheduler.MultiStepLR(class_optimizer, milestones=[60, 80], gamma=0.1)
105 | batch_size = 512
106 | acc = fast_classification.train_linear_layer(classifier, data_train, data_val, class_optimizer, scheduler=scheduler,
107 | writer=trainer.writer, tag=ckpt_epoch, batch_size=batch_size,
108 | num_epochs=100, device=trainer.device, tqdm_progress=True)
109 |
110 | print('Train', acc.train_last, ', Test', acc.val_smooth)
111 | trainer.writer.add_scalar('eval-train-%d' % num_classes, acc.train_last, ckpt_epoch)
112 | trainer.writer.add_scalar('eval-test-%d' % num_classes, acc.val_smooth, ckpt_epoch)
113 |
114 |
115 | def find_checkpoints(ckpt_path):
116 | def atoi(text):
117 | return int(text) if text.isdigit() else text
118 |
119 | def natural_keys(text):
120 | """alist.sort(key=natural_keys) sorts in human order
121 | http://nedbatchelder.com/blog/200712/human_sorting.html
122 | """
123 | return [atoi(c) for c in re.split(r'(\d+)', text)]
124 |
125 | ckpt_list = glob(os.path.join(ckpt_path, 'byol-ckpt-*.pt'))
126 | ckpt_list.sort(key=natural_keys)
127 | return ckpt_list
128 |
129 | def main():
130 | ckptpath = 'ckpt/{}'.format(FLAGS.ckptpath)
131 | already_computed = []
132 | while True:
133 | ckpt_list = find_checkpoints(ckptpath)
134 | for c in ckpt_list:
135 | ckpt_epoch = int(re.findall(r'(\d{1,3})\.pt$', c)[0])
136 | if ckpt_epoch= batch_size:
67 | sample_indices = np.random.choice(sum(mask), batch_size, replace=False)
68 | else:
69 | sample_indices = np.concatenate([np.arange(sum(mask)),
70 | np.random.choice(sum(mask), batch_size-sum(mask), replace=False)])
71 | x = self.features[mask][sample_indices]
72 | return x
73 |
74 | def __iter__(self):
75 | for _ in range(self.num_iterations):
76 | trials_1, trials_2 = self._sample_trials()
77 | x1, x2, dt = self._sample_pairs()
78 | x3 = self._sample_random(trials_1, self.batch_size)
79 | x4 = self._sample_random(trials_2, self.pool_batch_size)
80 |
81 | x1 = torch.tensor(x1)
82 | x2 = torch.tensor(x2)
83 | x3 = torch.tensor(x3)
84 | x4 = torch.tensor(x4)
85 |
86 | if self.transform is not None:
87 | if self.structured_transform:
88 | x1, x2 = self.transform(x1, x2)
89 | else:
90 | [x1] = self.transform(x1)
91 | [x2] = self.transform(x2)
92 | [x3] = self.transform(x3)
93 | [x4] = self.transform(x4)
94 | yield x1.type(torch.float32), x2.type(torch.float32), \
95 | x3.type(torch.float32), x4.type(torch.float32)
96 |
97 | def __len__(self):
98 | return self.num_iterations * self.batch_size * self.num_workers
99 |
100 | @staticmethod
101 | def prepare_views(inputs):
102 | x1, x2, x3, x4 = inputs
103 | outputs = {'view1': x1.squeeze(), 'view2': x2.squeeze(),
104 | 'view3': x3.squeeze(), 'view_pool': x4.squeeze()}
105 | return outputs
106 |
--------------------------------------------------------------------------------
/self_supervised/data/io.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | from scipy import io as spio
3 | import numpy as np
4 |
5 |
6 | def loadmat(filename):
7 | r"""This function should be called instead of direct spio.loadmat
8 | as it cures the problem of not properly recovering python dictionaries
9 | from mat files. It calls the function check keys to cure all entries
10 | which are still mat-objects.
11 | """
12 |
13 | def _check_keys(d):
14 | r"""Checks if entries in dictionary are mat-objects. If yes
15 | todict is called to change them to nested dictionaries.
16 | """
17 | for key in d:
18 | if isinstance(d[key], spio.matlab.mio5_params.mat_struct):
19 | d[key] = _todict(d[key])
20 | return d
21 |
22 | def _todict(matobj):
23 | r"""A recursive function which constructs from matobjects nested dictionaries."""
24 | d = {}
25 | for strg in matobj._fieldnames:
26 | elem = matobj.__dict__[strg]
27 | if isinstance(elem, spio.matlab.mio5_params.mat_struct):
28 | d[strg] = _todict(elem)
29 | elif isinstance(elem, np.ndarray):
30 | d[strg] = _tolist(elem)
31 | else:
32 | d[strg] = elem
33 | return d
34 |
35 | def _tolist(ndarray):
36 | r"""A recursive function which constructs lists from cellarrays
37 | (which are loaded as numpy ndarrays), recursing into the elements
38 | if they contain matobjects.
39 | """
40 | elem_list = []
41 | for sub_elem in ndarray:
42 | if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct):
43 | elem_list.append(_todict(sub_elem))
44 | elif isinstance(sub_elem, np.ndarray):
45 | elem_list.append(_tolist(sub_elem))
46 | else:
47 | elem_list.append(sub_elem)
48 | return elem_list
49 |
50 | data = scipy.io.loadmat(filename, struct_as_record=False, squeeze_me=True)
51 | return _check_keys(data)
52 |
--------------------------------------------------------------------------------
/self_supervised/data/monkey_reach_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 | import torch
7 |
8 | from self_supervised.data.io import loadmat
9 |
10 |
11 | FILENAMES = {
12 | ('mihi', 1): 'full-mihi-03032014',
13 | ('mihi', 2): 'full-mihi-03062014',
14 | ('chewie', 1): 'full-chewie-10032013',
15 | ('chewie', 2): 'full-chewie-12192013',
16 | }
17 |
18 |
19 | class ReachNeuralDataset:
20 | def __init__(self, path, primate='mihi', day=1,
21 | binning_period=0.1, binning_overlap=0.0, train_split=0.8,
22 | scale_firing_rates=False, scale_velocity=False, sort_by_reach=True):
23 |
24 | self.path = path
25 | # get path to data
26 | assert primate in ['mihi', 'chewie']
27 | assert day in [1, 2]
28 | self.primate = primate
29 |
30 | self.filename = FILENAMES[(self.primate, day)]
31 | self.raw_path = os.path.join(self.path, 'raw/%s.mat') % self.filename
32 | self.processed_path = os.path.join(self.path, 'processed/%s.pkl') % (self.filename + '-%.2f' % binning_period)
33 |
34 | # get binning parameters
35 | self.binning_period = binning_period
36 | self.binning_overlap = binning_overlap
37 | if self.binning_overlap != 0:
38 | raise NotImplemented
39 |
40 | # train/val split
41 | self.train_split = train_split
42 |
43 | # initialize some parameters
44 | self.dataset_ = {}
45 | self.subset = 'train' # default selected subset
46 |
47 | ### Process data
48 | # load data
49 | if not os.path.exists(self.processed_path):
50 | data_train_test = self._process_data()
51 | else:
52 | data_train_test = self._load_processed_data()
53 |
54 | # split data
55 | data_train, data_test = self._split_data(data_train_test)
56 | self._num_trials = {'train': len(data_train['firing_rates']),
57 | 'test': len(data_test['firing_rates'])}
58 |
59 | # compute mean and std of firing rates
60 | self.mean, self.std = self._compute_mean_std(data_train, feature='firing_rates')
61 |
62 | # remove neurons with no variance
63 | data_train, data_test = self._remove_static_neurons(data_train, data_test)
64 |
65 | # scale data
66 | if scale_firing_rates:
67 | data_train, data_test = self._scale_data(data_train, data_test, feature='firing_rates')
68 | if scale_velocity:
69 | data_train, data_test = self._scale_data(data_train, data_test, feature='velocity')
70 |
71 | # sort by reach direction
72 | if sort_by_reach:
73 | data_train = self._sort_by_reach_direction(data_train)
74 | data_test = self._sort_by_reach_direction(data_test)
75 |
76 | # build sequences
77 | trial_lengths_train = [seq.shape[0] for seq in data_train['firing_rates']]
78 |
79 | # merge everything
80 | for feature in data_train.keys():
81 | data_train[feature] = np.concatenate(data_train[feature]).squeeze()
82 | data_test[feature] = np.concatenate(data_test[feature]).squeeze()
83 |
84 | data_train['trial_lengths'] = trial_lengths_train
85 | data_train['reach_directions'] = np.unique(data_train['labels']).tolist()
86 | data_train['reach_lengths'] = [np.sum(data_train['labels'] == reach_id)
87 | for reach_id in data_train['reach_directions']]
88 |
89 | # map labels to 0 .. N-1 for training
90 | data_train['raw_labels'] = data_train['labels'].copy()
91 | data_test['raw_labels'] = data_test['labels'].copy()
92 |
93 | data_train['labels'] = self._map_labels(data_train)
94 | data_test['labels'] = self._map_labels(data_test)
95 |
96 | self.dataset_['train'] = data_train
97 | self.dataset_['test'] = data_test
98 |
99 | @property
100 | def dataset(self):
101 | return self.dataset_[self.subset]
102 |
103 | def __getattr__(self, item):
104 | return self.dataset[item]
105 |
106 | def train(self):
107 | self.subset = 'train'
108 |
109 | def test(self):
110 | self.subset = 'test'
111 |
112 | @property
113 | def num_trials(self):
114 | return self._num_trials[self.subset]
115 |
116 | @property
117 | def num_neurons(self):
118 | return self[0]['firing_rates'].shape[1]
119 |
120 | def _process_data(self):
121 | print('Preparing dataset: Binning data.')
122 | # load data
123 | mat_dict = loadmat(self.raw_path)
124 |
125 | # bin data
126 | data = self._bin_data(mat_dict)
127 |
128 | self._save_processed_data(data)
129 | return data
130 |
131 | def _save_processed_data(self, data):
132 | with open(self.processed_path, 'wb') as output:
133 | pickle.dump({'data': data}, output)
134 |
135 | def _load_processed_data(self):
136 | with open(self.processed_path, "rb") as fp:
137 | data = pickle.load(fp)['data']
138 | return data
139 |
140 | def _bin_data(self, mat_dict):
141 | # load matrix
142 | trialtable = mat_dict['trial_table']
143 | neurons = mat_dict['out_struct']['units']
144 | pos = np.array(mat_dict['out_struct']['pos'])
145 | vel = np.array(mat_dict['out_struct']['vel'])
146 | acc = np.array(mat_dict['out_struct']['acc'])
147 | force = np.array(mat_dict['out_struct']['force'])
148 | time = vel[:, 0]
149 |
150 | num_neurons = len(neurons)
151 | num_trials = trialtable.shape[0]
152 |
153 | data = {'firing_rates': [], 'position': [], 'velocity': [], 'acceleration': [],
154 | 'force': [], 'labels': [], 'sequence': []}
155 | for trial_id in tqdm(range(num_trials)):
156 | min_T = trialtable[trial_id, 9]
157 | max_T = trialtable[trial_id, 12]
158 |
159 | # grids= minT:(delT-TO):(maxT-delT);
160 | grid = np.arange(min_T, max_T + self.binning_period, self.binning_period)
161 | grids = grid[:-1]
162 | gride = grid[1:]
163 | num_bins = len(grids)
164 |
165 | neurons_binned = np.zeros((num_bins, num_neurons))
166 | pos_binned = np.zeros((num_bins, 2))
167 | vel_binned = np.zeros((num_bins, 2))
168 | acc_binned = np.zeros((num_bins, 2))
169 | force_binned = np.zeros((num_bins, 2))
170 | targets_binned = np.zeros((num_bins, 1))
171 | id_binned = trial_id * np.ones((num_bins, 1))
172 |
173 | for k in range(num_bins):
174 | bin_mask = (time >= grids[k]) & (time <= gride[k])
175 | if len(pos) > 0:
176 | pos_binned[k, :] = np.mean(pos[bin_mask, 1:], axis=0)
177 | vel_binned[k, :] = np.mean(vel[bin_mask, 1:], axis=0)
178 | if len(acc):
179 | acc_binned[k, :] = np.mean(acc[bin_mask, 1:], axis=0)
180 | if len(force) > 0:
181 | force_binned[k, :] = np.mean(force[bin_mask, 1:], axis=0)
182 | targets_binned[k, 0] = trialtable[trial_id, 1]
183 |
184 | for i in range(num_neurons):
185 | for k in range(num_bins):
186 | spike_times = neurons[i]['ts']
187 | bin_mask = (spike_times >= grids[k]) & (spike_times <= gride[k])
188 | neurons_binned[k, i] = np.sum(bin_mask) / self.binning_period
189 |
190 | data['firing_rates'].append(neurons_binned)
191 | data['position'].append(pos_binned)
192 | data['velocity'].append(vel_binned)
193 | data['acceleration'].append(acc_binned)
194 | data['force'].append(force_binned)
195 | data['labels'].append(targets_binned)
196 | data['sequence'].append(id_binned)
197 | return data
198 |
199 | def _split_data(self, data):
200 | num_trials = len(data['firing_rates'])
201 | split_id = int(num_trials * self.train_split)
202 |
203 | data_train = {}
204 | data_test = {}
205 | for key, feature in data.items():
206 | data_train[key] = feature[:split_id]
207 | data_test[key] = feature[split_id:]
208 | return data_train, data_test
209 |
210 | def _remove_static_neurons(self, data_train, data_test):
211 | for i in range(len(data_train['firing_rates'])):
212 | data_train['firing_rates'][i] = data_train['firing_rates'][i][:, self.std > 1e-3]
213 | for i in range(len(data_test['firing_rates'])):
214 | data_test['firing_rates'][i] = data_test['firing_rates'][i][:, self.std > 1e-3]
215 | self.mean = self.mean[self.std > 1e-3]
216 | self.std = self.std[self.std > 1e-3]
217 | return data_train, data_test
218 |
219 | def _compute_mean_std(self, data, feature='firing_rates'):
220 | concatenated_data = np.concatenate(data[feature])
221 | mean = concatenated_data.mean(axis=0)
222 | std = concatenated_data.std(axis=0)
223 | return mean, std
224 |
225 | def _scale_data(self, data_train, data_test, feature):
226 | concatenated_data = np.concatenate(data_train[feature])
227 | mean = concatenated_data.mean(axis=0)
228 | std = concatenated_data.std(axis=0)
229 |
230 | for i in range(len(data_train[feature])):
231 | data_train[feature][i] = (data_train[feature][i] - mean) / std
232 | for i in range(len(data_test[feature])):
233 | data_test[feature][i] = (data_test[feature][i] - mean) / std
234 | return data_train, data_test
235 |
236 | def _sort_by_reach_direction(self, data):
237 | sorted_by_label = np.argsort(np.array([reach_dir[0, 0] for reach_dir in data['labels']]))
238 | for feature in data.keys():
239 | data[feature] = np.array(data[feature])[sorted_by_label]
240 | return data
241 |
242 | def _map_labels(self, data):
243 | labels = data['labels']
244 | for i, l in enumerate(np.unique(labels)):
245 | labels[data['labels']==l] = i
246 | return labels
247 |
248 |
249 | def get_class_data(dataset, device='cpu'):
250 | def get_data():
251 | firing_rates = dataset.firing_rates
252 | labels = dataset.labels
253 | data = [torch.tensor(firing_rates, dtype=torch.float32, device=device),
254 | torch.tensor(labels, dtype=torch.long, device=device)]
255 | return data
256 | dataset.train()
257 | data_train = get_data()
258 |
259 | dataset.test()
260 | data_test = get_data()
261 |
262 | dataset.train()
263 | return data_train, data_test
264 |
265 |
266 | def get_angular_data(dataset, velocity_threshold=-1., device='cpu'):
267 | def get_data():
268 | velocity_mask = np.linalg.norm(dataset.velocity, 2, axis=1) > velocity_threshold
269 | firing_rates = dataset.firing_rates[velocity_mask]
270 | labels = dataset.labels[velocity_mask]
271 |
272 | angles = (2 * np.pi / 8 * labels)[:, np.newaxis]
273 | cos_sin = np.concatenate([np.cos(angles), np.sin(angles)], axis=1)
274 | data = [torch.tensor(firing_rates, dtype=torch.float32, device=device),
275 | torch.tensor(angles, dtype=torch.float32, device=device),
276 | torch.tensor(cos_sin, dtype=torch.float32, device=device)]
277 | return data
278 | dataset.train()
279 | data_train = get_data()
280 |
281 | dataset.test()
282 | data_test = get_data()
283 |
284 | dataset.train()
285 | return data_train, data_test
286 |
--------------------------------------------------------------------------------
/self_supervised/data/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def diagu_indices(n, k_min=1, k_max=None):
6 | if k_max is None:
7 | return np.array(np.triu_indices(n, 1)).T
8 | else:
9 | all_pairs = set(zip(*map(np.ndarray.tolist, np.triu_indices(n, k_min))))
10 | rm_pairs = set(zip(*map(np.ndarray.tolist, np.triu_indices(n, 1 + k_max))))
11 | pairs = all_pairs - rm_pairs
12 | return np.array(list(pairs))
13 |
14 |
15 | def onlywithin_indices(sequence_lengths, k_min=1, k_max=None):
16 | cum_n = 0
17 | pair_arrays = []
18 | for i, n in enumerate(sequence_lengths):
19 | pairs = diagu_indices(n, k_min, k_max) + cum_n
20 | pair_arrays.append(np.hstack([np.ones((pairs.shape[0], 1))*i, pairs]))
21 | cum_n += n
22 | return np.concatenate(pair_arrays).astype(int)
23 |
24 |
25 | def batch_iter(X, *tensors, batch_size=256):
26 | r"""Creates iterator over tensors.
27 |
28 | Args:
29 | X (torch.tensor): Feature tensor (shape: num_instances x num_features).
30 | tensors (torch.tensor): Target tensors (shape: num_instances).
31 | batch_size (int, Optional): Batch size. (default: :obj:`256`)
32 | """
33 | idxs = torch.randperm(X.size(0))
34 | if X.is_cuda:
35 | idxs = idxs.cuda()
36 | for batch_idxs in idxs.split(batch_size):
37 | res = [X[batch_idxs]]
38 | for tensor in tensors:
39 | res.append(tensor[batch_idxs])
40 | yield res
41 |
--------------------------------------------------------------------------------
/self_supervised/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_loss import CosineLoss
2 |
--------------------------------------------------------------------------------
/self_supervised/loss/cosine_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class CosineLoss(torch.nn.Module):
6 | r"""Cosine loss.
7 |
8 | .. note::
9 |
10 | Also known as normalized L2 distance.
11 | """
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, outputs, targets):
16 | outputs = F.normalize(outputs, dim=-1, p=2)
17 | targets = F.normalize(targets, dim=-1, p=2)
18 | return (2 - 2 * (outputs * targets).sum(dim=-1)).mean()
19 |
--------------------------------------------------------------------------------
/self_supervised/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp3 import MLP3
2 | from .byol import BYOL
3 | from .double_byol import DoubleBYOL
4 | from .myow_factory import myow_factory
5 |
6 | # generate myow variants
7 | MYOW = myow_factory(DoubleBYOL)
8 |
9 | __all__ = [
10 | 'BYOL',
11 | 'MLP3',
12 | 'MYOW'
13 | ]
14 |
--------------------------------------------------------------------------------
/self_supervised/model/byol.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 |
6 | class BYOL(torch.nn.Module):
7 | r"""Base backbone-agnostic BYOL architecture.
8 | The BYOL architecture was proposed by Grill et al. in
9 | Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
10 | https://arxiv.org/abs/2006.07733
11 |
12 | Two views are separately forwarded through the online and target networks:
13 | .. math::
14 | y = f_{\theta}(x),\ z = g_{\theta}(y)\\
15 | y^\prime = f_{\xi}(x^\prime),\ z^\prime = g_{\xi}(y^\prime)
16 |
17 | then the predictor learns to predict the target projection from the online projection in order to minimize the
18 | following loss:
19 | .. math::
20 | \mathcal{L}_{\theta, \xi} = 2-2 \cdot \frac{\left\langle q_{\theta}\left(z\right),
21 | z^{\prime}\right\rangle}{\left\|q_{\theta}\left(z\right)\right\|_{2}
22 | \cdot\left\|z^{\prime}\right\|_{2}}.
23 |
24 | Args:
25 | encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks.
26 | projector (torch.nn.Module): Projector network to be duplicated and used in both online and target networks.
27 | predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection.
28 | """
29 | def __init__(self, encoder, projector, predictor):
30 | super().__init__()
31 | self.online_encoder = encoder
32 | self.target_encoder = copy.deepcopy(encoder)
33 | self._stop_gradient(self.target_encoder)
34 |
35 | self.online_projector = projector
36 | self.target_projector = copy.deepcopy(projector)
37 | self._stop_gradient(self.target_projector)
38 |
39 | self.predictor = predictor
40 |
41 | @property
42 | def trainable_modules(self):
43 | r"""Returns the list of modules that will updated via an optimizer."""
44 | return [self.online_encoder, self.online_projector, self.predictor]
45 |
46 | @property
47 | def _ema_module_pairs(self):
48 | return [(self.online_encoder, self.target_encoder),
49 | (self.online_projector, self.target_projector)]
50 |
51 | def _stop_gradient(self, network):
52 | r"""Stops parameters of :obj:`network` of being updated through back-propagation."""
53 | for param in network.parameters():
54 | param.requires_grad = False
55 |
56 | @torch.no_grad()
57 | def _reset_moving_average(self):
58 | r"""Resets target network to have the same parameters as the online network."""
59 | for online_module, target_module in self._ema_module_pairs:
60 | for param_q, param_k in zip(online_module.parameters(), target_module.parameters()):
61 | param_k.data.copy_(param_q.data) # initialize
62 | param_k.requires_grad = False # stop gradient
63 |
64 | @torch.no_grad()
65 | def update_target_network(self, mm):
66 | r"""Performs a momentum update of the target network's weights.
67 |
68 | Args:
69 | mm (float): Momentum used in moving average update.
70 | """
71 | assert 0.0 <= mm <= 1.0, "Momentum needs to be between 0.0 and 1.0, got %.5f" % mm
72 | for online_module, target_module in self._ema_module_pairs:
73 | for param_q, param_k in zip(online_module.parameters(), target_module.parameters()):
74 | param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)
75 |
76 | def forward(self, inputs, get_embedding='predictor'):
77 | r"""Defines the computation performed at every call. Supports single or dual forwarding through online and/or
78 | target networks. Supports resuming computation from encoder space.
79 |
80 | :obj:`get_embedding` determines whether the computation stops at the encoder space (:obj:`"encoder"`) or all
81 | computations including the projections and the prediction (:obj:`"predictor"`).
82 |
83 | :obj:`inputs` can include :obj:`"online_view"` and/or :obj:`"target_view"`. The prefix determines which
84 | branch the tensor is passed through. It is possible to give one view only, which will be forwarded through
85 | its corresponding branch.
86 |
87 | To resume computation from the encoder space, simply pass :obj:`"online_y"` and/or :obj:`"target_y"`. If, for
88 | example, :obj:`"online_y"` is present in :obj:`inputs` then :obj:`"online_view"`, if passed, would be ignored.
89 |
90 | Args:
91 | inputs (dict): Inputs to be forwarded through the networks.
92 | get_embedding (String, Optional): Determines where the computation stops, can be :obj:`"encoder"` or
93 | :obj:`"predictor"`. (default: :obj:`"predictor"`)
94 |
95 | Returns:
96 | dict
97 |
98 | Example::
99 | net = BYOL(...)
100 | inputs = {'online_view': x1, 'target_view': x2}
101 | outputs = net(inputs) # outputs online_q and target_z
102 |
103 | inputs = {'online_view': x1}
104 | outputs = net(inputs, get_embedding='encoder') # outputs online_y
105 |
106 | inputs = {'online_y': y1, 'target_y': y2}
107 | outputs = net(inputs) # outputs online_q and target_z
108 | """
109 | assert get_embedding in ['encoder', 'predictor'], "Module name needs to be in %r." % ['encoder', 'predictor']
110 |
111 | outputs = {}
112 | if 'online_view' in inputs or 'online_y' in inputs:
113 | # forward online network
114 | if not('online_y' in inputs):
115 | # representation is not already computed, requires forwarding the view through the online encoder.
116 | online_view = inputs['online_view']
117 | online_y = self.online_encoder(online_view)
118 | online_y = online_y.view(online_y.shape[0], -1).contiguous() # flatten
119 | else:
120 | # resume forwarding
121 | online_y = inputs['online_y']
122 |
123 | if get_embedding == 'encoder':
124 | outputs['online_y'] = online_y
125 |
126 | if get_embedding == 'predictor':
127 | online_z = self.online_projector(online_y)
128 | online_q = self.predictor(online_z)
129 |
130 | outputs['online_q'] = online_q
131 |
132 | if 'target_view' in inputs or 'target_y' in inputs:
133 | # forward target network
134 | with torch.no_grad():
135 | if not ('target_y' in inputs):
136 | # representation is not already computed, requires forwarding the view through the target encoder.
137 | target_view = inputs['target_view']
138 | target_y = self.target_encoder(target_view)
139 | target_y = target_y.view(target_y.shape[0], -1).contiguous() # flatten
140 | else:
141 | # resume forwarding
142 | target_y = inputs['target_y']
143 |
144 | if get_embedding == 'encoder':
145 | outputs['target_y'] = target_y
146 |
147 | if get_embedding == 'predictor':
148 | # forward projector and predictor
149 | target_z = self.target_projector(target_y).detach().clone()
150 |
151 | outputs['target_z'] = target_z
152 | return outputs
153 |
--------------------------------------------------------------------------------
/self_supervised/model/double_byol.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 | from .byol import BYOL
6 |
7 |
8 | class DoubleBYOL(BYOL):
9 | r"""BYOL with dual projector/predictor networks.
10 |
11 | When the projectors are cascaded, the two views are separately forwarded through the online and target networks:
12 | .. math::
13 | y = f_{\theta}(x),\ z = g_{\theta}(y), v = h_{\theta}(z)\\
14 | y^\prime = f_{\xi}(x^\prime),\ z^\prime = g_{\xi}(y^\prime), v^\prime = h_{\xi}(z^\prime)
15 |
16 | then prediction is performed either in the first projection space or the second.
17 | In the first, the predictor learns to predict the target projection from the online projection in order to minimize
18 | the following loss:
19 | .. math::
20 | \mathcal{L}_{\theta, \xi} = 2-2 \cdot \frac{\left\langle q_{\theta}\left(z\right),
21 | z^{\prime}\right\rangle}{\left\|q_{\theta}\left(z\right)\right\|_{2}
22 | \cdot\left\|z^{\prime}\right\|_{2}}
23 |
24 | In the second, the second predictor learns to predict the second target projection from the second online projection
25 | in order to minimize the following loss:
26 | .. math::
27 | \mathcal{L}_{\theta, \xi} = 2-2 \cdot \frac{\left\langle r_{\theta}\left(v\right),
28 | v^{\prime}\right\rangle}{\left\|v_{\theta}\left(v\right)\right\|_{2}
29 | \cdot\left\|v^{\prime}\right\|_{2}}.
30 |
31 | Args:
32 | encoder (torch.nn.Module): Encoder network to be duplicated and used in both online and target networks.
33 | projector (torch.nn.Module): Projector network to be duplicated and used in both online and target networks.
34 | projector_m (torch.nn.Module): Second projector network to be duplicated and used in both online and
35 | target networks.
36 | predictor (torch.nn.Module): Predictor network used to predict the target projection from the online projection.
37 | predictor_m (torch.nn.Module): Second predictor network used to predict the target projection from the
38 | online projection.
39 | layout (String, Optional): Defines the layout of the dual projectors. Can be either :obj:`"cascaded"` or
40 | :obj:`"parallel"`. (default: :obj:`"cascaded"`)
41 | """
42 | def __init__(self, encoder, projector, projector_m, predictor, predictor_m, layout='cascaded'):
43 | super().__init__(encoder, projector, predictor)
44 |
45 | assert layout in ['cascaded', 'parallel'], "layout should be 'cascaded' or 'parallel', got %s." % layout
46 | self.layout = layout
47 |
48 | self.online_projector_m = projector_m
49 | self.target_projector_m = copy.deepcopy(projector_m)
50 | self._stop_gradient(self.target_projector_m)
51 |
52 | self.predictor_m = predictor_m
53 |
54 | @property
55 | def trainable_modules(self):
56 | r"""Returns the list of modules that will updated via an optimizer."""
57 | return [self.online_encoder, self.online_projector, self.online_projector_m, self.predictor, self.predictor_m]
58 |
59 | @property
60 | def _ema_module_pairs(self):
61 | return [(self.online_encoder, self.target_encoder),
62 | (self.online_projector, self.target_projector),
63 | (self.online_projector_m, self.target_projector_m)]
64 |
65 | def forward(self, inputs, get_embedding='predictor'):
66 | r"""Defines the computation performed at every call. Supports single or dual forwarding through online and/or
67 | target networks. Supports resuming computation from encoder space.
68 |
69 | :obj:`get_embedding` determines whether the computation stops at the encoder space (:obj:`"encoder"`) or at
70 | the first projector space (:obj:`"predictor"`) or the second (:obj:`"predictor_m"`). With the last two
71 | options, predictions are also made.
72 |
73 | :obj:`inputs` can include :obj:`"online_view"` and/or :obj:`"target_view"`. The prefix determines which
74 | branch the tensor is passed through. It is possible to give one view only, which will be forwarded through
75 | its corresponding branch.
76 |
77 | To resume computation from the encoder space, simply pass :obj:`"online_y"` and/or :obj:`"target_y"`. If, for
78 | example, :obj:`"online_y"` is present in :obj:`inputs` then :obj:`"online_view"`, if passed, would be ignored.
79 |
80 | Args:
81 | inputs (dict): Inputs to be forwarded through the networks.
82 | get_embedding (String, Optional): Determines where the computation stops, can be :obj:`"encoder"`,
83 | :obj:`"predictor"` or :obj:`"predictor_m"`. (default: :obj:`"predictor"`)
84 |
85 | Returns:
86 | dict
87 |
88 | Example::
89 | net = BYOL(...)
90 | inputs = {'online_view': x1, 'target_view': x2}
91 | outputs = net(inputs) # outputs online_q and target_z
92 |
93 | inputs = {'online_view': x1}
94 | outputs = net(inputs, get_embedding='encoder') # outputs online_y
95 |
96 | inputs = {'online_y': y1, 'target_y': y2}
97 | outputs = net(inputs) # outputs online_q and target_z
98 |
99 | inputs = {'online_view': x1, 'target_view': x2}
100 | outputs = net(inputs, get_embedding='predictor_m') # outputs online_q_m and target_v
101 | """
102 | assert get_embedding in ['encoder', 'predictor', 'predictor_m'], \
103 | "Module name needs to be in %r." % ['encoder', 'predictor', 'predictor_m']
104 |
105 | outputs = {}
106 | if 'online_view' in inputs or 'online_y' in inputs:
107 | # forward online network
108 | if not('online_y' in inputs):
109 | # representation is not already computed, requires forwarding the view through the online encoder.
110 | online_view = inputs['online_view']
111 | online_y = self.online_encoder(online_view)
112 | online_y = online_y.view(online_y.shape[0], -1).contiguous() # flatten
113 | else:
114 | # resume forwarding
115 | online_y = inputs['online_y']
116 |
117 | if get_embedding == 'encoder':
118 | outputs['online_y'] = online_y
119 |
120 | if get_embedding == 'predictor':
121 | online_z = self.online_projector(online_y)
122 | online_q = self.predictor(online_z)
123 |
124 | outputs['online_q'] = online_q
125 |
126 | if get_embedding == 'predictor_m':
127 | if self.layout == 'parallel':
128 | online_v = self.online_projector_m(online_y)
129 | online_q_m = self.predictor_m(online_v)
130 |
131 | outputs['online_q_m'] = online_q_m
132 |
133 | elif self.layout == 'cascaded':
134 | online_z = self.online_projector(online_y)
135 | online_v = self.online_projector_m(online_z)
136 | online_q_m = self.predictor_m(online_v)
137 |
138 | outputs['online_q_m'] = online_q_m
139 |
140 | if 'target_view' in inputs or 'target_y' in inputs:
141 | # forward target encoder
142 | with torch.no_grad():
143 | if not ('target_y' in inputs):
144 | # representation is not already computed, requires forwarding the view through the target encoder.
145 | target_view = inputs['target_view']
146 | target_y = self.target_encoder(target_view)
147 | target_y = target_y.view(target_y.shape[0], -1).contiguous()
148 | else:
149 | # resume forwarding
150 | target_y = inputs['target_y']
151 |
152 | if get_embedding == 'encoder':
153 | outputs['target_y'] = target_y
154 |
155 | if get_embedding == 'predictor':
156 | # forward projector and predictor
157 | target_z = self.target_projector(target_y).detach().clone()
158 |
159 | outputs['target_z'] = target_z
160 |
161 | if get_embedding == 'predictor_m':
162 | if self.layout == 'parallel':
163 | target_v = self.target_projector_m(target_y)
164 |
165 | outputs['target_v'] = target_v
166 |
167 | elif self.layout == 'cascaded':
168 | target_z = self.target_projector(target_y)
169 | target_v = self.target_projector_m(target_z)
170 |
171 | outputs['target_v'] = target_v
172 | return outputs
173 |
--------------------------------------------------------------------------------
/self_supervised/model/mlp3.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class MLP3(nn.Module):
5 | r"""MLP class used for projector and predictor in :class:`BYOL`. The MLP has one hidden layer.
6 |
7 | .. note::
8 |
9 | The hidden layer should be larger than both input and output layers, according to the
10 | :class:`BYOL` paper.
11 |
12 | Args:
13 | input_size (int): Size of input features.
14 | output_size (int): Size of output features (projection or prediction).
15 | hidden_size (int): Size of hidden layer.
16 | """
17 | def __init__(self, input_size, output_size, hidden_size):
18 | super().__init__()
19 | self.net = nn.Sequential(
20 | nn.Linear(input_size, hidden_size),
21 | nn.BatchNorm1d(hidden_size),
22 | nn.ReLU(inplace=True),
23 | nn.Linear(hidden_size, output_size)
24 | )
25 |
26 | def forward(self, x):
27 | return self.net(x)
28 |
--------------------------------------------------------------------------------
/self_supervised/model/myow_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def myow_factory(byol_class):
6 | r"""Factory function for adding mining feature to an architecture."""
7 | class MYOW(byol_class):
8 | r"""
9 | Class that adds ability to mine views to base class :obj:`byol_class`.
10 |
11 | Args:
12 | n_neighbors (int, optional): Number of neighbors used in knn. (default: :obj:`1`)
13 | """
14 |
15 | def __init__(self, *args, n_neighbors=1):
16 | super().__init__(*args)
17 |
18 | self.k = n_neighbors
19 |
20 | def _compute_distance(self, x, y):
21 | x = F.normalize(x, dim=-1, p=2)
22 | y = F.normalize(y, dim=-1, p=2)
23 |
24 | dist = 2 - 2 * torch.sum(x.view(x.shape[0], 1, x.shape[1]) *
25 | y.view(1, y.shape[0], y.shape[1]), -1)
26 | return dist
27 |
28 | def _knn(self, x, y):
29 | # compute distance
30 | dist = self._compute_distance(x, y)
31 |
32 | # compute k nearest neighbors
33 | values, indices = torch.topk(dist, k=self.k, largest=False)
34 |
35 | # randomly select one of the neighbors
36 | selection_mask = torch.randint(self.k, size=(indices.size(0),))
37 | mined_views_ids = indices[torch.arange(indices.size(0)).to(selection_mask), selection_mask]
38 | return mined_views_ids
39 |
40 | def mine_views(self, y, y_pool):
41 | r"""Finds, for each element in batch :obj:`y`, its nearest neighbors in :obj:`y_pool`, randomly selects one
42 | of them and returns the corresponding index.
43 |
44 | Args:
45 | y (torch.Tensor): batch of representation vectors.
46 | y_pool (torch.Tensor): pool of candidate representation vectors.
47 |
48 | Returns:
49 | torch.Tensor: Indices of mined views in :obj:`y_pool`.
50 | """
51 | mined_views_ids = self._knn(y, y_pool)
52 | return mined_views_ids
53 | return MYOW
54 |
--------------------------------------------------------------------------------
/self_supervised/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import MLP
2 | from .resnets import resnet_cifar
3 |
--------------------------------------------------------------------------------
/self_supervised/nets/mlp.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class MLP(nn.Module):
5 | r"""Multi-layer perceptron model, with optional regularization layers.
6 |
7 | Args:
8 | hidden_layers (list): List of layer dimensions, from input layer to output layer.
9 | activation (torch.nn.Module, optional): Activation function. (default: :obj:`nn.ReLU`).
10 | batchnorm (boolean, optional): If set to :obj:`True`, batchnorm layers are added after each block.
11 | (default: :obj:`False`).
12 | bias (boolean, optional): If set to :obj:`True`, bias will be used in linear layers. (default: :obj:`True`).
13 | drop_last_nonlin (boolean, optional): If set to :obj:`True`, the last layer won't have non-linearities or
14 | regularization layers. (default: :obj:`True`)
15 | """
16 | def __init__(self, hidden_layers, activation=nn.ReLU(True), batchnorm=False, bias=True, drop_last_nonlin=True):
17 | super().__init__()
18 |
19 | # build the layers
20 | layers = []
21 | for in_dim, out_dim in zip(hidden_layers[:-1], hidden_layers[1:]):
22 | layers.append(nn.Linear(in_dim, out_dim, bias=bias))
23 | if batchnorm:
24 | layers.append(nn.BatchNorm1d(num_features=out_dim))
25 | if activation is not None:
26 | layers.append(activation)
27 |
28 | # remove activation and/or batchnorm layers from the last block
29 | if drop_last_nonlin:
30 | remove_layers = -(int(activation is not None) + int(batchnorm))
31 | if remove_layers:
32 | layers = layers[:remove_layers]
33 |
34 | self.layers = nn.Sequential(*layers)
35 |
36 | def forward(self, x):
37 | x = self.layers(x)
38 | return x
39 |
--------------------------------------------------------------------------------
/self_supervised/nets/resnets.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.resnet import resnet18
3 |
4 |
5 | class resnet_cifar(nn.Module):
6 | r"""CIFAR-variant of ResNet18."""
7 | def __init__(self):
8 | super().__init__()
9 |
10 | self.f = []
11 | for name, module in resnet18().named_children():
12 | if name == 'conv1':
13 | module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
14 |
15 | if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
16 | self.f.append(module)
17 |
18 | # encoder
19 | self.f = nn.Sequential(*self.f)
20 |
21 | def forward(self, x):
22 | y = self.f(x)
23 | return y.view(y.size(0), -1).contiguous()
24 |
25 | @staticmethod
26 | def get_linear_classifier(input_dim=512, output_dim=10):
27 | r"""Return linear classification layer."""
28 | return nn.Linear(input_dim, output_dim)
29 |
--------------------------------------------------------------------------------
/self_supervised/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .lars import LARS
2 |
--------------------------------------------------------------------------------
/self_supervised/optimizer/lars.py:
--------------------------------------------------------------------------------
1 | """Source https://github.com/noahgolmant/pytorch-lars"""
2 |
3 | import torch
4 | from torch.optim.optimizer import Optimizer
5 |
6 |
7 | class LARS(Optimizer):
8 | r"""Implements layer-wise adaptive rate scaling for SGD.
9 |
10 | Args:
11 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
12 | lr (float): base learning rate (\gamma_0)
13 | momentum (float, optional): momentum factor (default: 0) ("m")
14 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) ("\beta")
15 | dampening (float, optional): dampening for momentum (default: 0)
16 | eta (float, optional): LARS coefficient
17 | nesterov (bool, optional): enables Nesterov momentum (default: False)
18 |
19 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
20 | Large Batch Training of Convolutional Networks: https://arxiv.org/abs/1708.03888
21 |
22 | Example:
23 | >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, eta=1e-3)
24 | >>> optimizer.zero_grad()
25 | >>> loss_fn(model(input), target).backward()
26 | >>> optimizer.step()
27 | """
28 |
29 | def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, eta=0.001, nesterov=False):
30 | if lr < 0.0:
31 | raise ValueError("Invalid learning rate: {}".format(lr))
32 | if momentum < 0.0:
33 | raise ValueError("Invalid momentum value: {}".format(momentum))
34 | if weight_decay < 0.0:
35 | raise ValueError(
36 | "Invalid weight_decay value: {}".format(weight_decay))
37 | if eta < 0.0:
38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta))
39 |
40 | defaults = dict(
41 | lr=lr, momentum=momentum, dampening=dampening,
42 | weight_decay=weight_decay, nesterov=nesterov, eta=eta)
43 | if nesterov and (momentum <= 0 or dampening != 0):
44 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
45 |
46 | super(LARS, self).__init__(params, defaults)
47 |
48 | def __setstate__(self, state):
49 | super(LARS, self).__setstate__(state)
50 | for group in self.param_groups:
51 | group.setdefault('nesterov', False)
52 |
53 | @torch.no_grad()
54 | def step(self, closure=None):
55 | r"""Performs a single optimization step.
56 |
57 | Args:
58 | closure (callable, optional): A closure that reevaluates the model and returns the loss.
59 | """
60 | loss = None
61 | if closure is not None:
62 | with torch.enable_grad():
63 | loss = closure()
64 |
65 | for group in self.param_groups:
66 | weight_decay = group['weight_decay']
67 | momentum = group['momentum']
68 | dampening = group['dampening']
69 | eta = group['eta']
70 | nesterov = group['nesterov']
71 | lr = group['lr']
72 | lars_exclude = group.get('lars_exclude', False)
73 |
74 | for p in group['params']:
75 | if p.grad is None:
76 | continue
77 |
78 | d_p = p.grad
79 |
80 | if lars_exclude:
81 | local_lr = 1.
82 | else:
83 | weight_norm = torch.norm(p).item()
84 | grad_norm = torch.norm(d_p).item()
85 | # Compute local learning rate for this layer
86 | local_lr = eta * weight_norm / \
87 | (grad_norm + weight_decay * weight_norm)
88 |
89 | actual_lr = local_lr * lr
90 | d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr)
91 | if momentum != 0:
92 | param_state = self.state[p]
93 | if 'momentum_buffer' not in param_state:
94 | buf = param_state['momentum_buffer'] = \
95 | torch.clone(d_p).detach()
96 | else:
97 | buf = param_state['momentum_buffer']
98 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
99 | if nesterov:
100 | d_p = d_p.add(buf, alpha=momentum)
101 | else:
102 | d_p = buf
103 | p.add_(-d_p)
104 |
105 | return loss
106 |
--------------------------------------------------------------------------------
/self_supervised/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification import train_classifier
2 | from . import fast_classification
3 | from . import neural_tasks
4 |
--------------------------------------------------------------------------------
/self_supervised/tasks/classification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 |
4 | from self_supervised.data import utils
5 | from self_supervised.utils import MetricLogger
6 |
7 |
8 | def compute_accuracy(net, classifier, data, transform=None, device='cpu'):
9 | r"""Evaluates the classification accuracy when a list of :class:`torch.Tensor` is given.
10 |
11 | Args:
12 | net (torch.nn.Module): Frozen encoder.
13 | classifier (torch.nn.Module): Linear layer.
14 | data (list of torch.nn.Tensor): Inputs, target class and target angles.
15 | transform (Callable, Optional): Transformation to use. Added for the purposes of
16 | normalization. (default: :obj:`None`)
17 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
18 |
19 | Returns:
20 | float: Accuracy.
21 | """
22 | classifier.eval()
23 | # prepare inputs
24 | x, label = data
25 | x = x.to(device)
26 | label = label.to(device)
27 |
28 | if transform is not None:
29 | x = transform(x)
30 |
31 | # feed to network and classifier
32 | with torch.no_grad():
33 | representation = net(x)
34 | pred_logits = classifier(representation)
35 |
36 | # compute accuracy
37 | _, pred_class = torch.max(pred_logits, 1)
38 | acc = (pred_class == label).sum().item() / label.size(0)
39 | return acc
40 |
41 |
42 | def compute_accuracy_dataloader(net, classifier, dataloader, transform=None, device='cpu'):
43 | r"""Evaluates the classification accuracy when a :obj:`torch.data.DataLoader` is given.
44 |
45 | Args:
46 | net (torch.nn.Module): Frozen encoder.
47 | classifier (torch.nn.Module): Linear layer.
48 | dataloader (torch.data.DataLoader): Dataloader.
49 | transform (Callable, Optional): Transformation to use. Added for the purposes of
50 | normalization. (default: :obj:`None`)
51 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
52 |
53 | Returns:
54 | float: Accuracy.
55 | """
56 | classifier.eval()
57 | acc = []
58 | for x, label in dataloader:
59 | x = x.to(device)
60 | label = label.to(device)
61 |
62 | if transform is not None:
63 | x = transform(x)
64 |
65 | # feed to network and classifier
66 | with torch.no_grad():
67 | representation = net(x)
68 | representation = representation.view(representation.shape[0], -1)
69 | pred_logits = classifier(representation)
70 | # compute accuracy
71 | _, pred_class = torch.max(pred_logits, 1)
72 | acc.append((pred_class == label).sum().item() / label.size(0))
73 | return sum(acc)/len(acc)
74 |
75 |
76 | def train_classifier(net, classifier, data_train, data_val, optimizer, scheduler=None, transform=None,
77 | transform_val=None, batch_size=256, num_epochs=10, device='cpu',
78 | writer=None, tag='', tqdm_progress=False):
79 | r"""Trains linear layer to predict angle.
80 |
81 | Args:
82 | net (torch.nn.Module): Frozen encoder.
83 | classifier (torch.nn.Module): Trainable linear layer.
84 | data_train (torch.data.DataLoader or list of torch.nn.Tensor): Inputs and target class.
85 | data_val (torch.data.DataLoader or list of torch.nn.Tensor): Inputs and target class.
86 | optimizer (torch.optim.Optimizer): Optimizer for :obj:`classifier`.
87 | scheduler (torch.optim._LRScheduler, Optional): Learning rate scheduler. (default: :obj:`None`)
88 | transform (Callable, Optional): Transformation to use during training. (default: :obj:`None`)
89 | transform_val (Callable, Optional): Transformation to use during validation. Added for the purposes of
90 | normalization. (default: :obj:`None`)
91 | batch_size (int, Optional): Batch size used during training. (default: :obj:`256`)
92 | num_epochs (int, Optional): Number of training epochs. (default: :obj:`10`)
93 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
94 | writer (torch.utils.tensorboard.SummaryWriter, Optional): Summary writer. (default: :obj:`None`)
95 | tag (String, Optional): Tag used in :obj:`writer`. (default: :obj:`""`)
96 | tqdm_progress (bool, Optional): If :obj:`True`, show training progress.
97 |
98 | Returns:
99 | MetricLogger: Accuracy.
100 | """
101 | class_criterion = torch.nn.CrossEntropyLoss()
102 |
103 | acc = MetricLogger()
104 | for epoch in tqdm(range(num_epochs), disable=not tqdm_progress):
105 | classifier.train()
106 | if isinstance(data_train, list):
107 | iterator = utils.batch_iter(*data_train, batch_size=batch_size)
108 | else:
109 | iterator = iter(data_train)
110 |
111 | for x, label in iterator:
112 | optimizer.zero_grad()
113 |
114 | # load data
115 | x = x.to(device)
116 | label = label.to(device)
117 |
118 | if transform is not None:
119 | x = transform(x)
120 |
121 | # forward
122 | with torch.no_grad():
123 | representation = net(x)
124 | representation = representation.view(representation.shape[0], -1)
125 |
126 | pred_class = classifier(representation)
127 |
128 | # loss
129 | loss = class_criterion(pred_class, label)
130 |
131 | # backward
132 | loss.backward()
133 | optimizer.step()
134 |
135 | if scheduler is not None:
136 | scheduler.step()
137 |
138 | # compute classification accuracies
139 | if isinstance(data_train, list):
140 | acc_val = compute_accuracy(net, classifier, data_val, transform=transform_val, device=device)
141 | else:
142 | acc_val = compute_accuracy_dataloader(net, classifier, data_val, transform=transform_val, device=device)
143 |
144 | acc.update(0., acc_val)
145 | if writer is not None:
146 | writer.add_scalar('eval_acc/val-%r' % tag, acc_val, epoch)
147 |
148 | if isinstance(data_train, list):
149 | acc_train = compute_accuracy(net, classifier, data_train, transform=transform_val, device=device)
150 | else:
151 | acc_train = compute_accuracy_dataloader(net, classifier, data_train, transform=transform_val, device=device)
152 | acc.update(acc_train, acc_val)
153 | return acc
154 |
155 |
--------------------------------------------------------------------------------
/self_supervised/tasks/fast_classification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 |
4 | from self_supervised.data import utils
5 | from self_supervised.utils import MetricLogger
6 |
7 |
8 | def compute_representations(net, dataloader, device='cpu'):
9 | r"""Pre-computes the representation for the entire dataset.
10 |
11 | Args:
12 | net (torch.nn.Module): Frozen encoder.
13 | dataloader (torch.data.DataLoader): Dataloader.
14 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
15 |
16 | Returns:
17 | [torch.Tensor, torch.Tensor]: Representations and labels.
18 | """
19 | net.eval()
20 | reps = []
21 | labels = []
22 |
23 | for i, (x, label) in tqdm(enumerate(dataloader)):
24 | # load data
25 | x = x.to(device).squeeze()
26 | labels.append(label)
27 |
28 | # forward
29 | with torch.no_grad():
30 | representation = net(x)
31 | reps.append(representation.detach().cpu().squeeze())
32 |
33 | if i % 10 == 0:
34 | reps = [torch.cat(reps, dim=0)]
35 | labels = [torch.cat(labels, dim=0)]
36 |
37 | reps = torch.cat(reps, dim=0)
38 | labels = torch.cat(labels, dim=0)
39 | return [reps, labels]
40 |
41 |
42 | def compute_accuracy(classifier, data, batch_size=256, device='cpu'):
43 | r"""Evaluates the classification accuracy with representations pre-computed.
44 |
45 | Args:
46 | classifier (torch.nn.Module): Linear layer.
47 | data (list of torch.nn.Tensor): Inputs, target class and target angles.
48 | batch_size (int, Optional): Batch size used during evaluation. It has no impact on final accuracy.
49 | (default: :obj:`256`)
50 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
51 |
52 | Returns:
53 | float: Accuracy.
54 | """
55 | # prepare inputs
56 | classifier.eval()
57 | right = []
58 | total = []
59 | for x, label in utils.batch_iter(*data, batch_size=batch_size):
60 | x = x.to(device)
61 | label = label.to(device)
62 |
63 | # feed to network and classifier
64 | with torch.no_grad():
65 | pred_logits = classifier(x)
66 | # compute accuracy
67 | _, pred_class = torch.max(pred_logits, 1)
68 | right.append((pred_class == label).sum().item())
69 | total.append(label.size(0))
70 | classifier.train()
71 | return sum(right) / sum(total)
72 |
73 |
74 | def train_linear_layer(classifier, data_train, data_val, optimizer, scheduler=None,
75 | batch_size=256, num_epochs=10, device='cpu', writer=None, tag="", tqdm_progress=False):
76 | r"""Trains linear layer to predict angle with representation pre-computed.
77 |
78 | Args:
79 | classifier (torch.nn.Module): Trainable linear layer.
80 | data_train (torch.data.DataLoader or list of torch.nn.Tensor): Representations and target class.
81 | data_val (torch.data.DataLoader or list of torch.nn.Tensor): Representations and target class.
82 | optimizer (torch.optim.Optimizer): Optimizer for :obj:`classifier`.
83 | scheduler (torch.optim._LRScheduler, Optional): Learning rate scheduler. (default: :obj:`None`)
84 | batch_size (int, Optional): Batch size used during training. (default: :obj:`256`)
85 | num_epochs (int, Optional): Number of training epochs. (default: :obj:`10`)
86 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
87 | writer (torch.utils.tensorboard.SummaryWriter, Optional): Summary writer. (default: :obj:`None`)
88 | tag (String, Optional): Tag used in :obj:`writer`. (default: :obj:`""`)
89 | tqdm_progress (bool, Optional): If :obj:`True`, show training progress.
90 |
91 | Returns:
92 | MetricLogger: Accuracy.
93 | """
94 | class_criterion = torch.nn.CrossEntropyLoss()
95 |
96 | acc = MetricLogger(smoothing_factor=0.2)
97 | for epoch in tqdm(range(num_epochs), disable=not tqdm_progress):
98 | for x, label in utils.batch_iter(*data_train, batch_size=batch_size):
99 | classifier.train()
100 | optimizer.zero_grad()
101 |
102 | # load data
103 | x = x.to(device)
104 | label = label.to(device)
105 |
106 | pred_class = classifier(x)
107 |
108 | # loss
109 | loss = class_criterion(pred_class, label)
110 |
111 | # backward
112 | loss.backward()
113 | optimizer.step()
114 | if scheduler is not None:
115 | scheduler.step()
116 |
117 | # compute classification accuracies
118 | acc_val = compute_accuracy(classifier, data_val, batch_size=batch_size, device=device)
119 | acc.update(0., acc_val)
120 | if writer is not None:
121 | writer.add_scalar('eval_acc/val-%r' %tag, acc_val, epoch)
122 |
123 | acc_train = compute_accuracy(classifier, data_train, batch_size=batch_size, device=device)
124 | acc.update(acc_train, acc_val)
125 | return acc
126 |
--------------------------------------------------------------------------------
/self_supervised/tasks/neural_tasks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import numpy as np
4 |
5 | from self_supervised.data import utils
6 | from self_supervised.utils import MetricLogger
7 |
8 |
9 | def compute_angle_accuracy(net, classifier, data, transform=None, device='cpu'):
10 | r"""Evaluates the angle prediction performance of the network.
11 |
12 | Args:
13 | net (torch.nn.Module): Frozen encoder.
14 | classifier (torch.nn.Module): Linear layer.
15 | data (list of torch.nn.Tensor): Inputs, target class and target angles.
16 | transform (Callable, Optional): Transformation to use. Added for the purposes of
17 | normalization. (default: :obj:`None`)
18 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
19 |
20 | Returns:
21 | (float, float): Accuracy and delta-Accuracy.
22 | """
23 | # prepare inputs
24 | classifier.eval()
25 | x, a, y = data
26 |
27 | x = x.to(device).squeeze()
28 | a = a.to(device).squeeze()
29 |
30 | # transform data
31 | if transform is not None:
32 | [x] = transform(x)
33 |
34 | # feed to classifier
35 | with torch.no_grad():
36 | representation = net(x).detach()
37 | pred_cos_sin = classifier(representation).detach().clone()
38 |
39 | pred_angles = torch.atan2(pred_cos_sin[:, 1], pred_cos_sin[:, 0])
40 | pred_angles[pred_angles < 0] = pred_angles[pred_angles < 0] + 2 * np.pi
41 |
42 | diff_angles = torch.abs(pred_angles - a.squeeze())
43 | diff_angles[diff_angles > np.pi] = torch.abs(diff_angles[diff_angles > np.pi] - 2 * np.pi)
44 |
45 | error = 0.
46 | acc = (diff_angles < (np.pi / 8)).sum()
47 | acc = acc.item() / x.size(0)
48 | delta_acc = (diff_angles < (3 * np.pi / 16)).sum()
49 | delta_acc = delta_acc.item() / x.size(0)
50 | return acc, delta_acc
51 |
52 |
53 | def train_angle_classifier(net, classifier, data_train, data_val, optimizer, transform=None,
54 | transform_val=None, batch_size=256, num_epochs=10, device='cpu'):
55 | r"""Trains linear layer to predict angle.
56 |
57 | Args:
58 | net (torch.nn.Module): Frozen encoder.
59 | classifier (torch.nn.Module): Trainable linear layer.
60 | data_train (list of torch.nn.Tensor): Inputs, target class and target angles.
61 | data_val (list of torch.nn.Tensor): Inputs, target class and target angles.
62 | optimizer (torch.optim.Optimizer): Optimizer for :obj:`classifier`.
63 | transform (Callable, Optional): Transformation to use during training. (default: :obj:`None`)
64 | transform_val (Callable, Optional): Transformation to use during validation. Added for the purposes of
65 | normalization. (default: :obj:`None`)
66 | batch_size (int, Optional): Batch size used during training. (default: :obj:`256`)
67 | num_epochs (int, Optional): Number of training epochs. (default: :obj:`10`)
68 | device (String, Optional): Device used. (default: :obj:`"cpu"`)
69 |
70 | Returns:
71 | (MetricLogger, MetricLogger): Accuracy and delta-Accuracy.
72 | """
73 | class_criterion = torch.nn.MSELoss()
74 |
75 | acc = MetricLogger()
76 | delta_acc = MetricLogger()
77 |
78 | for epoch in tqdm(range(num_epochs), disable=True):
79 | classifier.train()
80 | for x, _, label in utils.batch_iter(*data_train, batch_size=batch_size):
81 | x = x.to(device).squeeze()
82 | label = label.to(device).squeeze()
83 |
84 | # transform data
85 | if transform is not None:
86 | [x] = transform(x)
87 |
88 | optimizer.zero_grad()
89 | # forward
90 | with torch.no_grad():
91 | representation = net(x).detach().clone()
92 | representation = representation.view(representation.shape[0], -1)
93 |
94 | pred_class = classifier(representation)
95 |
96 | # loss
97 | loss = class_criterion(pred_class, label)
98 |
99 | # backward
100 | loss.backward()
101 | optimizer.step()
102 |
103 | # compute classification accuracies
104 | acc_train, delta_acc_train = compute_angle_accuracy(net, classifier, data_train, transform=transform_val,
105 | device=device)
106 | acc_test, delta_acc_test = compute_angle_accuracy(net, classifier, data_val, transform=transform_val,
107 | device=device)
108 |
109 | acc.update(acc_train, acc_test, step=epoch)
110 | delta_acc.update(delta_acc_train, delta_acc_test, step=epoch)
111 | return acc, delta_acc
112 |
--------------------------------------------------------------------------------
/self_supervised/tensorboard/__init__.py:
--------------------------------------------------------------------------------
1 | from . import embedding_projector
2 |
--------------------------------------------------------------------------------
/self_supervised/tensorboard/embedding_projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tensorflow as tf
3 | import tensorboard as tb
4 |
5 | # fix a bug with tensorboard
6 | tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
7 |
8 |
9 | def log_representation(net, inputs, metadata, writer, step, tag='representation', metadata_header=None,
10 | inputs_are_images=False):
11 | r"""
12 | Computes representations and logs them to tensorboard.
13 |
14 | Args:
15 | net (torch.nn.Module): Encoder.
16 | inputs (torch.Tensor): Inputs.
17 | writer (torch.writer.SummaryWriter): Summary writer.
18 | metadata (torch.Tensor or list): A list of labels, each element will be convert to string.
19 | step (int): Global step value to record.
20 | tag (string, optional): Name for the embedding. (default: :obj:`representation`)
21 | metadata_header (list, optional): Metadata header. (default: :obj:`None`)
22 | inputs_are_images (boolean, optional): Set to :obj:`True` if inputs are images. (default: :obj:`False`)
23 | """
24 | with torch.no_grad():
25 | representation = net(inputs)
26 | representation = representation.view(representation.shape[0], -1).detach()
27 |
28 | label_img = inputs if inputs_are_images else None
29 | writer.add_embedding(representation, metadata, tag=tag, global_step=step, metadata_header=metadata_header,
30 | label_img=label_img)
31 |
--------------------------------------------------------------------------------
/self_supervised/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .byol_trainer import BYOLTrainer
2 | from .myow_trainer import MYOWTrainer
3 |
--------------------------------------------------------------------------------
/self_supervised/trainer/byol_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | import torch.distributed as dist
8 | from torch.utils.tensorboard import SummaryWriter
9 | from torch.optim import SGD, Adam
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 |
12 | from self_supervised.model import BYOL, MLP3
13 | from self_supervised.optimizer import LARS
14 | from self_supervised.loss import CosineLoss
15 |
16 |
17 | class BYOLTrainer():
18 | def __init__(self,
19 | encoder, representation_size, projection_size, projection_hidden_size,
20 | train_dataloader, prepare_views, total_epochs, warmup_epochs, base_lr, base_momentum,
21 | batch_size=256, decay='cosine', n_decay=1.5, m_decay='cosine',
22 | optimizer_type="lars", momentum=1.0, weight_decay=1.0, exclude_bias_and_bn=False,
23 | transform=None, transform_1=None, transform_2=None, symmetric_loss=False,
24 | world_size=1, rank=0, distributed=False, gpu=0, master_gpu=0, port='12355',
25 | ckpt_path="./models/ckpt-%d.pt", log_step=1, log_dir=None, **kwargs):
26 |
27 | # device parameters
28 | self.world_size = world_size
29 | self.rank = rank
30 | self.gpu = gpu
31 | self.master_gpu = master_gpu
32 | self.distributed = distributed
33 |
34 | if torch.cuda.is_available():
35 | self.device = torch.device(f'cuda:{self.gpu}')
36 | torch.cuda.set_device(self.device)
37 | else:
38 | self.device = torch.device('cpu')
39 |
40 | print('Using %r.' %self.device)
41 |
42 | # checkpoint
43 | self.ckpt_path = ckpt_path
44 |
45 | # build network
46 | self.representation_size = representation_size
47 | self.projection_size = projection_size
48 | self.projection_hidden_size = projection_hidden_size
49 | self.model = self.build_model(encoder)
50 |
51 | if self.distributed:
52 | os.environ['MASTER_ADDR'] = 'localhost'
53 | os.environ['MASTER_PORT'] = port
54 | dist.init_process_group(backend='nccl', init_method='env://', rank=self.rank, world_size=self.world_size)
55 | self.group = dist.new_group()
56 |
57 | self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
58 | self.model = DDP(self.model, device_ids=[self.gpu], find_unused_parameters=True)
59 |
60 | # dataloaders
61 | self.train_dataloader = train_dataloader
62 | self.prepare_views = prepare_views # outputs view1 and view2 (pre-gpu-transform)
63 |
64 | # transformers
65 | # these are on gpu transforms! can have cpu transform in dataloaders
66 | self.transform_1 = transform_1 if transform_1 is not None else transform # class 1 of transformations
67 | self.transform_2 = transform_2 if transform_2 is not None else transform # class 2 of transformations
68 | assert (self.transform_1 is None) == (self.transform_2 is None)
69 |
70 | # training parameters
71 | self.total_epochs = total_epochs
72 | self.warmup_epochs = warmup_epochs
73 |
74 | # todo fix batch shape (double batch loader)
75 | self.train_batch_size = batch_size
76 | self.global_batch_size = self.world_size * self.train_batch_size
77 |
78 | self.num_examples = len(self.train_dataloader.dataset)
79 | self.warmup_steps = self.warmup_epochs * self.num_examples // self.global_batch_size
80 | self.total_steps = self.total_epochs * self.num_examples // self.global_batch_size
81 |
82 | self.step = 0
83 | base_lr = base_lr / 256
84 | self.max_lr = base_lr * self.global_batch_size
85 |
86 | self.base_mm = base_momentum
87 |
88 | assert decay in ['cosine', 'poly']
89 | self.decay = decay
90 | self.n_decay = n_decay
91 |
92 | assert m_decay in ['cosine', 'cste']
93 | self.m_decay = m_decay
94 |
95 | # configure optimizer
96 | self.momentum = momentum
97 | self.weight_decay = weight_decay
98 | self.exclude_bias_and_bn = exclude_bias_and_bn
99 |
100 | if self.exclude_bias_and_bn:
101 | if not self.distributed:
102 | params = self._collect_params(self.model.trainable_modules)
103 | else:
104 | # todo make sure this is correct
105 | params = self._collect_params(self.model.module.trainable_module_list)
106 | else:
107 | params = self.model.parameters()
108 |
109 | if optimizer_type == "lars":
110 | self.optimizer = LARS(params, lr=self.max_lr, momentum=self.momentum, weight_decay=self.weight_decay)
111 | elif optimizer_type == "sgd":
112 | self.optimizer = SGD(params, lr=base_lr, momentum=self.momentum, weight_decay=self.weight_decay)
113 | elif optimizer_type == "adam":
114 | if momentum != 1.0:
115 | warnings.warn("Adam optimizer doesn't use momentum. Momentum %.2f will be ignored." % momentum)
116 | self.optimizer = Adam(params, lr=base_lr, weight_decay=self.weight_decay)
117 | else:
118 | raise ValueError("Optimizer type needs to be 'lars', 'sgd' or 'adam', got (%s)." % optimizer_type)
119 |
120 | self.loss = CosineLoss().to(self.device)
121 | self.symmetric_loss = symmetric_loss
122 |
123 | # logging
124 | self.log_step = log_step
125 | if self.rank == 0:
126 | self.writer = SummaryWriter(log_dir)
127 |
128 | def build_model(self, encoder):
129 | projector = MLP3(self.representation_size, self.projection_size, self.projection_hidden_size)
130 | predictor = MLP3(self.projection_size, self.projection_size, self.projection_hidden_size)
131 | net = BYOL(encoder, projector, predictor)
132 | return net.to(self.device)
133 |
134 | def _collect_params(self, model_list):
135 | """
136 | exclude_bias_and bn: exclude bias and bn from both weight decay and LARS adaptation
137 | in the PyTorch implementation of ResNet, `downsample.1` are bn layers
138 | """
139 | param_list = []
140 | for model in model_list:
141 | for name, param in model.named_parameters():
142 | if self.exclude_bias_and_bn and ('bn' in name or 'downsample.1' in name or 'bias' in name):
143 | param_dict = {'params': param, 'weight_decay': 0., 'lars_exclude': True}
144 | else:
145 | param_dict = {'params': param}
146 | param_list.append(param_dict)
147 | return param_list
148 |
149 | def _cosine_decay(self, step):
150 | return 0.5 * self.max_lr * (1 + np.cos((step - self.warmup_steps) * np.pi / (self.total_steps - self.warmup_steps)))
151 |
152 | def _poly_decay(self, step):
153 | return self.max_lr * (1 - ((step - self.warmup_steps) / (self.total_steps- self.warmup_steps)) ** self.n_decay)
154 |
155 | def update_learning_rate(self, step, decay='poly'):
156 | """learning rate warm up and decay"""
157 | if step <= self.warmup_steps:
158 | lr = self.max_lr * step / self.warmup_steps
159 | else:
160 | if self.decay == 'cosine':
161 | lr = self._cosine_decay(step)
162 | elif self.decay == 'poly':
163 | lr = self._poly_decay(step)
164 | else:
165 | raise AttributeError
166 | for param_group in self.optimizer.param_groups:
167 | param_group['lr'] = lr
168 |
169 | def update_momentum(self, step):
170 | if self.m_decay == 'cosine':
171 | self.mm = 1 - (1 - self.base_mm) * (np.cos(np.pi * step / self.total_steps) + 1) / 2
172 | elif self.m_decay == 'cste':
173 | self.mm = self.base_mm
174 | else:
175 | raise AttributeError
176 |
177 | def save_checkpoint(self, epoch):
178 | if self.rank == 0:
179 | state = {
180 | 'epoch': epoch,
181 | 'steps': self.step,
182 | 'model': self.model.state_dict(),
183 | 'optimizer': self.optimizer.state_dict(),
184 | }
185 | torch.save(state, self.ckpt_path %(epoch))
186 |
187 | def load_checkpoint(self, epoch):
188 | model_path = self.ckpt_path %(epoch)
189 | map_location = {"cuda:{}": "cuda:{}".format(self.master_gpu, self.gpu)}
190 | map_location = "cuda:{}".format(self.gpu)
191 | checkpoint = torch.load(model_path, map_location=map_location)
192 |
193 | self.step = checkpoint['steps']
194 | self.model.load_state_dict(checkpoint['model'], strict=False)
195 |
196 | self.optimizer.load_state_dict(checkpoint['optimizer'])
197 |
198 | def cleanup(self):
199 | dist.destroy_process_group()
200 |
201 | def forward_loss(self, preds, targets):
202 | loss = self.loss(preds, targets)
203 | return loss
204 |
205 | def update_target_network(self):
206 | if not self.distributed:
207 | self.model.update_target_network(self.mm)
208 | else:
209 | self.model.module.update_target_network(self.mm)
210 |
211 | def log_schedule(self, loss):
212 | self.writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], self.step)
213 | self.writer.add_scalar('mm', self.mm, self.step)
214 | self.writer.add_scalar('loss', loss, self.step)
215 |
216 | def train_epoch(self):
217 | self.model.train()
218 | for inputs in self.train_dataloader:
219 | # update parameters
220 | self.update_learning_rate(self.step)
221 | self.update_momentum(self.step)
222 |
223 | inputs = self.prepare_views(inputs)
224 | view1 = inputs['view1'].to(self.device)
225 | view2 = inputs['view2'].to(self.device)
226 |
227 | if self.transform_1 is not None:
228 | # apply transforms
229 | view1 = self.transform_1(view1)
230 | view2 = self.transform_2(view2)
231 |
232 | # forward
233 | outputs = self.model({'online_view': view1, 'target_view':view2})
234 | loss = self.forward_loss(outputs['online_q'], outputs['target_z'])
235 | if self.symmetric_loss:
236 | outputs = self.model({'online_view': view2, 'target_view': view1})
237 | loss += self.forward_loss(outputs['online_q'], outputs['target_z'])
238 | loss /= 2
239 |
240 | # backprop online network
241 | self.optimizer.zero_grad()
242 | loss.backward()
243 | self.optimizer.step()
244 |
245 | # update moving average
246 | self.update_target_network()
247 |
248 | # log
249 | if self.step % self.log_step == 0 and self.rank == 0:
250 | self.log_schedule(loss=loss.item())
251 |
252 | # update parameters
253 | self.step += 1
254 |
--------------------------------------------------------------------------------
/self_supervised/trainer/myow_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.distributed as dist
4 |
5 | from self_supervised.model import MYOW, MLP3
6 | from self_supervised.trainer import BYOLTrainer
7 |
8 |
9 | class MYOWTrainer(BYOLTrainer):
10 | def __init__(self, view_pool_dataloader=None, transform_m=None,
11 | myow_warmup_epochs=0, myow_rampup_epochs=None, myow_max_weight=1., view_miner_k=4,
12 | log_img_step=0, untransform_vis=None, projection_size_2=None, projection_hidden_size_2=None, **kwargs):
13 |
14 | self.projection_size_2 = projection_size_2 if projection_size_2 is not None else kwargs['projection_size']
15 | self.projection_hidden_size_2 = projection_hidden_size_2 if projection_hidden_size_2 is not None \
16 | else kwargs['projection_hidden_size']
17 |
18 | # view pool dataloader
19 | self.view_pool_dataloader = view_pool_dataloader
20 |
21 | # view miner
22 | self.view_miner_k = view_miner_k
23 |
24 | # transform class for minning
25 | self.transform_m = transform_m
26 |
27 | # myow loss
28 | self.mined_loss_weight = 0.
29 | self.myow_max_weight = myow_max_weight
30 | self.myow_warmup_epochs = myow_warmup_epochs if myow_warmup_epochs is not None else 0
31 | self.myow_rampup_epochs = myow_rampup_epochs if myow_rampup_epochs is not None else kwargs['total_epochs']
32 |
33 | # convert to steps
34 | world_size = kwargs['world_size'] if 'world_size' in kwargs else 1
35 | self.num_examples = len(kwargs['train_dataloader'].dataset)
36 | self.train_batch_size = kwargs['batch_size']
37 | self.global_batch_size = world_size * self.train_batch_size
38 | self.myow_warmup_steps = self.myow_warmup_epochs * self.num_examples // self.global_batch_size
39 | self.myow_rampup_steps = self.myow_rampup_epochs * self.num_examples // self.global_batch_size
40 | self.total_steps = kwargs['total_epochs'] * self.num_examples // self.global_batch_size
41 |
42 | # logger
43 | self.log_img_step = log_img_step
44 | self.untransform_vis = untransform_vis
45 |
46 | super().__init__(**kwargs)
47 |
48 | def build_model(self, encoder):
49 | projector_1 = MLP3(self.representation_size, self.projection_size, self.projection_hidden_size)
50 | projector_2 = MLP3(self.projection_size, self.projection_size_2, self.projection_hidden_size_2)
51 | predictor_1 = MLP3(self.projection_size, self.projection_size, self.projection_hidden_size)
52 | predictor_2 = MLP3(self.projection_size_2, self.projection_size_2, self.projection_hidden_size_2)
53 | net = MYOW(encoder, projector_1, projector_2, predictor_1, predictor_2, n_neighbors=self.view_miner_k)
54 | return net.to(self.device)
55 |
56 | def update_mined_loss_weight(self, step):
57 | max_w = self.myow_max_weight
58 | min_w = 0.
59 | if step < self.myow_warmup_steps:
60 | self.mined_loss_weight = min_w
61 | elif step > self.myow_rampup_steps:
62 | self.mined_loss_weight = max_w
63 | else:
64 | self.mined_loss_weight = min_w + (max_w - min_w) * (step - self.myow_warmup_steps) / \
65 | (self.myow_rampup_steps - self.myow_warmup_steps)
66 |
67 | def log_schedule(self, loss):
68 | super().log_schedule(loss)
69 | self.writer.add_scalar('myow_weight', self.mined_loss_weight, self.step)
70 |
71 | def log_correspondance(self, view, view_mined):
72 | """ currently only implements 2d images"""
73 | img_batch = np.zeros((16, view.shape[1], view.shape[2], view.shape[3]))
74 | for i in range(8):
75 | img_batch[i] = self.untransform_vis(view[i]).detach().cpu().numpy()
76 | img_batch[8+i] = self.untransform_vis(view_mined[i]).detach().cpu().numpy()
77 | self.writer.add_images('correspondence', img_batch, self.step)
78 |
79 | def train_epoch(self):
80 | self.model.train()
81 | if self.view_pool_dataloader is not None:
82 | view_pooler = iter(self.view_pool_dataloader)
83 | for inputs in self.train_dataloader:
84 | # update parameters
85 | self.update_learning_rate(self.step)
86 | self.update_momentum(self.step)
87 | self.update_mined_loss_weight(self.step)
88 | self.optimizer.zero_grad()
89 |
90 | inputs = self.prepare_views(inputs)
91 | view1 = inputs['view1'].to(self.device)
92 | view2 = inputs['view2'].to(self.device)
93 |
94 | if self.transform_1 is not None:
95 | # apply transforms
96 | view1 = self.transform_1(view1)
97 | view2 = self.transform_2(view2)
98 |
99 | # forward
100 | outputs = self.model({'online_view': view1, 'target_view':view2})
101 | weight = 1 / (1. + self.mined_loss_weight)
102 | if self.symmetric_loss:
103 | weight /= 2.
104 | loss = weight * self.forward_loss(outputs['online_q'], outputs['target_z'])
105 |
106 | if self.distributed and self.mined_loss_weight > 0 and not self.symmetric_loss:
107 | with self.model.no_sync():
108 | loss.backward()
109 | else:
110 | loss.backward()
111 |
112 | if self.symmetric_loss:
113 | outputs = self.model({'online_view': view2, 'target_view': view1})
114 | weight = 1 / (1. + self.mined_loss_weight) / 2.
115 | loss = weight * self.forward_loss(outputs['online_q'], outputs['target_z'])
116 | if self.distributed and self.mined_loss_weight > 0:
117 | with self.model.no_sync():
118 | loss.backward()
119 | else:
120 | loss.backward()
121 |
122 | # mine view
123 | if self.mined_loss_weight > 0:
124 | if self.view_pool_dataloader is not None:
125 | try:
126 | # currently only supports img, label
127 | view_pool, label_pool = next(view_pooler)
128 | view_pool = view_pool.to(self.device).squeeze()
129 | except StopIteration:
130 | # reinit the dataloader
131 | view_pooler = iter(self.view_pool_dataloader)
132 | view_pool, label_pool = next(view_pooler)
133 | view_pool = view_pool.to(self.device).squeeze()
134 | view3 = inputs['view1'].to(self.device)
135 | else:
136 | view3 = inputs['view3'].to(self.device).squeeze() \
137 | if 'view3' in inputs else inputs['view1'].to(self.device).squeeze()
138 | view_pool = inputs['view_pool'].to(self.device).squeeze()
139 |
140 | # apply transform
141 | if self.transform_m is not None:
142 | # apply transforms
143 | view3 = self.transform_m(view3)
144 | view_pool = self.transform_m(view_pool)
145 |
146 | # compute representations
147 | outputs = self.model({'online_view': view3}, get_embedding='encoder')
148 | online_y = outputs['online_y']
149 | outputs_pool = self.model({'target_view': view_pool}, get_embedding='encoder')
150 | target_y_pool = outputs_pool['target_y']
151 |
152 | # mine views
153 | if self.distributed:
154 | gather_list = [torch.zeros_like(target_y_pool) for _ in range(self.world_size)]
155 | dist.all_gather(gather_list, target_y_pool, self.group)
156 | target_y_pool = torch.cat(gather_list, dim=0)
157 | selection_mask = self.model.module.mine_views(online_y, target_y_pool)
158 | else:
159 | selection_mask = self.model.mine_views(online_y, target_y_pool)
160 |
161 | target_y_mined = target_y_pool[selection_mask].contiguous()
162 | outputs_mined = self.model({'online_y': online_y,'target_y': target_y_mined}, get_embedding='predictor_m')
163 | weight = self.mined_loss_weight / (1. + self.mined_loss_weight)
164 | loss = weight * self.forward_loss(outputs_mined['online_q_m'], outputs_mined['target_v'])
165 | loss.backward()
166 |
167 | self.optimizer.step()
168 |
169 | # update moving average
170 | self.update_target_network()
171 |
172 | # log
173 | if self.step % self.log_step == 0 and self.rank == 0:
174 | self.log_schedule(loss=loss.item())
175 |
176 | # log images
177 | if self.mined_loss_weight > 0 and self.log_img_step > 0 and self.step % self.log_img_step == 0 and self.rank == 0:
178 | if self.distributed:
179 | # get image pools from all gpus
180 | gather_list = [torch.zeros_like(view_pool) for _ in range(self.world_size)]
181 | dist.all_gather(gather_list, view_pool, self.group)
182 | view_pool = torch.cat(gather_list, dim=0)
183 | self.log_correspondance(view3, view_pool[selection_mask])
184 |
185 | # update parameters
186 | self.step += 1
187 |
188 | return loss.item()
189 |
--------------------------------------------------------------------------------
/self_supervised/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from . import neural as neural_transforms
2 |
--------------------------------------------------------------------------------
/self_supervised/transforms/neural.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Compose:
5 | r"""Composes several transforms together.
6 |
7 | Args:
8 | transforms (Callable): List of transforms to compose.
9 | """
10 | def __init__(self, *transforms):
11 | self.transforms = transforms
12 |
13 | def __call__(self, *x):
14 | for transform in self.transforms:
15 | x = transform(*x)
16 | return x
17 |
18 |
19 | class Dropout:
20 | r"""Drops a neuron with a probability of :obj:`p`.
21 |
22 | .. note::
23 |
24 | If more than one tensor is given, the same dropout pattern will be applied to all.
25 |
26 | Args:
27 | p (float, Optional): Probability of dropout. (default: :obj:`0.5`)
28 | apply_p (float, Optional): Probability of applying the transformation. (default: :obj:`1.0`)
29 | scale (float, Optional): If :obj:`True`, the activity of the neurons are scaled to account for the decrease in
30 | overall activity. (default: :obj:`False`)
31 | """
32 | def __init__(self, p: float = 0.5, apply_p=1.):
33 | self.p = p
34 | self.apply_p = apply_p
35 |
36 | def __call__(self, *x_list):
37 | # create dropout mask (batch_size, num_neurons)
38 | dropout_mask = torch.rand(x_list[0].size()) < 1 - self.p
39 | # create apply mask: (batch_size,)
40 | apply_mask = torch.rand(x_list[0].size(0)) < 1 - self.apply_p
41 | dropout_mask = dropout_mask + apply_mask.view((-1, 1))
42 | return [x * dropout_mask.to(x) for x in x_list]
43 |
44 |
45 | class RandomizedDropout:
46 | r"""Drops a neuron with a random probability uniformly sampled between :obj:`0` and :obj:`p`.
47 |
48 | .. note::
49 |
50 | If more than one tensor is given, the same dropout pattern will be applied to all.
51 |
52 | Args:
53 | p (float, Optional): Upper bound for the probability of dropout. (default: :obj:`0.5`)
54 | apply_p (float, Optional): Probability of applying the transformation. (default: :obj:`1.0`)
55 | scale (float, Optional): If :obj:`True`, the activity of the neurons are scaled to account for the decrease in
56 | overall activity. (default: :obj:`False`)
57 | """
58 | def __init__(self, p: float = 0.5, apply_p=1.):
59 | self.p = p
60 | self.apply_p = apply_p
61 |
62 | def __call__(self, *x_list):
63 | # generate a random dropout probability for each sample
64 | p = torch.rand(x_list[0].size(0)) * self.p
65 | # generate dropout mask
66 | dropout_mask = torch.rand(x_list[0].size()) < 1 - p.view((-1, 1))
67 | # generate mask for applying dropout
68 | apply_mask = torch.rand(x_list[0].size(0)) < 1 - self.apply_p
69 | dropout_mask = dropout_mask + apply_mask.view((-1, 1))
70 | return [x * dropout_mask.to(x) for x in x_list]
71 |
72 |
73 | class Noise:
74 | r"""Adds Gaussian noise to neural activity. The firing rate vector needs to have already been normalized, and
75 | the Gaussian noise is center and has standard deviation of :obj:`std`.
76 |
77 | .. note::
78 |
79 | If more than one tensor is given, the same dropout pattern will be applied to all.
80 |
81 | Args:
82 | std (float): Standard deviation of Gaussian noise.
83 | """
84 | def __init__(self, std):
85 | self.std = std
86 |
87 | def __call__(self, *x_list):
88 | if self.std == 0:
89 | return x_list
90 | noise = torch.normal(0.0, self.std, size=x_list[0].size())
91 | return [x + noise.to(x) for x in x_list]
92 |
93 |
94 | class Pepper:
95 | r"""Adds a constant to the neuron firing rate with a probability of :obj:`p`. The firing rate vector needs to have
96 | already been normalized.
97 |
98 | .. note::
99 |
100 | If more than one tensor is given, the same dropout pattern will be applied to all.
101 |
102 | Args:
103 | p (float, Optional): Probability of adding pepper. (default: :obj:`0.5`)
104 | apply_p (float, Optional): Probability of applying the transformation. (default: :obj:`1.0`)
105 | sigma (float, Optional): Constant to be added to neural activity. (default: :obj:`1.0`)
106 | """
107 | def __init__(self, p=0.5, sigma=1.0, apply_p=1.):
108 | self.p = p
109 | self.sigma = sigma
110 | self.apply_p = apply_p
111 |
112 | def __call__(self, *x_list):
113 | keep_mask = torch.rand(x_list[0].size()) < self.p
114 | random_pepper = self.sigma * keep_mask
115 | apply_mask = torch.rand(x_list[0].size(0)) < self.apply_p
116 | random_pepper = random_pepper * apply_mask.view((-1, 1))
117 | return [x + random_pepper.to(x) for x in x_list]
118 |
119 |
120 | class Normalize:
121 | r"""Normalization transform.
122 |
123 | Args:
124 | mean (torch.Tensor): Mean.
125 | std (torch.Tensor): Standard deviation.
126 | """
127 | def __init__(self, mean, std):
128 | self.mean = mean
129 | self.std = std
130 |
131 | def __call__(self, *x_list):
132 | return [torch.div(x - self.mean.to(x), self.std.to(x)) for x in x_list]
133 |
--------------------------------------------------------------------------------
/self_supervised/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metric_logger import MetricLogger
2 | from .random_seeders import set_random_seeds
3 |
--------------------------------------------------------------------------------
/self_supervised/utils/metric_logger.py:
--------------------------------------------------------------------------------
1 |
2 | class MetricLogger:
3 | r"""Keeps track of training and validation curves, by recording:
4 | - Last value of train and validation metrics.
5 | - Train and validation metrics corresponding to maximum or minimum validation metric value.
6 | - Exponential moving average of train and validation metrics.
7 |
8 | Args:
9 | smoothing_factor (float, Optional): Smoothing factor used in exponential moving average.
10 | (default: :obj:`0.4`).
11 | max (bool, Optional): If :obj:`True`, tracks max value. Otherwise, tracks min value. (default: :obj:`True`).
12 | """
13 | def __init__(self, smoothing_factor=0.4, max=True):
14 | self.smoothing_factor = smoothing_factor
15 | self.max = max
16 |
17 | # init variables
18 | # last
19 | self.train_last = None
20 | self.val_last = None
21 | self.test_last = None
22 |
23 | # moving average
24 | self.train_smooth = None
25 | self.val_smooth = None
26 | self.test_smooth = None
27 |
28 | # max
29 | self.train_minmax = None
30 | self.val_minmax = None
31 | self.test_minmax = None
32 | self.step_minmax = None
33 |
34 | def __repr__(self):
35 | out = "Last: (Train) %.4f (Val) %.4f\n" % (self.train_last, self.val_last)
36 | out += "Smooth: (Train) %.4f (Val) %.4f\n" % (self.train_smooth, self.val_smooth)
37 | out += "Max: (Train) %.4f (Val) %.4f\n" % (self.train_minmax, self.val_minmax)
38 | return out
39 |
40 | def update(self, train_value, val_value, test_value=0., step=None):
41 | # last values
42 | self.train_last = train_value
43 | self.val_last = val_value
44 | self.test_last = test_value
45 |
46 | # exponential moving average
47 | self.train_smooth = self.smoothing_factor * train_value + (1 - self.smoothing_factor) * self.train_smooth \
48 | if self.train_smooth is not None else train_value
49 | self.val_smooth = self.smoothing_factor * val_value + (1 - self.smoothing_factor) * self.val_smooth \
50 | if self.val_smooth is not None else val_value
51 | self.test_smooth = self.smoothing_factor * test_value + (1 - self.smoothing_factor) * self.test_smooth \
52 | if self.test_smooth is not None else test_value
53 |
54 | # max/min validation accuracy
55 | if self.val_minmax is None or (self.max and self.val_minmax < val_value) or \
56 | (not self.max and self.val_minmax > val_value):
57 | self.train_minmax = train_value
58 | self.val_minmax = val_value
59 | self.test_minmax = test_value
60 | if step:
61 | self.step_minmax = step
62 |
63 | def __getattr__(self, item):
64 | if item not in ['train_min', 'train_max', 'val_min', 'val_max', 'test_min', 'test_max']:
65 | raise AttributeError
66 | if self.max and item in ['train_min', 'val_min', 'test_min']:
67 | raise AttributeError('Tracking maximum values, not minimum.')
68 | if not self.max and item in ['train_max', 'val_max', 'test_max']:
69 | raise AttributeError('Tracking minimum values, not maximum.')
70 |
71 | if 'train' in item:
72 | return self.train_minmax
73 | elif 'val' in item:
74 | return self.val_minmax
75 | elif 'test' in item:
76 | return self.test_minmax
77 |
--------------------------------------------------------------------------------
/self_supervised/utils/random_seeders.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def set_random_seeds(random_seed=0):
8 | r"""Sets the seed for generating random numbers.
9 |
10 | Args:
11 | random_seed: Desired random seed.
12 | """
13 | torch.manual_seed(random_seed)
14 | torch.cuda.manual_seed(random_seed)
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 | np.random.seed(random_seed)
18 | random.seed(random_seed)
19 |
--------------------------------------------------------------------------------