├── .gitignore ├── LICENSE ├── README.md ├── constants.py ├── counterfactual_search.py ├── environment.yml ├── features_to_topk_matrix.py ├── models ├── mlpmixer.py ├── mobilenetv2.py ├── resnet.py └── swin.py ├── self_supervised_models ├── launch.sh ├── models │ ├── __init__.py │ ├── barlow.py │ ├── byol.py │ ├── dcl.py │ ├── dclw.py │ ├── dino.py │ ├── moco.py │ ├── nnclr.py │ ├── simclr.py │ ├── simsiam.py │ ├── smog.py │ └── swav.py └── train_ssl.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vasu Singla 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Simple Efficient Data Attribution 2 | 3 | This work contains all the code for the paper [A Simple and Efficient Baseline for Data Attribution on Images](https://arxiv.org/abs/2311.03386) by Vasu Singla, Pedro Sandoval-Segura, Micah Goldblum, Jonas Geiping, Tom Goldstein 4 | 5 | 6 | ### Installation 7 | 8 | This repository requires [FFCV](https://github.com/libffcv/ffcv) library, and [PyTorch](https://pytorch.org/). You can also install a very bloated environment via the following command. NOTE - This environment is bloated, and contains packages not required for this repository. 9 | 10 | ``` 11 | conda env create -f environment.yml 12 | ``` 13 | 14 | ### Data 15 | 16 | All the data used for the paper is provided [Google Drive Link](https://drive.google.com/drive/folders/10_WMZ4c8Co_VV-i3isoPcdM-t9q0-VuL?usp=drive_link). We describe all the data included below - 17 | 18 | 19 | 1. **Top-k Train Samples** - For our repository, we pre-compute the closest top-k training samples from each method and our baselines. These are also provided in the link under the subfolders `cifar10/topk_train_samples` and `imagenet/topk_train_samples` for CIFAR-10 and Imagenet respectively. 20 | 2. **Test Indices** - We randomly selected 100 and 30 test samples for CIFAR-10 and Imagenet used throughout the paper, these are provided at `cifar10/test_indices` and `imagenet/test_indices`. 21 | 3. **Mislabel Support MetaData** - To compute mislabel support, we also need to specify which class to flip a test sample to. For CIFAR-10, we trained 10 Resnet-9 models for this task, and for Imagenet we trained 4 Resnet-18 models. The average predictions of these are provided in the link above. The metadata also requires labels for the dataset which are included above. 22 | 4. **Models** - *Note that the models used are not required to run this code, only the top-k training samples are required*. However, for transperancy the link also contains our trained Self-Supervised Models, and DataModel Weights for CIFAR-10. All of these are provided at the link [here](https://drive.google.com/drive/folders/1Nh_3lZx_sn0_bANoNJGizfvXfWc5Bmz5?usp=sharing). For Imagenet MoCo model, you directly download it from the official [repo](https://github.com/facebookresearch/moco). For reproducing TRAK, you can follow the tutorial from the author's original [code](https://github.com/MadryLab/trak). 23 | 24 | ### Counterfactual Estimation on CIFAR-10 25 | 26 | To perform counterfactual estimation for a single test sample on CIFAR-10 run the following - 27 | 28 | ``` 29 | python counterfactual_search.py --test-idx $test_idx \ 30 | --matrix-path $matrix_path \ 31 | --results-path $results_path \ 32 | --num-tests 5 \ 33 | --search-budget 7 \ 34 | --arch $arch 35 | ``` 36 | 37 | The arguments are defined as follows - 38 | 39 | ``` 40 | --test-idx Specifies the test index on which to perform counterfactual estimation 41 | --matrix-path Path to matrix containing top-k **training indices for each validation sample** 42 | --results-path Path where results for the test sample are dumped as a pickle file 43 | --search-budget Budget to use for bisection search 44 | --arch Model architecture to use {resnet-9, mobilenetv2} 45 | --flip-class Boolean argument, if specified computes mislabel support instead of removal support 46 | ``` 47 | 48 | When using `--flip-class`, you also need to specify where the metadata regarding the test labels and second predicted class using `--label-path` and `--rank-path`. This metadata is provided in the data above. 49 | 50 | ### CounterFactual Estimation on Imagenet 51 | 52 | TODO. This has a few of our SLURM stuff built-in that needs to be removed for release. In the meantime if you want, you can adapt the code we used from [FFCV Imagenet](https://github.com/libffcv/ffcv-imagenet/tree/main) to do counterfactual estimation. 53 | 54 | ### Self Supervised Models - CIFAR 55 | 56 | To train CIFAR-10 SSL models, use the `self_supervised_models` subfolder. The `train_ssl.py` script provides an interface for the same. 57 | 58 | ### Citation 59 | 60 | If you find our code useful, please consider citing our work - 61 | 62 | ``` 63 | @misc{singla2023simple, 64 | title={A Simple and Efficient Baseline for Data Attribution on Images}, 65 | author={Vasu Singla and Pedro Sandoval-Segura and Micah Goldblum and Jonas Geiping and Tom Goldstein}, 66 | year={2023}, 67 | eprint={2311.03386}, 68 | archivePrefix={arXiv}, 69 | primaryClass={cs.CV} 70 | } 71 | ``` 72 | 73 | If you run into any problems, please raise a Github Issue, we'll be happy to help! 74 | 75 | ### Acknowledgments 76 | 77 | The Datamodels weights on CIFAR-10 using 50% of the data were downloaded from [here](https://github.com/MadryLab/datamodels-data). We also trained our own datamodels using code available [here](https://github.com/MadryLab/datamodels/tree/main). 78 | 79 | The TRAK models were trained using code available [here](https://github.com/MadryLab/trak). 80 | 81 | FFCV Imagenet training code was used from [here](https://github.com/libffcv/ffcv-imagenet/tree/main). 82 | 83 | The Self-Supervised models were trained using [Lightly Benchmark Code](https://docs.lightly.ai/self-supervised-learning/getting_started/benchmarks.html). -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Note that statistics are wrt to uin8 range, [0,255]. 2 | CIFAR_MEAN = [125.307, 122.961, 113.8575] 3 | CIFAR_STD = [51.5865, 50.847, 51.255] 4 | CIFAR_TRAIN_SIZE = 50000 5 | CIFAR_TEST_SIZE = 10000 6 | -------------------------------------------------------------------------------- /counterfactual_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | from train_utils import binary_search 5 | import pickle as pkl 6 | import os 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--test-idx', type=int, default=0, 11 | help='Index of test example for which data support \ 12 | needs to be estimated.') 13 | parser.add_argument('--flip-class', action='store_true') 14 | parser.add_argument('--rank-path', type=str, default='data/logits/average_rank.npy') 15 | parser.add_argument('--matrix-path', type=str, default='data/topk_train_samples/dmodel_1280.npy') 16 | parser.add_argument('--results-path', type=str, default='results/dmodel_1280/') 17 | parser.add_argument('--num-tests', type=int, default=8) 18 | parser.add_argument('--search-budget', type=int, default=8) 19 | parser.add_argument('--arch', type=str, default='resnet9') 20 | return parser.parse_args() 21 | 22 | def main(args): 23 | # args = parse_args() 24 | test_idxs = np.load('data/test_indices.npy/test_100.npy') 25 | for idx, test_idx in test_idxs: 26 | topk_matrix = np.load(args.matrix_path) 27 | train_idxs = topk_matrix[idx] 28 | if args.flip_class: 29 | test_labels = np.load('data/info/test_labels.npy') 30 | rank_info = np.load(args.rank_path)[test_idx] 31 | sorted_logits = rank_info.argsort() 32 | flip_class = sorted_logits[1] 33 | if flip_class == test_labels[test_idx]: 34 | flip_class = sorted_logits[0] 35 | else: 36 | flip_class = None 37 | data_support = binary_search(train_idxs, 38 | flip_class=flip_class, 39 | eval_idx=test_idx, 40 | search_budget=args.search_budget, 41 | num_tests=args.num_tests, 42 | arch=args.arch) 43 | os.makedirs(args.results_path, exist_ok=True) 44 | fname = os.path.join(args.results_path, f'{args.test_idx}.pkl') 45 | with open(fname, 'wb') as f: 46 | pkl.dump(data_support, f) 47 | 48 | if __name__ == '__main__': 49 | args = parse_args() 50 | main(args) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ffcv 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - absl-py=1.3.0=pyhd8ed1ab_0 10 | - aiohttp=3.8.3=py39hb9d737c_1 11 | - aiosignal=1.3.1=pyhd8ed1ab_0 12 | - alsa-lib=1.2.8=h166bdaf_0 13 | - anyio=3.5.0=py39h06a4308_0 14 | - aom=3.5.0=h27087fc_0 15 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 16 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 17 | - asttokens=2.0.5=pyhd3eb1b0_0 18 | - async-timeout=4.0.2=pyhd8ed1ab_0 19 | - attr=2.5.1=h166bdaf_1 20 | - attrs=22.1.0=py39h06a4308_0 21 | - babel=2.9.1=pyhd3eb1b0_0 22 | - backcall=0.2.0=pyhd3eb1b0_0 23 | - beautifulsoup4=4.11.1=py39h06a4308_0 24 | - binutils=2.39=hdd6e379_1 25 | - binutils_impl_linux-64=2.39=he00db2b_1 26 | - binutils_linux-64=2.39=h5fc0e48_11 27 | - blas=1.0=mkl 28 | - bleach=4.1.0=pyhd3eb1b0_0 29 | - blinker=1.5=pyhd8ed1ab_0 30 | - brotlipy=0.7.0=py39h27cfd23_1003 31 | - bzip2=1.0.8=h7b6447c_0 32 | - c-ares=1.18.1=h7f8727e_0 33 | - c-compiler=1.5.1=h166bdaf_0 34 | - ca-certificates=2023.7.22=hbcca054_0 35 | - cachetools=5.2.0=pyhd8ed1ab_0 36 | - cairo=1.16.0=ha61ee94_1014 37 | - certifi=2023.7.22=pyhd8ed1ab_0 38 | - cffi=1.15.1=py39h5eee18b_2 39 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 40 | - click=8.1.3=unix_pyhd8ed1ab_2 41 | - compilers=1.5.1=ha770c72_0 42 | - cryptography=38.0.1=py39h9ce1e76_0 43 | - cudatoolkit=11.3.1=h9edb442_11 44 | - cupy=11.3.0=py39hc3c280e_1 45 | - cxx-compiler=1.5.1=h924138e_0 46 | - dbus=1.13.18=hb2f20db_0 47 | - debugpy=1.5.1=py39h295c915_0 48 | - decorator=5.1.1=pyhd3eb1b0_0 49 | - defusedxml=0.7.1=pyhd3eb1b0_0 50 | - einops=0.6.1=pyhd8ed1ab_0 51 | - entrypoints=0.4=py39h06a4308_0 52 | - executing=0.8.3=pyhd3eb1b0_0 53 | - expat=2.5.0=h27087fc_0 54 | - fastrlock=0.8=py39h5a03fae_3 55 | - ffmpeg=5.1.2=gpl_hc51e5dc_103 56 | - fftw=3.3.10=nompi_hf0379b8_106 57 | - flit-core=3.6.0=pyhd3eb1b0_0 58 | - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0 59 | - font-ttf-inconsolata=2.001=hcb22688_0 60 | - font-ttf-source-code-pro=2.030=hd3eb1b0_0 61 | - font-ttf-ubuntu=0.83=h8b1ccd4_0 62 | - fontconfig=2.14.1=hc2a2eb6_0 63 | - fonts-anaconda=1=h8fa9717_0 64 | - fonts-conda-ecosystem=1=hd3eb1b0_0 65 | - fortran-compiler=1.5.1=h2a4ca65_0 66 | - freeglut=3.2.2=h9c3ff4c_1 67 | - freetype=2.12.1=h4a9f257_0 68 | - frozenlist=1.3.3=py39hb9d737c_0 69 | - gcc=10.4.0=hb92f740_11 70 | - gcc_impl_linux-64=10.4.0=h5231bdf_19 71 | - gcc_linux-64=10.4.0=h9215b83_11 72 | - gettext=0.21.1=h27087fc_0 73 | - gfortran=10.4.0=h0c96582_11 74 | - gfortran_impl_linux-64=10.4.0=h7d168d2_19 75 | - gfortran_linux-64=10.4.0=h69d5af5_11 76 | - gh=2.25.1=ha8f183a_0 77 | - giflib=5.2.1=h7b6447c_0 78 | - glib=2.74.1=h6239696_1 79 | - glib-tools=2.74.1=h6239696_1 80 | - gmp=6.2.1=h295c915_3 81 | - gnutls=3.7.8=hf3e180e_0 82 | - google-auth=2.15.0=pyh1a96a4e_0 83 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 84 | - graphite2=1.3.14=h295c915_1 85 | - grpcio=1.51.1=py39h8c60046_0 86 | - gst-plugins-base=1.21.2=h3e40eee_0 87 | - gstreamer=1.21.2=hd4edc92_0 88 | - gstreamer-orc=0.4.33=h166bdaf_0 89 | - gxx=10.4.0=hb92f740_11 90 | - gxx_impl_linux-64=10.4.0=h5231bdf_19 91 | - gxx_linux-64=10.4.0=h6e491c6_11 92 | - harfbuzz=5.3.0=h418a68e_0 93 | - hdf5=1.12.2=nompi_h4df4325_100 94 | - icu=70.1=h27087fc_0 95 | - idna=3.4=py39h06a4308_0 96 | - importlib-metadata=4.11.3=py39h06a4308_0 97 | - intel-openmp=2022.1.0=h9e868ea_3769 98 | - ipykernel=6.15.2=py39h06a4308_0 99 | - ipython=8.6.0=py39h06a4308_0 100 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 101 | - jack=1.9.21=h583fa2b_2 102 | - jasper=2.0.33=ha77e612_0 103 | - jedi=0.18.1=py39h06a4308_1 104 | - jinja2=3.1.2=py39h06a4308_0 105 | - jpeg=9e=h7f8727e_0 106 | - json5=0.9.6=pyhd3eb1b0_0 107 | - jsonschema=4.16.0=py39h06a4308_0 108 | - jupyter_client=7.4.7=py39h06a4308_0 109 | - jupyter_core=4.11.2=py39h06a4308_0 110 | - jupyter_server=1.18.1=py39h06a4308_0 111 | - jupyterlab=3.5.0=pyhd8ed1ab_0 112 | - jupyterlab_pygments=0.1.2=py_0 113 | - jupyterlab_server=2.16.3=py39h06a4308_0 114 | - kernel-headers_linux-64=2.6.32=he073ed8_15 115 | - keyutils=1.6.1=h166bdaf_0 116 | - krb5=1.19.3=h08a2579_0 117 | - lame=3.100=h7b6447c_0 118 | - lcms2=2.12=h3be6417_0 119 | - ld_impl_linux-64=2.39=hcc3a1bd_1 120 | - lerc=3.0=h295c915_0 121 | - libabseil=20220623.0=cxx17_h48a1fff_5 122 | - libblas=3.9.0=16_linux64_mkl 123 | - libcap=2.66=ha37c62d_0 124 | - libcblas=3.9.0=16_linux64_mkl 125 | - libclang=15.0.6=default_h2e3cab8_0 126 | - libclang13=15.0.6=default_h3a83d3e_0 127 | - libcups=2.3.3=h3e49a29_2 128 | - libcurl=7.86.0=h2283fc2_1 129 | - libdb=6.2.32=h6a678d5_1 130 | - libdeflate=1.8=h7f8727e_5 131 | - libdrm=2.4.114=h166bdaf_0 132 | - libedit=3.1.20210910=h7f8727e_0 133 | - libev=4.33=h7f8727e_1 134 | - libevent=2.1.10=h28343ad_4 135 | - libffi=3.4.2=h6a678d5_6 136 | - libflac=1.4.2=h27087fc_0 137 | - libgcc-devel_linux-64=10.4.0=hd38fd1e_19 138 | - libgcc-ng=12.2.0=h65d4601_19 139 | - libgcrypt=1.10.1=h166bdaf_0 140 | - libgfortran-ng=11.2.0=h00389a5_1 141 | - libgfortran5=11.2.0=h1234567_1 142 | - libglib=2.74.1=h606061b_1 143 | - libglu=9.0.0=hf484d3e_1 144 | - libgomp=12.2.0=h65d4601_19 145 | - libgpg-error=1.45=hc0c96e0_0 146 | - libgrpc=1.51.1=h30feacc_0 147 | - libiconv=1.17=h166bdaf_0 148 | - libidn2=2.3.2=h7f8727e_0 149 | - libjpeg-turbo=2.1.4=h166bdaf_0 150 | - liblapack=3.9.0=16_linux64_mkl 151 | - liblapacke=3.9.0=16_linux64_mkl 152 | - libllvm11=11.1.0=h9e868ea_6 153 | - libllvm15=15.0.6=h63197d8_0 154 | - libnghttp2=1.47.0=hff17c54_1 155 | - libnsl=2.0.0=h5eee18b_0 156 | - libogg=1.3.5=h27cfd23_1 157 | - libopencv=4.6.0=py39h9757d25_6 158 | - libopus=1.3.1=h7b6447c_0 159 | - libpciaccess=0.17=h166bdaf_0 160 | - libpng=1.6.39=h753d276_0 161 | - libpq=15.1=h67c24c5_1 162 | - libprotobuf=3.21.10=h6239696_0 163 | - libsanitizer=10.4.0=h5246dfb_19 164 | - libsndfile=1.1.0=h27087fc_0 165 | - libsodium=1.0.18=h7b6447c_0 166 | - libsqlite=3.40.0=h753d276_0 167 | - libssh2=1.10.0=hf14f497_3 168 | - libstdcxx-devel_linux-64=10.4.0=hd38fd1e_19 169 | - libstdcxx-ng=12.2.0=h46fd767_19 170 | - libsystemd0=252=h2a991cd_0 171 | - libtasn1=4.19.0=h166bdaf_0 172 | - libtiff=4.4.0=hecacb30_2 173 | - libtool=2.4.6=h295c915_1008 174 | - libudev1=252=h166bdaf_0 175 | - libunistring=0.9.10=h27cfd23_0 176 | - libuuid=2.32.1=h7f98852_1000 177 | - libva=2.16.0=h166bdaf_0 178 | - libvorbis=1.3.7=h7b6447c_0 179 | - libvpx=1.11.0=h295c915_0 180 | - libwebp=1.2.4=h11a3e52_0 181 | - libwebp-base=1.2.4=h5eee18b_0 182 | - libxcb=1.13=h1bed415_1 183 | - libxkbcommon=1.0.3=he3ba5ed_0 184 | - libxml2=2.10.3=h7463322_0 185 | - libzlib=1.2.13=h166bdaf_4 186 | - llvm-openmp=14.0.6=h9e868ea_0 187 | - llvmlite=0.39.1=py39he621ea3_0 188 | - lz4-c=1.9.3=h295c915_1 189 | - markdown=3.4.1=pyhd8ed1ab_0 190 | - markupsafe=2.1.1=py39h7f8727e_0 191 | - matplotlib-inline=0.1.6=py39h06a4308_0 192 | - mistune=0.8.4=py39h27cfd23_1000 193 | - mkl=2022.1.0=hc2b9512_224 194 | - mpg123=1.30.2=h27087fc_1 195 | - multidict=6.0.2=py39hb9d737c_2 196 | - mysql-common=8.0.31=h26416b9_0 197 | - mysql-libs=8.0.31=hbc51c84_0 198 | - nbclassic=0.4.8=py39h06a4308_0 199 | - nbclient=0.5.13=py39h06a4308_0 200 | - nbconvert=6.4.4=py39h06a4308_0 201 | - nbformat=5.5.0=py39h06a4308_0 202 | - ncurses=6.3=h5eee18b_3 203 | - nest-asyncio=1.5.5=py39h06a4308_0 204 | - nettle=3.8.1=hc379101_1 205 | - notebook=6.5.2=py39h06a4308_0 206 | - notebook-shim=0.2.2=py39h06a4308_0 207 | - nspr=4.35=h27087fc_0 208 | - nss=3.82=he02c5a1_0 209 | - numba=0.56.4=py39h61ddf18_0 210 | - numpy=1.23.5=py39h3d75532_0 211 | - oauthlib=3.2.2=pyhd8ed1ab_0 212 | - opencv=4.6.0=py39hf3d152e_6 213 | - openh264=2.3.1=h27087fc_1 214 | - openssl=3.1.0=hd590300_3 215 | - p11-kit=0.24.1=hc5aa10d_0 216 | - pandocfilters=1.5.0=pyhd3eb1b0_0 217 | - parso=0.8.3=pyhd3eb1b0_0 218 | - pcre2=10.40=hc3806b6_0 219 | - pexpect=4.8.0=pyhd3eb1b0_3 220 | - pickleshare=0.7.5=pyhd3eb1b0_1003 221 | - pillow=9.2.0=py39hace64e9_1 222 | - pip=22.2.2=py39h06a4308_0 223 | - pixman=0.40.0=h7f8727e_1 224 | - pkg-config=0.29.2=h36c2ea0_1008 225 | - prometheus_client=0.14.1=py39h06a4308_0 226 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 227 | - ptyprocess=0.7.0=pyhd3eb1b0_2 228 | - pulseaudio=16.1=h126f2b6_0 229 | - pure_eval=0.2.2=pyhd3eb1b0_0 230 | - py-opencv=4.6.0=py39hcca971b_6 231 | - pyasn1=0.4.8=py_0 232 | - pyasn1-modules=0.2.7=py_0 233 | - pycparser=2.21=pyhd3eb1b0_0 234 | - pygments=2.11.2=pyhd3eb1b0_0 235 | - pyjwt=2.6.0=pyhd8ed1ab_0 236 | - pyopenssl=22.0.0=pyhd3eb1b0_0 237 | - pyrsistent=0.18.0=py39heee7806_0 238 | - pysocks=1.7.1=py39h06a4308_0 239 | - python=3.9.15=hba424b6_0_cpython 240 | - python-fastjsonschema=2.16.2=py39h06a4308_0 241 | - python_abi=3.9=3_cp39 242 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 243 | - pytorch-mutex=1.0=cuda 244 | - pyu2f=0.1.5=pyhd8ed1ab_0 245 | - pyzmq=23.2.0=py39h6a678d5_0 246 | - qt-main=5.15.6=he99da89_3 247 | - re2=2022.06.01=h27087fc_1 248 | - readline=8.2=h5eee18b_0 249 | - requests=2.28.1=py39h06a4308_0 250 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 251 | - rsa=4.9=pyhd8ed1ab_0 252 | - send2trash=1.8.0=pyhd3eb1b0_1 253 | - setuptools=65.5.0=py39h06a4308_0 254 | - sniffio=1.2.0=py39h06a4308_1 255 | - soupsieve=2.3.2.post1=py39h06a4308_0 256 | - stack_data=0.2.0=pyhd3eb1b0_0 257 | - svt-av1=1.3.0=h27087fc_0 258 | - sysroot_linux-64=2.12=he073ed8_15 259 | - tensorboard=2.11.0=pyhd8ed1ab_0 260 | - tensorboard-data-server=0.6.1=py39h3ccb8fc_4 261 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 262 | - terminado=0.13.1=py39h06a4308_0 263 | - testpath=0.6.0=py39h06a4308_0 264 | - tk=8.6.12=h1ccaba5_0 265 | - tomli=2.0.1=py39h06a4308_0 266 | - torchaudio=0.12.1=py39_cu113 267 | - torchvision=0.13.1=py39_cu113 268 | - tornado=6.2=py39h5eee18b_0 269 | - traitlets=5.1.1=pyhd3eb1b0_0 270 | - typing-extensions=4.4.0=py39h06a4308_0 271 | - typing_extensions=4.4.0=py39h06a4308_0 272 | - tzdata=2022f=h04d1e81_0 273 | - urllib3=1.26.12=py39h06a4308_0 274 | - wcwidth=0.2.5=pyhd3eb1b0_0 275 | - webencodings=0.5.1=py39h06a4308_1 276 | - websocket-client=0.58.0=py39h06a4308_4 277 | - werkzeug=2.2.2=pyhd8ed1ab_0 278 | - wheel=0.37.1=pyhd3eb1b0_0 279 | - x264=1!164.3095=h166bdaf_2 280 | - x265=3.5=h924138e_3 281 | - xcb-util=0.4.0=h166bdaf_0 282 | - xcb-util-image=0.4.0=h166bdaf_0 283 | - xcb-util-keysyms=0.4.0=h166bdaf_0 284 | - xcb-util-renderutil=0.3.9=h166bdaf_0 285 | - xcb-util-wm=0.4.1=h166bdaf_0 286 | - xorg-fixesproto=5.0=h7f98852_1002 287 | - xorg-inputproto=2.3.2=h7f98852_1002 288 | - xorg-kbproto=1.0.7=h7f98852_1002 289 | - xorg-libice=1.0.10=h7f98852_0 290 | - xorg-libsm=1.2.3=hd9c2040_1000 291 | - xorg-libx11=1.7.2=h7f98852_0 292 | - xorg-libxau=1.0.9=h7f98852_0 293 | - xorg-libxext=1.3.4=h7f98852_1 294 | - xorg-libxfixes=5.0.3=h7f98852_1004 295 | - xorg-libxi=1.7.10=h7f98852_0 296 | - xorg-libxrender=0.9.10=h7f98852_1003 297 | - xorg-renderproto=0.11.1=h7f98852_1002 298 | - xorg-xextproto=7.3.0=h7f98852_1002 299 | - xorg-xproto=7.0.31=h27cfd23_1007 300 | - xz=5.2.8=h5eee18b_0 301 | - yarl=1.8.1=py39hb9d737c_0 302 | - zeromq=4.3.4=h2531618_0 303 | - zipp=3.8.0=py39h06a4308_0 304 | - zlib=1.2.13=h166bdaf_4 305 | - zstd=1.5.2=ha4553b6_0 306 | - pip: 307 | - braceexpand==0.1.7 308 | - contourpy==1.0.6 309 | - cycler==0.11.0 310 | - docker-pycreds==0.4.0 311 | - ffcv==0.0.3rc1 312 | - fonttools==4.38.0 313 | - gitdb==4.0.10 314 | - gitpython==3.1.29 315 | - imgcat==0.5.0 316 | - ipdb==0.13.11 317 | - joblib==1.2.0 318 | - kiwisolver==1.4.4 319 | - matplotlib==3.6.2 320 | - packaging==21.3 321 | - pandas==1.5.2 322 | - pathtools==0.1.2 323 | - promise==2.3 324 | - protobuf==4.21.10 325 | - psutil==5.9.4 326 | - pyparsing==3.0.9 327 | - python-dateutil==2.8.2 328 | - pytorch-pfn-extras==0.6.3 329 | - pytz==2022.6 330 | - pyyaml==6.0 331 | - scikit-learn==1.1.3 332 | - scipy==1.9.3 333 | - sentry-sdk==1.11.1 334 | - setproctitle==1.3.2 335 | - shortuuid==1.0.11 336 | - six==1.16.0 337 | - smmap==5.0.0 338 | - tabulate==0.9.0 339 | - terminaltables==3.1.10 340 | - threadpoolctl==3.1.0 341 | - torchmetrics==0.11.4 342 | - tqdm==4.64.1 343 | - wandb==0.13.6 344 | - webdataset==0.2.31 345 | - wget==3.2 346 | -------------------------------------------------------------------------------- /features_to_topk_matrix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import argparse 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | from sklearn import svm 9 | from tqdm import tqdm 10 | 11 | model_dir = '' # Add model directory here 12 | required_test_indices_path = './data/test_indices/test_100.npy' 13 | save_topk_train_samples_dir = './data/topk_train_samples/' 14 | CIFAR10_ROOT = os.getenv('CIFAR10_ROOT') 15 | 16 | def main(): 17 | # Arguments to modify 18 | reg_C = 0.1 # SVM regularization parameter 19 | topk = 1280 # Top k training samples to save 20 | model_name = 'Moco_800epoch' 21 | label_conditioning = True 22 | distance_key = 'esvm' # 'esvm' for exemplar svm or 'l2knn' for euclidean distance KNN 23 | 24 | # No need to modify strings below 25 | run_name = f'tmp_requiredidx_{model_name}_{distance_key}_v2' + (f'_c{str(reg_C).replace(".", "")}' if distance_key in ['esvm', 'svm'] else '') 26 | conditioning_str = 'class_conditioned' if label_conditioning else 'unconditioned' 27 | print(f"Model name: {model_name}") 28 | print(f"Label conditioning: {label_conditioning}") 29 | print(f"Distance key: {distance_key}") 30 | print(f"Run name: {run_name}") 31 | 32 | feature_dir = os.path.join(model_dir, model_name) 33 | train_features = np.load(os.path.join(feature_dir, 'train.npy')) 34 | test_features = np.load(os.path.join(feature_dir, 'test.npy')) 35 | 36 | cifar_train_ds = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=True) 37 | cifar_test_ds = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=False) 38 | train_labels = np.array([cifar_train_ds[i][1] for i in range(len(cifar_train_ds))]) 39 | test_labels = np.array([cifar_test_ds[i][1] for i in range(len(cifar_test_ds))]) 40 | 41 | print(f"test features have the following shape: {test_features.shape}") 42 | print(f"test labels have the following shape: {test_labels.shape}") 43 | 44 | 45 | target_to_nearest_train_idx = np.zeros((10000, topk), dtype=np.int64) 46 | required_test_indices = np.load(required_test_indices_path) 47 | 48 | # For every target sample, we save topk indices of the topk training samples 49 | for i, cur_feature in tqdm(enumerate(test_features), desc='Computing topk matrix'): 50 | if i not in required_test_indices: 51 | continue 52 | 53 | # if label conditioning, we select training samples with same label as target 54 | if label_conditioning: 55 | cur_label = test_labels[i] 56 | train_indices_of_same_class = np.where(train_labels == cur_label)[0] 57 | cur_train_features = train_features[train_indices_of_same_class] 58 | else: 59 | cur_train_features = train_features 60 | 61 | # For ffcv_resnet package, every train image must be done one at a time 62 | # add back first dimension to cur feature 63 | cur_feature = np.expand_dims(cur_feature, axis=0) 64 | distances = get_distance(distance_key, cur_feature, cur_train_features, reg_C) 65 | 66 | if label_conditioning: 67 | topk_indices = train_indices_of_same_class[np.argsort(distances)[:topk]] 68 | else: 69 | topk_indices = np.argsort(distances)[:topk] 70 | 71 | target_to_nearest_train_idx[i] = topk_indices 72 | 73 | print('Created target_to_nearest_train_idx matrix with shape', target_to_nearest_train_idx.shape) 74 | 75 | # Save the topk matrix 76 | save_dir = os.path.join(save_topk_train_samples_dir, conditioning_str) 77 | os.makedirs(save_dir, exist_ok=True) 78 | np.save(os.path.join(save_dir, f'{run_name}_{topk}.npy'), target_to_nearest_train_idx) 79 | 80 | cifar_train_ds = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=True) 81 | cifar_test_ds = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=False) 82 | 83 | target_idx = 6218 84 | 85 | # plot target image followed by top 49 nearest training images 86 | # should not override the target image 87 | fig, axs = plt.subplots(5, 10, figsize=(20, 10)) 88 | fig.suptitle(run_name, fontsize=16) 89 | axs[0, 0].imshow(cifar_test_ds[target_idx][0]) 90 | axs[0, 0].set_title(f'Target, Class {cifar_test_ds[target_idx][1]}') 91 | axs[0, 0].axis('off') 92 | 93 | for i in range(49): 94 | row_idx = i // 10 95 | column_idx = (i % 10) + 1 96 | if column_idx == 10: 97 | column_idx = 0 98 | row_idx += 1 99 | 100 | axs[row_idx, column_idx].imshow(cifar_train_ds[target_to_nearest_train_idx[target_idx][i]][0]) 101 | axs[row_idx, column_idx].set_title(f'#{i+1}, Class {cifar_train_ds[target_to_nearest_train_idx[target_idx][i]][1]}') 102 | axs[row_idx, column_idx].axis('off') 103 | 104 | # Save the plot 105 | print(f'Saving plot to target{target_idx}_{run_name}_{conditioning_str}.png') 106 | fig.savefig(f'target{target_idx}_{run_name}_{conditioning_str}', dpi=300) 107 | 108 | def get_distance(distance_key, cur_feature, cur_train_features, reg_C): 109 | if distance_key == 'esvm': 110 | return get_exemplar_svm_distance(cur_feature, cur_train_features, reg_C) 111 | elif distance_key == 'l2knn': 112 | euclidean_distances = torch.cdist(torch.from_numpy(cur_feature), torch.from_numpy(cur_train_features), p=2) 113 | euclidean_distances = euclidean_distances[0] 114 | return euclidean_distances.numpy() 115 | else: 116 | raise ValueError(f"Invalid distance key: {distance_key}") 117 | 118 | def get_exemplar_svm_distance(test_feature, train_features, reg_C): 119 | """ 120 | returns svm distance of test_feature to train_features 121 | """ 122 | svm_x = np.concatenate((test_feature, train_features), axis=0) 123 | svm_y = np.zeros(svm_x.shape[0]) 124 | svm_y[0] = 1 # the target feature is our only positive example 125 | 126 | clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=reg_C) 127 | clf.fit(svm_x, svm_y) 128 | 129 | # get the top features 130 | similarities = clf.decision_function(train_features) 131 | return -similarities 132 | 133 | 134 | if __name__ == "__main__": 135 | main() -------------------------------------------------------------------------------- /models/mlpmixer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py 2 | from torch import nn 3 | from functools import partial 4 | from einops.layers.torch import Rearrange, Reduce 5 | 6 | pair = lambda x: x if isinstance(x, tuple) else (x, x) 7 | 8 | class PreNormResidual(nn.Module): 9 | def __init__(self, dim, fn): 10 | super().__init__() 11 | self.fn = fn 12 | self.norm = nn.LayerNorm(dim) 13 | 14 | def forward(self, x): 15 | return self.fn(self.norm(x)) + x 16 | 17 | def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): 18 | inner_dim = int(dim * expansion_factor) 19 | return nn.Sequential( 20 | dense(dim, inner_dim), 21 | nn.GELU(), 22 | nn.Dropout(dropout), 23 | dense(inner_dim, dim), 24 | nn.Dropout(dropout) 25 | ) 26 | 27 | def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.): 28 | image_h, image_w = pair(image_size) 29 | assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size' 30 | num_patches = (image_h // patch_size) * (image_w // patch_size) 31 | chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear 32 | 33 | return nn.Sequential( 34 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 35 | nn.Linear((patch_size ** 2) * channels, dim), 36 | *[nn.Sequential( 37 | PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), 38 | PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) 39 | ) for _ in range(depth)], 40 | nn.LayerNorm(dim), 41 | Reduce('b n c -> b c', 'mean'), 42 | nn.Linear(dim, num_classes) 43 | ) -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch as ch 2 | 3 | class Mul(ch.nn.Module): 4 | def __init__(self, weight): 5 | super(Mul, self).__init__() 6 | self.weight = weight 7 | def forward(self, x): return x * self.weight 8 | 9 | class Flatten(ch.nn.Module): 10 | def forward(self, x): return x.view(x.size(0), -1) 11 | 12 | class Residual(ch.nn.Module): 13 | def __init__(self, module): 14 | super(Residual, self).__init__() 15 | self.module = module 16 | def forward(self, x): return x + self.module(x) 17 | 18 | def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1): 19 | return ch.nn.Sequential( 20 | ch.nn.Conv2d(channels_in, channels_out, 21 | kernel_size=kernel_size, stride=stride, padding=padding, 22 | groups=groups, bias=False), 23 | ch.nn.BatchNorm2d(channels_out), 24 | ch.nn.ReLU(inplace=True) 25 | ) 26 | 27 | # NUM_CLASSES = 10 28 | 29 | 30 | def resnet9(num_classes=10): 31 | model = ch.nn.Sequential( 32 | conv_bn(3, 64, kernel_size=3, stride=1, padding=1), 33 | conv_bn(64, 128, kernel_size=5, stride=2, padding=2), 34 | Residual(ch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))), 35 | conv_bn(128, 256, kernel_size=3, stride=1, padding=1), 36 | ch.nn.MaxPool2d(2), 37 | Residual(ch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))), 38 | conv_bn(256, 128, kernel_size=3, stride=1, padding=0), 39 | ch.nn.AdaptiveMaxPool2d((1, 1)), 40 | Flatten(), 41 | ch.nn.Linear(128, num_classes, bias=False), 42 | Mul(0.2) 43 | ) 44 | return model 45 | -------------------------------------------------------------------------------- /models/swin.py: -------------------------------------------------------------------------------- 1 | # https://github.com/berniwal/swin-transformer-pytorch 2 | 3 | import torch 4 | from torch import nn, einsum 5 | import numpy as np 6 | from einops import rearrange, repeat 7 | 8 | 9 | class CyclicShift(nn.Module): 10 | def __init__(self, displacement): 11 | super().__init__() 12 | self.displacement = displacement 13 | 14 | def forward(self, x): 15 | return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)) 16 | 17 | 18 | class Residual(nn.Module): 19 | def __init__(self, fn): 20 | super().__init__() 21 | self.fn = fn 22 | 23 | def forward(self, x, **kwargs): 24 | return self.fn(x, **kwargs) + x 25 | 26 | 27 | class PreNorm(nn.Module): 28 | def __init__(self, dim, fn): 29 | super().__init__() 30 | self.norm = nn.LayerNorm(dim) 31 | self.fn = fn 32 | 33 | def forward(self, x, **kwargs): 34 | return self.fn(self.norm(x), **kwargs) 35 | 36 | 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, hidden_dim): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | nn.Linear(dim, hidden_dim), 42 | nn.GELU(), 43 | nn.Linear(hidden_dim, dim), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | def create_mask(window_size, displacement, upper_lower, left_right): 51 | mask = torch.zeros(window_size ** 2, window_size ** 2) 52 | 53 | if upper_lower: 54 | mask[-displacement * window_size:, :-displacement * window_size] = float('-inf') 55 | mask[:-displacement * window_size, -displacement * window_size:] = float('-inf') 56 | 57 | if left_right: 58 | mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size) 59 | mask[:, -displacement:, :, :-displacement] = float('-inf') 60 | mask[:, :-displacement, :, -displacement:] = float('-inf') 61 | mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)') 62 | 63 | return mask 64 | 65 | 66 | def get_relative_distances(window_size): 67 | indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)])) 68 | distances = indices[None, :, :] - indices[:, None, :] 69 | return distances 70 | 71 | 72 | class WindowAttention(nn.Module): 73 | def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding): 74 | super().__init__() 75 | inner_dim = head_dim * heads 76 | 77 | self.heads = heads 78 | self.scale = head_dim ** -0.5 79 | self.window_size = window_size 80 | self.relative_pos_embedding = relative_pos_embedding 81 | self.shifted = shifted 82 | 83 | if self.shifted: 84 | displacement = window_size // 2 85 | self.cyclic_shift = CyclicShift(-displacement) 86 | self.cyclic_back_shift = CyclicShift(displacement) 87 | self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 88 | upper_lower=True, left_right=False), requires_grad=False) 89 | self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, 90 | upper_lower=False, left_right=True), requires_grad=False) 91 | 92 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 93 | 94 | if self.relative_pos_embedding: 95 | self.relative_indices = get_relative_distances(window_size) + window_size - 1 96 | self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) 97 | else: 98 | self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) 99 | 100 | self.to_out = nn.Linear(inner_dim, dim) 101 | 102 | def forward(self, x): 103 | if self.shifted: 104 | x = self.cyclic_shift(x) 105 | 106 | b, n_h, n_w, _, h = *x.shape, self.heads 107 | 108 | qkv = self.to_qkv(x).chunk(3, dim=-1) 109 | nw_h = n_h // self.window_size 110 | nw_w = n_w // self.window_size 111 | 112 | q, k, v = map( 113 | lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', 114 | h=h, w_h=self.window_size, w_w=self.window_size), qkv) 115 | 116 | dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale 117 | 118 | if self.relative_pos_embedding: 119 | dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]] 120 | else: 121 | dots += self.pos_embedding 122 | 123 | if self.shifted: 124 | dots[:, :, -nw_w:] += self.upper_lower_mask 125 | dots[:, :, nw_w - 1::nw_w] += self.left_right_mask 126 | 127 | attn = dots.softmax(dim=-1) 128 | 129 | out = einsum('b h w i j, b h w j d -> b h w i d', attn, v) 130 | out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', 131 | h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) 132 | out = self.to_out(out) 133 | 134 | if self.shifted: 135 | out = self.cyclic_back_shift(out) 136 | return out 137 | 138 | 139 | class SwinBlock(nn.Module): 140 | def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding): 141 | super().__init__() 142 | self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, 143 | heads=heads, 144 | head_dim=head_dim, 145 | shifted=shifted, 146 | window_size=window_size, 147 | relative_pos_embedding=relative_pos_embedding))) 148 | self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim))) 149 | 150 | def forward(self, x): 151 | x = self.attention_block(x) 152 | x = self.mlp_block(x) 153 | return x 154 | 155 | 156 | class PatchMerging(nn.Module): 157 | def __init__(self, in_channels, out_channels, downscaling_factor): 158 | super().__init__() 159 | self.downscaling_factor = downscaling_factor 160 | self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0) 161 | self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels) 162 | 163 | def forward(self, x): 164 | b, c, h, w = x.shape 165 | new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor 166 | x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1) 167 | x = self.linear(x) 168 | return x 169 | 170 | 171 | class StageModule(nn.Module): 172 | def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size, 173 | relative_pos_embedding): 174 | super().__init__() 175 | assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.' 176 | 177 | self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension, 178 | downscaling_factor=downscaling_factor) 179 | 180 | self.layers = nn.ModuleList([]) 181 | for _ in range(layers // 2): 182 | self.layers.append(nn.ModuleList([ 183 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4, 184 | shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding), 185 | SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4, 186 | shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding), 187 | ])) 188 | 189 | def forward(self, x): 190 | x = self.patch_partition(x) 191 | for regular_block, shifted_block in self.layers: 192 | x = regular_block(x) 193 | x = shifted_block(x) 194 | return x.permute(0, 3, 1, 2) 195 | 196 | 197 | class SwinTransformer(nn.Module): 198 | def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7, 199 | downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True): 200 | super().__init__() 201 | 202 | self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0], 203 | downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim, 204 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 205 | self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1], 206 | downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim, 207 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 208 | self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2], 209 | downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim, 210 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 211 | self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3], 212 | downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim, 213 | window_size=window_size, relative_pos_embedding=relative_pos_embedding) 214 | 215 | self.mlp_head = nn.Sequential( 216 | nn.LayerNorm(hidden_dim * 8), 217 | nn.Linear(hidden_dim * 8, num_classes) 218 | ) 219 | 220 | def forward(self, img): 221 | x = self.stage1(img) 222 | x = self.stage2(x) 223 | x = self.stage3(x) 224 | x = self.stage4(x) 225 | x = x.mean(dim=[2, 3]) 226 | return self.mlp_head(x) 227 | 228 | 229 | def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs): 230 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) 231 | 232 | 233 | def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs): 234 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) 235 | 236 | 237 | def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs): 238 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) 239 | 240 | 241 | def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs): 242 | return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs) -------------------------------------------------------------------------------- /self_supervised_models/launch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --qos=high 3 | #SBATCH --gres=gpu:rtxa6000:1 4 | #SBATCH --time=12:00:00 5 | #SBATCH --ntasks=4 6 | #SBATCH --mem=16G 7 | 8 | python train_ssl.py -------------------------------------------------------------------------------- /self_supervised_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .barlow import BarlowTwinsModel 2 | from .byol import BYOLModel 3 | from .dcl import DCL 4 | from .dclw import DCLW 5 | from .dino import DINOModel 6 | from .moco import MocoModel 7 | from .nnclr import NNCLRModel 8 | from .simclr import SimCLRModel 9 | from .simsiam import SimSiamModel 10 | from .smog import SMoGModel 11 | from .swav import SwaVModel 12 | 13 | -------------------------------------------------------------------------------- /self_supervised_models/models/barlow.py: -------------------------------------------------------------------------------- 1 | import lightly 2 | import torch 3 | import torch.nn as nn 4 | from lightly.models import modules 5 | from lightly.models.modules import heads 6 | from lightly.models import utils 7 | from lightly.utils import BenchmarkModule 8 | 9 | class BarlowTwinsModel(BenchmarkModule): 10 | def __init__(self, dataloader_kNN, num_classes, 11 | lr_factor=0.1, max_epochs=200): 12 | super().__init__(dataloader_kNN, num_classes) 13 | self.lr_factor = lr_factor 14 | self.max_epochs = max_epochs 15 | # create a ResNet backbone and remove the classification head 16 | resnet = lightly.models.ResNetGenerator('resnet-18') 17 | self.backbone = nn.Sequential( 18 | *list(resnet.children())[:-1], 19 | nn.AdaptiveAvgPool2d(1) 20 | ) 21 | # use a 2-layer projection head for cifar10 as described in the paper 22 | self.projection_head = heads.ProjectionHead([ 23 | ( 24 | 512, 25 | 2048, 26 | nn.BatchNorm1d(2048), 27 | nn.ReLU(inplace=True) 28 | ), 29 | ( 30 | 2048, 31 | 2048, 32 | None, 33 | None 34 | ) 35 | ]) 36 | 37 | self.criterion = lightly.loss.BarlowTwinsLoss(gather_distributed=False) 38 | 39 | def forward(self, x): 40 | x = self.backbone(x).flatten(start_dim=1) 41 | z = self.projection_head(x) 42 | return z 43 | 44 | def training_step(self, batch, batch_index): 45 | (x0, x1), _, _ = batch 46 | z0 = self.forward(x0) 47 | z1 = self.forward(x1) 48 | loss = self.criterion(z0, z1) 49 | self.log('train_loss_ssl', loss) 50 | return loss 51 | 52 | def configure_optimizers(self): 53 | optim = torch.optim.SGD( 54 | self.parameters(), 55 | lr=6e-2 * self.lr_factor, 56 | momentum=0.9, 57 | weight_decay=5e-4 58 | ) 59 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 60 | return [optim], [scheduler] 61 | -------------------------------------------------------------------------------- /self_supervised_models/models/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class BYOLModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200, 13 | backbone='resnet-18'): 14 | super().__init__(dataloader_kNN, num_classes) 15 | self.lr_factor = lr_factor 16 | self.max_epochs = max_epochs 17 | # create a ResNet backbone and remove the classification head 18 | self.backbone_name = backbone 19 | resnet = lightly.models.ResNetGenerator(backbone) 20 | self.backbone = nn.Sequential( 21 | *list(resnet.children())[:-1], 22 | nn.AdaptiveAvgPool2d(1) 23 | ) 24 | 25 | # create a byol model based on ResNet 26 | self.projection_head = heads.BYOLProjectionHead(512, 1024, 256) 27 | self.prediction_head = heads.BYOLPredictionHead(256, 1024, 256) 28 | 29 | self.backbone_momentum = copy.deepcopy(self.backbone) 30 | self.projection_head_momentum = copy.deepcopy(self.projection_head) 31 | 32 | utils.deactivate_requires_grad(self.backbone_momentum) 33 | utils.deactivate_requires_grad(self.projection_head_momentum) 34 | 35 | self.criterion = lightly.loss.NegativeCosineSimilarity() 36 | 37 | def forward(self, x): 38 | y = self.backbone(x).flatten(start_dim=1) 39 | z = self.projection_head(y) 40 | p = self.prediction_head(z) 41 | return p 42 | 43 | def forward_momentum(self, x): 44 | y = self.backbone_momentum(x).flatten(start_dim=1) 45 | z = self.projection_head_momentum(y) 46 | z = z.detach() 47 | return z 48 | 49 | def training_step(self, batch, batch_idx): 50 | utils.update_momentum(self.backbone, self.backbone_momentum, m=0.99) 51 | utils.update_momentum(self.projection_head, self.projection_head_momentum, m=0.99) 52 | (x0, x1), _, _ = batch 53 | p0 = self.forward(x0) 54 | z0 = self.forward_momentum(x0) 55 | p1 = self.forward(x1) 56 | z1 = self.forward_momentum(x1) 57 | loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0)) 58 | self.log('train_loss_ssl', loss) 59 | return loss 60 | 61 | def configure_optimizers(self): 62 | params = list(self.backbone.parameters()) \ 63 | + list(self.projection_head.parameters()) \ 64 | + list(self.prediction_head.parameters()) 65 | optim = torch.optim.SGD( 66 | params, 67 | lr=6e-2 * self.lr_factor, 68 | momentum=0.9, 69 | weight_decay=5e-4, 70 | ) 71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 72 | return [optim], [scheduler] 73 | -------------------------------------------------------------------------------- /self_supervised_models/models/dcl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class DCL(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200): 13 | super().__init__(dataloader_kNN, num_classes) 14 | self.lr_factor = lr_factor 15 | self.max_epochs = max_epochs 16 | # create a ResNet backbone and remove the classification head 17 | resnet = lightly.models.ResNetGenerator('resnet-18') 18 | self.backbone = nn.Sequential( 19 | *list(resnet.children())[:-1], 20 | nn.AdaptiveAvgPool2d(1) 21 | ) 22 | self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) 23 | self.criterion = lightly.loss.DCLLoss() 24 | 25 | def forward(self, x): 26 | x = self.backbone(x).flatten(start_dim=1) 27 | z = self.projection_head(x) 28 | return z 29 | 30 | def training_step(self, batch, batch_index): 31 | (x0, x1), _, _ = batch 32 | z0 = self.forward(x0) 33 | z1 = self.forward(x1) 34 | loss = self.criterion(z0, z1) 35 | self.log('train_loss_ssl', loss) 36 | return loss 37 | 38 | def configure_optimizers(self): 39 | optim = torch.optim.SGD( 40 | self.parameters(), 41 | lr=6e-2 * self.lr_factor, 42 | momentum=0.9, 43 | weight_decay=5e-4 44 | ) 45 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 46 | return [optim], [scheduler] 47 | 48 | -------------------------------------------------------------------------------- /self_supervised_models/models/dclw.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class DCLW(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200): 13 | super().__init__(dataloader_kNN, num_classes) 14 | self.lr_factor = lr_factor 15 | self.max_epochs = max_epochs 16 | # create a ResNet backbone and remove the classification head 17 | resnet = lightly.models.ResNetGenerator('resnet-18') 18 | self.backbone = nn.Sequential( 19 | *list(resnet.children())[:-1], 20 | nn.AdaptiveAvgPool2d(1) 21 | ) 22 | self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) 23 | self.criterion = lightly.loss.DCLWLoss() 24 | 25 | def forward(self, x): 26 | x = self.backbone(x).flatten(start_dim=1) 27 | z = self.projection_head(x) 28 | return z 29 | 30 | def training_step(self, batch, batch_index): 31 | (x0, x1), _, _ = batch 32 | z0 = self.forward(x0) 33 | z1 = self.forward(x1) 34 | loss = self.criterion(z0, z1) 35 | self.log('train_loss_ssl', loss) 36 | return loss 37 | 38 | def configure_optimizers(self): 39 | optim = torch.optim.SGD( 40 | self.parameters(), 41 | lr=6e-2 * self.lr_factor, 42 | momentum=0.9, 43 | weight_decay=5e-4 44 | ) 45 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 46 | return [optim], [scheduler] 47 | 48 | -------------------------------------------------------------------------------- /self_supervised_models/models/dino.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class DINOModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200, 13 | backbone='resnet-18'): 14 | super().__init__(dataloader_kNN, num_classes) 15 | self.lr_factor = lr_factor 16 | self.max_epochs = max_epochs 17 | # create a ResNet backbone and remove the classification head 18 | self.backbone_name = backbone 19 | resnet = lightly.models.ResNetGenerator(backbone) 20 | self.backbone = nn.Sequential( 21 | *list(resnet.children())[:-1], 22 | nn.AdaptiveAvgPool2d(1) 23 | ) 24 | self.head = self._build_projection_head() 25 | self.teacher_backbone = copy.deepcopy(self.backbone) 26 | self.teacher_head = self._build_projection_head() 27 | 28 | utils.deactivate_requires_grad(self.teacher_backbone) 29 | utils.deactivate_requires_grad(self.teacher_head) 30 | 31 | self.criterion = lightly.loss.DINOLoss(output_dim=2048) 32 | 33 | def _build_projection_head(self): 34 | head = heads.DINOProjectionHead(512, 2048, 256, 2048, batch_norm=True) 35 | # use only 2 layers for cifar10 36 | head.layers = heads.ProjectionHead([ 37 | (512, 2048, nn.BatchNorm1d(2048), nn.GELU()), 38 | (2048, 256, None, None), 39 | ]).layers 40 | return head 41 | 42 | def forward(self, x): 43 | y = self.backbone(x).flatten(start_dim=1) 44 | z = self.head(y) 45 | return z 46 | 47 | def forward_teacher(self, x): 48 | y = self.teacher_backbone(x).flatten(start_dim=1) 49 | z = self.teacher_head(y) 50 | return z 51 | 52 | def training_step(self, batch, batch_idx): 53 | utils.update_momentum(self.backbone, self.teacher_backbone, m=0.99) 54 | utils.update_momentum(self.head, self.teacher_head, m=0.99) 55 | views, _, _ = batch 56 | views = [view.to(self.device) for view in views] 57 | global_views = views[:2] 58 | teacher_out = [self.forward_teacher(view) for view in global_views] 59 | student_out = [self.forward(view) for view in views] 60 | loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch) 61 | self.log('train_loss_ssl', loss) 62 | return loss 63 | 64 | def configure_optimizers(self): 65 | param = list(self.backbone.parameters()) \ 66 | + list(self.head.parameters()) 67 | optim = torch.optim.SGD( 68 | param, 69 | lr=6e-2 * self.lr_factor, 70 | momentum=0.9, 71 | weight_decay=5e-4, 72 | ) 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 74 | return [optim], [scheduler] 75 | 76 | -------------------------------------------------------------------------------- /self_supervised_models/models/moco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class MocoModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200, 13 | backbone='resnet-18'): 14 | super().__init__(dataloader_kNN, num_classes) 15 | self.lr_factor = lr_factor 16 | self.max_epochs = max_epochs 17 | # TODO: HARDCODED 18 | self.sync_batchnorm = False 19 | self.distributed = False 20 | # create a ResNet backbone and remove the classification head 21 | num_splits = 0 if self.sync_batchnorm else 8 22 | self.backbone_name = backbone 23 | resnet = lightly.models.ResNetGenerator(backbone, num_splits=num_splits) 24 | self.backbone = nn.Sequential( 25 | *list(resnet.children())[:-1], 26 | nn.AdaptiveAvgPool2d(1) 27 | ) 28 | 29 | # create a moco model based on ResNet 30 | self.projection_head = heads.MoCoProjectionHead(512, 512, 128) 31 | self.backbone_momentum = copy.deepcopy(self.backbone) 32 | self.projection_head_momentum = copy.deepcopy(self.projection_head) 33 | utils.deactivate_requires_grad(self.backbone_momentum) 34 | utils.deactivate_requires_grad(self.projection_head_momentum) 35 | 36 | # create our loss with the optional memory bank 37 | self.criterion = lightly.loss.NTXentLoss( 38 | temperature=0.1, 39 | memory_bank_size=4096, 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.backbone(x).flatten(start_dim=1) 44 | return self.projection_head(x) 45 | 46 | def training_step(self, batch, batch_idx): 47 | (x0, x1), _, _ = batch 48 | 49 | # update momentum 50 | utils.update_momentum(self.backbone, self.backbone_momentum, 0.99) 51 | utils.update_momentum(self.projection_head, self.projection_head_momentum, 0.99) 52 | 53 | def step(x0_, x1_): 54 | x1_, shuffle = utils.batch_shuffle(x1_, distributed=self.distributed) 55 | x0_ = self.backbone(x0_).flatten(start_dim=1) 56 | x0_ = self.projection_head(x0_) 57 | 58 | x1_ = self.backbone_momentum(x1_).flatten(start_dim=1) 59 | x1_ = self.projection_head_momentum(x1_) 60 | x1_ = utils.batch_unshuffle(x1_, shuffle, distributed=self.distributed) 61 | return x0_, x1_ 62 | 63 | # We use a symmetric loss (model trains faster at little compute overhead) 64 | # https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb 65 | loss_1 = self.criterion(*step(x0, x1)) 66 | loss_2 = self.criterion(*step(x1, x0)) 67 | 68 | loss = 0.5 * (loss_1 + loss_2) 69 | self.log('train_loss_ssl', loss) 70 | return loss 71 | 72 | def configure_optimizers(self): 73 | params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) 74 | optim = torch.optim.SGD( 75 | params, 76 | lr=6e-2 * self.lr_factor, 77 | momentum=0.9, 78 | weight_decay=5e-4, 79 | ) 80 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 81 | return [optim], [scheduler] 82 | -------------------------------------------------------------------------------- /self_supervised_models/models/nnclr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class NNCLRModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200): 13 | super().__init__(dataloader_kNN, num_classes) 14 | self.lr_factor = lr_factor 15 | self.max_epochs = max_epochs 16 | # create a ResNet backbone and remove the classification head 17 | resnet = lightly.models.ResNetGenerator('resnet-18') 18 | self.backbone = nn.Sequential( 19 | *list(resnet.children())[:-1], 20 | nn.AdaptiveAvgPool2d(1) 21 | ) 22 | self.prediction_head = heads.NNCLRPredictionHead(256, 4096, 256) 23 | # use only a 2-layer projection head for cifar10 24 | self.projection_head = heads.ProjectionHead([ 25 | ( 26 | 512, 27 | 2048, 28 | nn.BatchNorm1d(2048), 29 | nn.ReLU(inplace=True) 30 | ), 31 | ( 32 | 2048, 33 | 256, 34 | nn.BatchNorm1d(256), 35 | None 36 | ) 37 | ]) 38 | 39 | self.criterion = lightly.loss.NTXentLoss() 40 | self.memory_bank = modules.NNMemoryBankModule(size=4096) 41 | 42 | def forward(self, x): 43 | y = self.backbone(x).flatten(start_dim=1) 44 | z = self.projection_head(y) 45 | p = self.prediction_head(z) 46 | z = z.detach() 47 | return z, p 48 | 49 | def training_step(self, batch, batch_idx): 50 | (x0, x1), _, _ = batch 51 | z0, p0 = self.forward(x0) 52 | z1, p1 = self.forward(x1) 53 | z0 = self.memory_bank(z0, update=False) 54 | z1 = self.memory_bank(z1, update=True) 55 | loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0)) 56 | return loss 57 | 58 | def configure_optimizers(self): 59 | optim = torch.optim.SGD( 60 | self.parameters(), 61 | lr=6e-2 * self.lr_factor, 62 | momentum=0.9, 63 | weight_decay=5e-4, 64 | ) 65 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 66 | return [optim], [scheduler] 67 | 68 | -------------------------------------------------------------------------------- /self_supervised_models/models/simclr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class SimCLRModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200, 13 | backbone='resnet-18'): 14 | super().__init__(dataloader_kNN, num_classes) 15 | self.lr_factor = lr_factor 16 | self.max_epochs = max_epochs 17 | # create a ResNet backbone and remove the classification head 18 | self.backbone_name = backbone 19 | resnet = lightly.models.ResNetGenerator(backbone) 20 | self.backbone = nn.Sequential( 21 | *list(resnet.children())[:-1], 22 | nn.AdaptiveAvgPool2d(1) 23 | ) 24 | self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) 25 | self.criterion = lightly.loss.NTXentLoss() 26 | 27 | def forward(self, x): 28 | x = self.backbone(x).flatten(start_dim=1) 29 | z = self.projection_head(x) 30 | return z 31 | 32 | def training_step(self, batch, batch_index): 33 | (x0, x1), _, _ = batch 34 | z0 = self.forward(x0) 35 | z1 = self.forward(x1) 36 | loss = self.criterion(z0, z1) 37 | self.log('train_loss_ssl', loss) 38 | return loss 39 | 40 | def configure_optimizers(self): 41 | optim = torch.optim.SGD( 42 | self.parameters(), 43 | lr=6e-2 * self.lr_factor, 44 | momentum=0.9, 45 | weight_decay=5e-4 46 | ) 47 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 48 | return [optim], [scheduler] 49 | -------------------------------------------------------------------------------- /self_supervised_models/models/simsiam.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import lightly 3 | import torch 4 | import torch.nn as nn 5 | from lightly.models import modules 6 | from lightly.models.modules import heads 7 | from lightly.models import utils 8 | from lightly.utils import BenchmarkModule 9 | 10 | class SimSiamModel(BenchmarkModule): 11 | def __init__(self, dataloader_kNN, num_classes, 12 | lr_factor=0.1, max_epochs=200): 13 | super().__init__(dataloader_kNN, num_classes) 14 | self.lr_factor = lr_factor 15 | self.max_epochs = max_epochs 16 | # create a ResNet backbone and remove the classification head 17 | resnet = lightly.models.ResNetGenerator('resnet-18') 18 | self.backbone = nn.Sequential( 19 | *list(resnet.children())[:-1], 20 | nn.AdaptiveAvgPool2d(1) 21 | ) 22 | self.prediction_head = heads.SimSiamPredictionHead(2048, 512, 2048) 23 | # use a 2-layer projection head for cifar10 as described in the paper 24 | self.projection_head = heads.ProjectionHead([ 25 | ( 26 | 512, 27 | 2048, 28 | nn.BatchNorm1d(2048), 29 | nn.ReLU(inplace=True) 30 | ), 31 | ( 32 | 2048, 33 | 2048, 34 | nn.BatchNorm1d(2048), 35 | None 36 | ) 37 | ]) 38 | self.criterion = lightly.loss.NegativeCosineSimilarity() 39 | 40 | def forward(self, x): 41 | f = self.backbone(x).flatten(start_dim=1) 42 | z = self.projection_head(f) 43 | p = self.prediction_head(z) 44 | z = z.detach() 45 | return z, p 46 | 47 | def training_step(self, batch, batch_idx): 48 | (x0, x1), _, _ = batch 49 | z0, p0 = self.forward(x0) 50 | z1, p1 = self.forward(x1) 51 | loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0)) 52 | self.log('train_loss_ssl', loss) 53 | return loss 54 | 55 | def configure_optimizers(self): 56 | optim = torch.optim.SGD( 57 | self.parameters(), 58 | lr=6e-2, # no lr-scaling, results in better training stability 59 | momentum=0.9, 60 | weight_decay=5e-4 61 | ) 62 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 63 | return [optim], [scheduler] 64 | -------------------------------------------------------------------------------- /self_supervised_models/models/smog.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import lightly 4 | import torch 5 | import torch.nn as nn 6 | from lightly.models import modules 7 | from lightly.models.modules import heads 8 | from lightly.models import utils 9 | from lightly.utils import BenchmarkModule 10 | from sklearn.cluster import KMeans 11 | 12 | class SMoGModel(BenchmarkModule): 13 | 14 | def __init__(self, dataloader_kNN, num_classes, 15 | lr_factor=0.1, max_epochs=200): 16 | super().__init__(dataloader_kNN, num_classes) 17 | 18 | self.lr_factor = lr_factor 19 | self.max_epochs = max_epochs 20 | # create a ResNet backbone and remove the classification head 21 | resnet = lightly.models.ResNetGenerator('resnet-18') 22 | self.backbone = nn.Sequential( 23 | *list(resnet.children())[:-1], 24 | nn.AdaptiveAvgPool2d(1) 25 | ) 26 | 27 | # create a model based on ResNet 28 | self.projection_head = heads.SMoGProjectionHead(512, 2048, 128) 29 | self.prediction_head = heads.SMoGPredictionHead(128, 2048, 128) 30 | self.backbone_momentum = copy.deepcopy(self.backbone) 31 | self.projection_head_momentum = copy.deepcopy(self.projection_head) 32 | utils.deactivate_requires_grad(self.backbone_momentum) 33 | utils.deactivate_requires_grad(self.projection_head_momentum) 34 | 35 | # smog 36 | self.n_groups = 300 37 | memory_bank_size = 10000 38 | self.memory_bank = lightly.loss.memory_bank.MemoryBankModule(size=memory_bank_size) 39 | # create our loss 40 | group_features = torch.nn.functional.normalize( 41 | torch.rand(self.n_groups, 128), dim=1 42 | ) 43 | self.smog = heads.SMoGPrototypes(group_features=group_features, beta=0.99) 44 | self.criterion = nn.CrossEntropyLoss() 45 | 46 | def _cluster_features(self, features: torch.Tensor) -> torch.Tensor: 47 | features = features.cpu().numpy() 48 | kmeans = KMeans(self.n_groups).fit(features) 49 | clustered = torch.from_numpy(kmeans.cluster_centers_).float() 50 | clustered = torch.nn.functional.normalize(clustered, dim=1) 51 | return clustered 52 | 53 | def _reset_group_features(self): 54 | # see https://arxiv.org/pdf/2207.06167.pdf Table 7b) 55 | features = self.memory_bank.bank 56 | group_features = self._cluster_features(features.t()) 57 | self.smog.set_group_features(group_features) 58 | 59 | def _reset_momentum_weights(self): 60 | # see https://arxiv.org/pdf/2207.06167.pdf Table 7b) 61 | self.backbone_momentum = copy.deepcopy(self.backbone) 62 | self.projection_head_momentum = copy.deepcopy(self.projection_head) 63 | utils.deactivate_requires_grad(self.backbone_momentum) 64 | utils.deactivate_requires_grad(self.projection_head_momentum) 65 | 66 | def training_step(self, batch, batch_idx): 67 | 68 | if self.global_step > 0 and self.global_step % 300 == 0: 69 | # reset group features and weights every 300 iterations 70 | self._reset_group_features() 71 | self._reset_momentum_weights() 72 | else: 73 | # update momentum 74 | utils.update_momentum(self.backbone, self.backbone_momentum, 0.99) 75 | utils.update_momentum(self.projection_head, self.projection_head_momentum, 0.99) 76 | 77 | (x0, x1), _, _ = batch 78 | 79 | if batch_idx % 2: 80 | # swap batches every second iteration 81 | x0, x1 = x1, x0 82 | 83 | x0_features = self.backbone(x0).flatten(start_dim=1) 84 | x0_encoded = self.projection_head(x0_features) 85 | x0_predicted = self.prediction_head(x0_encoded) 86 | x1_features = self.backbone_momentum(x1).flatten(start_dim=1) 87 | x1_encoded = self.projection_head_momentum(x1_features) 88 | 89 | # update group features and get group assignments 90 | assignments = self.smog.assign_groups(x1_encoded) 91 | group_features = self.smog.get_updated_group_features(x0_encoded) 92 | logits = self.smog(x0_predicted, group_features, temperature=0.1) 93 | self.smog.set_group_features(group_features) 94 | 95 | loss = self.criterion(logits, assignments) 96 | 97 | # use memory bank to periodically reset the group features with k-means 98 | self.memory_bank(x0_encoded, update=True) 99 | 100 | return loss 101 | 102 | def configure_optimizers(self): 103 | params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) + list(self.prediction_head.parameters()) 104 | optim = torch.optim.SGD( 105 | params, 106 | lr=0.01, 107 | momentum=0.9, 108 | weight_decay=1e-6, 109 | ) 110 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 111 | return [optim], [scheduler] 112 | 113 | -------------------------------------------------------------------------------- /self_supervised_models/models/swav.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import lightly 4 | import torch 5 | import torch.nn as nn 6 | from lightly.models import modules 7 | from lightly.models.modules import heads 8 | from lightly.models import utils 9 | from lightly.utils import BenchmarkModule 10 | 11 | class SwaVModel(BenchmarkModule): 12 | def __init__(self, dataloader_kNN, num_classes, 13 | lr_factor=0.1, max_epochs=200): 14 | super().__init__(dataloader_kNN, num_classes) 15 | self.lr_factor = lr_factor 16 | self.max_epochs = max_epochs 17 | # self.gather_distr 18 | # create a ResNet backbone and remove the classification head 19 | resnet = lightly.models.ResNetGenerator('resnet-18') 20 | self.backbone = nn.Sequential( 21 | *list(resnet.children())[:-1], 22 | nn.AdaptiveAvgPool2d(1) 23 | ) 24 | 25 | self.projection_head = heads.SwaVProjectionHead(512, 512, 128) 26 | self.prototypes = heads.SwaVPrototypes(128, 512) # use 512 prototypes 27 | 28 | self.criterion = lightly.loss.SwaVLoss(sinkhorn_gather_distributed=False) 29 | 30 | def forward(self, x): 31 | x = self.backbone(x).flatten(start_dim=1) 32 | x = self.projection_head(x) 33 | x = nn.functional.normalize(x, dim=1, p=2) 34 | return self.prototypes(x) 35 | 36 | def training_step(self, batch, batch_idx): 37 | # normalize the prototypes so they are on the unit sphere 38 | self.prototypes.normalize() 39 | 40 | # the multi-crop dataloader returns a list of image crops where the 41 | # first two items are the high resolution crops and the rest are low 42 | # resolution crops 43 | multi_crops, _, _ = batch 44 | multi_crop_features = [self.forward(x) for x in multi_crops] 45 | 46 | # split list of crop features into high and low resolution 47 | high_resolution_features = multi_crop_features[:2] 48 | low_resolution_features = multi_crop_features[2:] 49 | 50 | # calculate the SwaV loss 51 | loss = self.criterion( 52 | high_resolution_features, 53 | low_resolution_features 54 | ) 55 | 56 | self.log('train_loss_ssl', loss) 57 | return loss 58 | 59 | def configure_optimizers(self): 60 | optim = torch.optim.Adam( 61 | self.parameters(), 62 | lr=1e-3 * self.lr_factor, 63 | weight_decay=1e-6, 64 | ) 65 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs) 66 | return [optim], [scheduler] 67 | -------------------------------------------------------------------------------- /self_supervised_models/train_ssl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Benchmark Results 4 | 5 | Updated: 18.02.2022 (6618fa3c36b0c9f3a9d7a21bcdb00bf4fd258ee8)) 6 | 7 | ------------------------------------------------------------------------------------------ 8 | | Model | Batch Size | Epochs | KNN Test Accuracy | Time | Peak GPU Usage | 9 | ------------------------------------------------------------------------------------------ 10 | | BarlowTwins | 128 | 200 | 0.835 | 193.4 Min | 2.2 GByte | 11 | | BYOL | 128 | 200 | 0.872 | 217.0 Min | 2.3 GByte | 12 | | DCL (*) | 128 | 200 | 0.842 | 126.9 Min | 1.7 GByte | 13 | | DCLW (*) | 128 | 200 | 0.833 | 127.5 Min | 1.8 GByte | 14 | | DINO | 128 | 200 | 0.868 | 220.7 Min | 2.3 GByte | 15 | | Moco | 128 | 200 | 0.838 | 229.5 Min | 2.3 GByte | 16 | | NNCLR | 128 | 200 | 0.838 | 198.7 Min | 2.2 GByte | 17 | | SimCLR | 128 | 200 | 0.822 | 182.7 Min | 2.2 GByte | 18 | | SimSiam | 128 | 200 | 0.779 | 182.6 Min | 2.3 GByte | 19 | | SwaV | 128 | 200 | 0.806 | 182.4 Min | 2.2 GByte | 20 | ------------------------------------------------------------------------------------------ 21 | | BarlowTwins | 512 | 200 | 0.827 | 160.7 Min | 7.5 GByte | 22 | | BYOL | 512 | 200 | 0.872 | 188.5 Min | 7.7 GByte | 23 | | DCL (*) | 512 | 200 | 0.834 | 113.6 Min | 6.1 GByte | 24 | | DCLW (*) | 512 | 200 | 0.830 | 113.8 Min | 6.2 GByte | 25 | | DINO | 512 | 200 | 0.862 | 191.1 Min | 7.5 GByte | 26 | | Moco (**) | 512 | 200 | 0.850 | 196.8 Min | 7.8 GByte | 27 | | NNCLR (**) | 512 | 200 | 0.836 | 164.7 Min | 7.6 GByte | 28 | | SimCLR | 512 | 200 | 0.828 | 158.2 Min | 7.5 GByte | 29 | | SimSiam | 512 | 200 | 0.814 | 159.0 Min | 7.6 GByte | 30 | | SwaV | 512 | 200 | 0.833 | 158.4 Min | 7.5 GByte | 31 | ------------------------------------------------------------------------------------------ 32 | | BarlowTwins | 512 | 800 | 0.857 | 641.5 Min | 7.5 GByte | 33 | | BYOL | 512 | 800 | 0.911 | 754.2 Min | 7.8 GByte | 34 | | DCL (*) | 512 | 800 | 0.873 | 459.6 Min | 6.1 GByte | 35 | | DCLW (*) | 512 | 800 | 0.873 | 455.8 Min | 6.1 GByte | 36 | | DINO | 512 | 800 | 0.884 | 765.5 Min | 7.6 GByte | 37 | | Moco (**) | 512 | 800 | 0.900 | 787.7 Min | 7.8 GByte | 38 | | NNCLR (**) | 512 | 800 | 0.896 | 659.2 Min | 7.6 GByte | 39 | | SimCLR | 512 | 800 | 0.875 | 632.5 Min | 7.5 GByte | 40 | | SimSiam | 512 | 800 | 0.906 | 636.5 Min | 7.6 GByte | 41 | | SwaV | 512 | 800 | 0.881 | 634.9 Min | 7.5 GByte | 42 | ------------------------------------------------------------------------------------------ 43 | 44 | (*): Smaller runtime and memory requirements due to different hardware settings 45 | and pytorch version. Runtime and memory requirements are comparable to SimCLR 46 | with the default settings. 47 | (**): Increased size of memory bank from 4096 to 8192 to avoid too quickly 48 | changing memory bank due to larger batch size. 49 | 50 | The benchmarks were created on a single NVIDIA RTX A6000. 51 | 52 | Note that this benchmark also supports a multi-GPU setup. If you run it on 53 | a system with multiple GPUs make sure that you kill all the processes when 54 | killing the application. Due to the way we setup this benchmark the distributed 55 | processes might continue the benchmark if one of the nodes is killed. 56 | If you know how to fix this don't hesitate to create an issue or PR :) 57 | 58 | """ 59 | import copy 60 | import os 61 | 62 | import time 63 | import lightly 64 | import numpy as np 65 | import pytorch_lightning as pl 66 | import torch 67 | import torch.nn as nn 68 | import torchvision 69 | from lightly.models import modules 70 | from lightly.models.modules import heads 71 | from lightly.models import utils 72 | from lightly.utils import BenchmarkModule 73 | from pytorch_lightning.loggers import TensorBoardLogger 74 | from models import * 75 | import argparse 76 | 77 | parser = argparse.ArgumentParser(description='Benchmark SSL models') 78 | parser.add_argument('--model', type=str, default='SimCLR', help='SSL model to benchmark') 79 | parser.add_argument('--backbone', type=str, default='resnet-18', help='Model backbone') 80 | args = parser.parse_args() 81 | 82 | logs_root_dir = os.path.join(os.getcwd(), 'benchmark_logs') 83 | 84 | # set max_epochs to 800 for long run (takes around 10h on a single V100) 85 | max_epochs = 800 86 | num_workers = 8 87 | knn_k = 200 88 | knn_t = 0.1 89 | classes = 10 90 | 91 | # Set to True to enable Distributed Data Parallel training. 92 | distributed = False 93 | 94 | # Set to True to enable Synchronized Batch Norm (requires distributed=True). 95 | # If enabled the batch norm is calculated over all gpus, otherwise the batch 96 | # norm is only calculated from samples on the same gpu. 97 | sync_batchnorm = False 98 | 99 | # Set to True to gather features from all gpus before calculating 100 | # the loss (requires distributed=True). 101 | # If enabled then the loss on every gpu is calculated with features from all 102 | # gpus, otherwise only features from the same gpu are used. 103 | gather_distributed = False 104 | 105 | # benchmark 106 | n_runs = 1 # optional, increase to create multiple runs and report mean + std 107 | batch_size = 512 108 | lr_factor = batch_size / 128 # scales the learning rate linearly with batch size 109 | 110 | # use a GPU if available 111 | gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 112 | 113 | if distributed: 114 | distributed_backend = 'ddp' 115 | # reduce batch size for distributed training 116 | batch_size = batch_size // gpus 117 | else: 118 | distributed_backend = None 119 | # limit to single gpu if not using distributed training 120 | gpus = min(gpus, 1) 121 | 122 | # Adapted from our MoCo Tutorial on CIFAR-10 123 | # 124 | # Replace the path with the location of your CIFAR-10 dataset. 125 | # We assume we have a train folder with subfolders 126 | # for each class and .png images inside. 127 | # 128 | # You can download `CIFAR-10 in folders from kaggle 129 | # `_. 130 | 131 | # The dataset structure should be like this: 132 | # cifar10/train/ 133 | # L airplane/ 134 | # L 10008_airplane.png 135 | # L ... 136 | # L automobile/ 137 | # L bird/ 138 | # L cat/ 139 | # L deer/ 140 | # L dog/ 141 | # L frog/ 142 | # L horse/ 143 | # L ship/ 144 | # L truck/ 145 | path_to_train = os.path.join(os.getenv('NEXUS_DIR'), 'datasets/cifar10/cifar10/train/') 146 | path_to_test = os.path.join(os.getenv('NEXUS_DIR'), 'datasets/cifar10/cifar10/test/') 147 | 148 | # Use SimCLR augmentations, additionally, disable blur for cifar10 149 | collate_fn = lightly.data.SimCLRCollateFunction( 150 | input_size=32, 151 | gaussian_blur=0., 152 | ) 153 | 154 | # Multi crop augmentation for SwAV, additionally, disable blur for cifar10 155 | swav_collate_fn = lightly.data.SwaVCollateFunction( 156 | crop_sizes=[32], 157 | crop_counts=[2], # 2 crops @ 32x32px 158 | crop_min_scales=[0.14], 159 | gaussian_blur=0, 160 | ) 161 | 162 | # Multi crop augmentation for DINO, additionally, disable blur for cifar10 163 | dino_collate_fn = lightly.data.DINOCollateFunction( 164 | global_crop_size=32, 165 | n_local_views=0, 166 | gaussian_blur=(0, 0, 0), 167 | ) 168 | 169 | # Two crops for SMoG 170 | smog_collate_function = lightly.data.collate.SMoGCollateFunction( 171 | crop_sizes=[32, 32], 172 | crop_counts=[1, 1], 173 | gaussian_blur_probs=[0., 0.], 174 | crop_min_scales=[0.2, 0.2], 175 | crop_max_scales=[1.0, 1.0], 176 | ) 177 | 178 | # No additional augmentations for the test set 179 | test_transforms = torchvision.transforms.Compose([ 180 | torchvision.transforms.ToTensor(), 181 | torchvision.transforms.Normalize( 182 | mean=lightly.data.collate.imagenet_normalize['mean'], 183 | std=lightly.data.collate.imagenet_normalize['std'], 184 | ) 185 | ]) 186 | 187 | dataset_train_ssl = lightly.data.LightlyDataset( 188 | input_dir=path_to_train 189 | ) 190 | 191 | # we use test transformations for getting the feature for kNN on train data 192 | dataset_train_kNN = lightly.data.LightlyDataset( 193 | input_dir=path_to_train, 194 | transform=test_transforms 195 | ) 196 | 197 | dataset_test = lightly.data.LightlyDataset( 198 | input_dir=path_to_test, 199 | transform=test_transforms 200 | ) 201 | 202 | def get_data_loaders(batch_size: int, model): 203 | """Helper method to create dataloaders for ssl, kNN train and kNN test 204 | 205 | Args: 206 | batch_size: Desired batch size for all dataloaders 207 | """ 208 | col_fn = collate_fn 209 | if model == SwaVModel: 210 | col_fn = swav_collate_fn 211 | elif model == DINOModel: 212 | col_fn = dino_collate_fn 213 | elif model == SMoGModel: 214 | col_fn = smog_collate_function 215 | dataloader_train_ssl = torch.utils.data.DataLoader( 216 | dataset_train_ssl, 217 | batch_size=batch_size, 218 | shuffle=True, 219 | collate_fn=col_fn, 220 | drop_last=True, 221 | num_workers=num_workers 222 | ) 223 | 224 | dataloader_train_kNN = torch.utils.data.DataLoader( 225 | dataset_train_kNN, 226 | batch_size=batch_size, 227 | shuffle=False, 228 | drop_last=False, 229 | num_workers=num_workers 230 | ) 231 | 232 | dataloader_test = torch.utils.data.DataLoader( 233 | dataset_test, 234 | batch_size=batch_size, 235 | shuffle=False, 236 | drop_last=False, 237 | num_workers=num_workers 238 | ) 239 | 240 | return dataloader_train_ssl, dataloader_train_kNN, dataloader_test 241 | 242 | 243 | 244 | models = [ 245 | eval(args.model) 246 | # BarlowTwinsModel, 247 | # BYOLModel, 248 | # DCL, 249 | # DCLW, 250 | # DINOModel, 251 | # MocoModel, 252 | # NNCLRModel, 253 | # SimCLRModel, 254 | # SimSiamModel, 255 | # SwaVModel, 256 | # SMoGModel 257 | ] 258 | bench_results = dict() 259 | 260 | experiment_version = None 261 | # loop through configurations and train models 262 | for BenchmarkModel in models: 263 | runs = [] 264 | model_name = BenchmarkModel.__name__.replace('Model', '') 265 | for seed in range(n_runs): 266 | pl.seed_everything(seed) 267 | dataloader_train_ssl, dataloader_train_kNN, dataloader_test = get_data_loaders( 268 | batch_size=batch_size, 269 | model=BenchmarkModel, 270 | ) 271 | benchmark_model = BenchmarkModel(dataloader_train_kNN, classes, 272 | lr_factor, max_epochs, backbone = args.backbone) 273 | 274 | # Save logs to: {CWD}/benchmark_logs/cifar10/{experiment_version}/{model_name}/ 275 | # If multiple runs are specified a subdirectory for each run is created. 276 | sub_dir = model_name if n_runs <= 1 else f'{model_name}/run{seed}' 277 | logger = TensorBoardLogger( 278 | save_dir=os.path.join(logs_root_dir, 'cifar10'), 279 | name='', 280 | sub_dir=sub_dir, 281 | version=experiment_version, 282 | ) 283 | if experiment_version is None: 284 | # Save results of all models under same version directory 285 | experiment_version = logger.version 286 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 287 | dirpath=os.path.join(logger.log_dir, 'checkpoints') 288 | ) 289 | trainer = pl.Trainer( 290 | max_epochs=max_epochs, 291 | gpus=gpus, 292 | default_root_dir=logs_root_dir, 293 | strategy=distributed_backend, 294 | sync_batchnorm=sync_batchnorm, 295 | logger=logger, 296 | callbacks=[checkpoint_callback] 297 | ) 298 | start = time.time() 299 | trainer.fit( 300 | benchmark_model, 301 | train_dataloaders=dataloader_train_ssl, 302 | val_dataloaders=dataloader_test 303 | ) 304 | end = time.time() 305 | run = { 306 | 'model': model_name, 307 | 'batch_size': batch_size, 308 | 'epochs': max_epochs, 309 | 'max_accuracy': benchmark_model.max_accuracy, 310 | 'runtime': end - start, 311 | 'gpu_memory_usage': torch.cuda.max_memory_allocated(), 312 | 'seed': seed, 313 | } 314 | runs.append(run) 315 | print(run) 316 | 317 | # delete model and trainer + free up cuda memory 318 | del benchmark_model 319 | del trainer 320 | torch.cuda.reset_peak_memory_stats() 321 | torch.cuda.empty_cache() 322 | 323 | bench_results[model_name] = runs 324 | 325 | # print results table 326 | header = ( 327 | f"| {'Model':<13} | {'Batch Size':>10} | {'Epochs':>6} " 328 | f"| {'KNN Test Accuracy':>18} | {'Time':>10} | {'Peak GPU Usage':>14} |" 329 | ) 330 | print('-' * len(header)) 331 | print(header) 332 | print('-' * len(header)) 333 | for model, results in bench_results.items(): 334 | runtime = np.array([result['runtime'] for result in results]) 335 | runtime = runtime.mean() / 60 # convert to min 336 | accuracy = np.array([result['max_accuracy'] for result in results]) 337 | gpu_memory_usage = np.array([result['gpu_memory_usage'] for result in results]) 338 | gpu_memory_usage = gpu_memory_usage.max() / (1024**3) # convert to gbyte 339 | 340 | if len(accuracy) > 1: 341 | accuracy_msg = f"{accuracy.mean():>8.3f} +- {accuracy.std():>4.3f}" 342 | else: 343 | accuracy_msg = f"{accuracy.mean():>18.3f}" 344 | 345 | print( 346 | f"| {model:<13} | {batch_size:>10} | {max_epochs:>6} " 347 | f"| {accuracy_msg} | {runtime:>6.1f} Min " 348 | f"| {gpu_memory_usage:>8.1f} GByte |", 349 | flush=True 350 | ) 351 | print('-' * len(header)) 352 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch as ch 4 | import torchvision 5 | import numpy as np 6 | from tqdm import tqdm 7 | import time 8 | import argparse 9 | import ipdb 10 | 11 | from typing import List, Callable, Tuple, Optional 12 | from ffcv.fields import IntField, RGBImageField 13 | from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder 14 | from ffcv.loader import Loader, OrderOption 15 | from ffcv.pipeline.operation import Operation 16 | from ffcv.transforms import RandomHorizontalFlip, Cutout, \ 17 | RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage 18 | from ffcv.transforms.common import Squeeze 19 | from ffcv.writer import DatasetWriter 20 | from ffcv.pipeline.state import State 21 | from ffcv.pipeline.operation import AllocationQuery, Operation 22 | from ffcv.pipeline.compiler import Compiler 23 | from torch.cuda.amp import GradScaler, autocast 24 | from torch.nn import CrossEntropyLoss 25 | from torch.optim import SGD, lr_scheduler 26 | 27 | from models.resnet import resnet9 28 | from models.swin import swin_t 29 | from models.mobilenetv2 import MobileNetV2 30 | from constants import CIFAR_MEAN, CIFAR_STD 31 | 32 | # def parse_args(): 33 | # parser = argparse.ArgumentParser() 34 | # parser.add_argument('--batch_size', type=int, default=512) 35 | # parser.add_argument('--num_workers', type=int, default=4) 36 | # parser.add_argument('--epochs', type=int, default=24) 37 | # parser.add_argument('--train-samples-remove') 38 | 39 | class CorruptFixedLabels(Operation): 40 | def __init__(self, flip_class, corrupt_idxs=None): 41 | super().__init__() 42 | self.flip_class = flip_class 43 | self.corrupt_idxs = corrupt_idxs 44 | 45 | def generate_code(self) -> Callable: 46 | # dst will be None since we don't ask for an allocation 47 | parallel_range = Compiler.get_iterator() 48 | corrupt_idxs = self.corrupt_idxs 49 | def corrupt_fixed(labs, _, inds): 50 | for iter_idx in parallel_range(labs.shape[0]): 51 | dset_idx = inds[iter_idx] 52 | if dset_idx in corrupt_idxs: 53 | labs[iter_idx] = self.flip_class 54 | # if np.random.rand() < 0.2: 55 | # They will also be corrupted to a deterministic label: 56 | # labs[i] = np.random.randint(low=0, high=10) 57 | return labs 58 | 59 | corrupt_fixed.is_parallel = True 60 | corrupt_fixed.with_indices = True 61 | return corrupt_fixed 62 | 63 | def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: 64 | # No updates to state or extra memory necessary! 65 | return previous_state, None 66 | 67 | def create_model(arch='resnet9'): 68 | if arch == 'resnet9': 69 | model = resnet9(num_classes=10) 70 | elif arch == 'swin_t': 71 | model = swin_t(window_size=4, 72 | num_classes=10, 73 | downscaling_factors=(2,2,2,1),) 74 | elif arch == 'mobilenetv2': 75 | model = MobileNetV2(num_classes=10) 76 | else: 77 | print(f"Model {arch} not supported.") 78 | raise NotImplementedError 79 | model = model.to(memory_format=ch.channels_last).cuda() 80 | return model 81 | 82 | def create_loaders(train_path=None, train_examples_remove=None, 83 | eval_idx=None, batch_size=512, num_workers=4, 84 | flip_class=None): 85 | loaders = {} 86 | for name in ['train', 'test']: 87 | if name == 'train' and flip_class is not None: 88 | label_pipeline: List[Operation] = [IntDecoder(), CorruptFixedLabels(), ToTensor(), ToDevice('cuda:0')] 89 | else: 90 | label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0')] 91 | if name == 'train': 92 | label_pipeline.extend([Squeeze()]) 93 | else: 94 | if eval_idx is not None: 95 | label_pipeline.extend([Squeeze(0)]) 96 | else: 97 | label_pipeline.extend([Squeeze()]) 98 | image_pipeline: List[Operation] = [SimpleRGBImageDecoder()] 99 | 100 | # Add image transforms and normalization 101 | if name == 'train': 102 | image_pipeline.extend([ 103 | RandomHorizontalFlip(), 104 | RandomTranslate(padding=2), 105 | Cutout(8, tuple(map(int, CIFAR_MEAN))), # Note Cutout is done before normalization. 106 | ]) 107 | image_pipeline.extend([ 108 | ToTensor(), 109 | ToDevice('cuda:0', non_blocking=True), 110 | ToTorchImage(), 111 | Convert(ch.float16), 112 | torchvision.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 113 | ]) 114 | 115 | # Create loaders 116 | if name == 'train': 117 | selected_idx = set(np.arange(50000)) 118 | if train_examples_remove is not None: 119 | train_examples_remove = set(train_examples_remove) 120 | selected_idx -= train_examples_remove 121 | selected_idx = list(selected_idx) 122 | else: 123 | if eval_idx is not None: 124 | selected_idx = [eval_idx] 125 | else: 126 | selected_idx = list(np.arange(10000)) 127 | order = OrderOption.RANDOM if name == 'train' else OrderOption.SEQUENTIAL 128 | if name == 'train' and train_path is None: 129 | path = os.path.join(os.getenv('DATA_DIR'), 'cifar_train.beton') 130 | elif name=='test': 131 | path = os.path.join(os.getenv('DATA_DIR'), 'cifar_test.beton') 132 | else: 133 | path = train_path 134 | loaders[name] = Loader(path, 135 | batch_size=batch_size, 136 | num_workers=num_workers, 137 | order=order, 138 | drop_last=(name == 'train'), 139 | indices=selected_idx, 140 | pipelines={'image': image_pipeline, 141 | 'label': label_pipeline}) 142 | return loaders 143 | 144 | def train_slow(model, loaders, batch_size=512, epochs=24, lr=0.1): 145 | opt = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) 146 | iters_per_epoch = 50000 // batch_size 147 | scheduler = lr_scheduler.MultiStepLR(opt, 148 | milestones=[0.5 * iters_per_epoch, 149 | 0.75 * iters_per_epoch], 150 | gamma=0.1) 151 | scaler = GradScaler() 152 | loss_fn = CrossEntropyLoss(label_smoothing=0.1) 153 | 154 | for ep in range(epochs): 155 | for ims, labs in loaders['train']: 156 | opt.zero_grad(set_to_none=True) 157 | with autocast(): 158 | out = model(ims) 159 | loss = loss_fn(out, labs) 160 | 161 | scaler.scale(loss).backward() 162 | scaler.step(opt) 163 | scaler.update() 164 | scheduler.step() 165 | 166 | 167 | def train(model, loaders, batch_size, epochs=24): 168 | opt = SGD(model.parameters(), lr=.5, momentum=0.9, weight_decay=5e-4) 169 | iters_per_epoch = 50000 // batch_size 170 | lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch), 171 | [0, 5 * iters_per_epoch, epochs * iters_per_epoch], 172 | [0, 1, 0]) 173 | scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__) 174 | scaler = GradScaler() 175 | loss_fn = CrossEntropyLoss(label_smoothing=0.1) 176 | 177 | for ep in range(epochs): 178 | for ims, labs in loaders['train']: 179 | opt.zero_grad(set_to_none=True) 180 | with autocast(): 181 | out = model(ims) 182 | loss = loss_fn(out, labs) 183 | 184 | scaler.scale(loss).backward() 185 | scaler.step(opt) 186 | scaler.update() 187 | scheduler.step() 188 | 189 | def test(model, loaders): 190 | model.eval() 191 | with ch.no_grad(): 192 | total_correct, total_num = 0., 0. 193 | for ims, labs in loaders['test']: 194 | with autocast(): 195 | out = (model(ims)) 196 | total_correct += out.argmax(1).eq(labs).sum().cpu().item() 197 | total_num += ims.shape[0] 198 | return total_correct, total_num 199 | 200 | def create_train_dataset(train_idxs, flip_class, eval_idx): 201 | # Create new datasets with flipped labels. 202 | train_data = torchvision.datasets.CIFAR10(os.getenv("DATA_DIR"), train=True, download=False) 203 | 204 | # Flip training data labels, based on the given indexes. 205 | for idx in train_idxs: 206 | train_data.targets[idx] = flip_class 207 | 208 | train_path = f'/scratch0/cifar_train_{eval_idx}.beton' 209 | writer = DatasetWriter(train_path, { 210 | 'image': RGBImageField(), 211 | 'label': IntField() 212 | }) 213 | writer.from_indexed_dataset(train_data) 214 | return train_path 215 | 216 | def counterfactual_test(train_idxs=None, flip_class=None, 217 | eval_idx=None, num_tests = 10, batch_size=512, 218 | epochs=24, num_workers=4, 219 | arch='resnet9'): 220 | # Remove train examples, and test if prediction flips on certain test samples. 221 | # Run the test multiple times to get a better estimate. 222 | 223 | # For training, load indexes to remove from the train set. 224 | # For evaluation, only use the given eval idx. 225 | # ipdb.set_trace() 226 | total_correct = 0 227 | if flip_class is None: 228 | loaders = create_loaders(train_examples_remove=train_idxs, 229 | eval_idx=eval_idx, batch_size=batch_size, 230 | num_workers=num_workers) 231 | else: 232 | train_path = create_train_dataset(train_idxs, flip_class, eval_idx) 233 | loaders = create_loaders(train_path=train_path, 234 | eval_idx=eval_idx, batch_size=batch_size, 235 | num_workers=num_workers) 236 | for test_num in range(num_tests): 237 | model = create_model(arch=arch) 238 | train(model, loaders, batch_size, epochs=epochs) 239 | correct, _ = test(model, loaders) 240 | total_correct += correct 241 | return total_correct/num_tests 242 | 243 | def binary_search(train_idxs, flip_class=None, 244 | eval_idx=None, search_budget = 8, num_tests = 5, 245 | arch='resnet9'): 246 | # Fixed budget binary search for the minimum number of examples to remove. 247 | # Note that this is not guaranteed to find the minimum. 248 | # It is possible that the minimum is not in the range [0, len(train_examples_remove)]. 249 | # If we can't flip with all selected training examples, return -1. 250 | # If we can flip it, minimize the number of examples to remove using binary search. 251 | # We could use a linear search, but that would be too slow. 252 | # The binary search idea is supported by monotonicity of train-test influence. 253 | # This monotonicity is also supported by datamodels being linear. 254 | 255 | num_search = search_budget 256 | # Get the number of examples to remove for each binary search step. 257 | low = 0 258 | high = len(train_idxs) 259 | mid = high 260 | 261 | avg_correct = counterfactual_test(train_idxs[:mid], flip_class, eval_idx, num_tests=num_tests, 262 | arch=arch) 263 | if avg_correct > 0.5: 264 | print(f"Sample Idx: {eval_idx}, Min samples: -1") 265 | return -1 # if can't flip with all selected training examples, return -1 266 | 267 | min_samples = mid 268 | while num_search > 0: 269 | num_search -= 1 270 | mid = (low + high) // 2 271 | avg_correct = counterfactual_test(train_idxs[:mid], flip_class, eval_idx, num_tests=num_tests, 272 | arch=arch) 273 | if avg_correct > 0.5: 274 | low = mid 275 | else: 276 | high = mid 277 | min_samples = min(mid, min_samples) # update min_samples if successful in flipping 278 | print(f"Sample Idx: {eval_idx}, Min samples: {min_samples}") 279 | return min_samples # return the minimum number of samples to remove 280 | 281 | --------------------------------------------------------------------------------