├── .gitignore ├── LICENSE ├── README.md ├── algorithms ├── distill.py ├── finetune.py ├── linear_eval.py ├── pretrain_dc.py ├── pretrain_frepo.py ├── pretrain_krr_st.py ├── scratch.py ├── wrapper.py └── zeroshot_kd.py ├── assets └── concept.png ├── data ├── aircraft.py ├── augmentation.py ├── cars.py ├── cub2011.py ├── dogs.py └── wrapper.py ├── environment.yml ├── model_pool.py ├── models ├── alexnet.py ├── convnet.py ├── mobilenet.py ├── resnet.py ├── vgg.py └── wrapper.py ├── test.py ├── test_kd.py ├── test_scratch.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | */__pycache__ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Dataset Distillation for Transfer Learning 2 | This is the official Pytorch implementation for the paper ["**Self-Supervised Dataset Distillation for Transfer Learning**", in ICLR 2024.](https://openreview.net/forum?id=h57gkDO2Yg) 3 | 4 | ## Summary 5 | 6 | Dataset distillation aims to optimize a small set so that a model trained on the set achieves performance similar to that of a model trained on the full dataset. While many supervised methods have achieved remarkable success in distilling a large dataset into a small set of representative samples, however, they are not designed to produce a distilled dataset that can be effectively used to facilitate self-supervised pre-training. To this end, we propose a novel problem of distilling an unlabeled dataset into a set of small synthetic samples for efficient self-supervised learning (SSL). We first prove that a gradient of synthetic samples with respect to a SSL objective in naive bilevel optimization is biased due to the randomness originating from data augmentations or masking for inner optimization. To address this issue, we propose to minimize the mean squared error (MSE) between a model's representations of the synthetic examples and their corresponding learnable target feature representations for the inner objective, which does not introduce any randomness. Our primary motivation is that the model obtained by the proposed inner optimization can mimic the self-supervised target model. To achieve this, we also introduce the MSE between representations of the inner model and the self-supervised target model on the original full dataset for outer optimization. We empirically validate the effectiveness of our method on transfer learning. 7 | 8 |   9 | 10 | __Contribution of this work__ 11 | - We propose a new problem of self-supervised dataset distillation for transfer learning, where we distill an unlabeled dataset into a small set, 12 | pre-train a model on it, and fine-tune it on target tasks. 13 | - We have observed training instability when utilizing existing SSL objectives in bilevel optimization for self-supervised dataset distillation. Furthermore, we prove that a gradient of the SSL objectives with data augmentations or masking inputs is a biased estimator of the true gradient. 14 | - To address the instability, we propose KRR-ST using MSE without any randomness at an inner loop. For the inner loop, we minimize MSE between a model representation of synthetic samples and target representations. For an outer loop, we minimize MSE between the original data representation of the model from inner loop and that of the model pre-trained on the original dataset. 15 | - We extensively validate our proposed method on numerous target datasets and architectures, and show that ours outperforms supervised dataset distillation methods. 16 | 17 | ## Dependencies 18 | This code is written in Python. Dependencies include 19 | * python >= 3.10 20 | * pytorch = 2.1.2 21 | * torchvision = 0.16.2 22 | * tqdm 23 | * korina = 0.7.1 24 | * transformers = 4.36.2 25 | 26 | ```bash 27 | conda env create -f environment.yml 28 | conda activate dd 29 | ``` 30 | 31 | ## Data and Model Checkpoints 32 | * Download **Full Data**(~40GB) from [here](https://drive.google.com/file/d/1P0zwURUbVsqoVgIRcIZXGAtGrkRvGvH0/view?usp=sharing). 33 | * Download **Distilled Data**(~702MB) from [here](https://drive.google.com/file/d/1vDghSAUnmdWdGJgx9iK8dMOwoId0nuKF/view?usp=sharing). 34 | * Download **Target (Teacher) Model Checkpoints**(~158MB) from [here](https://drive.google.com/file/d/1IuN4rhlB5UuJX_jrbVIWEBXo10QWHPBE/view?usp=sharing). 35 | 36 | directory should be look like this: 37 | ```shell 38 | ┌── datasets/ 39 | ┌── aircraft/ 40 | ┌── X_te_32.pth 41 | ├── ... 42 | └── Y_tr_224.pth 43 | ├── cars/ 44 | ... 45 | └── tinyimagenet/ 46 | 47 | ├── synthetic_data/ 48 | ┌── cifar100/ 49 | ┌── dm/ 50 | ┌── x_syn.pt 51 | └── y_syn.pt 52 | ├── ... 53 | └── random/ 54 | ├── ... 55 | └── tinyimagenet/ 56 | 57 | └── teacher_ckpt/ 58 | ┌── barlow_twins_resnet18_cifar100.pt 59 | ├── ... 60 | └── teacher_cifar10.pt 61 | ``` 62 | 63 | ## Dataset Distillation 64 | To distill **CIFAR100**, run the following code: 65 | ```bash 66 | python train.py --exp_name EXP_NAME (e.g. "cifar100_exp") --data_name cifar100 --outer_lr 1e-4 --gpu_id N 67 | ``` 68 | 69 | To distill **TinyImageNet**, run the following code: 70 | ```bash 71 | python train.py --exp_name EXP_NAME (e.g. "tinyimagenet_exp") --data_name tinyimagenet --outer_lr 1e-5 --gpu_id N 72 | ``` 73 | 74 | To distill **ImageNet 64x64**, run the following code: 75 | ```bash 76 | python train.py --exp_name EXP_NAME (e.g. "imagenet_exp") --data_name imagenet --outer_lr 1e-5 --gpu_id N 77 | ``` 78 | 79 | ## Transfer Learning 80 | To reproduce **transfer learning with CIFAR100 (Table 1)**, run the following code: 81 | ```bash 82 | python test_scratch.py --source_data_name cifar100 --target_data_name full --gpu_id N 83 | python test.py --source_data_name cifar100 --target_data_name full --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "kip", "frepo", "krr_st"]) --test_model base --gpu_id N 84 | ``` 85 | 86 | To reproduce **transfer learning with TinyImageNet (Table 2)**, run the following code: 87 | ```bash 88 | python test_scratch.py --source_data_name tinyimagenet --target_data_name full --gpu_id N 89 | python test.py --source_data_name tinyimagenet --target_data_name full --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model base --gpu_id N 90 | ``` 91 | 92 | To reproduce **transfer learning with ImageNet 64x64 (Table 3)**, run the following code: 93 | ```bash 94 | python test_scratch.py --source_data_name imagenet --target_data_name full --gpu_id N 95 | python test.py --source_data_name imagenet --target_data_name full --method METHOD (["random", "frepo", "krr_st"]) --test_model base --gpu_id N 96 | ``` 97 | 98 | To reproduce **architecture generalization with TinyImageNet (Figure 3)**, run the following code: 99 | ```bash 100 | python test_scratch.py --source_data_name tinyimagenet --target_data_name aircraft_cars_cub2011_dogs_flowers --test_model ARCHITECTURE (["vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N 101 | python test.py --source_data_name tinyimagenet --target_data_name aircraft_cars_cub2011_dogs_flowers --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model ARCHITECTURE (["vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N 102 | ``` 103 | 104 | To reproduce **target data-free knowledge distillation with TinyImageNet (Table 4)**, run the following code: 105 | ```bash 106 | python test_kd.py --source_data_name tinyimagenet --method METHOD (["gaussian", "random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model ARCHITECTURE (["base", "vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N 107 | ``` 108 | 109 | ## Reference 110 | To cite our paper, please use this BibTex 111 | ```bibtex 112 | @inproceedings{lee2024selfsupdd, 113 | title={Self-Supervised Dataset Distillation for Transfer Learning}, 114 | author={Dong Bok Lee and and Seanie Lee and Joonho Ko and Kenji Kawaguch and Juho Lee and Sung Ju Hwang}, 115 | booktitle={Proceedings of the 12th International Conference on Learning Representations}, 116 | year={2024} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /algorithms/distill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def run( 8 | args, device, target_model, 9 | model_pool, outer_opt, 10 | iter_tr, aug, 11 | x_syn, y_syn 12 | ): 13 | outer_opt.zero_grad() 14 | 15 | # prepare model 16 | idx = np.random.randint(args.num_models) 17 | model = model_pool.models[idx] 18 | model.eval(); model.zero_grad() 19 | 20 | # data 21 | x_real, _ = next(iter_tr) 22 | x_real = x_real.to(device) 23 | target_model.eval() 24 | with torch.no_grad(): 25 | x_real = aug(x_real) 26 | y_real = target_model(x_real) 27 | torch.cuda.empty_cache() 28 | 29 | # feature 30 | f_syn = model.embed(x_syn) 31 | f_syn = torch.cat([ f_syn, torch.ones(f_syn.shape[0], 1).to(device) ], dim=1) 32 | with torch.no_grad(): 33 | f_real = model.embed(x_real) 34 | f_real = torch.cat([ f_real, torch.ones(f_real.shape[0], 1).to(device) ], dim=1) 35 | 36 | # kernel 37 | K_real_syn = f_real @ f_syn.permute(1, 0).contiguous() 38 | K_syn_syn = f_syn @ f_syn.permute(1, 0).contiguous() 39 | 40 | # lambda and eye 41 | lambda_ = 1e-6 * torch.trace(K_syn_syn.detach()) 42 | eye = torch.eye(K_syn_syn.shape[0]).to(device) 43 | 44 | # mse loss 45 | outer_loss = F.mse_loss( 46 | K_real_syn @ torch.linalg.solve(K_syn_syn + lambda_*eye, y_syn), 47 | y_real 48 | ) 49 | outer_grad = torch.autograd.grad(outer_loss, [x_syn, y_syn]) 50 | 51 | # meta update 52 | if x_syn.grad is None: 53 | x_syn.grad = outer_grad[0] 54 | else: 55 | x_syn.grad.data.copy_(outer_grad[0].data) 56 | if y_syn.grad is None: 57 | y_syn.grad = outer_grad[1] 58 | else: 59 | y_syn.grad.data.copy_(outer_grad[1].data) 60 | if args.outer_grad_norm > 0.: 61 | nn.utils.clip_grad_norm_(x_syn, args.outer_grad_norm) 62 | nn.utils.clip_grad_norm_(y_syn, args.outer_grad_norm) 63 | outer_opt.step() 64 | 65 | model_pool.update(idx, x_syn.detach(), y_syn.detach()) 66 | 67 | return outer_loss 68 | -------------------------------------------------------------------------------- /algorithms/finetune.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from models.wrapper import get_model 7 | from utils import InfIterator 8 | 9 | def run( 10 | args, device, 11 | model_name, init_model, 12 | dl_tr, dl_te, 13 | aug_tr, aug_te 14 | ): 15 | iter_tr = InfIterator(dl_tr) 16 | 17 | # model 18 | model = get_model(model_name, args.img_shape, args.num_classes).to(device) 19 | if hasattr(init_model, "fc"): 20 | del init_model.fc 21 | model.load_state_dict(init_model.state_dict(), strict=False) 22 | 23 | # opt 24 | if args.test_opt == "sgd": 25 | opt = torch.optim.SGD(model.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd) 26 | elif args.test_opt == "adam": 27 | opt = torch.optim.AdamW(model.parameters(), lr=args.test_lr, weight_decay=args.test_wd) 28 | else: 29 | raise NotImplementedError 30 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration) 31 | 32 | model.train() 33 | for _ in trange(args.test_iteration): 34 | x_tr, y_tr = next(iter_tr) 35 | x_tr, y_tr = x_tr.to(device), y_tr.to(device) 36 | with torch.no_grad(): 37 | x_tr = aug_tr(x_tr) 38 | loss = F.cross_entropy(model(x_tr), y_tr) 39 | opt.zero_grad() 40 | loss.backward() 41 | opt.step() 42 | sch.step() 43 | 44 | model.eval() 45 | with torch.no_grad(): 46 | loss, acc, denominator = 0., 0., 0. 47 | for x_te, y_te in dl_te: 48 | x_te, y_te = x_te.to(device), y_te.to(device) 49 | x_te = aug_te(x_te) 50 | l_te = model(x_te) 51 | loss += F.cross_entropy(l_te, y_te, reduction="sum") 52 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum() 53 | denominator += x_te.shape[0] 54 | loss /= denominator; acc /= (denominator/100.) 55 | 56 | del model 57 | 58 | return loss.item(), acc.item() -------------------------------------------------------------------------------- /algorithms/linear_eval.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | from models.wrapper import get_model 9 | from utils import InfIterator 10 | 11 | def run( 12 | args, device, 13 | model_name, init_model, 14 | dl_tr, dl_te, 15 | aug_tr, aug_te 16 | ): 17 | 18 | # model 19 | model = get_model(model_name, args.img_shape, args.num_classes).to(device) 20 | if hasattr(init_model, "fc"): 21 | del init_model.fc 22 | model.load_state_dict(init_model.state_dict(), strict=False) 23 | model.fc = nn.Identity() 24 | 25 | model.eval() 26 | with torch.no_grad(): 27 | # tr feature 28 | F_tr, Y_tr = [], [] 29 | for x_tr, y_tr in dl_tr: 30 | x_tr = x_tr.to(device) 31 | x_tr = aug_te(x_tr) 32 | f_tr = model(x_tr).cpu() 33 | F_tr.append(f_tr); Y_tr.append(y_tr) 34 | F_tr, Y_tr = torch.cat(F_tr, dim=0), torch.cat(Y_tr, dim=0) 35 | num_features = F_tr.shape[-1] 36 | dl_feature_tr = DataLoader(TensorDataset(F_tr, Y_tr), batch_size=args.test_batch_size, shuffle=True, num_workers=0, pin_memory=True) 37 | iter_feature_tr = InfIterator(dl_feature_tr) 38 | 39 | # te feature 40 | F_te, Y_te = [], [] 41 | for x_te, y_te in dl_te: 42 | x_te = x_te.to(device) 43 | x_te = aug_te(x_te) 44 | f_te = model(x_te).cpu() 45 | F_te.append(f_te); Y_te.append(y_te) 46 | F_te, Y_te = torch.cat(F_te, dim=0), torch.cat(Y_te, dim=0) 47 | dl_feature_te = DataLoader(TensorDataset(F_te, Y_te), batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True) 48 | 49 | # opt 50 | linear = nn.Linear(num_features, args.num_classes).to(device) 51 | if args.test_opt == "sgd": 52 | opt = torch.optim.SGD(linear.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd) 53 | elif args.test_opt == "adam": 54 | opt = torch.optim.AdamW(linear.parameters(), lr=args.test_lr, weight_decay=args.test_wd) 55 | else: 56 | raise NotImplementedError 57 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration) 58 | 59 | linear.train() 60 | for _ in trange(args.test_iteration): 61 | x_tr, y_tr = next(iter_feature_tr) 62 | x_tr, y_tr = x_tr.to(device), y_tr.to(device) 63 | loss = F.cross_entropy(linear(x_tr), y_tr) 64 | opt.zero_grad() 65 | loss.backward() 66 | opt.step() 67 | sch.step() 68 | 69 | linear.eval() 70 | with torch.no_grad(): 71 | loss, acc, denominator = 0., 0., 0. 72 | for x_te, y_te in dl_feature_te: 73 | x_te, y_te = x_te.to(device), y_te.to(device) 74 | l_te = linear(x_te) 75 | loss += F.cross_entropy(l_te, y_te, reduction="sum") 76 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum() 77 | denominator += x_te.shape[0] 78 | loss /= denominator; acc /= (denominator/100.) 79 | 80 | del model, linear 81 | 82 | return loss.item(), acc.item() -------------------------------------------------------------------------------- /algorithms/pretrain_dc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from data.augmentation import DiffAugment 4 | from models.wrapper import get_model 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from tqdm import trange 7 | 8 | 9 | def run( 10 | args, device, 11 | model_name, 12 | x_syn, y_syn 13 | ): 14 | x_syn, y_syn = x_syn.detach(), y_syn.detach() 15 | x_syn, y_syn = x_syn.to(device), y_syn.to(device) 16 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0) 17 | 18 | # model and opt 19 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device) 20 | if args.pre_opt == "sgd": 21 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd) 22 | elif args.pre_opt == "adam": 23 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd) 24 | else: 25 | raise NotImplementedError 26 | sch = torch.optim.lr_scheduler.MultiStepLR( 27 | opt, milestones=[args.pre_epoch // 2], gamma=0.1) 28 | 29 | # pretrain 30 | model.train() 31 | for _ in trange(1, args.pre_epoch+1): 32 | for x_syn, y_syn in dl_syn: 33 | # loss 34 | with torch.no_grad(): 35 | x_syn = DiffAugment(x_syn, args.dsa_strategy, param=args.dsa_param) 36 | loss = F.cross_entropy(model(x_syn), y_syn) 37 | # update 38 | opt.zero_grad() 39 | loss.backward() 40 | opt.step() 41 | 42 | sch.step() 43 | 44 | return model 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /algorithms/pretrain_frepo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.augmentation import DiffAugment 3 | from models.wrapper import get_model 4 | from torch.utils.data import DataLoader, TensorDataset 5 | from tqdm import trange 6 | from transformers import get_cosine_schedule_with_warmup 7 | from utils import InfIterator 8 | 9 | 10 | def run( 11 | args, device, 12 | model_name, 13 | x_syn, y_syn 14 | ): 15 | x_syn, y_syn = x_syn.detach(), y_syn.detach() 16 | x_syn, y_syn = x_syn.to(device), y_syn.to(device) 17 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0) 18 | iter_syn = InfIterator(dl_syn) 19 | 20 | # model and opt 21 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device) 22 | if args.pre_opt == "sgd": 23 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd) 24 | elif args.pre_opt == "adam": 25 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd) 26 | else: 27 | raise NotImplementedError 28 | sch = get_cosine_schedule_with_warmup(opt, 500, args.pre_iteration) 29 | 30 | # pretrain 31 | model.train() 32 | for _ in trange(1, args.pre_iteration+1): 33 | x_syn, y_syn = next(iter_syn) 34 | # loss 35 | with torch.no_grad(): 36 | x_syn = DiffAugment(x_syn, args.dsa_strategy, param=args.dsa_param) 37 | loss = torch.mean(torch.sum((model(x_syn) - y_syn) ** 2 * 0.5, dim=-1)) 38 | # update 39 | opt.zero_grad() 40 | loss.backward() 41 | opt.step() 42 | sch.step() 43 | 44 | return model 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /algorithms/pretrain_krr_st.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.wrapper import get_model 4 | from torch.utils.data import DataLoader, TensorDataset 5 | from tqdm import trange 6 | 7 | 8 | def run( 9 | args, device, 10 | model_name, 11 | x_syn, y_syn 12 | ): 13 | x_syn, y_syn = x_syn.detach(), y_syn.detach() 14 | x_syn, y_syn = x_syn.to(device), y_syn.to(device) 15 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0) 16 | 17 | # model and opt 18 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device) 19 | if args.pre_opt == "sgd": 20 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd) 21 | elif args.pre_opt == "adam": 22 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd) 23 | else: 24 | raise NotImplementedError 25 | 26 | # pretrain 27 | model.train() 28 | for _ in trange(1, args.pre_epoch+1): 29 | for x_syn, y_syn in dl_syn: 30 | # loss 31 | loss = F.mse_loss(model(x_syn), y_syn) 32 | # update 33 | opt.zero_grad() 34 | loss.backward() 35 | opt.step() 36 | 37 | return model 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /algorithms/scratch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.wrapper import get_model 4 | from tqdm import trange 5 | from utils import InfIterator 6 | 7 | 8 | def run( 9 | args, device, 10 | model_name, 11 | dl_tr, dl_te, 12 | aug_tr, aug_te 13 | ): 14 | iter_tr = InfIterator(dl_tr) 15 | 16 | # model 17 | model = get_model(model_name, args.img_shape, args.num_classes).to(device) 18 | 19 | # opt 20 | if args.test_opt == "sgd": 21 | opt = torch.optim.SGD(model.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd) 22 | elif args.test_opt == "adam": 23 | opt = torch.optim.AdamW(model.parameters(), lr=args.test_lr, weight_decay=args.test_wd) 24 | else: 25 | raise NotImplementedError 26 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration) 27 | 28 | model.train() 29 | for _ in trange(args.test_iteration): 30 | x_tr, y_tr = next(iter_tr) 31 | x_tr, y_tr = x_tr.to(device), y_tr.to(device) 32 | with torch.no_grad(): 33 | x_tr = aug_tr(x_tr) 34 | loss = F.cross_entropy(model(x_tr), y_tr) 35 | opt.zero_grad() 36 | loss.backward() 37 | opt.step() 38 | sch.step() 39 | 40 | model.eval() 41 | with torch.no_grad(): 42 | loss, acc, denominator = 0., 0., 0. 43 | for x_te, y_te in dl_te: 44 | x_te, y_te = x_te.to(device), y_te.to(device) 45 | x_te = aug_te(x_te) 46 | l_te = model(x_te) 47 | loss += F.cross_entropy(l_te, y_te, reduction="sum") 48 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum() 49 | denominator += x_te.shape[0] 50 | loss /= denominator; acc /= (denominator/100.) 51 | 52 | del model 53 | 54 | return loss.item(), acc.item() 55 | 56 | -------------------------------------------------------------------------------- /algorithms/wrapper.py: -------------------------------------------------------------------------------- 1 | from algorithms import distill 2 | 3 | from algorithms import pretrain_dc 4 | from algorithms import pretrain_frepo 5 | from algorithms import pretrain_krr_st 6 | 7 | from algorithms import scratch 8 | from algorithms import linear_eval 9 | from algorithms import finetune 10 | from algorithms import zeroshot_kd 11 | 12 | def get_algorithm(name): 13 | return globals()[name] -------------------------------------------------------------------------------- /algorithms/zeroshot_kd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from data.augmentation import DiffAugment 4 | from models.wrapper import get_model 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from tqdm import trange 7 | 8 | 9 | def distillation_loss(student, teacher, T=1.0, reduction="mean"): 10 | student = F.log_softmax(student/T, dim=-1) 11 | teacher = F.softmax(teacher/T, dim=-1) 12 | loss = -(teacher * student).sum(dim=-1) 13 | 14 | if reduction == "mean": 15 | return loss.mean(dim=0) 16 | elif reduction == "sum": 17 | return loss.sum(dim=0) 18 | else: 19 | raise NotImplementedError 20 | 21 | def run( 22 | args, device, 23 | model_name, init_model, teacher, 24 | x_syn, dl_te, aug_te 25 | ): 26 | 27 | x_syn = x_syn.detach() 28 | x_syn = x_syn.to(device) 29 | dl_syn = DataLoader(TensorDataset(x_syn), batch_size=args.test_batch_size, shuffle=True, num_workers=0) 30 | 31 | # model 32 | student = get_model(model_name, args.img_shape, args.num_classes).to(device) 33 | if hasattr(init_model, "fc"): 34 | del init_model.fc 35 | student.load_state_dict(init_model.state_dict(), strict=False) 36 | 37 | # opt 38 | if args.test_opt == "sgd": 39 | opt = torch.optim.SGD(student.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd) 40 | elif args.test_opt == "adam": 41 | opt = torch.optim.AdamW(student.parameters(), lr=args.test_lr, weight_decay=args.test_wd) 42 | else: 43 | raise NotImplementedError 44 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_epoch) 45 | 46 | student.train(); teacher.train() 47 | for _ in trange(args.test_epoch): 48 | for x, in dl_syn: 49 | with torch.no_grad(): 50 | if args.method == "gaussian": 51 | x = torch.randn_like(x) 52 | else: 53 | x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param) 54 | y = teacher(x) 55 | loss = distillation_loss(student(x), y) 56 | opt.zero_grad() 57 | loss.backward() 58 | opt.step() 59 | sch.step() 60 | 61 | student.train() 62 | with torch.no_grad(): 63 | loss, acc, denominator = 0., 0., 0. 64 | for x_te, y_te in dl_te: 65 | x_te, y_te = x_te.to(device), y_te.to(device) 66 | x_te = aug_te(x_te) 67 | l_te = student(x_te) 68 | loss += F.cross_entropy(l_te, y_te, reduction="sum") 69 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum() 70 | denominator += x_te.shape[0] 71 | loss /= denominator; acc /= (denominator/100.) 72 | 73 | del student 74 | 75 | return loss.item(), acc.item() 76 | -------------------------------------------------------------------------------- /assets/concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/db-Lee/selfsup_dd/4beb4b19c23872394d8f6418938b0e2bb3ce0466/assets/concept.png -------------------------------------------------------------------------------- /data/aircraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.folder import default_loader 5 | from torchvision.datasets.utils import download_url 6 | from torchvision.datasets.utils import extract_archive 7 | 8 | 9 | class Aircraft(VisionDataset): 10 | """`FGVC-Aircraft `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | class_type (string, optional): choose from ('variant', 'family', 'manufacturer'). 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | download (bool, optional): If true, downloads the dataset from the internet and 22 | puts it in root directory. If dataset is already downloaded, it is not 23 | downloaded again. 24 | """ 25 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 26 | class_types = ('variant', 'family', 'manufacturer') 27 | splits = ('train', 'val', 'trainval', 'test') 28 | img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images') 29 | 30 | def __init__(self, root, train=True, class_type='variant', transform=None, 31 | target_transform=None, download=False): 32 | super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform) 33 | split = 'trainval' if train else 'test' 34 | if split not in self.splits: 35 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 36 | split, ', '.join(self.splits), 37 | )) 38 | if class_type not in self.class_types: 39 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 40 | class_type, ', '.join(self.class_types), 41 | )) 42 | 43 | self.class_type = class_type 44 | self.split = split 45 | self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', 46 | 'images_%s_%s.txt' % (self.class_type, self.split)) 47 | 48 | if download: 49 | self.download() 50 | 51 | (image_ids, targets, classes, class_to_idx) = self.find_classes() 52 | samples = self.make_dataset(image_ids, targets) 53 | 54 | self.loader = default_loader 55 | 56 | self.samples = samples 57 | self.classes = classes 58 | self.class_to_idx = class_to_idx 59 | 60 | def __getitem__(self, index): 61 | path, target = self.samples[index] 62 | sample = self.loader(path) 63 | if self.transform is not None: 64 | sample = self.transform(sample) 65 | if self.target_transform is not None: 66 | target = self.target_transform(target) 67 | return sample, target 68 | 69 | def __len__(self): 70 | return len(self.samples) 71 | 72 | def _check_exists(self): 73 | return os.path.exists(os.path.join(self.root, self.img_folder)) and \ 74 | os.path.exists(self.classes_file) 75 | 76 | def download(self): 77 | if self._check_exists(): 78 | return 79 | 80 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 81 | print('Downloading %s...' % self.url) 82 | tar_name = self.url.rpartition('/')[-1] 83 | download_url(self.url, root=self.root, filename=tar_name) 84 | tar_path = os.path.join(self.root, tar_name) 85 | print('Extracting %s...' % tar_path) 86 | extract_archive(tar_path) 87 | print('Done!') 88 | 89 | def find_classes(self): 90 | # read classes file, separating out image IDs and class names 91 | image_ids = [] 92 | targets = [] 93 | with open(self.classes_file, 'r') as f: 94 | for line in f: 95 | split_line = line.split(' ') 96 | image_ids.append(split_line[0]) 97 | targets.append(' '.join(split_line[1:])) 98 | 99 | # index class names 100 | classes = np.unique(targets) 101 | class_to_idx = {classes[i]: i for i in range(len(classes))} 102 | targets = [class_to_idx[c] for c in targets] 103 | 104 | return image_ids, targets, classes, class_to_idx 105 | 106 | def make_dataset(self, image_ids, targets): 107 | assert (len(image_ids) == len(targets)) 108 | images = [] 109 | for i in range(len(image_ids)): 110 | item = (os.path.join(self.root, self.img_folder, 111 | '%s.jpg' % image_ids[i]), targets[i]) 112 | images.append(item) 113 | return images 114 | 115 | 116 | if __name__ == '__main__': 117 | train_dataset = Aircraft('./aircraft', train=True, download=False) 118 | test_dataset = Aircraft('./aircraft', train=False, download=False) 119 | -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import kornia.augmentation as K 8 | 9 | NUM_CLASSES = { 10 | 'svhn': 10, 11 | 'cifar10': 10, 12 | 'cifar100': 100, 13 | 'aircraft': 100, 14 | 'cars': 196, 15 | 'cub2011': 200, 16 | 'dogs': 120, 17 | 'flowers': 102, 18 | 'tinyimagenet': 200, 19 | 'imagenet': 1000, 20 | 'imagenette': 10 21 | } 22 | 23 | MEAN = { 24 | 32: 25 | {'svhn': (0.4377, 0.4438, 0.4728), 26 | 'cifar10': (0.4914, 0.4822, 0.4465), 27 | 'cifar100': (0.5071, 0.4866, 0.4409), 28 | 'aircraft': (0.4804, 0.5116, 0.5349), 29 | 'cars': (0.4706, 0.4600, 0.4548), 30 | 'cub2011': (0.4857, 0.4995, 0.4324), 31 | 'dogs': (0.4765, 0.4516, 0.3911), 32 | 'flowers': (0.4344, 0.3030, 0.2955), 33 | 'tinyimagenet': (0.4802, 0.4481, 0.3975), 34 | 'imagenet': (0.4810, 0.4574, 0.4078)}, 35 | 64: 36 | {'svhn': (0.4377, 0.4438, 0.4728), 37 | 'cifar10': (0.4914, 0.4821, 0.4465), 38 | 'cifar100': (0.5070, 0.4865, 0.4409), 39 | 'aircraft': (0.4797, 0.5108, 0.5341), 40 | 'cars': (0.4707, 0.4601, 0.4549), 41 | 'cub2011': (0.4856, 0.4995, 0.4324), 42 | 'dogs': (0.4765, 0.4517, 0.3911), 43 | 'flowers': (0.4344, 0.3830, 0.2955), 44 | 'tinyimagenet': (0.4802, 0.4481, 0.3975), 45 | 'imagenet': (0.4810, 0.4574, 0.4078)}, 46 | 224: 47 | {'aircraft': (0.4797, 0.5109, 0.5342), 48 | 'cars': (0.4707, 0.4601, 0.4549), 49 | 'cub2011': (0.4856, 0.4994, 0.4324), 50 | 'dogs': (0.4765, 0.4517, 0.3912), 51 | 'flowers': (0.4344, 0.3830, 0.2955), 52 | 'imagenette': (0.4655, 0.4546, 0.4250)} 53 | } 54 | STD = { 55 | 32: 56 | {'svhn': (0.1980, 0.2010, 0.1970), 57 | 'cifar10': (0.2470, 0.2435, 0.2616), 58 | 'cifar100': (0.2673, 0.2564, 0.2762), 59 | 'aircraft': (0.2021, 0.1953, 0.2297), 60 | 'cars': (0.2746, 0.2740, 0.2831), 61 | 'cub2011': (0.2145, 0.2098, 0.2496), 62 | 'dogs': (0.2490, 0.2435, 0.2479), 63 | 'flowers': (0.2811, 0.2318, 0.2607), 64 | 'tinyimagenet': (0.2770, 0.2691, 0.2821), 65 | 'imagenet': (0.2633, 0.2560, 0.2708)}, 66 | 64: 67 | {'svhn': (0.1981, 0.2011, 0.1971), 68 | 'cifar10': (0.2469, 0.2433, 0.2614), 69 | 'cifar100': (0.2671, 0.2562, 0.2760), 70 | 'aircraft': (0.2117, 0.2049, 0.2380), 71 | 'cars': (0.2836, 0.2826, 0.2914), 72 | 'cub2011': (0.2218, 0.2170, 0.2564), 73 | 'dogs': (0.2551, 0.2495, 0.2539), 74 | 'flowers': (0.2878, 0.2390, 0.2674), 75 | 'tinyimagenet': (0.2770, 0.2691, 0.2821), 76 | 'imagenet': (0.2633, 0.2560, 0.2708)}, 77 | 224: 78 | {'aircraft': (0.2204, 0.2135, 0.2451), 79 | 'cars': (0.2927, 0.2917, 0.3001), 80 | 'cub2011': (0.2295, 0.2250, 0.2635), 81 | 'dogs': (0.2617, 0.2564, 0.2607), 82 | 'flowers': (0.2928, 0.2449, 0.2726), 83 | 'imagenette': (0.2804, 0.2754, 0.2965)} 84 | } 85 | 86 | def get_aug(data_name, size, aug=True): 87 | if aug: 88 | if data_name == "svhn": 89 | transform_tr = nn.Sequential( 90 | K.RandomCrop(size=(size,size), padding=4), 91 | K.Normalize(MEAN[size][data_name], STD[size][data_name]) 92 | ) 93 | else: 94 | transform_tr = nn.Sequential( 95 | K.RandomCrop(size=(size,size), padding=4), 96 | K.RandomHorizontalFlip(p=0.5), 97 | K.Normalize(MEAN[size][data_name], STD[size][data_name]), 98 | ) 99 | else: 100 | transform_tr = K.Normalize(MEAN[size][data_name], STD[size][data_name]) 101 | 102 | transform_te = K.Normalize(MEAN[size][data_name], STD[size][data_name]) 103 | 104 | return transform_tr, transform_te 105 | 106 | """DC Augmentation""" 107 | class ParamDiffAug(): 108 | def __init__(self): 109 | self.aug_mode = 'S' #'multiple or single' 110 | self.prob_flip = 0.5 111 | self.ratio_scale = 1.2 112 | self.ratio_rotate = 15.0 113 | self.ratio_crop_pad = 0.125 114 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5 115 | self.brightness = 1.0 116 | self.saturation = 2.0 117 | self.contrast = 0.5 118 | 119 | 120 | def set_seed_DiffAug(param): 121 | if param.latestseed == -1: 122 | return 123 | else: 124 | torch.random.manual_seed(param.latestseed) 125 | param.latestseed += 1 126 | 127 | 128 | def DiffAugment(x, strategy='', seed = -1, param = None): 129 | if strategy == 'None' or strategy == 'none' or strategy == '': 130 | return x 131 | 132 | if seed == -1: 133 | param.Siamese = False 134 | else: 135 | param.Siamese = True 136 | 137 | param.latestseed = seed 138 | 139 | if strategy: 140 | if param.aug_mode == 'M': # original 141 | for p in strategy.split('_'): 142 | for f in AUGMENT_FNS[p]: 143 | x = f(x, param) 144 | elif param.aug_mode == 'S': 145 | pbties = strategy.split('_') 146 | set_seed_DiffAug(param) 147 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()] 148 | for f in AUGMENT_FNS[p]: 149 | x = f(x, param) 150 | else: 151 | exit('unknown augmentation mode: %s'%param.aug_mode) 152 | x = x.contiguous() 153 | return x 154 | 155 | 156 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans. 157 | def rand_scale(x, param): 158 | # x>1, max scale 159 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times 160 | ratio = param.ratio_scale 161 | set_seed_DiffAug(param) 162 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 163 | set_seed_DiffAug(param) 164 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio 165 | theta = [[[sx[i], 0, 0], 166 | [0, sy[i], 0],] for i in range(x.shape[0])] 167 | theta = torch.tensor(theta, dtype=torch.float) 168 | if param.Siamese: # Siamese augmentation: 169 | theta[:] = theta[0].clone().detach() 170 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 171 | x = F.grid_sample(x, grid, align_corners=True) 172 | return x 173 | 174 | 175 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree 176 | ratio = param.ratio_rotate 177 | set_seed_DiffAug(param) 178 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi) 179 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], 180 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])] 181 | theta = torch.tensor(theta, dtype=torch.float) 182 | if param.Siamese: # Siamese augmentation: 183 | theta[:] = theta[0].clone().detach() 184 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device) 185 | x = F.grid_sample(x, grid, align_corners=True) 186 | return x 187 | 188 | 189 | def rand_flip(x, param): 190 | prob = param.prob_flip 191 | set_seed_DiffAug(param) 192 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device) 193 | if param.Siamese: # Siamese augmentation: 194 | randf[:] = randf[0].clone().detach() 195 | return torch.where(randf < prob, x.flip(3), x) 196 | 197 | 198 | def rand_brightness(x, param): 199 | ratio = param.brightness 200 | set_seed_DiffAug(param) 201 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 202 | if param.Siamese: # Siamese augmentation: 203 | randb[:] = randb[0].clone().detach() 204 | x = x + (randb - 0.5)*ratio 205 | return x 206 | 207 | 208 | def rand_saturation(x, param): 209 | ratio = param.saturation 210 | x_mean = x.mean(dim=1, keepdim=True) 211 | set_seed_DiffAug(param) 212 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 213 | if param.Siamese: # Siamese augmentation: 214 | rands[:] = rands[0].clone().detach() 215 | x = (x - x_mean) * (rands * ratio) + x_mean 216 | return x 217 | 218 | 219 | def rand_contrast(x, param): 220 | ratio = param.contrast 221 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 222 | set_seed_DiffAug(param) 223 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) 224 | if param.Siamese: # Siamese augmentation: 225 | randc[:] = randc[0].clone().detach() 226 | x = (x - x_mean) * (randc + ratio) + x_mean 227 | return x 228 | 229 | 230 | def rand_crop(x, param): 231 | # The image is padded on its surrounding and then cropped. 232 | ratio = param.ratio_crop_pad 233 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 234 | set_seed_DiffAug(param) 235 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 236 | set_seed_DiffAug(param) 237 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 238 | if param.Siamese: # Siamese augmentation: 239 | translation_x[:] = translation_x[0].clone().detach() 240 | translation_y[:] = translation_y[0].clone().detach() 241 | grid_batch, grid_x, grid_y = torch.meshgrid( 242 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 243 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 244 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 245 | indexing="ij" 246 | ) 247 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 248 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 249 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 250 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 251 | return x 252 | 253 | 254 | def rand_cutout(x, param): 255 | ratio = param.ratio_cutout 256 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 257 | set_seed_DiffAug(param) 258 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 259 | set_seed_DiffAug(param) 260 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 261 | if param.Siamese: # Siamese augmentation: 262 | offset_x[:] = offset_x[0].clone().detach() 263 | offset_y[:] = offset_y[0].clone().detach() 264 | grid_batch, grid_x, grid_y = torch.meshgrid( 265 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 266 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 267 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 268 | indexing="ij" 269 | ) 270 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 271 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 272 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 273 | mask[grid_batch, grid_x, grid_y] = 0 274 | x = x * mask.unsqueeze(1) 275 | return x 276 | 277 | 278 | AUGMENT_FNS = { 279 | 'color': [rand_brightness, rand_saturation, rand_contrast], 280 | 'crop': [rand_crop], 281 | 'cutout': [rand_cutout], 282 | 'flip': [rand_flip], 283 | 'scale': [rand_scale], 284 | 'rotate': [rand_rotate], 285 | } -------------------------------------------------------------------------------- /data/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io as sio 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.folder import default_loader 5 | from torchvision.datasets.utils import download_url 6 | from torchvision.datasets.utils import extract_archive 7 | 8 | 9 | class Cars(VisionDataset): 10 | """`Stanford Cars `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | file_list = { 25 | 'imgs': ('http://imagenet.stanford.edu/internal/car196/car_ims.tgz', 'car_ims.tgz'), 26 | 'annos': ('http://imagenet.stanford.edu/internal/car196/cars_annos.mat', 'cars_annos.mat') 27 | } 28 | 29 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 30 | super(Cars, self).__init__(root, transform=transform, target_transform=target_transform) 31 | 32 | self.loader = default_loader 33 | self.train = train 34 | 35 | if self._check_exists(): 36 | print('Files already downloaded and verified.') 37 | elif download: 38 | self._download() 39 | else: 40 | raise RuntimeError( 41 | 'Dataset not found. You can use download=True to download it.') 42 | 43 | loaded_mat = sio.loadmat(os.path.join(self.root, self.file_list['annos'][1])) 44 | loaded_mat = loaded_mat['annotations'][0] 45 | self.samples = [] 46 | for item in loaded_mat: 47 | if self.train != bool(item[-1][0]): 48 | path = str(item[0][0]) 49 | label = int(item[-2][0]) - 1 50 | self.samples.append((path, label)) 51 | 52 | def __getitem__(self, index): 53 | path, target = self.samples[index] 54 | path = os.path.join(self.root, path) 55 | 56 | image = self.loader(path) 57 | if self.transform is not None: 58 | image = self.transform(image) 59 | if self.target_transform is not None: 60 | target = self.target_transform(target) 61 | return image, target 62 | 63 | def __len__(self): 64 | return len(self.samples) 65 | 66 | def _check_exists(self): 67 | return (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) 68 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))) 69 | 70 | def _download(self): 71 | print('Downloading...') 72 | for url, filename in self.file_list.values(): 73 | download_url(url, root=self.root, filename=filename) 74 | print('Extracting...') 75 | archive = os.path.join(self.root, self.file_list['imgs'][1]) 76 | extract_archive(archive) 77 | 78 | 79 | if __name__ == '__main__': 80 | train_dataset = Cars('./cars', train=True, download=False) 81 | test_dataset = Cars('./cars', train=False, download=False) -------------------------------------------------------------------------------- /data/cub2011.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from torchvision.datasets import VisionDataset 5 | from torchvision.datasets.folder import default_loader 6 | from torchvision.datasets.utils import download_file_from_google_drive 7 | 8 | 9 | class Cub2011(VisionDataset): 10 | """`CUB-200-2011 `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | base_folder = 'CUB_200_2011/images' 25 | # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 26 | file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' 27 | filename = 'CUB_200_2011.tgz' 28 | tgz_md5 = '97eceeb196236b17998738112f37df78' 29 | 30 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 31 | super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform) 32 | 33 | self.loader = default_loader 34 | self.train = train 35 | if download: 36 | self._download() 37 | 38 | if not self._check_integrity(): 39 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it') 40 | 41 | def _load_metadata(self): 42 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 43 | names=['img_id', 'filepath']) 44 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 45 | sep=' ', names=['img_id', 'target']) 46 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 47 | sep=' ', names=['img_id', 'is_training_img']) 48 | 49 | data = images.merge(image_class_labels, on='img_id') 50 | self.data = data.merge(train_test_split, on='img_id') 51 | 52 | class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'), 53 | sep=' ', names=['class_name'], usecols=[1]) 54 | self.class_names = class_names['class_name'].to_list() 55 | if self.train: 56 | self.data = self.data[self.data.is_training_img == 1] 57 | else: 58 | self.data = self.data[self.data.is_training_img == 0] 59 | 60 | def _check_integrity(self): 61 | try: 62 | self._load_metadata() 63 | except Exception: 64 | return False 65 | 66 | for index, row in self.data.iterrows(): 67 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 68 | if not os.path.isfile(filepath): 69 | print(filepath) 70 | return False 71 | return True 72 | 73 | def _download(self): 74 | import tarfile 75 | 76 | if self._check_integrity(): 77 | print('Files already downloaded and verified') 78 | return 79 | 80 | download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5) 81 | 82 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 83 | tar.extractall(path=self.root) 84 | 85 | def __len__(self): 86 | return len(self.data) 87 | 88 | def __getitem__(self, idx): 89 | sample = self.data.iloc[idx] 90 | path = os.path.join(self.root, self.base_folder, sample.filepath) 91 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 92 | img = self.loader(path) 93 | 94 | if self.transform is not None: 95 | img = self.transform(img) 96 | if self.target_transform is not None: 97 | target = self.target_transform(target) 98 | return img, target 99 | 100 | 101 | if __name__ == '__main__': 102 | train_dataset = Cub2011('./cub2011', train=True, download=False) 103 | test_dataset = Cub2011('./cub2011', train=False, download=False) 104 | -------------------------------------------------------------------------------- /data/dogs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | from os.path import join 4 | from torchvision.datasets import VisionDataset 5 | from torchvision.datasets.folder import default_loader 6 | from torchvision.datasets.utils import download_url, list_dir 7 | 8 | 9 | class Dogs(VisionDataset): 10 | """`Stanford Dogs `_ Dataset. 11 | 12 | Args: 13 | root (string): Root directory of the dataset. 14 | train (bool, optional): If True, creates dataset from training set, otherwise 15 | creates from test set. 16 | transform (callable, optional): A function/transform that takes in an PIL image 17 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 18 | target_transform (callable, optional): A function/transform that takes in the 19 | target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' 25 | 26 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 27 | super(Dogs, self).__init__(root, transform=transform, target_transform=target_transform) 28 | 29 | self.loader = default_loader 30 | self.train = train 31 | 32 | if download: 33 | self.download() 34 | 35 | split = self.load_split() 36 | 37 | self.images_folder = join(self.root, 'Images') 38 | self.annotations_folder = join(self.root, 'Annotation') 39 | self._breeds = list_dir(self.images_folder) 40 | 41 | self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split] 42 | 43 | self._flat_breed_images = self._breed_images 44 | 45 | def __len__(self): 46 | return len(self._flat_breed_images) 47 | 48 | def __getitem__(self, index): 49 | image_name, target = self._flat_breed_images[index] 50 | image_path = join(self.images_folder, image_name) 51 | image = self.loader(image_path) 52 | 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | return image, target 58 | 59 | def download(self): 60 | import tarfile 61 | 62 | if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): 63 | if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: 64 | print('Files already downloaded and verified') 65 | return 66 | 67 | for filename in ['images', 'annotation', 'lists']: 68 | tar_filename = filename + '.tar' 69 | url = self.download_url_prefix + '/' + tar_filename 70 | download_url(url, self.root, tar_filename, None) 71 | print('Extracting downloaded file: ' + join(self.root, tar_filename)) 72 | with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: 73 | tar_file.extractall(self.root) 74 | os.remove(join(self.root, tar_filename)) 75 | 76 | def load_split(self): 77 | if self.train: 78 | split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] 79 | labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] 80 | else: 81 | split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] 82 | labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] 83 | 84 | split = [item[0][0] for item in split] 85 | labels = [item[0] - 1 for item in labels] 86 | return list(zip(split, labels)) 87 | 88 | def stats(self): 89 | counts = {} 90 | for index in range(len(self._flat_breed_images)): 91 | image_name, target_class = self._flat_breed_images[index] 92 | if target_class not in counts.keys(): 93 | counts[target_class] = 1 94 | else: 95 | counts[target_class] += 1 96 | 97 | print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()), 98 | float(len(self._flat_breed_images)) / float( 99 | len(counts.keys())))) 100 | 101 | return counts 102 | 103 | 104 | if __name__ == '__main__': 105 | train_dataset = Dogs('./dogs', train=True, download=False) 106 | test_dataset = Dogs('./dogs', train=False, download=False) 107 | -------------------------------------------------------------------------------- /data/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import TensorDataset, DataLoader 3 | 4 | from data.augmentation import get_aug 5 | 6 | def get_loader(root_dir, data_name, batch_size, size=32, aug=False): 7 | X_tr, Y_tr = torch.load(f"{root_dir}/{data_name}/X_tr_{size}.pth"), torch.load(f"{root_dir}/{data_name}/Y_tr_{size}.pth") 8 | dataset_tr = TensorDataset(X_tr, Y_tr) 9 | dataloader_tr = DataLoader(dataset_tr, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True) 10 | 11 | if data_name == "imagenet": 12 | dataloader_te = dataloader_tr 13 | else: 14 | X_te, Y_te = torch.load(f"{root_dir}/{data_name}/X_te_{size}.pth"), torch.load(f"{root_dir}/{data_name}/Y_te_{size}.pth") 15 | dataset_te = TensorDataset(X_te, Y_te) 16 | dataloader_te = DataLoader(dataset_te, batch_size=batch_size, num_workers=0, shuffle=False, pin_memory=True) 17 | 18 | 19 | transform_tr, transform_te = get_aug(data_name, size, aug) 20 | 21 | return dataloader_tr, dataloader_te, transform_tr, transform_te -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dd 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py310h6a678d5_7 11 | - bzip2=1.0.8=h5eee18b_5 12 | - ca-certificates=2024.3.11=h06a4308_0 13 | - certifi=2024.2.2=py310h06a4308_0 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cuda-cudart=11.8.89=0 16 | - cuda-cupti=11.8.87=0 17 | - cuda-libraries=11.8.0=0 18 | - cuda-nvrtc=11.8.89=0 19 | - cuda-nvtx=11.8.86=0 20 | - cuda-runtime=11.8.0=0 21 | - ffmpeg=4.3=hf484d3e_0 22 | - filelock=3.13.1=py310h06a4308_0 23 | - freetype=2.12.1=h4a9f257_0 24 | - gmp=6.2.1=h295c915_3 25 | - gmpy2=2.1.2=py310heeb90bb_0 26 | - gnutls=3.6.15=he1e5248_0 27 | - idna=3.4=py310h06a4308_0 28 | - intel-openmp=2023.1.0=hdb19cb5_46306 29 | - jinja2=3.1.3=py310h06a4308_0 30 | - jpeg=9e=h5eee18b_1 31 | - lame=3.100=h7b6447c_0 32 | - lcms2=2.12=h3be6417_0 33 | - ld_impl_linux-64=2.38=h1181459_1 34 | - lerc=3.0=h295c915_0 35 | - libcublas=11.11.3.6=0 36 | - libcufft=10.9.0.58=0 37 | - libcufile=1.9.1.3=0 38 | - libcurand=10.3.5.147=0 39 | - libcusolver=11.4.1.48=0 40 | - libcusparse=11.7.5.86=0 41 | - libdeflate=1.17=h5eee18b_1 42 | - libffi=3.4.4=h6a678d5_0 43 | - libgcc-ng=11.2.0=h1234567_1 44 | - libgomp=11.2.0=h1234567_1 45 | - libiconv=1.16=h7f8727e_2 46 | - libidn2=2.3.4=h5eee18b_0 47 | - libjpeg-turbo=2.0.0=h9bf148f_0 48 | - libnpp=11.8.0.86=0 49 | - libnvjpeg=11.9.0.86=0 50 | - libpng=1.6.39=h5eee18b_0 51 | - libstdcxx-ng=11.2.0=h1234567_1 52 | - libtasn1=4.19.0=h5eee18b_0 53 | - libtiff=4.5.1=h6a678d5_0 54 | - libunistring=0.9.10=h27cfd23_0 55 | - libuuid=1.41.5=h5eee18b_0 56 | - libwebp-base=1.3.2=h5eee18b_0 57 | - llvm-openmp=14.0.6=h9e868ea_0 58 | - lz4-c=1.9.4=h6a678d5_0 59 | - markupsafe=2.1.3=py310h5eee18b_0 60 | - mkl=2023.1.0=h213fc3f_46344 61 | - mkl-service=2.4.0=py310h5eee18b_1 62 | - mkl_fft=1.3.8=py310h5eee18b_0 63 | - mkl_random=1.2.4=py310hdb19cb5_0 64 | - mpc=1.1.0=h10f8cd9_1 65 | - mpfr=4.0.2=hb69a4c5_1 66 | - mpmath=1.3.0=py310h06a4308_0 67 | - ncurses=6.4=h6a678d5_0 68 | - nettle=3.7.3=hbbd107a_1 69 | - networkx=3.1=py310h06a4308_0 70 | - numpy=1.26.4=py310h5f9d8c6_0 71 | - numpy-base=1.26.4=py310hb5e798b_0 72 | - openh264=2.1.1=h4ff587b_0 73 | - openjpeg=2.4.0=h3ad879b_0 74 | - openssl=3.0.13=h7f8727e_0 75 | - pillow=10.2.0=py310h5eee18b_0 76 | - pip=23.3.1=py310h06a4308_0 77 | - pysocks=1.7.1=py310h06a4308_0 78 | - python=3.10.14=h955ad1f_0 79 | - pytorch=2.1.2=py3.10_cuda11.8_cudnn8.7.0_0 80 | - pytorch-cuda=11.8=h7e8668a_5 81 | - pytorch-mutex=1.0=cuda 82 | - pyyaml=6.0.1=py310h5eee18b_0 83 | - readline=8.2=h5eee18b_0 84 | - requests=2.31.0=py310h06a4308_1 85 | - setuptools=68.2.2=py310h06a4308_0 86 | - sqlite=3.41.2=h5eee18b_0 87 | - sympy=1.12=py310h06a4308_0 88 | - tbb=2021.8.0=hdb19cb5_0 89 | - tk=8.6.12=h1ccaba5_0 90 | - torchaudio=2.1.2=py310_cu118 91 | - torchtriton=2.1.0=py310 92 | - torchvision=0.16.2=py310_cu118 93 | - typing_extensions=4.9.0=py310h06a4308_1 94 | - tzdata=2024a=h04d1e81_0 95 | - urllib3=2.1.0=py310h06a4308_1 96 | - wheel=0.41.2=py310h06a4308_0 97 | - xz=5.4.6=h5eee18b_0 98 | - yaml=0.2.5=h7b6447c_0 99 | - zlib=1.2.13=h5eee18b_0 100 | - zstd=1.5.5=hc292b87_0 101 | - pip: 102 | - fsspec==2024.3.1 103 | - huggingface-hub==0.22.2 104 | - kornia==0.7.1 105 | - packaging==24.0 106 | - regex==2023.12.25 107 | - safetensors==0.4.2 108 | - tokenizers==0.15.2 109 | - tqdm==4.66.2 110 | - transformers==4.36.2 111 | prefix: /c2/seanie/anaconda3/envs/dd 112 | -------------------------------------------------------------------------------- /model_pool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm import trange 5 | 6 | from models.wrapper import get_model 7 | 8 | 9 | class ModelPool: 10 | def __init__(self, args, device): 11 | self.device = device 12 | self.online_iteration = args.online_iteration 13 | self.num_models = args.num_models 14 | 15 | # model func 16 | self.model_func = lambda _: get_model(args.train_model, args.img_shape, args.num_pretrain_classes).to(device) 17 | 18 | # opt func 19 | if args.online_opt == "sgd": 20 | self.opt_func = lambda param: torch.optim.SGD(param, lr=args.online_lr, momentum=0.9, weight_decay=args.online_wd) 21 | elif args.online_opt == "adam": 22 | self.opt_func = lambda param: torch.optim.AdamW(param, lr=args.online_lr, weight_decay=args.online_wd) 23 | else: 24 | raise NotImplementedError 25 | 26 | self.iterations = [ 0 ] *self.num_models 27 | self.models = [ self.model_func(None) for _ in range(self.num_models) ] 28 | self.opts = [ self.opt_func(self.models[i].parameters()) for i in range(self.num_models) ] 29 | 30 | def init(self, x_syn, y_syn): 31 | for idx in range(self.num_models): 32 | online_iteration = np.random.randint(1, self.online_iteration) 33 | self.iterations[idx] = online_iteration 34 | model = self.models[idx] 35 | opt = self.opts[idx] 36 | model.train() 37 | print(f"{idx}-th model init") 38 | for _ in trange(online_iteration): 39 | opt.zero_grad() 40 | loss = F.mse_loss(model(x_syn), y_syn) 41 | loss.backward() 42 | opt.step() 43 | 44 | def update(self, idx, x_syn, y_syn): 45 | # reset 46 | if self.iterations[idx] >= self.online_iteration: 47 | self.models[idx] = self.model_func(None) 48 | self.opts[idx] = self.opt_func(self.models[idx].parameters()) 49 | model = self.models[idx] 50 | opt = self.opts[idx] 51 | 52 | # train the model for 1 step 53 | else: 54 | self.iterations[idx] = self.iterations[idx] + 1 55 | model = self.models[idx] 56 | opt = self.opts[idx] 57 | 58 | model.train() 59 | opt.zero_grad() 60 | loss = F.mse_loss(model(x_syn), y_syn) 61 | loss.backward() 62 | opt.step() 63 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AlexNet(nn.Module): 6 | def __init__(self, img_size, num_classes): 7 | super(AlexNet, self).__init__() 8 | self.features = nn.Sequential( 9 | nn.Conv2d(3, 128, kernel_size=5, stride=1, padding=2), 10 | nn.BatchNorm2d(128), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2, stride=2), 13 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 14 | nn.BatchNorm2d(192), 15 | nn.ReLU(inplace=True), 16 | nn.MaxPool2d(kernel_size=2, stride=2), 17 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(256), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(192), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 24 | nn.BatchNorm2d(192), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | ) 28 | img_size = img_size // 8 29 | self.fc = nn.Linear(192 * img_size * img_size, num_classes) 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.fc(x) 35 | return x 36 | 37 | def embed(self, x): 38 | x = self.features(x) 39 | x = x.view(x.size(0), -1) 40 | return x -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class NoneBlock(nn.Module): 6 | def __init__(self, num_channels, affine): 7 | super().__init__() 8 | 9 | def forward(self, x): 10 | return x 11 | 12 | class ConvBlock(nn.Module): 13 | def __init__(self, in_channels, out_channels, norm_layer): 14 | super().__init__() 15 | self.block = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 17 | norm_layer(out_channels, affine=True), 18 | nn.ReLU(inplace=False), 19 | nn.AvgPool2d(kernel_size=2, stride=2) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.block(x) 24 | return x 25 | 26 | class ConvNet(nn.Module): 27 | def __init__( 28 | self, 29 | img_shape=(3,32,32), 30 | num_classes=10, 31 | num_channels=[128, 256, 512], 32 | norm="bn" 33 | ): 34 | super(ConvNet, self).__init__() 35 | 36 | 37 | HW = img_shape[1] 38 | 39 | if norm.lower() == "bn": 40 | norm_layer = nn.BatchNorm2d 41 | elif norm.lower() == "in": 42 | norm_layer = nn.InstanceNorm2d 43 | elif norm.lower() == "none": 44 | norm_layer = NoneBlock 45 | else: 46 | raise NotImplementedError 47 | 48 | layers = [] 49 | for i in range(len(num_channels)): 50 | if i == 0: 51 | layers.append(ConvBlock(img_shape[0], num_channels[0], norm_layer)) 52 | else: 53 | layers.append(ConvBlock(num_channels[i-1], num_channels[i], norm_layer)) 54 | HW = HW // 2 55 | self.layers = nn.ModuleList(layers) 56 | self.num_features = HW*HW*num_channels[-1] 57 | self.fc = nn.Linear(self.num_features, num_classes) 58 | 59 | def forward(self, x): 60 | for layer in self.layers: 61 | x = layer(x) 62 | x = x.reshape(x.size(0), -1) 63 | x = self.fc(x) 64 | return x 65 | 66 | def embed(self, x): 67 | for layer in self.layers: 68 | x = layer(x) 69 | x = x.reshape(x.size(0), -1) 70 | return x 71 | """ 72 | 73 | class ConvBlock(nn.Module): 74 | def __init__(self, in_channels, out_channels): 75 | super(ConvBlock, self).__init__() 76 | self.block = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 78 | nn.InstanceNorm2d(out_channels, affine=True), 79 | #nn.BatchNorm2d(out_channels, affine=True), 80 | nn.ReLU(inplace=True), 81 | nn.AvgPool2d(kernel_size=2, stride=2) 82 | ) 83 | 84 | def forward(self, x): 85 | x = self.block(x) 86 | return x 87 | 88 | class ConvNet(nn.Module): 89 | def __init__( 90 | self, 91 | img_shape=(3,32,32), 92 | num_classes=10, 93 | num_layers=3, 94 | num_channels=128 95 | ): 96 | super(ConvNet, self).__init__() 97 | 98 | layers = [] 99 | HW = img_shape[1] 100 | for i in range(num_layers): 101 | if i == 0: 102 | layers.append(ConvBlock(img_shape[0], num_channels)) 103 | else: 104 | layers.append(ConvBlock(num_channels, num_channels)) 105 | HW = HW // 2 106 | self.layers = nn.ModuleList(layers) 107 | self.num_features = HW*HW*num_channels 108 | self.fc = nn.Linear(self.num_features, num_classes) 109 | 110 | def forward(self, x): 111 | for layer in self.layers: 112 | x = layer(x) 113 | x = x.reshape(x.size(0), -1) 114 | x = self.fc(x) 115 | return x 116 | 117 | def embed(self, x): 118 | for layer in self.layers: 119 | x = layer(x) 120 | x = x.reshape(x.size(0), -1) 121 | return x 122 | """ -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | See the paper "Inverted Residuals and Linear Bottlenecks: 3 | Mobile Networks for Classification, Detection and Segmentation" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class Block(nn.Module): 10 | '''expand + depthwise + pointwise''' 11 | def __init__(self, in_planes, out_planes, expansion, stride): 12 | super(Block, self).__init__() 13 | self.stride = stride 14 | 15 | planes = expansion * in_planes 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 21 | self.bn3 = nn.BatchNorm2d(out_planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride == 1 and in_planes != out_planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 27 | nn.BatchNorm2d(out_planes), 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = F.relu(self.bn2(self.conv2(out))) 33 | out = self.bn3(self.conv3(out)) 34 | out = out + self.shortcut(x) if self.stride==1 else out 35 | return out 36 | 37 | 38 | class MobileNet(nn.Module): 39 | # (expansion, out_planes, num_blocks, stride) 40 | cfg = [(1, 16, 1, 1), 41 | (2, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 42 | (2, 32, 3, 2), 43 | (2, 64, 4, 2), 44 | (2, 96, 3, 1), 45 | (2, 160, 3, 2), 46 | (2, 320, 1, 1)] 47 | 48 | def __init__(self, num_classes=10): 49 | super(MobileNet, self).__init__() 50 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 51 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(32) 53 | self.layers = self._make_layers(in_planes=32) 54 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 55 | self.bn2 = nn.BatchNorm2d(1280) 56 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 57 | self.num_features = 1280 58 | self.fc = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | x = F.relu(self.bn1(self.conv1(x))) 71 | x = self.layers(x) 72 | x = F.relu(self.bn2(self.conv2(x))) 73 | x = self.avgpool(x) 74 | x = x.reshape(x.size(0), -1) 75 | x = self.fc(x) 76 | return x 77 | 78 | def embed(self, x): 79 | x = F.relu(self.bn1(self.conv1(x))) 80 | x = self.layers(x) 81 | x = F.relu(self.bn2(self.conv2(x))) 82 | x = self.avgpool(x) 83 | x = x.reshape(x.size(0), -1) 84 | return x -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=dilation, 16 | groups=groups, 17 | bias=False, 18 | dilation=dilation, 19 | ) 20 | 21 | 22 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion: int = 1 29 | 30 | def __init__( 31 | self, 32 | inplanes: int, 33 | planes: int, 34 | stride: int = 1, 35 | downsample: Optional[nn.Module] = None, 36 | groups: int = 1, 37 | base_width: int = 64, 38 | dilation: int = 1, 39 | norm_layer: Optional[Callable[..., nn.Module]] = None, 40 | ) -> None: 41 | super().__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 78 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 79 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 80 | # This variant is also known as ResNet V1.5 and improves accuracy according to 81 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 82 | 83 | expansion: int = 4 84 | 85 | def __init__( 86 | self, 87 | inplanes: int, 88 | planes: int, 89 | stride: int = 1, 90 | downsample: Optional[nn.Module] = None, 91 | groups: int = 1, 92 | base_width: int = 64, 93 | dilation: int = 1, 94 | norm_layer: Optional[Callable[..., nn.Module]] = None, 95 | ) -> None: 96 | super().__init__() 97 | if norm_layer is None: 98 | norm_layer = nn.BatchNorm2d 99 | width = int(planes * (base_width / 64.0)) * groups 100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv1x1(inplanes, width) 102 | self.bn1 = norm_layer(width) 103 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 104 | self.bn2 = norm_layer(width) 105 | self.conv3 = conv1x1(width, planes * self.expansion) 106 | self.bn3 = norm_layer(planes * self.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.downsample = downsample 109 | self.stride = stride 110 | 111 | def forward(self, x: Tensor) -> Tensor: 112 | identity = x 113 | 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.bn2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.bn3(out) 124 | 125 | if self.downsample is not None: 126 | identity = self.downsample(x) 127 | 128 | out += identity 129 | out = self.relu(out) 130 | 131 | return out 132 | 133 | 134 | class ResNet(nn.Module): 135 | def __init__( 136 | self, 137 | block: Type[Union[BasicBlock, Bottleneck]], 138 | layers: List[int], 139 | num_classes: int = 1000, 140 | zero_init_residual: bool = False, 141 | groups: int = 1, 142 | width_per_group: int = 64, 143 | replace_stride_with_dilation: Optional[List[bool]] = None, 144 | norm_layer: Optional[Callable[..., nn.Module]] = None, 145 | ) -> None: 146 | super().__init__() 147 | if norm_layer is None: 148 | norm_layer = nn.BatchNorm2d 149 | self._norm_layer = norm_layer 150 | 151 | self.inplanes = 64 152 | self.dilation = 1 153 | if replace_stride_with_dilation is None: 154 | # each element in the tuple indicates if we should replace 155 | # the 2x2 stride with a dilated convolution instead 156 | replace_stride_with_dilation = [False, False, False] 157 | if len(replace_stride_with_dilation) != 3: 158 | raise ValueError( 159 | "replace_stride_with_dilation should be None " 160 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 161 | ) 162 | self.groups = groups 163 | self.base_width = width_per_group 164 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 165 | self.bn1 = norm_layer(self.inplanes) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 173 | self.fc = nn.Linear(512 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 179 | nn.init.constant_(m.weight, 1) 180 | nn.init.constant_(m.bias, 0) 181 | 182 | # Zero-initialize the last BN in each residual branch, 183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 185 | self.num_features = 512 186 | if zero_init_residual: 187 | for m in self.modules(): 188 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 189 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 190 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 191 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 192 | 193 | def _make_layer( 194 | self, 195 | block: Type[Union[BasicBlock, Bottleneck]], 196 | planes: int, 197 | blocks: int, 198 | stride: int = 1, 199 | dilate: bool = False, 200 | ) -> nn.Sequential: 201 | norm_layer = self._norm_layer 202 | downsample = None 203 | previous_dilation = self.dilation 204 | if dilate: 205 | self.dilation *= stride 206 | stride = 1 207 | if stride != 1 or self.inplanes != planes * block.expansion: 208 | downsample = nn.Sequential( 209 | conv1x1(self.inplanes, planes * block.expansion, stride), 210 | norm_layer(planes * block.expansion), 211 | ) 212 | 213 | layers = [] 214 | layers.append( 215 | block( 216 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 217 | ) 218 | ) 219 | self.inplanes = planes * block.expansion 220 | for _ in range(1, blocks): 221 | layers.append( 222 | block( 223 | self.inplanes, 224 | planes, 225 | groups=self.groups, 226 | base_width=self.base_width, 227 | dilation=self.dilation, 228 | norm_layer=norm_layer, 229 | ) 230 | ) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def _forward_impl(self, x: Tensor) -> Tensor: 235 | # See note [TorchScript super()] 236 | x = self.conv1(x) 237 | x = self.bn1(x) 238 | x = self.relu(x) 239 | x = self.maxpool(x) 240 | 241 | x = self.layer1(x) 242 | x = self.layer2(x) 243 | x = self.layer3(x) 244 | x = self.layer4(x) 245 | 246 | x = self.avgpool(x) 247 | x = torch.flatten(x, 1) 248 | x = self.fc(x) 249 | 250 | return x 251 | 252 | def forward(self, x: Tensor) -> Tensor: 253 | return self._forward_impl(x) 254 | 255 | def embed(self, x: Tensor) -> Tensor: 256 | x = self.conv1(x) 257 | x = self.bn1(x) 258 | x = self.relu(x) 259 | x = self.maxpool(x) 260 | 261 | x = self.layer1(x) 262 | x = self.layer2(x) 263 | x = self.layer3(x) 264 | x = self.layer4(x) 265 | 266 | x = self.avgpool(x) 267 | x = torch.flatten(x, 1) 268 | return x 269 | 270 | def ResNet10(num_classes): 271 | model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=num_classes, zero_init_residual=True) 272 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 273 | model.maxpool = nn.Identity() 274 | return model 275 | 276 | def ResNet18(num_classes): 277 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, zero_init_residual=True) 278 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 279 | model.maxpool = nn.Identity() 280 | return model -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ''' VGG ''' 6 | cfg_vgg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | class VGG(nn.Module): 13 | def __init__(self, vgg_name, img_size, num_classes, norm='batchnorm'): 14 | super(VGG, self).__init__() 15 | img_size = img_size // 32 16 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 17 | self.fc = nn.Linear(512*img_size*img_size, num_classes) 18 | 19 | def forward(self, x): 20 | x = self.features(x) 21 | x = x.view(x.size(0), -1) 22 | x = self.fc(x) 23 | return x 24 | 25 | def embed(self, x): 26 | x = self.features(x) 27 | x = x.view(x.size(0), -1) 28 | return x 29 | 30 | def _make_layers(self, cfg, norm): 31 | layers = [] 32 | in_channels = 3 33 | for ic, x in enumerate(cfg): 34 | if x == 'M': 35 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 36 | else: 37 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 38 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 39 | nn.ReLU(inplace=True)] 40 | in_channels = x 41 | return nn.Sequential(*layers) 42 | 43 | def VGG11(img_size, num_classes): 44 | return VGG('VGG11', img_size, num_classes) 45 | -------------------------------------------------------------------------------- /models/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet18, resnet34, resnet50 3 | 4 | from models.convnet import ConvNet 5 | from models.resnet import ResNet10, ResNet18 6 | from models.vgg import VGG11 7 | from models.alexnet import AlexNet 8 | from models.mobilenet import MobileNet 9 | 10 | def get_model(name, img_shape=(3,32,32), num_classes=10, dropout=0.0): 11 | if "convnet" in name.lower(): 12 | name_splited = name.lower().split("_") 13 | num_channels = [] 14 | for idx, ns in enumerate(name_splited): 15 | if idx != 0: 16 | if ns.isdigit(): 17 | num_channels.append(int(ns)) 18 | else: 19 | norm = ns 20 | model = ConvNet(img_shape, num_classes, num_channels, norm) 21 | elif "resnet10" == name.lower(): 22 | model = ResNet10(num_classes) 23 | elif "resnet18" == name.lower(): 24 | model = ResNet18(num_classes) 25 | elif "vgg" == name.lower(): 26 | model = VGG11(img_shape[1], num_classes) 27 | elif "alexnet" == name.lower(): 28 | model = AlexNet(img_shape[1], num_classes) 29 | elif "mobilenet" == name.lower(): 30 | model = MobileNet(num_classes) 31 | else: 32 | raise NotImplementedError 33 | if dropout > 0.0: 34 | model.fc = nn.Sequential( 35 | nn.Dropout(dropout), 36 | model.fc 37 | ) 38 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from data.wrapper import get_loader 8 | from data.augmentation import NUM_CLASSES, ParamDiffAug 9 | from algorithms.wrapper import get_algorithm 10 | 11 | def main(args): 12 | device = torch.device(f"cuda:{args.gpu_id}") 13 | torch.cuda.set_device(device) 14 | 15 | # default augment 16 | args.dsa_param = ParamDiffAug() 17 | args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' 18 | 19 | # seed 20 | if args.seed is None: 21 | args.seed = random.randint(0, 9999) 22 | random.seed(args.seed) 23 | np.random.seed(args.seed) 24 | torch.manual_seed(args.seed) 25 | 26 | # data 27 | x_syn = torch.load(f"{args.synthetic_data_dir}/{args.source_data_name}/{args.method}/x_syn.pt", map_location="cpu").detach() 28 | y_syn = torch.load(f"{args.synthetic_data_dir}/{args.source_data_name}/{args.method}/y_syn.pt", map_location="cpu").detach() 29 | if args.method == "kip" or args.method == "frepo" or args.method == "krr_st": 30 | y_syn = y_syn.float() 31 | else: 32 | y_syn = y_syn.long() 33 | if args.method == "krr_st" : 34 | args.num_pretrain_classes = y_syn.shape[-1] 35 | else: 36 | args.num_pretrain_classes = NUM_CLASSES[args.source_data_name] 37 | 38 | print(args) 39 | 40 | # algo 41 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt": 42 | pretrain = get_algorithm("pretrain_dc") 43 | elif args.method == "kip" or args.method == "frepo": 44 | pretrain = get_algorithm("pretrain_frepo") 45 | elif args.method == "krr_st": 46 | pretrain = get_algorithm("pretrain_krr_st") 47 | else: 48 | raise NotImplementedError 49 | test_algo = get_algorithm("finetune") 50 | 51 | # target_data_name 52 | if args.target_data_name == "full": 53 | if args.source_data_name == "cifar100": 54 | data_name_list = ["cifar100", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"] 55 | elif args.source_data_name == "tinyimagenet": 56 | data_name_list = ["tinyimagenet", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"] 57 | elif args.source_data_name == "imagenet": 58 | data_name_list = ["cifar10", "cifar100", "aircraft", "cars", "cub2011", "dogs", "flowers"] 59 | elif args.source_data_name == "imagenette": 60 | data_name_list = ["imagenette", "aircraft", "cars", "cub2011", "dogs", "flowers"] 61 | else: 62 | raise NotImplementedError 63 | else: 64 | data_name_list = args.target_data_name.split("_") 65 | 66 | # train 67 | acc_dict = { data_name: [] for data_name in data_name_list } 68 | for _ in range(args.num_test): 69 | x_syn, y_syn = x_syn.to(device), y_syn.to(device) 70 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn) 71 | init_model = init_model.cpu() 72 | x_syn, y_syn = x_syn.cpu(), y_syn.cpu() 73 | for data_name in data_name_list: 74 | args.num_classes = NUM_CLASSES[data_name] 75 | if data_name in ["tinyimagenet", "cifar100", "cifar10"]: 76 | args.test_iteration = 10000 77 | else: 78 | args.test_iteration = 5000 79 | dl_tr, dl_te, aug_tr, aug_te = get_loader( 80 | args.data_dir, data_name, args.test_batch_size, args.img_size, True) 81 | _, acc = test_algo.run(args, device, args.test_model, init_model, dl_tr, dl_te, aug_tr, aug_te) 82 | print(data_name, acc) 83 | acc_dict[data_name].append(acc) 84 | 85 | for data_name in data_name_list: 86 | print(f"{data_name}, mean: {np.mean(acc_dict[data_name])}, std: {np.std(acc_dict[data_name])}") 87 | 88 | if __name__ == '__main__': 89 | parser = argparse.ArgumentParser(description='Parameter Processing') 90 | 91 | # seed 92 | parser.add_argument('--seed', type=int, default=None) 93 | 94 | # data 95 | parser.add_argument('--source_data_name', type=str, default="tinyimagenet") 96 | parser.add_argument('--target_data_name', type=str, default="full") 97 | 98 | # dir 99 | parser.add_argument('--data_dir', type=str, default="./datasets") 100 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data") 101 | parser.add_argument('--log_dir', type=str, default="./test_log") 102 | 103 | # dc method 104 | parser.add_argument('--method', type=str, default="krr_st") 105 | 106 | # hparams for model 107 | parser.add_argument('--test_model', type=str, default="base") 108 | parser.add_argument('--dropout', type=float, default=0.0) 109 | 110 | # hparms for test 111 | parser.add_argument('--num_test', type=int, default=3) 112 | 113 | # gpus 114 | parser.add_argument('--gpu_id', type=int, default=0) 115 | args = parser.parse_args() 116 | 117 | # img_size 118 | if args.source_data_name == "cifar100": 119 | args.img_size = 32 120 | if args.test_model == "base": 121 | args.test_model = "convnet_128_256_512_bn" 122 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet": 123 | args.img_size = 64 124 | if args.test_model == "base": 125 | args.test_model = "convnet_64_128_256_512_bn" 126 | elif args.source_data_name == "imagenette": 127 | args.img_size = 224 128 | if args.test_model == "base": 129 | args.test_model = "convnet_32_64_128_256_512_bn" 130 | else: 131 | raise NotImplementedError 132 | args.img_shape = (3, args.img_size, args.img_size) 133 | 134 | # pretrain hparams 135 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt": 136 | args.pre_opt = "sgd" 137 | args.pre_epoch = 1000 138 | args.pre_iteration = None 139 | args.pre_batch_size = 256 140 | args.pre_lr = 0.01 141 | args.pre_wd = 5e-4 142 | 143 | elif args.method == "kip" or args.method == "frepo": 144 | #step_per_prototpyes = {10: 1000, 100: 2000, 200: 20000, 400: 5000, 500: 5000, 1000: 10000, 2000: 40000, 5000: 40000} 145 | args.pre_opt = "adam" 146 | args.pre_epoch = None 147 | if args.source_data_name == "cifar100" or args.source_data_name == "imagenet": 148 | args.pre_iteration = 10000 # 1000 149 | args.pre_batch_size = 500 150 | elif args.source_data_name == "tinyimagenet": 151 | args.pre_iteration = 40000 # 2000 152 | if args.test_model == "mobilenet": 153 | args.pre_batch_size = 256 154 | else: 155 | args.pre_batch_size = 500 156 | elif args.source_data_name == "imagenette": 157 | args.pre_iteration = 1000 # 10 158 | args.pre_batch_size = 10 159 | args.pre_lr = 0.0003 160 | args.pre_wd = 0.0 161 | 162 | elif args.method == "krr_st": 163 | args.pre_opt = "sgd" 164 | args.pre_epoch = 1000 165 | args.pre_batch_size = 256 166 | args.pre_lr = 0.1 167 | args.pre_wd = 1e-3 168 | else: 169 | raise NotImplementedError 170 | 171 | # finetune hyperparams 172 | args.test_opt = "sgd" 173 | args.test_iteration = None 174 | if args.img_size == 224 and args.test_model == "resnet18": 175 | args.test_batch_size = 64 176 | else: 177 | args.test_batch_size = 256 178 | args.test_lr = 0.01 179 | args.test_wd = 5e-4 180 | 181 | main(args) 182 | -------------------------------------------------------------------------------- /test_kd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from algorithms.wrapper import get_algorithm 8 | from data.augmentation import NUM_CLASSES, ParamDiffAug 9 | from data.wrapper import get_loader 10 | from models.wrapper import get_model 11 | 12 | 13 | def main(args): 14 | device = torch.device(f"cuda:{args.gpu_id}") 15 | torch.cuda.set_device(device) 16 | 17 | # default augment 18 | args.dsa_param = ParamDiffAug() 19 | args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate' 20 | 21 | # seed 22 | if args.seed is None: 23 | args.seed = random.randint(0, 9999) 24 | random.seed(args.seed) 25 | np.random.seed(args.seed) 26 | torch.manual_seed(args.seed) 27 | 28 | # data 29 | if args.method != "gaussian": 30 | x_syn = torch.load(f"./synthetic_data/{args.source_data_name}/{args.method}/x_syn.pt", map_location="cpu").detach() 31 | y_syn = torch.load(f"./synthetic_data/{args.source_data_name}/{args.method}/y_syn.pt", map_location="cpu").detach() 32 | if args.method == "kip" or args.method == "frepo" or args.method == "krr_st": 33 | y_syn = y_syn.float() 34 | else: 35 | y_syn = y_syn.long() 36 | else: 37 | x_syn = torch.load(f"./synthetic_data/{args.source_data_name}/random/x_syn.pt", map_location="cpu").detach() 38 | y_syn = torch.load(f"./synthetic_data/{args.source_data_name}/random/y_syn.pt", map_location="cpu").detach() 39 | if args.method == "krr_st" : 40 | args.num_pretrain_classes = y_syn.shape[-1] 41 | else: 42 | args.num_pretrain_classes = NUM_CLASSES[args.source_data_name] 43 | 44 | print(args) 45 | 46 | # algo 47 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt": 48 | pretrain = get_algorithm("pretrain_dc") 49 | elif args.method == "kip" or args.method == "frepo": 50 | pretrain = get_algorithm("pretrain_frepo") 51 | elif args.method == "krr_st": 52 | pretrain = get_algorithm("pretrain_krr_st") 53 | elif args.method == "gaussian": 54 | pass 55 | else: 56 | raise NotImplementedError 57 | test_algo = get_algorithm("zeroshot_kd") 58 | 59 | data_name = "cifar10" 60 | args.img_shape = (3, args.img_size, args.img_size) 61 | dl_tr, dl_te, aug_tr, aug_te = get_loader( 62 | args.data_dir, data_name, args.test_batch_size, args.img_size, True) 63 | data = { 64 | "num_classes": NUM_CLASSES[data_name.lower()], 65 | "dl_tr": dl_tr, 66 | "dl_te": dl_te, 67 | "aug_tr": aug_tr, 68 | "aug_te": aug_te 69 | } 70 | 71 | teacher = get_model("resnet18", args.img_shape, data["num_classes"]).to(device) 72 | ckpt = torch.load(f"./teacher_ckpt/teacher_{data_name}.pt", map_location="cpu") 73 | teacher.load_state_dict(ckpt) 74 | for p in teacher.parameters(): 75 | p.requires_grad_(False) 76 | 77 | acc_list = [] 78 | for _ in range(args.num_test): 79 | if args.method != "gaussian": 80 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn) 81 | else: 82 | init_model = get_model(args.test_model, args.img_shape, 1).to(device) 83 | 84 | args.num_classes = data["num_classes"] 85 | dl_te = data["dl_te"] 86 | aug_te = data["aug_te"] 87 | _, acc = test_algo.run(args, device, args.test_model, init_model, teacher, x_syn, dl_te, aug_te) 88 | print(acc) 89 | acc_list.append(acc) 90 | 91 | print(f"{data_name}, mean: {np.mean(acc_list)}, std: {np.std(acc_list)}") 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description='Parameter Processing') 95 | 96 | # seed 97 | parser.add_argument('--seed', type=int, default=None) 98 | 99 | # data 100 | parser.add_argument('--source_data_name', type=str, default="tinyimagenet") 101 | parser.add_argument('--target_data_name', type=str, default="cifar10") 102 | 103 | # dir 104 | parser.add_argument('--data_dir', type=str, default="../evaluation_seanie/datasets") 105 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data") 106 | parser.add_argument('--log_dir', type=str, default="./test_log") 107 | 108 | # dc method 109 | parser.add_argument('--method', type=str, default="krr_st") 110 | 111 | # hparams for model 112 | parser.add_argument('--test_model', type=str, default="base") 113 | parser.add_argument('--dropout', type=float, default=0.0) 114 | 115 | # hparms for test 116 | parser.add_argument('--num_test', type=int, default=3) 117 | 118 | # gpus 119 | parser.add_argument('--gpu_id', type=int, default=0) 120 | args = parser.parse_args() 121 | 122 | # img_size 123 | if args.source_data_name == "cifar100": 124 | args.img_size = 32 125 | if args.test_model == "base": 126 | args.test_model = "convnet_128_256_512_bn" 127 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet": 128 | args.img_size = 64 129 | if args.test_model == "base": 130 | args.test_model = "convnet_64_128_256_512_bn" 131 | elif args.source_data_name == "imagenette": 132 | args.img_size = 224 133 | if args.test_model == "base": 134 | args.test_model = "convnet_32_64_128_256_512_bn" 135 | else: 136 | raise NotImplementedError 137 | args.img_shape = (3, args.img_size, args.img_size) 138 | 139 | # pretrain hparams 140 | if args.method == "gaussian" or args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt": 141 | args.pre_opt = "sgd" 142 | args.pre_epoch = 1000 143 | args.pre_iteration = None 144 | args.pre_batch_size = 256 145 | args.pre_lr = 0.01 146 | args.pre_wd = 5e-4 147 | 148 | elif args.method == "kip" or args.method == "frepo": 149 | #step_per_prototpyes = {10: 1000, 100: 2000, 200: 20000, 400: 5000, 500: 5000, 1000: 10000, 2000: 40000, 5000: 40000} 150 | args.pre_opt = "adam" 151 | args.pre_epoch = None 152 | if args.source_data_name == "cifar100" or args.source_data_name == "imagenet": 153 | args.pre_iteration = 10000 # 1000 154 | args.pre_batch_size = 500 155 | elif args.source_data_name == "tinyimagenet": 156 | args.pre_iteration = 40000 # 2000 157 | if args.test_model == "mobilenet": 158 | args.pre_batch_size = 128 159 | else: 160 | args.pre_batch_size = 500 161 | elif args.source_data_name == "imagenette": 162 | args.pre_iteration = 1000 # 10 163 | args.pre_batch_size = 10 164 | args.pre_lr = 0.0003 165 | args.pre_wd = 0.0 166 | 167 | elif args.method == "krr_st": 168 | args.pre_opt = "sgd" 169 | args.pre_epoch = 1000 170 | args.pre_batch_size = 256 171 | args.pre_lr = 0.1 172 | args.pre_wd = 1e-3 173 | else: 174 | raise NotImplementedError 175 | 176 | # zeroshot_kd hyperparams 177 | args.test_opt = "adam" 178 | args.test_epoch = 1000 179 | args.test_batch_size = 512 180 | if args.method == "gaussian": 181 | args.test_lr = 1e-3 182 | else: 183 | args.test_lr = 1e-4 184 | args.test_wd = 0. 185 | 186 | main(args) 187 | -------------------------------------------------------------------------------- /test_scratch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from data.wrapper import get_loader 8 | from data.augmentation import NUM_CLASSES 9 | from algorithms.wrapper import get_algorithm 10 | 11 | def main(args): 12 | device = torch.device(f"cuda:{args.gpu_id}") 13 | torch.cuda.set_device(device) 14 | 15 | # seed 16 | if args.seed is None: 17 | args.seed = random.randint(0, 9999) 18 | random.seed(args.seed) 19 | np.random.seed(args.seed) 20 | torch.manual_seed(args.seed) 21 | 22 | print(args) 23 | 24 | # algo 25 | test_algo = get_algorithm("scratch") 26 | 27 | # target_data_name 28 | if args.target_data_name == "full": 29 | if args.source_data_name == "cifar100": 30 | data_name_list = ["cifar100", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"] 31 | elif args.source_data_name == "tinyimagenet": 32 | data_name_list = ["tinyimagenet", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"] 33 | elif args.source_data_name == "imagenet": 34 | data_name_list = ["cifar10", "cifar100", "aircraft", "cars", "cub2011", "dogs", "flowers"] 35 | elif args.source_data_name == "imagenette": 36 | data_name_list = ["imagenette", "aircraft", "cars", "cub2011", "dogs", "flowers"] 37 | else: 38 | raise NotImplementedError 39 | else: 40 | data_name_list = args.target_data_name.split("_") 41 | 42 | # train 43 | acc_dict = { data_name: [] for data_name in data_name_list } 44 | for _ in range(args.num_test): 45 | for data_name in data_name_list: 46 | args.num_classes = NUM_CLASSES[data_name] 47 | if data_name in ["tinyimagenet", "cifar100", "cifar10"]: 48 | args.test_iteration = 10000 49 | else: 50 | args.test_iteration = 5000 51 | dl_tr, dl_te, aug_tr, aug_te = get_loader( 52 | args.data_dir, data_name, args.test_batch_size, args.img_size, True) 53 | _, acc = test_algo.run(args, device, args.test_model, dl_tr, dl_te, aug_tr, aug_te) 54 | print(data_name, acc) 55 | acc_dict[data_name].append(acc) 56 | 57 | for data_name in data_name_list: 58 | print(f"{data_name}, mean: {np.mean(acc_dict[data_name])}, std: {np.std(acc_dict[data_name])}") 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser(description='Parameter Processing') 62 | 63 | # seed 64 | parser.add_argument('--seed', type=int, default=None) 65 | 66 | # data 67 | parser.add_argument('--source_data_name', type=str, default="cifar100") 68 | parser.add_argument('--target_data_name', type=str, default="full") 69 | 70 | # dir 71 | parser.add_argument('--data_dir', type=str, default="../evaluation_seanie/datasets") 72 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data") 73 | parser.add_argument('--log_dir', type=str, default="./test_log") 74 | 75 | # hparams for model 76 | parser.add_argument('--test_model', type=str, default="base") 77 | 78 | # hparms for test 79 | parser.add_argument('--num_test', type=int, default=3) 80 | 81 | # gpus 82 | parser.add_argument('--gpu_id', type=int, default=0) 83 | args = parser.parse_args() 84 | 85 | # img_size 86 | if args.source_data_name == "cifar100": 87 | args.img_size = 32 88 | if args.test_model == "base": 89 | args.test_model = "convnet_128_256_512_bn" 90 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet": 91 | args.img_size = 64 92 | if args.test_model == "base": 93 | args.test_model = "convnet_64_128_256_512_bn" 94 | elif args.source_data_name == "imagenette": 95 | args.img_size = 224 96 | if args.test_model == "base": 97 | args.test_model = "convnet_32_64_128_256_512_bn" 98 | else: 99 | raise NotImplementedError 100 | args.img_shape = (3, args.img_size, args.img_size) 101 | 102 | # finetune hyperparams 103 | args.test_opt = "sgd" 104 | args.test_iteration = None 105 | if args.img_size == 224 and args.test_model == "resnet18": 106 | args.test_batch_size = 64 107 | else: 108 | args.test_batch_size = 256 109 | args.test_lr = 0.01 110 | args.test_wd = 5e-4 111 | 112 | main(args) 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from torchvision.models import resnet18 9 | 10 | from data.wrapper import get_loader 11 | from data.augmentation import NUM_CLASSES 12 | from algorithms.wrapper import get_algorithm 13 | from utils import InfIterator, Logger 14 | from model_pool import ModelPool 15 | 16 | def main(args): 17 | device = torch.device(f"cuda:{args.gpu_id}") 18 | torch.cuda.set_device(device) 19 | 20 | # seed 21 | if args.seed is None: 22 | args.seed = random.randint(0, 9999) 23 | random.seed(args.seed) 24 | np.random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | 27 | # data 28 | args.img_shape = (3, args.img_size, args.img_size) 29 | dl, _, _, _ = get_loader(args.data_dir, args.data_name, args.outer_batch_size, args.img_size, False) 30 | iter_tr = InfIterator(dl) 31 | dl_tr, dl_te, aug_tr, aug_te = get_loader(args.data_dir, args.data_name, args.test_batch_size, args.img_size, True) 32 | 33 | # target model 34 | if args.data_name == "imagenette": 35 | target_model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50') 36 | target_model.fc = nn.Identity() 37 | target_model = target_model.to(device) 38 | else: 39 | target_model = resnet18() 40 | target_model.fc = nn.Identity() 41 | target_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 42 | target_model.maxpool = nn.Identity() 43 | target_model = target_model.to(device) 44 | target_model.load_state_dict(torch.load(f"./teacher_ckpt/barlow_twins_resnet18_{args.data_name}.pt", map_location="cpu")) 45 | 46 | # get features 47 | target_model.eval() 48 | with torch.no_grad(): 49 | x_syn, y_syn = torch.FloatTensor([]), torch.FloatTensor([]) 50 | while x_syn.shape[0] < args.num_images: 51 | x, y = next(iter_tr) 52 | x_syn = torch.cat([x_syn, x], dim=0) 53 | x = x.to(device) 54 | x = aug_te(x) 55 | y = target_model(x).cpu() 56 | torch.cuda.empty_cache() 57 | y_syn = torch.cat([y_syn, y], dim=0) 58 | 59 | args.num_pretrain_classes = y_syn.shape[-1] 60 | args.num_classes = NUM_CLASSES[args.data_name] 61 | 62 | # x_syn, y_syn 63 | x_syn, y_syn = x_syn[:args.num_images].to(device), y_syn[:args.num_images].to(device) 64 | x_syn.requires_grad_(True); y_syn.requires_grad_(True) 65 | 66 | # outer opt 67 | if args.outer_opt == "sgd": 68 | outer_opt = torch.optim.SGD([x_syn, y_syn], lr=args.outer_lr, momentum=0.5, weight_decay=args.outer_wd) 69 | elif args.outer_opt == "adam": 70 | outer_opt = torch.optim.AdamW([x_syn, y_syn], lr=args.outer_lr, weight_decay=args.outer_wd) 71 | else: 72 | raise NotImplementedError 73 | outer_sch = torch.optim.lr_scheduler.LinearLR( 74 | outer_opt, start_factor=1.0, end_factor=1e-3, total_iters=args.outer_iteration) 75 | 76 | # model pool 77 | model_pool = ModelPool(args, device) 78 | model_pool.init(x_syn.detach(), y_syn.detach()) 79 | 80 | # logger 81 | logger = Logger( 82 | save_dir=f"{args.save_dir}/{args.exp_name}", 83 | save_only_last=True, 84 | print_every=args.print_every, 85 | save_every=args.save_every, 86 | total_step=args.outer_iteration, 87 | print_to_stdout=True 88 | ) 89 | logger.register_object_to_save(x_syn, "x_syn") 90 | logger.register_object_to_save(y_syn, "y_syn") 91 | logger.start() 92 | 93 | # algo 94 | train_algo = get_algorithm("distill") 95 | pretrain = get_algorithm("pretrain_krr_st") 96 | test_algo = get_algorithm("linear_eval") 97 | 98 | # outer loop 99 | for outer_step in range(1, args.outer_iteration+1): 100 | 101 | # meta train 102 | loss = train_algo.run( 103 | args, device, target_model, model_pool, outer_opt, iter_tr, aug_tr, x_syn, y_syn) 104 | logger.meter("meta_train", "mse loss", loss) 105 | 106 | # meta test 107 | if outer_step % args.eval_every == 0 or outer_step == args.outer_iteration: 108 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn) 109 | loss, acc = test_algo.run(args, device, args.test_model, init_model, dl_tr, dl_te, aug_tr, aug_te) 110 | del init_model 111 | logger.meter(f"meta_test", "loss", loss) 112 | logger.meter(f"meta_test", "accuracy", acc) 113 | 114 | outer_sch.step() 115 | logger.step() 116 | 117 | logger.finish() 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser(description='Parameter Processing') 121 | 122 | # seed 123 | parser.add_argument('--seed', type=int, default=None) 124 | 125 | # data 126 | parser.add_argument('--target_model', type=str, default="resnet18") 127 | parser.add_argument('--data_name', type=str, default="cifar100") 128 | parser.add_argument('--num_workers', type=int, default=0) 129 | 130 | # dir 131 | parser.add_argument('--data_dir', type=str, default="./datasets") 132 | parser.add_argument('--save_dir', type=str, default="./results") 133 | parser.add_argument('--exp_name', type=str, default=None) 134 | 135 | # algorithm 136 | parser.add_argument('--num_models', type=int, default=10) 137 | 138 | # hparms for online 139 | parser.add_argument('--online_opt', type=str, default="sgd") 140 | parser.add_argument('--online_iteration', type=int, default=1000) 141 | parser.add_argument('--online_lr', type=float, default=0.1) 142 | parser.add_argument('--online_wd', type=float, default=1e-3) 143 | 144 | # hparms for pretrain 145 | parser.add_argument('--pre_opt', type=str, default="sgd") 146 | parser.add_argument('--pre_epoch', type=int, default=1000) 147 | parser.add_argument('--pre_batch_size', type=int, default=256) 148 | parser.add_argument('--pre_lr', type=float, default=0.1) 149 | parser.add_argument('--pre_wd', type=float, default=1e-3) 150 | 151 | # hparms for test 152 | parser.add_argument('--test_opt', type=str, default="sgd") 153 | parser.add_argument('--test_iteration', type=int, default=5000) 154 | parser.add_argument('--test_batch_size', type=float, default=512) 155 | parser.add_argument('--test_lr', type=float, default=0.2) 156 | parser.add_argument('--test_wd', type=float, default=0.0) 157 | 158 | # hparms for outer 159 | parser.add_argument('--outer_opt', type=str, default="adam") 160 | parser.add_argument('--outer_iteration', type=int, default=160000) 161 | parser.add_argument('--outer_batch_size', type=int, default=1024) 162 | parser.add_argument('--outer_lr', type=float, default=1e-3) 163 | parser.add_argument('--outer_wd', type=float, default=0.) 164 | parser.add_argument('--outer_grad_norm', type=float, default=0.0) 165 | 166 | # hparams for logger 167 | parser.add_argument('--print_every', type=int, default=100) 168 | parser.add_argument('--eval_every', type=int, default=16000) 169 | parser.add_argument('--save_every', type=int, default=2000) 170 | 171 | # gpus 172 | parser.add_argument('--gpu_id', type=int, default=0) 173 | args = parser.parse_args() 174 | 175 | if args.data_name == "cifar100": 176 | args.img_size = 32 177 | args.num_images = 1000 178 | args.train_model = "convnet_128_256_512_bn" 179 | args.test_model = "convnet_128_256_512_bn" 180 | elif args.data_name == "tinyimagenet": 181 | args.img_size = 64 182 | args.num_images = 2000 183 | args.train_model = "convnet_64_128_256_512_bn" 184 | args.test_model = "convnet_64_128_256_512_bn" 185 | elif args.data_name == "imagenet": 186 | args.img_size = 64 187 | args.num_images = 1000 188 | args.train_model = "convnet_64_128_256_512_bn" 189 | args.test_model = "convnet_64_128_256_512_bn" 190 | elif args.data_name == "imagenette": 191 | args.img_size = 224 192 | args.num_images = 10 193 | args.train_model = "convnet_32_64_128_256_512_bn" 194 | args.test_model = "convnet_32_64_128_256_512_bn" 195 | else: 196 | raise NotImplementedError 197 | 198 | main(args) 199 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import torch 5 | 6 | 7 | class InfIterator: 8 | def __init__(self, iterable): 9 | self.iterable = iterable 10 | self.iterator = iter(self.iterable) 11 | 12 | def __next__(self): 13 | try: 14 | return next(self.iterator) 15 | except StopIteration: 16 | self.iterator = iter(self.iterable) 17 | return next(self.iterator) 18 | 19 | class Logger: 20 | def __init__( 21 | self, 22 | save_dir=None, 23 | save_only_last=True, 24 | print_every=100, 25 | save_every=100, 26 | total_step=0, 27 | print_to_stdout=True, 28 | ): 29 | if save_dir is not None: 30 | self.save_dir = save_dir 31 | os.makedirs(self.save_dir, exist_ok=True) 32 | else: 33 | self.save_dir = None 34 | 35 | self.print_every = print_every 36 | self.save_every = save_every 37 | self.save_only_last = save_only_last 38 | self.step_count = 0 39 | self.total_step = total_step 40 | self.print_to_stdout = print_to_stdout 41 | 42 | self.writer = None 43 | self.start_time = None 44 | self.groups = dict() 45 | self.models_to_save = dict() 46 | self.objects_to_save = dict() 47 | 48 | def register_model_to_save(self, model, name): 49 | assert name not in self.models_to_save.keys(), "Name is already registered." 50 | 51 | self.models_to_save[name] = model 52 | 53 | def register_object_to_save(self, object, name): 54 | assert name not in self.objects_to_save.keys(), "Name is already registered." 55 | 56 | self.objects_to_save[name] = object 57 | 58 | def step(self): 59 | self.step_count += 1 60 | if self.step_count % self.print_every == 0: 61 | if self.print_to_stdout: 62 | self.print_log(self.step_count, self.total_step, elapsed_time=datetime.now() - self.start_time) 63 | 64 | if self.step_count % self.save_every == 0: 65 | if self.save_only_last: 66 | self.save_models() 67 | self.save_objects() 68 | else: 69 | self.save_models(self.step_count) 70 | self.save_objects(self.step_count) 71 | 72 | def meter(self, group_name, log_name, value): 73 | if group_name not in self.groups.keys(): 74 | self.groups[group_name] = dict() 75 | 76 | if log_name not in self.groups[group_name].keys(): 77 | self.groups[group_name][log_name] = Accumulator() 78 | 79 | self.groups[group_name][log_name].update_state(value) 80 | 81 | def reset_state(self): 82 | for _, group in self.groups.items(): 83 | for _, log in group.items(): 84 | log.reset_state() 85 | 86 | def print_log(self, step, total_step, elapsed_time=None): 87 | print(f"[Step {step:5d}/{total_step}]", end=" ") 88 | 89 | for name, group in self.groups.items(): 90 | print(f"({name})", end=" ") 91 | for log_name, log in group.items(): 92 | res = log.result() 93 | if res is None: 94 | continue 95 | 96 | if "acc" in log_name.lower(): 97 | print(f"{log_name} {res:.2f}", end=" | ") 98 | else: 99 | print(f"{log_name} {res:.4f}", end=" | ") 100 | 101 | if elapsed_time is not None: 102 | print(f"(Elapsed time) {elapsed_time}") 103 | else: 104 | print() 105 | 106 | def save_models(self, suffix=None): 107 | if self.save_dir is None: 108 | return 109 | for name, model in self.models_to_save.items(): 110 | _name = name 111 | if suffix: 112 | _name += f"_{suffix}" 113 | torch.save(model.state_dict(), os.path.join(self.save_dir, f"{_name}.pt")) 114 | 115 | if self.print_to_stdout: 116 | print(f"{name} is saved to {self.save_dir}") 117 | 118 | def save_objects(self, suffix=None): 119 | if self.save_dir is None: 120 | return 121 | 122 | for name, obj in self.objects_to_save.items(): 123 | _name = name 124 | if suffix: 125 | _name += f"_{suffix}" 126 | torch.save(obj, os.path.join(self.save_dir, f"{_name}.pt")) 127 | 128 | if self.print_to_stdout: 129 | print(f"{name} is saved to {self.save_dir}") 130 | 131 | def start(self): 132 | if self.print_to_stdout: 133 | print("Training starts!") 134 | self.start_time = datetime.now() 135 | 136 | def finish(self): 137 | if self.step_count % self.save_every != 0: 138 | if self.save_only_last: 139 | self.save_models() 140 | self.save_objects() 141 | else: 142 | self.save_models(self.step_count) 143 | self.save_objects(self.step_count) 144 | 145 | if self.print_to_stdout: 146 | print("Training is finished!") 147 | 148 | class Accumulator: 149 | def __init__(self): 150 | self.data = 0 151 | self.num_data = 0 152 | 153 | def reset_state(self): 154 | self.data = 0 155 | self.num_data = 0 156 | 157 | def update_state(self, tensor): 158 | with torch.no_grad(): 159 | self.data += tensor 160 | self.num_data += 1 161 | 162 | def result(self): 163 | if self.num_data == 0: 164 | return None 165 | data = self.data.item() if hasattr(self.data, 'item') else self.data 166 | return float(data) / self.num_data 167 | --------------------------------------------------------------------------------