├── .gitignore
├── LICENSE
├── README.md
├── algorithms
├── distill.py
├── finetune.py
├── linear_eval.py
├── pretrain_dc.py
├── pretrain_frepo.py
├── pretrain_krr_st.py
├── scratch.py
├── wrapper.py
└── zeroshot_kd.py
├── assets
└── concept.png
├── data
├── aircraft.py
├── augmentation.py
├── cars.py
├── cub2011.py
├── dogs.py
└── wrapper.py
├── environment.yml
├── model_pool.py
├── models
├── alexnet.py
├── convnet.py
├── mobilenet.py
├── resnet.py
├── vgg.py
└── wrapper.py
├── test.py
├── test_kd.py
├── test_scratch.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | */__pycache__
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Dataset Distillation for Transfer Learning
2 | This is the official Pytorch implementation for the paper ["**Self-Supervised Dataset Distillation for Transfer Learning**", in ICLR 2024.](https://openreview.net/forum?id=h57gkDO2Yg)
3 |
4 | ## Summary
5 |
6 | Dataset distillation aims to optimize a small set so that a model trained on the set achieves performance similar to that of a model trained on the full dataset. While many supervised methods have achieved remarkable success in distilling a large dataset into a small set of representative samples, however, they are not designed to produce a distilled dataset that can be effectively used to facilitate self-supervised pre-training. To this end, we propose a novel problem of distilling an unlabeled dataset into a set of small synthetic samples for efficient self-supervised learning (SSL). We first prove that a gradient of synthetic samples with respect to a SSL objective in naive bilevel optimization is biased due to the randomness originating from data augmentations or masking for inner optimization. To address this issue, we propose to minimize the mean squared error (MSE) between a model's representations of the synthetic examples and their corresponding learnable target feature representations for the inner objective, which does not introduce any randomness. Our primary motivation is that the model obtained by the proposed inner optimization can mimic the self-supervised target model. To achieve this, we also introduce the MSE between representations of the inner model and the self-supervised target model on the original full dataset for outer optimization. We empirically validate the effectiveness of our method on transfer learning.
7 |
8 |
9 |
10 | __Contribution of this work__
11 | - We propose a new problem of self-supervised dataset distillation for transfer learning, where we distill an unlabeled dataset into a small set,
12 | pre-train a model on it, and fine-tune it on target tasks.
13 | - We have observed training instability when utilizing existing SSL objectives in bilevel optimization for self-supervised dataset distillation. Furthermore, we prove that a gradient of the SSL objectives with data augmentations or masking inputs is a biased estimator of the true gradient.
14 | - To address the instability, we propose KRR-ST using MSE without any randomness at an inner loop. For the inner loop, we minimize MSE between a model representation of synthetic samples and target representations. For an outer loop, we minimize MSE between the original data representation of the model from inner loop and that of the model pre-trained on the original dataset.
15 | - We extensively validate our proposed method on numerous target datasets and architectures, and show that ours outperforms supervised dataset distillation methods.
16 |
17 | ## Dependencies
18 | This code is written in Python. Dependencies include
19 | * python >= 3.10
20 | * pytorch = 2.1.2
21 | * torchvision = 0.16.2
22 | * tqdm
23 | * korina = 0.7.1
24 | * transformers = 4.36.2
25 |
26 | ```bash
27 | conda env create -f environment.yml
28 | conda activate dd
29 | ```
30 |
31 | ## Data and Model Checkpoints
32 | * Download **Full Data**(~40GB) from [here](https://drive.google.com/file/d/1P0zwURUbVsqoVgIRcIZXGAtGrkRvGvH0/view?usp=sharing).
33 | * Download **Distilled Data**(~702MB) from [here](https://drive.google.com/file/d/1vDghSAUnmdWdGJgx9iK8dMOwoId0nuKF/view?usp=sharing).
34 | * Download **Target (Teacher) Model Checkpoints**(~158MB) from [here](https://drive.google.com/file/d/1IuN4rhlB5UuJX_jrbVIWEBXo10QWHPBE/view?usp=sharing).
35 |
36 | directory should be look like this:
37 | ```shell
38 | ┌── datasets/
39 | ┌── aircraft/
40 | ┌── X_te_32.pth
41 | ├── ...
42 | └── Y_tr_224.pth
43 | ├── cars/
44 | ...
45 | └── tinyimagenet/
46 |
47 | ├── synthetic_data/
48 | ┌── cifar100/
49 | ┌── dm/
50 | ┌── x_syn.pt
51 | └── y_syn.pt
52 | ├── ...
53 | └── random/
54 | ├── ...
55 | └── tinyimagenet/
56 |
57 | └── teacher_ckpt/
58 | ┌── barlow_twins_resnet18_cifar100.pt
59 | ├── ...
60 | └── teacher_cifar10.pt
61 | ```
62 |
63 | ## Dataset Distillation
64 | To distill **CIFAR100**, run the following code:
65 | ```bash
66 | python train.py --exp_name EXP_NAME (e.g. "cifar100_exp") --data_name cifar100 --outer_lr 1e-4 --gpu_id N
67 | ```
68 |
69 | To distill **TinyImageNet**, run the following code:
70 | ```bash
71 | python train.py --exp_name EXP_NAME (e.g. "tinyimagenet_exp") --data_name tinyimagenet --outer_lr 1e-5 --gpu_id N
72 | ```
73 |
74 | To distill **ImageNet 64x64**, run the following code:
75 | ```bash
76 | python train.py --exp_name EXP_NAME (e.g. "imagenet_exp") --data_name imagenet --outer_lr 1e-5 --gpu_id N
77 | ```
78 |
79 | ## Transfer Learning
80 | To reproduce **transfer learning with CIFAR100 (Table 1)**, run the following code:
81 | ```bash
82 | python test_scratch.py --source_data_name cifar100 --target_data_name full --gpu_id N
83 | python test.py --source_data_name cifar100 --target_data_name full --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "kip", "frepo", "krr_st"]) --test_model base --gpu_id N
84 | ```
85 |
86 | To reproduce **transfer learning with TinyImageNet (Table 2)**, run the following code:
87 | ```bash
88 | python test_scratch.py --source_data_name tinyimagenet --target_data_name full --gpu_id N
89 | python test.py --source_data_name tinyimagenet --target_data_name full --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model base --gpu_id N
90 | ```
91 |
92 | To reproduce **transfer learning with ImageNet 64x64 (Table 3)**, run the following code:
93 | ```bash
94 | python test_scratch.py --source_data_name imagenet --target_data_name full --gpu_id N
95 | python test.py --source_data_name imagenet --target_data_name full --method METHOD (["random", "frepo", "krr_st"]) --test_model base --gpu_id N
96 | ```
97 |
98 | To reproduce **architecture generalization with TinyImageNet (Figure 3)**, run the following code:
99 | ```bash
100 | python test_scratch.py --source_data_name tinyimagenet --target_data_name aircraft_cars_cub2011_dogs_flowers --test_model ARCHITECTURE (["vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N
101 | python test.py --source_data_name tinyimagenet --target_data_name aircraft_cars_cub2011_dogs_flowers --method METHOD (["random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model ARCHITECTURE (["vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N
102 | ```
103 |
104 | To reproduce **target data-free knowledge distillation with TinyImageNet (Table 4)**, run the following code:
105 | ```bash
106 | python test_kd.py --source_data_name tinyimagenet --method METHOD (["gaussian", "random", "kmeans", "dsa", "dm", "mtt", "frepo", "krr_st"]) --test_model ARCHITECTURE (["base", "vgg", "alexnet", "mobilenet", "resnet10"]) --gpu_id N
107 | ```
108 |
109 | ## Reference
110 | To cite our paper, please use this BibTex
111 | ```bibtex
112 | @inproceedings{lee2024selfsupdd,
113 | title={Self-Supervised Dataset Distillation for Transfer Learning},
114 | author={Dong Bok Lee and and Seanie Lee and Joonho Ko and Kenji Kawaguch and Juho Lee and Sung Ju Hwang},
115 | booktitle={Proceedings of the 12th International Conference on Learning Representations},
116 | year={2024}
117 | }
118 | ```
119 |
--------------------------------------------------------------------------------
/algorithms/distill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def run(
8 | args, device, target_model,
9 | model_pool, outer_opt,
10 | iter_tr, aug,
11 | x_syn, y_syn
12 | ):
13 | outer_opt.zero_grad()
14 |
15 | # prepare model
16 | idx = np.random.randint(args.num_models)
17 | model = model_pool.models[idx]
18 | model.eval(); model.zero_grad()
19 |
20 | # data
21 | x_real, _ = next(iter_tr)
22 | x_real = x_real.to(device)
23 | target_model.eval()
24 | with torch.no_grad():
25 | x_real = aug(x_real)
26 | y_real = target_model(x_real)
27 | torch.cuda.empty_cache()
28 |
29 | # feature
30 | f_syn = model.embed(x_syn)
31 | f_syn = torch.cat([ f_syn, torch.ones(f_syn.shape[0], 1).to(device) ], dim=1)
32 | with torch.no_grad():
33 | f_real = model.embed(x_real)
34 | f_real = torch.cat([ f_real, torch.ones(f_real.shape[0], 1).to(device) ], dim=1)
35 |
36 | # kernel
37 | K_real_syn = f_real @ f_syn.permute(1, 0).contiguous()
38 | K_syn_syn = f_syn @ f_syn.permute(1, 0).contiguous()
39 |
40 | # lambda and eye
41 | lambda_ = 1e-6 * torch.trace(K_syn_syn.detach())
42 | eye = torch.eye(K_syn_syn.shape[0]).to(device)
43 |
44 | # mse loss
45 | outer_loss = F.mse_loss(
46 | K_real_syn @ torch.linalg.solve(K_syn_syn + lambda_*eye, y_syn),
47 | y_real
48 | )
49 | outer_grad = torch.autograd.grad(outer_loss, [x_syn, y_syn])
50 |
51 | # meta update
52 | if x_syn.grad is None:
53 | x_syn.grad = outer_grad[0]
54 | else:
55 | x_syn.grad.data.copy_(outer_grad[0].data)
56 | if y_syn.grad is None:
57 | y_syn.grad = outer_grad[1]
58 | else:
59 | y_syn.grad.data.copy_(outer_grad[1].data)
60 | if args.outer_grad_norm > 0.:
61 | nn.utils.clip_grad_norm_(x_syn, args.outer_grad_norm)
62 | nn.utils.clip_grad_norm_(y_syn, args.outer_grad_norm)
63 | outer_opt.step()
64 |
65 | model_pool.update(idx, x_syn.detach(), y_syn.detach())
66 |
67 | return outer_loss
68 |
--------------------------------------------------------------------------------
/algorithms/finetune.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from models.wrapper import get_model
7 | from utils import InfIterator
8 |
9 | def run(
10 | args, device,
11 | model_name, init_model,
12 | dl_tr, dl_te,
13 | aug_tr, aug_te
14 | ):
15 | iter_tr = InfIterator(dl_tr)
16 |
17 | # model
18 | model = get_model(model_name, args.img_shape, args.num_classes).to(device)
19 | if hasattr(init_model, "fc"):
20 | del init_model.fc
21 | model.load_state_dict(init_model.state_dict(), strict=False)
22 |
23 | # opt
24 | if args.test_opt == "sgd":
25 | opt = torch.optim.SGD(model.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd)
26 | elif args.test_opt == "adam":
27 | opt = torch.optim.AdamW(model.parameters(), lr=args.test_lr, weight_decay=args.test_wd)
28 | else:
29 | raise NotImplementedError
30 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration)
31 |
32 | model.train()
33 | for _ in trange(args.test_iteration):
34 | x_tr, y_tr = next(iter_tr)
35 | x_tr, y_tr = x_tr.to(device), y_tr.to(device)
36 | with torch.no_grad():
37 | x_tr = aug_tr(x_tr)
38 | loss = F.cross_entropy(model(x_tr), y_tr)
39 | opt.zero_grad()
40 | loss.backward()
41 | opt.step()
42 | sch.step()
43 |
44 | model.eval()
45 | with torch.no_grad():
46 | loss, acc, denominator = 0., 0., 0.
47 | for x_te, y_te in dl_te:
48 | x_te, y_te = x_te.to(device), y_te.to(device)
49 | x_te = aug_te(x_te)
50 | l_te = model(x_te)
51 | loss += F.cross_entropy(l_te, y_te, reduction="sum")
52 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum()
53 | denominator += x_te.shape[0]
54 | loss /= denominator; acc /= (denominator/100.)
55 |
56 | del model
57 |
58 | return loss.item(), acc.item()
--------------------------------------------------------------------------------
/algorithms/linear_eval.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.utils.data import TensorDataset, DataLoader
7 |
8 | from models.wrapper import get_model
9 | from utils import InfIterator
10 |
11 | def run(
12 | args, device,
13 | model_name, init_model,
14 | dl_tr, dl_te,
15 | aug_tr, aug_te
16 | ):
17 |
18 | # model
19 | model = get_model(model_name, args.img_shape, args.num_classes).to(device)
20 | if hasattr(init_model, "fc"):
21 | del init_model.fc
22 | model.load_state_dict(init_model.state_dict(), strict=False)
23 | model.fc = nn.Identity()
24 |
25 | model.eval()
26 | with torch.no_grad():
27 | # tr feature
28 | F_tr, Y_tr = [], []
29 | for x_tr, y_tr in dl_tr:
30 | x_tr = x_tr.to(device)
31 | x_tr = aug_te(x_tr)
32 | f_tr = model(x_tr).cpu()
33 | F_tr.append(f_tr); Y_tr.append(y_tr)
34 | F_tr, Y_tr = torch.cat(F_tr, dim=0), torch.cat(Y_tr, dim=0)
35 | num_features = F_tr.shape[-1]
36 | dl_feature_tr = DataLoader(TensorDataset(F_tr, Y_tr), batch_size=args.test_batch_size, shuffle=True, num_workers=0, pin_memory=True)
37 | iter_feature_tr = InfIterator(dl_feature_tr)
38 |
39 | # te feature
40 | F_te, Y_te = [], []
41 | for x_te, y_te in dl_te:
42 | x_te = x_te.to(device)
43 | x_te = aug_te(x_te)
44 | f_te = model(x_te).cpu()
45 | F_te.append(f_te); Y_te.append(y_te)
46 | F_te, Y_te = torch.cat(F_te, dim=0), torch.cat(Y_te, dim=0)
47 | dl_feature_te = DataLoader(TensorDataset(F_te, Y_te), batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True)
48 |
49 | # opt
50 | linear = nn.Linear(num_features, args.num_classes).to(device)
51 | if args.test_opt == "sgd":
52 | opt = torch.optim.SGD(linear.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd)
53 | elif args.test_opt == "adam":
54 | opt = torch.optim.AdamW(linear.parameters(), lr=args.test_lr, weight_decay=args.test_wd)
55 | else:
56 | raise NotImplementedError
57 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration)
58 |
59 | linear.train()
60 | for _ in trange(args.test_iteration):
61 | x_tr, y_tr = next(iter_feature_tr)
62 | x_tr, y_tr = x_tr.to(device), y_tr.to(device)
63 | loss = F.cross_entropy(linear(x_tr), y_tr)
64 | opt.zero_grad()
65 | loss.backward()
66 | opt.step()
67 | sch.step()
68 |
69 | linear.eval()
70 | with torch.no_grad():
71 | loss, acc, denominator = 0., 0., 0.
72 | for x_te, y_te in dl_feature_te:
73 | x_te, y_te = x_te.to(device), y_te.to(device)
74 | l_te = linear(x_te)
75 | loss += F.cross_entropy(l_te, y_te, reduction="sum")
76 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum()
77 | denominator += x_te.shape[0]
78 | loss /= denominator; acc /= (denominator/100.)
79 |
80 | del model, linear
81 |
82 | return loss.item(), acc.item()
--------------------------------------------------------------------------------
/algorithms/pretrain_dc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from data.augmentation import DiffAugment
4 | from models.wrapper import get_model
5 | from torch.utils.data import DataLoader, TensorDataset
6 | from tqdm import trange
7 |
8 |
9 | def run(
10 | args, device,
11 | model_name,
12 | x_syn, y_syn
13 | ):
14 | x_syn, y_syn = x_syn.detach(), y_syn.detach()
15 | x_syn, y_syn = x_syn.to(device), y_syn.to(device)
16 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0)
17 |
18 | # model and opt
19 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device)
20 | if args.pre_opt == "sgd":
21 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd)
22 | elif args.pre_opt == "adam":
23 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd)
24 | else:
25 | raise NotImplementedError
26 | sch = torch.optim.lr_scheduler.MultiStepLR(
27 | opt, milestones=[args.pre_epoch // 2], gamma=0.1)
28 |
29 | # pretrain
30 | model.train()
31 | for _ in trange(1, args.pre_epoch+1):
32 | for x_syn, y_syn in dl_syn:
33 | # loss
34 | with torch.no_grad():
35 | x_syn = DiffAugment(x_syn, args.dsa_strategy, param=args.dsa_param)
36 | loss = F.cross_entropy(model(x_syn), y_syn)
37 | # update
38 | opt.zero_grad()
39 | loss.backward()
40 | opt.step()
41 |
42 | sch.step()
43 |
44 | return model
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/algorithms/pretrain_frepo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data.augmentation import DiffAugment
3 | from models.wrapper import get_model
4 | from torch.utils.data import DataLoader, TensorDataset
5 | from tqdm import trange
6 | from transformers import get_cosine_schedule_with_warmup
7 | from utils import InfIterator
8 |
9 |
10 | def run(
11 | args, device,
12 | model_name,
13 | x_syn, y_syn
14 | ):
15 | x_syn, y_syn = x_syn.detach(), y_syn.detach()
16 | x_syn, y_syn = x_syn.to(device), y_syn.to(device)
17 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0)
18 | iter_syn = InfIterator(dl_syn)
19 |
20 | # model and opt
21 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device)
22 | if args.pre_opt == "sgd":
23 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd)
24 | elif args.pre_opt == "adam":
25 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd)
26 | else:
27 | raise NotImplementedError
28 | sch = get_cosine_schedule_with_warmup(opt, 500, args.pre_iteration)
29 |
30 | # pretrain
31 | model.train()
32 | for _ in trange(1, args.pre_iteration+1):
33 | x_syn, y_syn = next(iter_syn)
34 | # loss
35 | with torch.no_grad():
36 | x_syn = DiffAugment(x_syn, args.dsa_strategy, param=args.dsa_param)
37 | loss = torch.mean(torch.sum((model(x_syn) - y_syn) ** 2 * 0.5, dim=-1))
38 | # update
39 | opt.zero_grad()
40 | loss.backward()
41 | opt.step()
42 | sch.step()
43 |
44 | return model
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/algorithms/pretrain_krr_st.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from models.wrapper import get_model
4 | from torch.utils.data import DataLoader, TensorDataset
5 | from tqdm import trange
6 |
7 |
8 | def run(
9 | args, device,
10 | model_name,
11 | x_syn, y_syn
12 | ):
13 | x_syn, y_syn = x_syn.detach(), y_syn.detach()
14 | x_syn, y_syn = x_syn.to(device), y_syn.to(device)
15 | dl_syn = DataLoader(TensorDataset(x_syn, y_syn), batch_size=args.pre_batch_size, shuffle=True, num_workers=0)
16 |
17 | # model and opt
18 | model = get_model(model_name, args.img_shape, args.num_pretrain_classes).to(device)
19 | if args.pre_opt == "sgd":
20 | opt = torch.optim.SGD(model.parameters(), lr=args.pre_lr, momentum=0.9, weight_decay=args.pre_wd)
21 | elif args.pre_opt == "adam":
22 | opt = torch.optim.AdamW(model.parameters(), lr=args.pre_lr, weight_decay=args.pre_wd)
23 | else:
24 | raise NotImplementedError
25 |
26 | # pretrain
27 | model.train()
28 | for _ in trange(1, args.pre_epoch+1):
29 | for x_syn, y_syn in dl_syn:
30 | # loss
31 | loss = F.mse_loss(model(x_syn), y_syn)
32 | # update
33 | opt.zero_grad()
34 | loss.backward()
35 | opt.step()
36 |
37 | return model
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/algorithms/scratch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from models.wrapper import get_model
4 | from tqdm import trange
5 | from utils import InfIterator
6 |
7 |
8 | def run(
9 | args, device,
10 | model_name,
11 | dl_tr, dl_te,
12 | aug_tr, aug_te
13 | ):
14 | iter_tr = InfIterator(dl_tr)
15 |
16 | # model
17 | model = get_model(model_name, args.img_shape, args.num_classes).to(device)
18 |
19 | # opt
20 | if args.test_opt == "sgd":
21 | opt = torch.optim.SGD(model.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd)
22 | elif args.test_opt == "adam":
23 | opt = torch.optim.AdamW(model.parameters(), lr=args.test_lr, weight_decay=args.test_wd)
24 | else:
25 | raise NotImplementedError
26 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_iteration)
27 |
28 | model.train()
29 | for _ in trange(args.test_iteration):
30 | x_tr, y_tr = next(iter_tr)
31 | x_tr, y_tr = x_tr.to(device), y_tr.to(device)
32 | with torch.no_grad():
33 | x_tr = aug_tr(x_tr)
34 | loss = F.cross_entropy(model(x_tr), y_tr)
35 | opt.zero_grad()
36 | loss.backward()
37 | opt.step()
38 | sch.step()
39 |
40 | model.eval()
41 | with torch.no_grad():
42 | loss, acc, denominator = 0., 0., 0.
43 | for x_te, y_te in dl_te:
44 | x_te, y_te = x_te.to(device), y_te.to(device)
45 | x_te = aug_te(x_te)
46 | l_te = model(x_te)
47 | loss += F.cross_entropy(l_te, y_te, reduction="sum")
48 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum()
49 | denominator += x_te.shape[0]
50 | loss /= denominator; acc /= (denominator/100.)
51 |
52 | del model
53 |
54 | return loss.item(), acc.item()
55 |
56 |
--------------------------------------------------------------------------------
/algorithms/wrapper.py:
--------------------------------------------------------------------------------
1 | from algorithms import distill
2 |
3 | from algorithms import pretrain_dc
4 | from algorithms import pretrain_frepo
5 | from algorithms import pretrain_krr_st
6 |
7 | from algorithms import scratch
8 | from algorithms import linear_eval
9 | from algorithms import finetune
10 | from algorithms import zeroshot_kd
11 |
12 | def get_algorithm(name):
13 | return globals()[name]
--------------------------------------------------------------------------------
/algorithms/zeroshot_kd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from data.augmentation import DiffAugment
4 | from models.wrapper import get_model
5 | from torch.utils.data import DataLoader, TensorDataset
6 | from tqdm import trange
7 |
8 |
9 | def distillation_loss(student, teacher, T=1.0, reduction="mean"):
10 | student = F.log_softmax(student/T, dim=-1)
11 | teacher = F.softmax(teacher/T, dim=-1)
12 | loss = -(teacher * student).sum(dim=-1)
13 |
14 | if reduction == "mean":
15 | return loss.mean(dim=0)
16 | elif reduction == "sum":
17 | return loss.sum(dim=0)
18 | else:
19 | raise NotImplementedError
20 |
21 | def run(
22 | args, device,
23 | model_name, init_model, teacher,
24 | x_syn, dl_te, aug_te
25 | ):
26 |
27 | x_syn = x_syn.detach()
28 | x_syn = x_syn.to(device)
29 | dl_syn = DataLoader(TensorDataset(x_syn), batch_size=args.test_batch_size, shuffle=True, num_workers=0)
30 |
31 | # model
32 | student = get_model(model_name, args.img_shape, args.num_classes).to(device)
33 | if hasattr(init_model, "fc"):
34 | del init_model.fc
35 | student.load_state_dict(init_model.state_dict(), strict=False)
36 |
37 | # opt
38 | if args.test_opt == "sgd":
39 | opt = torch.optim.SGD(student.parameters(), lr=args.test_lr, momentum=0.9, weight_decay=args.test_wd)
40 | elif args.test_opt == "adam":
41 | opt = torch.optim.AdamW(student.parameters(), lr=args.test_lr, weight_decay=args.test_wd)
42 | else:
43 | raise NotImplementedError
44 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, args.test_epoch)
45 |
46 | student.train(); teacher.train()
47 | for _ in trange(args.test_epoch):
48 | for x, in dl_syn:
49 | with torch.no_grad():
50 | if args.method == "gaussian":
51 | x = torch.randn_like(x)
52 | else:
53 | x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)
54 | y = teacher(x)
55 | loss = distillation_loss(student(x), y)
56 | opt.zero_grad()
57 | loss.backward()
58 | opt.step()
59 | sch.step()
60 |
61 | student.train()
62 | with torch.no_grad():
63 | loss, acc, denominator = 0., 0., 0.
64 | for x_te, y_te in dl_te:
65 | x_te, y_te = x_te.to(device), y_te.to(device)
66 | x_te = aug_te(x_te)
67 | l_te = student(x_te)
68 | loss += F.cross_entropy(l_te, y_te, reduction="sum")
69 | acc += torch.eq(l_te.argmax(dim=-1), y_te).float().sum()
70 | denominator += x_te.shape[0]
71 | loss /= denominator; acc /= (denominator/100.)
72 |
73 | del student
74 |
75 | return loss.item(), acc.item()
76 |
--------------------------------------------------------------------------------
/assets/concept.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/db-Lee/selfsup_dd/4beb4b19c23872394d8f6418938b0e2bb3ce0466/assets/concept.png
--------------------------------------------------------------------------------
/data/aircraft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.folder import default_loader
5 | from torchvision.datasets.utils import download_url
6 | from torchvision.datasets.utils import extract_archive
7 |
8 |
9 | class Aircraft(VisionDataset):
10 | """`FGVC-Aircraft `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | class_type (string, optional): choose from ('variant', 'family', 'manufacturer').
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | download (bool, optional): If true, downloads the dataset from the internet and
22 | puts it in root directory. If dataset is already downloaded, it is not
23 | downloaded again.
24 | """
25 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
26 | class_types = ('variant', 'family', 'manufacturer')
27 | splits = ('train', 'val', 'trainval', 'test')
28 | img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
29 |
30 | def __init__(self, root, train=True, class_type='variant', transform=None,
31 | target_transform=None, download=False):
32 | super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)
33 | split = 'trainval' if train else 'test'
34 | if split not in self.splits:
35 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
36 | split, ', '.join(self.splits),
37 | ))
38 | if class_type not in self.class_types:
39 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
40 | class_type, ', '.join(self.class_types),
41 | ))
42 |
43 | self.class_type = class_type
44 | self.split = split
45 | self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
46 | 'images_%s_%s.txt' % (self.class_type, self.split))
47 |
48 | if download:
49 | self.download()
50 |
51 | (image_ids, targets, classes, class_to_idx) = self.find_classes()
52 | samples = self.make_dataset(image_ids, targets)
53 |
54 | self.loader = default_loader
55 |
56 | self.samples = samples
57 | self.classes = classes
58 | self.class_to_idx = class_to_idx
59 |
60 | def __getitem__(self, index):
61 | path, target = self.samples[index]
62 | sample = self.loader(path)
63 | if self.transform is not None:
64 | sample = self.transform(sample)
65 | if self.target_transform is not None:
66 | target = self.target_transform(target)
67 | return sample, target
68 |
69 | def __len__(self):
70 | return len(self.samples)
71 |
72 | def _check_exists(self):
73 | return os.path.exists(os.path.join(self.root, self.img_folder)) and \
74 | os.path.exists(self.classes_file)
75 |
76 | def download(self):
77 | if self._check_exists():
78 | return
79 |
80 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
81 | print('Downloading %s...' % self.url)
82 | tar_name = self.url.rpartition('/')[-1]
83 | download_url(self.url, root=self.root, filename=tar_name)
84 | tar_path = os.path.join(self.root, tar_name)
85 | print('Extracting %s...' % tar_path)
86 | extract_archive(tar_path)
87 | print('Done!')
88 |
89 | def find_classes(self):
90 | # read classes file, separating out image IDs and class names
91 | image_ids = []
92 | targets = []
93 | with open(self.classes_file, 'r') as f:
94 | for line in f:
95 | split_line = line.split(' ')
96 | image_ids.append(split_line[0])
97 | targets.append(' '.join(split_line[1:]))
98 |
99 | # index class names
100 | classes = np.unique(targets)
101 | class_to_idx = {classes[i]: i for i in range(len(classes))}
102 | targets = [class_to_idx[c] for c in targets]
103 |
104 | return image_ids, targets, classes, class_to_idx
105 |
106 | def make_dataset(self, image_ids, targets):
107 | assert (len(image_ids) == len(targets))
108 | images = []
109 | for i in range(len(image_ids)):
110 | item = (os.path.join(self.root, self.img_folder,
111 | '%s.jpg' % image_ids[i]), targets[i])
112 | images.append(item)
113 | return images
114 |
115 |
116 | if __name__ == '__main__':
117 | train_dataset = Aircraft('./aircraft', train=True, download=False)
118 | test_dataset = Aircraft('./aircraft', train=False, download=False)
119 |
--------------------------------------------------------------------------------
/data/augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | import kornia.augmentation as K
8 |
9 | NUM_CLASSES = {
10 | 'svhn': 10,
11 | 'cifar10': 10,
12 | 'cifar100': 100,
13 | 'aircraft': 100,
14 | 'cars': 196,
15 | 'cub2011': 200,
16 | 'dogs': 120,
17 | 'flowers': 102,
18 | 'tinyimagenet': 200,
19 | 'imagenet': 1000,
20 | 'imagenette': 10
21 | }
22 |
23 | MEAN = {
24 | 32:
25 | {'svhn': (0.4377, 0.4438, 0.4728),
26 | 'cifar10': (0.4914, 0.4822, 0.4465),
27 | 'cifar100': (0.5071, 0.4866, 0.4409),
28 | 'aircraft': (0.4804, 0.5116, 0.5349),
29 | 'cars': (0.4706, 0.4600, 0.4548),
30 | 'cub2011': (0.4857, 0.4995, 0.4324),
31 | 'dogs': (0.4765, 0.4516, 0.3911),
32 | 'flowers': (0.4344, 0.3030, 0.2955),
33 | 'tinyimagenet': (0.4802, 0.4481, 0.3975),
34 | 'imagenet': (0.4810, 0.4574, 0.4078)},
35 | 64:
36 | {'svhn': (0.4377, 0.4438, 0.4728),
37 | 'cifar10': (0.4914, 0.4821, 0.4465),
38 | 'cifar100': (0.5070, 0.4865, 0.4409),
39 | 'aircraft': (0.4797, 0.5108, 0.5341),
40 | 'cars': (0.4707, 0.4601, 0.4549),
41 | 'cub2011': (0.4856, 0.4995, 0.4324),
42 | 'dogs': (0.4765, 0.4517, 0.3911),
43 | 'flowers': (0.4344, 0.3830, 0.2955),
44 | 'tinyimagenet': (0.4802, 0.4481, 0.3975),
45 | 'imagenet': (0.4810, 0.4574, 0.4078)},
46 | 224:
47 | {'aircraft': (0.4797, 0.5109, 0.5342),
48 | 'cars': (0.4707, 0.4601, 0.4549),
49 | 'cub2011': (0.4856, 0.4994, 0.4324),
50 | 'dogs': (0.4765, 0.4517, 0.3912),
51 | 'flowers': (0.4344, 0.3830, 0.2955),
52 | 'imagenette': (0.4655, 0.4546, 0.4250)}
53 | }
54 | STD = {
55 | 32:
56 | {'svhn': (0.1980, 0.2010, 0.1970),
57 | 'cifar10': (0.2470, 0.2435, 0.2616),
58 | 'cifar100': (0.2673, 0.2564, 0.2762),
59 | 'aircraft': (0.2021, 0.1953, 0.2297),
60 | 'cars': (0.2746, 0.2740, 0.2831),
61 | 'cub2011': (0.2145, 0.2098, 0.2496),
62 | 'dogs': (0.2490, 0.2435, 0.2479),
63 | 'flowers': (0.2811, 0.2318, 0.2607),
64 | 'tinyimagenet': (0.2770, 0.2691, 0.2821),
65 | 'imagenet': (0.2633, 0.2560, 0.2708)},
66 | 64:
67 | {'svhn': (0.1981, 0.2011, 0.1971),
68 | 'cifar10': (0.2469, 0.2433, 0.2614),
69 | 'cifar100': (0.2671, 0.2562, 0.2760),
70 | 'aircraft': (0.2117, 0.2049, 0.2380),
71 | 'cars': (0.2836, 0.2826, 0.2914),
72 | 'cub2011': (0.2218, 0.2170, 0.2564),
73 | 'dogs': (0.2551, 0.2495, 0.2539),
74 | 'flowers': (0.2878, 0.2390, 0.2674),
75 | 'tinyimagenet': (0.2770, 0.2691, 0.2821),
76 | 'imagenet': (0.2633, 0.2560, 0.2708)},
77 | 224:
78 | {'aircraft': (0.2204, 0.2135, 0.2451),
79 | 'cars': (0.2927, 0.2917, 0.3001),
80 | 'cub2011': (0.2295, 0.2250, 0.2635),
81 | 'dogs': (0.2617, 0.2564, 0.2607),
82 | 'flowers': (0.2928, 0.2449, 0.2726),
83 | 'imagenette': (0.2804, 0.2754, 0.2965)}
84 | }
85 |
86 | def get_aug(data_name, size, aug=True):
87 | if aug:
88 | if data_name == "svhn":
89 | transform_tr = nn.Sequential(
90 | K.RandomCrop(size=(size,size), padding=4),
91 | K.Normalize(MEAN[size][data_name], STD[size][data_name])
92 | )
93 | else:
94 | transform_tr = nn.Sequential(
95 | K.RandomCrop(size=(size,size), padding=4),
96 | K.RandomHorizontalFlip(p=0.5),
97 | K.Normalize(MEAN[size][data_name], STD[size][data_name]),
98 | )
99 | else:
100 | transform_tr = K.Normalize(MEAN[size][data_name], STD[size][data_name])
101 |
102 | transform_te = K.Normalize(MEAN[size][data_name], STD[size][data_name])
103 |
104 | return transform_tr, transform_te
105 |
106 | """DC Augmentation"""
107 | class ParamDiffAug():
108 | def __init__(self):
109 | self.aug_mode = 'S' #'multiple or single'
110 | self.prob_flip = 0.5
111 | self.ratio_scale = 1.2
112 | self.ratio_rotate = 15.0
113 | self.ratio_crop_pad = 0.125
114 | self.ratio_cutout = 0.5 # the size would be 0.5x0.5
115 | self.brightness = 1.0
116 | self.saturation = 2.0
117 | self.contrast = 0.5
118 |
119 |
120 | def set_seed_DiffAug(param):
121 | if param.latestseed == -1:
122 | return
123 | else:
124 | torch.random.manual_seed(param.latestseed)
125 | param.latestseed += 1
126 |
127 |
128 | def DiffAugment(x, strategy='', seed = -1, param = None):
129 | if strategy == 'None' or strategy == 'none' or strategy == '':
130 | return x
131 |
132 | if seed == -1:
133 | param.Siamese = False
134 | else:
135 | param.Siamese = True
136 |
137 | param.latestseed = seed
138 |
139 | if strategy:
140 | if param.aug_mode == 'M': # original
141 | for p in strategy.split('_'):
142 | for f in AUGMENT_FNS[p]:
143 | x = f(x, param)
144 | elif param.aug_mode == 'S':
145 | pbties = strategy.split('_')
146 | set_seed_DiffAug(param)
147 | p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
148 | for f in AUGMENT_FNS[p]:
149 | x = f(x, param)
150 | else:
151 | exit('unknown augmentation mode: %s'%param.aug_mode)
152 | x = x.contiguous()
153 | return x
154 |
155 |
156 | # We implement the following differentiable augmentation strategies based on the code provided in https://github.com/mit-han-lab/data-efficient-gans.
157 | def rand_scale(x, param):
158 | # x>1, max scale
159 | # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
160 | ratio = param.ratio_scale
161 | set_seed_DiffAug(param)
162 | sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
163 | set_seed_DiffAug(param)
164 | sy = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
165 | theta = [[[sx[i], 0, 0],
166 | [0, sy[i], 0],] for i in range(x.shape[0])]
167 | theta = torch.tensor(theta, dtype=torch.float)
168 | if param.Siamese: # Siamese augmentation:
169 | theta[:] = theta[0].clone().detach()
170 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
171 | x = F.grid_sample(x, grid, align_corners=True)
172 | return x
173 |
174 |
175 | def rand_rotate(x, param): # [-180, 180], 90: anticlockwise 90 degree
176 | ratio = param.ratio_rotate
177 | set_seed_DiffAug(param)
178 | theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
179 | theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
180 | [torch.sin(theta[i]), torch.cos(theta[i]), 0],] for i in range(x.shape[0])]
181 | theta = torch.tensor(theta, dtype=torch.float)
182 | if param.Siamese: # Siamese augmentation:
183 | theta[:] = theta[0].clone().detach()
184 | grid = F.affine_grid(theta, x.shape, align_corners=True).to(x.device)
185 | x = F.grid_sample(x, grid, align_corners=True)
186 | return x
187 |
188 |
189 | def rand_flip(x, param):
190 | prob = param.prob_flip
191 | set_seed_DiffAug(param)
192 | randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
193 | if param.Siamese: # Siamese augmentation:
194 | randf[:] = randf[0].clone().detach()
195 | return torch.where(randf < prob, x.flip(3), x)
196 |
197 |
198 | def rand_brightness(x, param):
199 | ratio = param.brightness
200 | set_seed_DiffAug(param)
201 | randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
202 | if param.Siamese: # Siamese augmentation:
203 | randb[:] = randb[0].clone().detach()
204 | x = x + (randb - 0.5)*ratio
205 | return x
206 |
207 |
208 | def rand_saturation(x, param):
209 | ratio = param.saturation
210 | x_mean = x.mean(dim=1, keepdim=True)
211 | set_seed_DiffAug(param)
212 | rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
213 | if param.Siamese: # Siamese augmentation:
214 | rands[:] = rands[0].clone().detach()
215 | x = (x - x_mean) * (rands * ratio) + x_mean
216 | return x
217 |
218 |
219 | def rand_contrast(x, param):
220 | ratio = param.contrast
221 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
222 | set_seed_DiffAug(param)
223 | randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
224 | if param.Siamese: # Siamese augmentation:
225 | randc[:] = randc[0].clone().detach()
226 | x = (x - x_mean) * (randc + ratio) + x_mean
227 | return x
228 |
229 |
230 | def rand_crop(x, param):
231 | # The image is padded on its surrounding and then cropped.
232 | ratio = param.ratio_crop_pad
233 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
234 | set_seed_DiffAug(param)
235 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
236 | set_seed_DiffAug(param)
237 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
238 | if param.Siamese: # Siamese augmentation:
239 | translation_x[:] = translation_x[0].clone().detach()
240 | translation_y[:] = translation_y[0].clone().detach()
241 | grid_batch, grid_x, grid_y = torch.meshgrid(
242 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
243 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
244 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
245 | indexing="ij"
246 | )
247 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
248 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
249 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
250 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
251 | return x
252 |
253 |
254 | def rand_cutout(x, param):
255 | ratio = param.ratio_cutout
256 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
257 | set_seed_DiffAug(param)
258 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
259 | set_seed_DiffAug(param)
260 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
261 | if param.Siamese: # Siamese augmentation:
262 | offset_x[:] = offset_x[0].clone().detach()
263 | offset_y[:] = offset_y[0].clone().detach()
264 | grid_batch, grid_x, grid_y = torch.meshgrid(
265 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
266 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
267 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
268 | indexing="ij"
269 | )
270 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
271 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
272 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
273 | mask[grid_batch, grid_x, grid_y] = 0
274 | x = x * mask.unsqueeze(1)
275 | return x
276 |
277 |
278 | AUGMENT_FNS = {
279 | 'color': [rand_brightness, rand_saturation, rand_contrast],
280 | 'crop': [rand_crop],
281 | 'cutout': [rand_cutout],
282 | 'flip': [rand_flip],
283 | 'scale': [rand_scale],
284 | 'rotate': [rand_rotate],
285 | }
--------------------------------------------------------------------------------
/data/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io as sio
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.folder import default_loader
5 | from torchvision.datasets.utils import download_url
6 | from torchvision.datasets.utils import extract_archive
7 |
8 |
9 | class Cars(VisionDataset):
10 | """`Stanford Cars `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | file_list = {
25 | 'imgs': ('http://imagenet.stanford.edu/internal/car196/car_ims.tgz', 'car_ims.tgz'),
26 | 'annos': ('http://imagenet.stanford.edu/internal/car196/cars_annos.mat', 'cars_annos.mat')
27 | }
28 |
29 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
30 | super(Cars, self).__init__(root, transform=transform, target_transform=target_transform)
31 |
32 | self.loader = default_loader
33 | self.train = train
34 |
35 | if self._check_exists():
36 | print('Files already downloaded and verified.')
37 | elif download:
38 | self._download()
39 | else:
40 | raise RuntimeError(
41 | 'Dataset not found. You can use download=True to download it.')
42 |
43 | loaded_mat = sio.loadmat(os.path.join(self.root, self.file_list['annos'][1]))
44 | loaded_mat = loaded_mat['annotations'][0]
45 | self.samples = []
46 | for item in loaded_mat:
47 | if self.train != bool(item[-1][0]):
48 | path = str(item[0][0])
49 | label = int(item[-2][0]) - 1
50 | self.samples.append((path, label))
51 |
52 | def __getitem__(self, index):
53 | path, target = self.samples[index]
54 | path = os.path.join(self.root, path)
55 |
56 | image = self.loader(path)
57 | if self.transform is not None:
58 | image = self.transform(image)
59 | if self.target_transform is not None:
60 | target = self.target_transform(target)
61 | return image, target
62 |
63 | def __len__(self):
64 | return len(self.samples)
65 |
66 | def _check_exists(self):
67 | return (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
68 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1])))
69 |
70 | def _download(self):
71 | print('Downloading...')
72 | for url, filename in self.file_list.values():
73 | download_url(url, root=self.root, filename=filename)
74 | print('Extracting...')
75 | archive = os.path.join(self.root, self.file_list['imgs'][1])
76 | extract_archive(archive)
77 |
78 |
79 | if __name__ == '__main__':
80 | train_dataset = Cars('./cars', train=True, download=False)
81 | test_dataset = Cars('./cars', train=False, download=False)
--------------------------------------------------------------------------------
/data/cub2011.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | from torchvision.datasets import VisionDataset
5 | from torchvision.datasets.folder import default_loader
6 | from torchvision.datasets.utils import download_file_from_google_drive
7 |
8 |
9 | class Cub2011(VisionDataset):
10 | """`CUB-200-2011 `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | base_folder = 'CUB_200_2011/images'
25 | # url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
26 | file_id = '1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
27 | filename = 'CUB_200_2011.tgz'
28 | tgz_md5 = '97eceeb196236b17998738112f37df78'
29 |
30 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
31 | super(Cub2011, self).__init__(root, transform=transform, target_transform=target_transform)
32 |
33 | self.loader = default_loader
34 | self.train = train
35 | if download:
36 | self._download()
37 |
38 | if not self._check_integrity():
39 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
40 |
41 | def _load_metadata(self):
42 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
43 | names=['img_id', 'filepath'])
44 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
45 | sep=' ', names=['img_id', 'target'])
46 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
47 | sep=' ', names=['img_id', 'is_training_img'])
48 |
49 | data = images.merge(image_class_labels, on='img_id')
50 | self.data = data.merge(train_test_split, on='img_id')
51 |
52 | class_names = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'classes.txt'),
53 | sep=' ', names=['class_name'], usecols=[1])
54 | self.class_names = class_names['class_name'].to_list()
55 | if self.train:
56 | self.data = self.data[self.data.is_training_img == 1]
57 | else:
58 | self.data = self.data[self.data.is_training_img == 0]
59 |
60 | def _check_integrity(self):
61 | try:
62 | self._load_metadata()
63 | except Exception:
64 | return False
65 |
66 | for index, row in self.data.iterrows():
67 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
68 | if not os.path.isfile(filepath):
69 | print(filepath)
70 | return False
71 | return True
72 |
73 | def _download(self):
74 | import tarfile
75 |
76 | if self._check_integrity():
77 | print('Files already downloaded and verified')
78 | return
79 |
80 | download_file_from_google_drive(self.file_id, self.root, self.filename, self.tgz_md5)
81 |
82 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
83 | tar.extractall(path=self.root)
84 |
85 | def __len__(self):
86 | return len(self.data)
87 |
88 | def __getitem__(self, idx):
89 | sample = self.data.iloc[idx]
90 | path = os.path.join(self.root, self.base_folder, sample.filepath)
91 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
92 | img = self.loader(path)
93 |
94 | if self.transform is not None:
95 | img = self.transform(img)
96 | if self.target_transform is not None:
97 | target = self.target_transform(target)
98 | return img, target
99 |
100 |
101 | if __name__ == '__main__':
102 | train_dataset = Cub2011('./cub2011', train=True, download=False)
103 | test_dataset = Cub2011('./cub2011', train=False, download=False)
104 |
--------------------------------------------------------------------------------
/data/dogs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io
3 | from os.path import join
4 | from torchvision.datasets import VisionDataset
5 | from torchvision.datasets.folder import default_loader
6 | from torchvision.datasets.utils import download_url, list_dir
7 |
8 |
9 | class Dogs(VisionDataset):
10 | """`Stanford Dogs `_ Dataset.
11 |
12 | Args:
13 | root (string): Root directory of the dataset.
14 | train (bool, optional): If True, creates dataset from training set, otherwise
15 | creates from test set.
16 | transform (callable, optional): A function/transform that takes in an PIL image
17 | and returns a transformed version. E.g, ``transforms.RandomCrop``
18 | target_transform (callable, optional): A function/transform that takes in the
19 | target and transforms it.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | """
24 | download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
25 |
26 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
27 | super(Dogs, self).__init__(root, transform=transform, target_transform=target_transform)
28 |
29 | self.loader = default_loader
30 | self.train = train
31 |
32 | if download:
33 | self.download()
34 |
35 | split = self.load_split()
36 |
37 | self.images_folder = join(self.root, 'Images')
38 | self.annotations_folder = join(self.root, 'Annotation')
39 | self._breeds = list_dir(self.images_folder)
40 |
41 | self._breed_images = [(annotation + '.jpg', idx) for annotation, idx in split]
42 |
43 | self._flat_breed_images = self._breed_images
44 |
45 | def __len__(self):
46 | return len(self._flat_breed_images)
47 |
48 | def __getitem__(self, index):
49 | image_name, target = self._flat_breed_images[index]
50 | image_path = join(self.images_folder, image_name)
51 | image = self.loader(image_path)
52 |
53 | if self.transform is not None:
54 | image = self.transform(image)
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 | return image, target
58 |
59 | def download(self):
60 | import tarfile
61 |
62 | if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
63 | if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
64 | print('Files already downloaded and verified')
65 | return
66 |
67 | for filename in ['images', 'annotation', 'lists']:
68 | tar_filename = filename + '.tar'
69 | url = self.download_url_prefix + '/' + tar_filename
70 | download_url(url, self.root, tar_filename, None)
71 | print('Extracting downloaded file: ' + join(self.root, tar_filename))
72 | with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
73 | tar_file.extractall(self.root)
74 | os.remove(join(self.root, tar_filename))
75 |
76 | def load_split(self):
77 | if self.train:
78 | split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
79 | labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
80 | else:
81 | split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
82 | labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
83 |
84 | split = [item[0][0] for item in split]
85 | labels = [item[0] - 1 for item in labels]
86 | return list(zip(split, labels))
87 |
88 | def stats(self):
89 | counts = {}
90 | for index in range(len(self._flat_breed_images)):
91 | image_name, target_class = self._flat_breed_images[index]
92 | if target_class not in counts.keys():
93 | counts[target_class] = 1
94 | else:
95 | counts[target_class] += 1
96 |
97 | print("%d samples spanning %d classes (avg %f per class)" % (len(self._flat_breed_images), len(counts.keys()),
98 | float(len(self._flat_breed_images)) / float(
99 | len(counts.keys()))))
100 |
101 | return counts
102 |
103 |
104 | if __name__ == '__main__':
105 | train_dataset = Dogs('./dogs', train=True, download=False)
106 | test_dataset = Dogs('./dogs', train=False, download=False)
107 |
--------------------------------------------------------------------------------
/data/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import TensorDataset, DataLoader
3 |
4 | from data.augmentation import get_aug
5 |
6 | def get_loader(root_dir, data_name, batch_size, size=32, aug=False):
7 | X_tr, Y_tr = torch.load(f"{root_dir}/{data_name}/X_tr_{size}.pth"), torch.load(f"{root_dir}/{data_name}/Y_tr_{size}.pth")
8 | dataset_tr = TensorDataset(X_tr, Y_tr)
9 | dataloader_tr = DataLoader(dataset_tr, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True)
10 |
11 | if data_name == "imagenet":
12 | dataloader_te = dataloader_tr
13 | else:
14 | X_te, Y_te = torch.load(f"{root_dir}/{data_name}/X_te_{size}.pth"), torch.load(f"{root_dir}/{data_name}/Y_te_{size}.pth")
15 | dataset_te = TensorDataset(X_te, Y_te)
16 | dataloader_te = DataLoader(dataset_te, batch_size=batch_size, num_workers=0, shuffle=False, pin_memory=True)
17 |
18 |
19 | transform_tr, transform_te = get_aug(data_name, size, aug)
20 |
21 | return dataloader_tr, dataloader_te, transform_tr, transform_te
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dd
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotli-python=1.0.9=py310h6a678d5_7
11 | - bzip2=1.0.8=h5eee18b_5
12 | - ca-certificates=2024.3.11=h06a4308_0
13 | - certifi=2024.2.2=py310h06a4308_0
14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
15 | - cuda-cudart=11.8.89=0
16 | - cuda-cupti=11.8.87=0
17 | - cuda-libraries=11.8.0=0
18 | - cuda-nvrtc=11.8.89=0
19 | - cuda-nvtx=11.8.86=0
20 | - cuda-runtime=11.8.0=0
21 | - ffmpeg=4.3=hf484d3e_0
22 | - filelock=3.13.1=py310h06a4308_0
23 | - freetype=2.12.1=h4a9f257_0
24 | - gmp=6.2.1=h295c915_3
25 | - gmpy2=2.1.2=py310heeb90bb_0
26 | - gnutls=3.6.15=he1e5248_0
27 | - idna=3.4=py310h06a4308_0
28 | - intel-openmp=2023.1.0=hdb19cb5_46306
29 | - jinja2=3.1.3=py310h06a4308_0
30 | - jpeg=9e=h5eee18b_1
31 | - lame=3.100=h7b6447c_0
32 | - lcms2=2.12=h3be6417_0
33 | - ld_impl_linux-64=2.38=h1181459_1
34 | - lerc=3.0=h295c915_0
35 | - libcublas=11.11.3.6=0
36 | - libcufft=10.9.0.58=0
37 | - libcufile=1.9.1.3=0
38 | - libcurand=10.3.5.147=0
39 | - libcusolver=11.4.1.48=0
40 | - libcusparse=11.7.5.86=0
41 | - libdeflate=1.17=h5eee18b_1
42 | - libffi=3.4.4=h6a678d5_0
43 | - libgcc-ng=11.2.0=h1234567_1
44 | - libgomp=11.2.0=h1234567_1
45 | - libiconv=1.16=h7f8727e_2
46 | - libidn2=2.3.4=h5eee18b_0
47 | - libjpeg-turbo=2.0.0=h9bf148f_0
48 | - libnpp=11.8.0.86=0
49 | - libnvjpeg=11.9.0.86=0
50 | - libpng=1.6.39=h5eee18b_0
51 | - libstdcxx-ng=11.2.0=h1234567_1
52 | - libtasn1=4.19.0=h5eee18b_0
53 | - libtiff=4.5.1=h6a678d5_0
54 | - libunistring=0.9.10=h27cfd23_0
55 | - libuuid=1.41.5=h5eee18b_0
56 | - libwebp-base=1.3.2=h5eee18b_0
57 | - llvm-openmp=14.0.6=h9e868ea_0
58 | - lz4-c=1.9.4=h6a678d5_0
59 | - markupsafe=2.1.3=py310h5eee18b_0
60 | - mkl=2023.1.0=h213fc3f_46344
61 | - mkl-service=2.4.0=py310h5eee18b_1
62 | - mkl_fft=1.3.8=py310h5eee18b_0
63 | - mkl_random=1.2.4=py310hdb19cb5_0
64 | - mpc=1.1.0=h10f8cd9_1
65 | - mpfr=4.0.2=hb69a4c5_1
66 | - mpmath=1.3.0=py310h06a4308_0
67 | - ncurses=6.4=h6a678d5_0
68 | - nettle=3.7.3=hbbd107a_1
69 | - networkx=3.1=py310h06a4308_0
70 | - numpy=1.26.4=py310h5f9d8c6_0
71 | - numpy-base=1.26.4=py310hb5e798b_0
72 | - openh264=2.1.1=h4ff587b_0
73 | - openjpeg=2.4.0=h3ad879b_0
74 | - openssl=3.0.13=h7f8727e_0
75 | - pillow=10.2.0=py310h5eee18b_0
76 | - pip=23.3.1=py310h06a4308_0
77 | - pysocks=1.7.1=py310h06a4308_0
78 | - python=3.10.14=h955ad1f_0
79 | - pytorch=2.1.2=py3.10_cuda11.8_cudnn8.7.0_0
80 | - pytorch-cuda=11.8=h7e8668a_5
81 | - pytorch-mutex=1.0=cuda
82 | - pyyaml=6.0.1=py310h5eee18b_0
83 | - readline=8.2=h5eee18b_0
84 | - requests=2.31.0=py310h06a4308_1
85 | - setuptools=68.2.2=py310h06a4308_0
86 | - sqlite=3.41.2=h5eee18b_0
87 | - sympy=1.12=py310h06a4308_0
88 | - tbb=2021.8.0=hdb19cb5_0
89 | - tk=8.6.12=h1ccaba5_0
90 | - torchaudio=2.1.2=py310_cu118
91 | - torchtriton=2.1.0=py310
92 | - torchvision=0.16.2=py310_cu118
93 | - typing_extensions=4.9.0=py310h06a4308_1
94 | - tzdata=2024a=h04d1e81_0
95 | - urllib3=2.1.0=py310h06a4308_1
96 | - wheel=0.41.2=py310h06a4308_0
97 | - xz=5.4.6=h5eee18b_0
98 | - yaml=0.2.5=h7b6447c_0
99 | - zlib=1.2.13=h5eee18b_0
100 | - zstd=1.5.5=hc292b87_0
101 | - pip:
102 | - fsspec==2024.3.1
103 | - huggingface-hub==0.22.2
104 | - kornia==0.7.1
105 | - packaging==24.0
106 | - regex==2023.12.25
107 | - safetensors==0.4.2
108 | - tokenizers==0.15.2
109 | - tqdm==4.66.2
110 | - transformers==4.36.2
111 | prefix: /c2/seanie/anaconda3/envs/dd
112 |
--------------------------------------------------------------------------------
/model_pool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from tqdm import trange
5 |
6 | from models.wrapper import get_model
7 |
8 |
9 | class ModelPool:
10 | def __init__(self, args, device):
11 | self.device = device
12 | self.online_iteration = args.online_iteration
13 | self.num_models = args.num_models
14 |
15 | # model func
16 | self.model_func = lambda _: get_model(args.train_model, args.img_shape, args.num_pretrain_classes).to(device)
17 |
18 | # opt func
19 | if args.online_opt == "sgd":
20 | self.opt_func = lambda param: torch.optim.SGD(param, lr=args.online_lr, momentum=0.9, weight_decay=args.online_wd)
21 | elif args.online_opt == "adam":
22 | self.opt_func = lambda param: torch.optim.AdamW(param, lr=args.online_lr, weight_decay=args.online_wd)
23 | else:
24 | raise NotImplementedError
25 |
26 | self.iterations = [ 0 ] *self.num_models
27 | self.models = [ self.model_func(None) for _ in range(self.num_models) ]
28 | self.opts = [ self.opt_func(self.models[i].parameters()) for i in range(self.num_models) ]
29 |
30 | def init(self, x_syn, y_syn):
31 | for idx in range(self.num_models):
32 | online_iteration = np.random.randint(1, self.online_iteration)
33 | self.iterations[idx] = online_iteration
34 | model = self.models[idx]
35 | opt = self.opts[idx]
36 | model.train()
37 | print(f"{idx}-th model init")
38 | for _ in trange(online_iteration):
39 | opt.zero_grad()
40 | loss = F.mse_loss(model(x_syn), y_syn)
41 | loss.backward()
42 | opt.step()
43 |
44 | def update(self, idx, x_syn, y_syn):
45 | # reset
46 | if self.iterations[idx] >= self.online_iteration:
47 | self.models[idx] = self.model_func(None)
48 | self.opts[idx] = self.opt_func(self.models[idx].parameters())
49 | model = self.models[idx]
50 | opt = self.opts[idx]
51 |
52 | # train the model for 1 step
53 | else:
54 | self.iterations[idx] = self.iterations[idx] + 1
55 | model = self.models[idx]
56 | opt = self.opts[idx]
57 |
58 | model.train()
59 | opt.zero_grad()
60 | loss = F.mse_loss(model(x_syn), y_syn)
61 | loss.backward()
62 | opt.step()
63 |
--------------------------------------------------------------------------------
/models/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class AlexNet(nn.Module):
6 | def __init__(self, img_size, num_classes):
7 | super(AlexNet, self).__init__()
8 | self.features = nn.Sequential(
9 | nn.Conv2d(3, 128, kernel_size=5, stride=1, padding=2),
10 | nn.BatchNorm2d(128),
11 | nn.ReLU(inplace=True),
12 | nn.MaxPool2d(kernel_size=2, stride=2),
13 | nn.Conv2d(128, 192, kernel_size=5, padding=2),
14 | nn.BatchNorm2d(192),
15 | nn.ReLU(inplace=True),
16 | nn.MaxPool2d(kernel_size=2, stride=2),
17 | nn.Conv2d(192, 256, kernel_size=3, padding=1),
18 | nn.BatchNorm2d(256),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(256, 192, kernel_size=3, padding=1),
21 | nn.BatchNorm2d(192),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(192, 192, kernel_size=3, padding=1),
24 | nn.BatchNorm2d(192),
25 | nn.ReLU(inplace=True),
26 | nn.MaxPool2d(kernel_size=2, stride=2),
27 | )
28 | img_size = img_size // 8
29 | self.fc = nn.Linear(192 * img_size * img_size, num_classes)
30 |
31 | def forward(self, x):
32 | x = self.features(x)
33 | x = x.view(x.size(0), -1)
34 | x = self.fc(x)
35 | return x
36 |
37 | def embed(self, x):
38 | x = self.features(x)
39 | x = x.view(x.size(0), -1)
40 | return x
--------------------------------------------------------------------------------
/models/convnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class NoneBlock(nn.Module):
6 | def __init__(self, num_channels, affine):
7 | super().__init__()
8 |
9 | def forward(self, x):
10 | return x
11 |
12 | class ConvBlock(nn.Module):
13 | def __init__(self, in_channels, out_channels, norm_layer):
14 | super().__init__()
15 | self.block = nn.Sequential(
16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
17 | norm_layer(out_channels, affine=True),
18 | nn.ReLU(inplace=False),
19 | nn.AvgPool2d(kernel_size=2, stride=2)
20 | )
21 |
22 | def forward(self, x):
23 | x = self.block(x)
24 | return x
25 |
26 | class ConvNet(nn.Module):
27 | def __init__(
28 | self,
29 | img_shape=(3,32,32),
30 | num_classes=10,
31 | num_channels=[128, 256, 512],
32 | norm="bn"
33 | ):
34 | super(ConvNet, self).__init__()
35 |
36 |
37 | HW = img_shape[1]
38 |
39 | if norm.lower() == "bn":
40 | norm_layer = nn.BatchNorm2d
41 | elif norm.lower() == "in":
42 | norm_layer = nn.InstanceNorm2d
43 | elif norm.lower() == "none":
44 | norm_layer = NoneBlock
45 | else:
46 | raise NotImplementedError
47 |
48 | layers = []
49 | for i in range(len(num_channels)):
50 | if i == 0:
51 | layers.append(ConvBlock(img_shape[0], num_channels[0], norm_layer))
52 | else:
53 | layers.append(ConvBlock(num_channels[i-1], num_channels[i], norm_layer))
54 | HW = HW // 2
55 | self.layers = nn.ModuleList(layers)
56 | self.num_features = HW*HW*num_channels[-1]
57 | self.fc = nn.Linear(self.num_features, num_classes)
58 |
59 | def forward(self, x):
60 | for layer in self.layers:
61 | x = layer(x)
62 | x = x.reshape(x.size(0), -1)
63 | x = self.fc(x)
64 | return x
65 |
66 | def embed(self, x):
67 | for layer in self.layers:
68 | x = layer(x)
69 | x = x.reshape(x.size(0), -1)
70 | return x
71 | """
72 |
73 | class ConvBlock(nn.Module):
74 | def __init__(self, in_channels, out_channels):
75 | super(ConvBlock, self).__init__()
76 | self.block = nn.Sequential(
77 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
78 | nn.InstanceNorm2d(out_channels, affine=True),
79 | #nn.BatchNorm2d(out_channels, affine=True),
80 | nn.ReLU(inplace=True),
81 | nn.AvgPool2d(kernel_size=2, stride=2)
82 | )
83 |
84 | def forward(self, x):
85 | x = self.block(x)
86 | return x
87 |
88 | class ConvNet(nn.Module):
89 | def __init__(
90 | self,
91 | img_shape=(3,32,32),
92 | num_classes=10,
93 | num_layers=3,
94 | num_channels=128
95 | ):
96 | super(ConvNet, self).__init__()
97 |
98 | layers = []
99 | HW = img_shape[1]
100 | for i in range(num_layers):
101 | if i == 0:
102 | layers.append(ConvBlock(img_shape[0], num_channels))
103 | else:
104 | layers.append(ConvBlock(num_channels, num_channels))
105 | HW = HW // 2
106 | self.layers = nn.ModuleList(layers)
107 | self.num_features = HW*HW*num_channels
108 | self.fc = nn.Linear(self.num_features, num_classes)
109 |
110 | def forward(self, x):
111 | for layer in self.layers:
112 | x = layer(x)
113 | x = x.reshape(x.size(0), -1)
114 | x = self.fc(x)
115 | return x
116 |
117 | def embed(self, x):
118 | for layer in self.layers:
119 | x = layer(x)
120 | x = x.reshape(x.size(0), -1)
121 | return x
122 | """
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | '''MobileNetV2 in PyTorch.
2 | See the paper "Inverted Residuals and Linear Bottlenecks:
3 | Mobile Networks for Classification, Detection and Segmentation" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class Block(nn.Module):
10 | '''expand + depthwise + pointwise'''
11 | def __init__(self, in_planes, out_planes, expansion, stride):
12 | super(Block, self).__init__()
13 | self.stride = stride
14 |
15 | planes = expansion * in_planes
16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
21 | self.bn3 = nn.BatchNorm2d(out_planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride == 1 and in_planes != out_planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
27 | nn.BatchNorm2d(out_planes),
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = F.relu(self.bn2(self.conv2(out)))
33 | out = self.bn3(self.conv3(out))
34 | out = out + self.shortcut(x) if self.stride==1 else out
35 | return out
36 |
37 |
38 | class MobileNet(nn.Module):
39 | # (expansion, out_planes, num_blocks, stride)
40 | cfg = [(1, 16, 1, 1),
41 | (2, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
42 | (2, 32, 3, 2),
43 | (2, 64, 4, 2),
44 | (2, 96, 3, 1),
45 | (2, 160, 3, 2),
46 | (2, 320, 1, 1)]
47 |
48 | def __init__(self, num_classes=10):
49 | super(MobileNet, self).__init__()
50 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
51 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
52 | self.bn1 = nn.BatchNorm2d(32)
53 | self.layers = self._make_layers(in_planes=32)
54 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
55 | self.bn2 = nn.BatchNorm2d(1280)
56 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
57 | self.num_features = 1280
58 | self.fc = nn.Linear(1280, num_classes)
59 |
60 | def _make_layers(self, in_planes):
61 | layers = []
62 | for expansion, out_planes, num_blocks, stride in self.cfg:
63 | strides = [stride] + [1]*(num_blocks-1)
64 | for stride in strides:
65 | layers.append(Block(in_planes, out_planes, expansion, stride))
66 | in_planes = out_planes
67 | return nn.Sequential(*layers)
68 |
69 | def forward(self, x):
70 | x = F.relu(self.bn1(self.conv1(x)))
71 | x = self.layers(x)
72 | x = F.relu(self.bn2(self.conv2(x)))
73 | x = self.avgpool(x)
74 | x = x.reshape(x.size(0), -1)
75 | x = self.fc(x)
76 | return x
77 |
78 | def embed(self, x):
79 | x = F.relu(self.bn1(self.conv1(x)))
80 | x = self.layers(x)
81 | x = F.relu(self.bn2(self.conv2(x)))
82 | x = self.avgpool(x)
83 | x = x.reshape(x.size(0), -1)
84 | return x
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Callable, List, Optional, Type, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(
11 | in_planes,
12 | out_planes,
13 | kernel_size=3,
14 | stride=stride,
15 | padding=dilation,
16 | groups=groups,
17 | bias=False,
18 | dilation=dilation,
19 | )
20 |
21 |
22 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
23 | """1x1 convolution"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion: int = 1
29 |
30 | def __init__(
31 | self,
32 | inplanes: int,
33 | planes: int,
34 | stride: int = 1,
35 | downsample: Optional[nn.Module] = None,
36 | groups: int = 1,
37 | base_width: int = 64,
38 | dilation: int = 1,
39 | norm_layer: Optional[Callable[..., nn.Module]] = None,
40 | ) -> None:
41 | super().__init__()
42 | if norm_layer is None:
43 | norm_layer = nn.BatchNorm2d
44 | if groups != 1 or base_width != 64:
45 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
46 | if dilation > 1:
47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49 | self.conv1 = conv3x3(inplanes, planes, stride)
50 | self.bn1 = norm_layer(planes)
51 | self.relu = nn.ReLU(inplace=True)
52 | self.conv2 = conv3x3(planes, planes)
53 | self.bn2 = norm_layer(planes)
54 | self.downsample = downsample
55 | self.stride = stride
56 |
57 | def forward(self, x: Tensor) -> Tensor:
58 | identity = x
59 |
60 | out = self.conv1(x)
61 | out = self.bn1(out)
62 | out = self.relu(out)
63 |
64 | out = self.conv2(out)
65 | out = self.bn2(out)
66 |
67 | if self.downsample is not None:
68 | identity = self.downsample(x)
69 |
70 | out += identity
71 | out = self.relu(out)
72 |
73 | return out
74 |
75 |
76 | class Bottleneck(nn.Module):
77 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
80 | # This variant is also known as ResNet V1.5 and improves accuracy according to
81 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82 |
83 | expansion: int = 4
84 |
85 | def __init__(
86 | self,
87 | inplanes: int,
88 | planes: int,
89 | stride: int = 1,
90 | downsample: Optional[nn.Module] = None,
91 | groups: int = 1,
92 | base_width: int = 64,
93 | dilation: int = 1,
94 | norm_layer: Optional[Callable[..., nn.Module]] = None,
95 | ) -> None:
96 | super().__init__()
97 | if norm_layer is None:
98 | norm_layer = nn.BatchNorm2d
99 | width = int(planes * (base_width / 64.0)) * groups
100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
101 | self.conv1 = conv1x1(inplanes, width)
102 | self.bn1 = norm_layer(width)
103 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
104 | self.bn2 = norm_layer(width)
105 | self.conv3 = conv1x1(width, planes * self.expansion)
106 | self.bn3 = norm_layer(planes * self.expansion)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.downsample = downsample
109 | self.stride = stride
110 |
111 | def forward(self, x: Tensor) -> Tensor:
112 | identity = x
113 |
114 | out = self.conv1(x)
115 | out = self.bn1(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv2(out)
119 | out = self.bn2(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv3(out)
123 | out = self.bn3(out)
124 |
125 | if self.downsample is not None:
126 | identity = self.downsample(x)
127 |
128 | out += identity
129 | out = self.relu(out)
130 |
131 | return out
132 |
133 |
134 | class ResNet(nn.Module):
135 | def __init__(
136 | self,
137 | block: Type[Union[BasicBlock, Bottleneck]],
138 | layers: List[int],
139 | num_classes: int = 1000,
140 | zero_init_residual: bool = False,
141 | groups: int = 1,
142 | width_per_group: int = 64,
143 | replace_stride_with_dilation: Optional[List[bool]] = None,
144 | norm_layer: Optional[Callable[..., nn.Module]] = None,
145 | ) -> None:
146 | super().__init__()
147 | if norm_layer is None:
148 | norm_layer = nn.BatchNorm2d
149 | self._norm_layer = norm_layer
150 |
151 | self.inplanes = 64
152 | self.dilation = 1
153 | if replace_stride_with_dilation is None:
154 | # each element in the tuple indicates if we should replace
155 | # the 2x2 stride with a dilated convolution instead
156 | replace_stride_with_dilation = [False, False, False]
157 | if len(replace_stride_with_dilation) != 3:
158 | raise ValueError(
159 | "replace_stride_with_dilation should be None "
160 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
161 | )
162 | self.groups = groups
163 | self.base_width = width_per_group
164 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
165 | self.bn1 = norm_layer(self.inplanes)
166 | self.relu = nn.ReLU(inplace=True)
167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
168 | self.layer1 = self._make_layer(block, 64, layers[0])
169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
173 | self.fc = nn.Linear(512 * block.expansion, num_classes)
174 |
175 | for m in self.modules():
176 | if isinstance(m, nn.Conv2d):
177 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
178 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
179 | nn.init.constant_(m.weight, 1)
180 | nn.init.constant_(m.bias, 0)
181 |
182 | # Zero-initialize the last BN in each residual branch,
183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
185 | self.num_features = 512
186 | if zero_init_residual:
187 | for m in self.modules():
188 | if isinstance(m, Bottleneck) and m.bn3.weight is not None:
189 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
190 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
191 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
192 |
193 | def _make_layer(
194 | self,
195 | block: Type[Union[BasicBlock, Bottleneck]],
196 | planes: int,
197 | blocks: int,
198 | stride: int = 1,
199 | dilate: bool = False,
200 | ) -> nn.Sequential:
201 | norm_layer = self._norm_layer
202 | downsample = None
203 | previous_dilation = self.dilation
204 | if dilate:
205 | self.dilation *= stride
206 | stride = 1
207 | if stride != 1 or self.inplanes != planes * block.expansion:
208 | downsample = nn.Sequential(
209 | conv1x1(self.inplanes, planes * block.expansion, stride),
210 | norm_layer(planes * block.expansion),
211 | )
212 |
213 | layers = []
214 | layers.append(
215 | block(
216 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
217 | )
218 | )
219 | self.inplanes = planes * block.expansion
220 | for _ in range(1, blocks):
221 | layers.append(
222 | block(
223 | self.inplanes,
224 | planes,
225 | groups=self.groups,
226 | base_width=self.base_width,
227 | dilation=self.dilation,
228 | norm_layer=norm_layer,
229 | )
230 | )
231 |
232 | return nn.Sequential(*layers)
233 |
234 | def _forward_impl(self, x: Tensor) -> Tensor:
235 | # See note [TorchScript super()]
236 | x = self.conv1(x)
237 | x = self.bn1(x)
238 | x = self.relu(x)
239 | x = self.maxpool(x)
240 |
241 | x = self.layer1(x)
242 | x = self.layer2(x)
243 | x = self.layer3(x)
244 | x = self.layer4(x)
245 |
246 | x = self.avgpool(x)
247 | x = torch.flatten(x, 1)
248 | x = self.fc(x)
249 |
250 | return x
251 |
252 | def forward(self, x: Tensor) -> Tensor:
253 | return self._forward_impl(x)
254 |
255 | def embed(self, x: Tensor) -> Tensor:
256 | x = self.conv1(x)
257 | x = self.bn1(x)
258 | x = self.relu(x)
259 | x = self.maxpool(x)
260 |
261 | x = self.layer1(x)
262 | x = self.layer2(x)
263 | x = self.layer3(x)
264 | x = self.layer4(x)
265 |
266 | x = self.avgpool(x)
267 | x = torch.flatten(x, 1)
268 | return x
269 |
270 | def ResNet10(num_classes):
271 | model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=num_classes, zero_init_residual=True)
272 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
273 | model.maxpool = nn.Identity()
274 | return model
275 |
276 | def ResNet18(num_classes):
277 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, zero_init_residual=True)
278 | model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
279 | model.maxpool = nn.Identity()
280 | return model
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | ''' VGG '''
6 | cfg_vgg = {
7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
11 | }
12 | class VGG(nn.Module):
13 | def __init__(self, vgg_name, img_size, num_classes, norm='batchnorm'):
14 | super(VGG, self).__init__()
15 | img_size = img_size // 32
16 | self.features = self._make_layers(cfg_vgg[vgg_name], norm)
17 | self.fc = nn.Linear(512*img_size*img_size, num_classes)
18 |
19 | def forward(self, x):
20 | x = self.features(x)
21 | x = x.view(x.size(0), -1)
22 | x = self.fc(x)
23 | return x
24 |
25 | def embed(self, x):
26 | x = self.features(x)
27 | x = x.view(x.size(0), -1)
28 | return x
29 |
30 | def _make_layers(self, cfg, norm):
31 | layers = []
32 | in_channels = 3
33 | for ic, x in enumerate(cfg):
34 | if x == 'M':
35 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
36 | else:
37 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
38 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x),
39 | nn.ReLU(inplace=True)]
40 | in_channels = x
41 | return nn.Sequential(*layers)
42 |
43 | def VGG11(img_size, num_classes):
44 | return VGG('VGG11', img_size, num_classes)
45 |
--------------------------------------------------------------------------------
/models/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torchvision.models import resnet18, resnet34, resnet50
3 |
4 | from models.convnet import ConvNet
5 | from models.resnet import ResNet10, ResNet18
6 | from models.vgg import VGG11
7 | from models.alexnet import AlexNet
8 | from models.mobilenet import MobileNet
9 |
10 | def get_model(name, img_shape=(3,32,32), num_classes=10, dropout=0.0):
11 | if "convnet" in name.lower():
12 | name_splited = name.lower().split("_")
13 | num_channels = []
14 | for idx, ns in enumerate(name_splited):
15 | if idx != 0:
16 | if ns.isdigit():
17 | num_channels.append(int(ns))
18 | else:
19 | norm = ns
20 | model = ConvNet(img_shape, num_classes, num_channels, norm)
21 | elif "resnet10" == name.lower():
22 | model = ResNet10(num_classes)
23 | elif "resnet18" == name.lower():
24 | model = ResNet18(num_classes)
25 | elif "vgg" == name.lower():
26 | model = VGG11(img_shape[1], num_classes)
27 | elif "alexnet" == name.lower():
28 | model = AlexNet(img_shape[1], num_classes)
29 | elif "mobilenet" == name.lower():
30 | model = MobileNet(num_classes)
31 | else:
32 | raise NotImplementedError
33 | if dropout > 0.0:
34 | model.fc = nn.Sequential(
35 | nn.Dropout(dropout),
36 | model.fc
37 | )
38 | return model
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from data.wrapper import get_loader
8 | from data.augmentation import NUM_CLASSES, ParamDiffAug
9 | from algorithms.wrapper import get_algorithm
10 |
11 | def main(args):
12 | device = torch.device(f"cuda:{args.gpu_id}")
13 | torch.cuda.set_device(device)
14 |
15 | # default augment
16 | args.dsa_param = ParamDiffAug()
17 | args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
18 |
19 | # seed
20 | if args.seed is None:
21 | args.seed = random.randint(0, 9999)
22 | random.seed(args.seed)
23 | np.random.seed(args.seed)
24 | torch.manual_seed(args.seed)
25 |
26 | # data
27 | x_syn = torch.load(f"{args.synthetic_data_dir}/{args.source_data_name}/{args.method}/x_syn.pt", map_location="cpu").detach()
28 | y_syn = torch.load(f"{args.synthetic_data_dir}/{args.source_data_name}/{args.method}/y_syn.pt", map_location="cpu").detach()
29 | if args.method == "kip" or args.method == "frepo" or args.method == "krr_st":
30 | y_syn = y_syn.float()
31 | else:
32 | y_syn = y_syn.long()
33 | if args.method == "krr_st" :
34 | args.num_pretrain_classes = y_syn.shape[-1]
35 | else:
36 | args.num_pretrain_classes = NUM_CLASSES[args.source_data_name]
37 |
38 | print(args)
39 |
40 | # algo
41 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt":
42 | pretrain = get_algorithm("pretrain_dc")
43 | elif args.method == "kip" or args.method == "frepo":
44 | pretrain = get_algorithm("pretrain_frepo")
45 | elif args.method == "krr_st":
46 | pretrain = get_algorithm("pretrain_krr_st")
47 | else:
48 | raise NotImplementedError
49 | test_algo = get_algorithm("finetune")
50 |
51 | # target_data_name
52 | if args.target_data_name == "full":
53 | if args.source_data_name == "cifar100":
54 | data_name_list = ["cifar100", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"]
55 | elif args.source_data_name == "tinyimagenet":
56 | data_name_list = ["tinyimagenet", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"]
57 | elif args.source_data_name == "imagenet":
58 | data_name_list = ["cifar10", "cifar100", "aircraft", "cars", "cub2011", "dogs", "flowers"]
59 | elif args.source_data_name == "imagenette":
60 | data_name_list = ["imagenette", "aircraft", "cars", "cub2011", "dogs", "flowers"]
61 | else:
62 | raise NotImplementedError
63 | else:
64 | data_name_list = args.target_data_name.split("_")
65 |
66 | # train
67 | acc_dict = { data_name: [] for data_name in data_name_list }
68 | for _ in range(args.num_test):
69 | x_syn, y_syn = x_syn.to(device), y_syn.to(device)
70 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn)
71 | init_model = init_model.cpu()
72 | x_syn, y_syn = x_syn.cpu(), y_syn.cpu()
73 | for data_name in data_name_list:
74 | args.num_classes = NUM_CLASSES[data_name]
75 | if data_name in ["tinyimagenet", "cifar100", "cifar10"]:
76 | args.test_iteration = 10000
77 | else:
78 | args.test_iteration = 5000
79 | dl_tr, dl_te, aug_tr, aug_te = get_loader(
80 | args.data_dir, data_name, args.test_batch_size, args.img_size, True)
81 | _, acc = test_algo.run(args, device, args.test_model, init_model, dl_tr, dl_te, aug_tr, aug_te)
82 | print(data_name, acc)
83 | acc_dict[data_name].append(acc)
84 |
85 | for data_name in data_name_list:
86 | print(f"{data_name}, mean: {np.mean(acc_dict[data_name])}, std: {np.std(acc_dict[data_name])}")
87 |
88 | if __name__ == '__main__':
89 | parser = argparse.ArgumentParser(description='Parameter Processing')
90 |
91 | # seed
92 | parser.add_argument('--seed', type=int, default=None)
93 |
94 | # data
95 | parser.add_argument('--source_data_name', type=str, default="tinyimagenet")
96 | parser.add_argument('--target_data_name', type=str, default="full")
97 |
98 | # dir
99 | parser.add_argument('--data_dir', type=str, default="./datasets")
100 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data")
101 | parser.add_argument('--log_dir', type=str, default="./test_log")
102 |
103 | # dc method
104 | parser.add_argument('--method', type=str, default="krr_st")
105 |
106 | # hparams for model
107 | parser.add_argument('--test_model', type=str, default="base")
108 | parser.add_argument('--dropout', type=float, default=0.0)
109 |
110 | # hparms for test
111 | parser.add_argument('--num_test', type=int, default=3)
112 |
113 | # gpus
114 | parser.add_argument('--gpu_id', type=int, default=0)
115 | args = parser.parse_args()
116 |
117 | # img_size
118 | if args.source_data_name == "cifar100":
119 | args.img_size = 32
120 | if args.test_model == "base":
121 | args.test_model = "convnet_128_256_512_bn"
122 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet":
123 | args.img_size = 64
124 | if args.test_model == "base":
125 | args.test_model = "convnet_64_128_256_512_bn"
126 | elif args.source_data_name == "imagenette":
127 | args.img_size = 224
128 | if args.test_model == "base":
129 | args.test_model = "convnet_32_64_128_256_512_bn"
130 | else:
131 | raise NotImplementedError
132 | args.img_shape = (3, args.img_size, args.img_size)
133 |
134 | # pretrain hparams
135 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt":
136 | args.pre_opt = "sgd"
137 | args.pre_epoch = 1000
138 | args.pre_iteration = None
139 | args.pre_batch_size = 256
140 | args.pre_lr = 0.01
141 | args.pre_wd = 5e-4
142 |
143 | elif args.method == "kip" or args.method == "frepo":
144 | #step_per_prototpyes = {10: 1000, 100: 2000, 200: 20000, 400: 5000, 500: 5000, 1000: 10000, 2000: 40000, 5000: 40000}
145 | args.pre_opt = "adam"
146 | args.pre_epoch = None
147 | if args.source_data_name == "cifar100" or args.source_data_name == "imagenet":
148 | args.pre_iteration = 10000 # 1000
149 | args.pre_batch_size = 500
150 | elif args.source_data_name == "tinyimagenet":
151 | args.pre_iteration = 40000 # 2000
152 | if args.test_model == "mobilenet":
153 | args.pre_batch_size = 256
154 | else:
155 | args.pre_batch_size = 500
156 | elif args.source_data_name == "imagenette":
157 | args.pre_iteration = 1000 # 10
158 | args.pre_batch_size = 10
159 | args.pre_lr = 0.0003
160 | args.pre_wd = 0.0
161 |
162 | elif args.method == "krr_st":
163 | args.pre_opt = "sgd"
164 | args.pre_epoch = 1000
165 | args.pre_batch_size = 256
166 | args.pre_lr = 0.1
167 | args.pre_wd = 1e-3
168 | else:
169 | raise NotImplementedError
170 |
171 | # finetune hyperparams
172 | args.test_opt = "sgd"
173 | args.test_iteration = None
174 | if args.img_size == 224 and args.test_model == "resnet18":
175 | args.test_batch_size = 64
176 | else:
177 | args.test_batch_size = 256
178 | args.test_lr = 0.01
179 | args.test_wd = 5e-4
180 |
181 | main(args)
182 |
--------------------------------------------------------------------------------
/test_kd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from algorithms.wrapper import get_algorithm
8 | from data.augmentation import NUM_CLASSES, ParamDiffAug
9 | from data.wrapper import get_loader
10 | from models.wrapper import get_model
11 |
12 |
13 | def main(args):
14 | device = torch.device(f"cuda:{args.gpu_id}")
15 | torch.cuda.set_device(device)
16 |
17 | # default augment
18 | args.dsa_param = ParamDiffAug()
19 | args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
20 |
21 | # seed
22 | if args.seed is None:
23 | args.seed = random.randint(0, 9999)
24 | random.seed(args.seed)
25 | np.random.seed(args.seed)
26 | torch.manual_seed(args.seed)
27 |
28 | # data
29 | if args.method != "gaussian":
30 | x_syn = torch.load(f"./synthetic_data/{args.source_data_name}/{args.method}/x_syn.pt", map_location="cpu").detach()
31 | y_syn = torch.load(f"./synthetic_data/{args.source_data_name}/{args.method}/y_syn.pt", map_location="cpu").detach()
32 | if args.method == "kip" or args.method == "frepo" or args.method == "krr_st":
33 | y_syn = y_syn.float()
34 | else:
35 | y_syn = y_syn.long()
36 | else:
37 | x_syn = torch.load(f"./synthetic_data/{args.source_data_name}/random/x_syn.pt", map_location="cpu").detach()
38 | y_syn = torch.load(f"./synthetic_data/{args.source_data_name}/random/y_syn.pt", map_location="cpu").detach()
39 | if args.method == "krr_st" :
40 | args.num_pretrain_classes = y_syn.shape[-1]
41 | else:
42 | args.num_pretrain_classes = NUM_CLASSES[args.source_data_name]
43 |
44 | print(args)
45 |
46 | # algo
47 | if args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt":
48 | pretrain = get_algorithm("pretrain_dc")
49 | elif args.method == "kip" or args.method == "frepo":
50 | pretrain = get_algorithm("pretrain_frepo")
51 | elif args.method == "krr_st":
52 | pretrain = get_algorithm("pretrain_krr_st")
53 | elif args.method == "gaussian":
54 | pass
55 | else:
56 | raise NotImplementedError
57 | test_algo = get_algorithm("zeroshot_kd")
58 |
59 | data_name = "cifar10"
60 | args.img_shape = (3, args.img_size, args.img_size)
61 | dl_tr, dl_te, aug_tr, aug_te = get_loader(
62 | args.data_dir, data_name, args.test_batch_size, args.img_size, True)
63 | data = {
64 | "num_classes": NUM_CLASSES[data_name.lower()],
65 | "dl_tr": dl_tr,
66 | "dl_te": dl_te,
67 | "aug_tr": aug_tr,
68 | "aug_te": aug_te
69 | }
70 |
71 | teacher = get_model("resnet18", args.img_shape, data["num_classes"]).to(device)
72 | ckpt = torch.load(f"./teacher_ckpt/teacher_{data_name}.pt", map_location="cpu")
73 | teacher.load_state_dict(ckpt)
74 | for p in teacher.parameters():
75 | p.requires_grad_(False)
76 |
77 | acc_list = []
78 | for _ in range(args.num_test):
79 | if args.method != "gaussian":
80 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn)
81 | else:
82 | init_model = get_model(args.test_model, args.img_shape, 1).to(device)
83 |
84 | args.num_classes = data["num_classes"]
85 | dl_te = data["dl_te"]
86 | aug_te = data["aug_te"]
87 | _, acc = test_algo.run(args, device, args.test_model, init_model, teacher, x_syn, dl_te, aug_te)
88 | print(acc)
89 | acc_list.append(acc)
90 |
91 | print(f"{data_name}, mean: {np.mean(acc_list)}, std: {np.std(acc_list)}")
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser(description='Parameter Processing')
95 |
96 | # seed
97 | parser.add_argument('--seed', type=int, default=None)
98 |
99 | # data
100 | parser.add_argument('--source_data_name', type=str, default="tinyimagenet")
101 | parser.add_argument('--target_data_name', type=str, default="cifar10")
102 |
103 | # dir
104 | parser.add_argument('--data_dir', type=str, default="../evaluation_seanie/datasets")
105 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data")
106 | parser.add_argument('--log_dir', type=str, default="./test_log")
107 |
108 | # dc method
109 | parser.add_argument('--method', type=str, default="krr_st")
110 |
111 | # hparams for model
112 | parser.add_argument('--test_model', type=str, default="base")
113 | parser.add_argument('--dropout', type=float, default=0.0)
114 |
115 | # hparms for test
116 | parser.add_argument('--num_test', type=int, default=3)
117 |
118 | # gpus
119 | parser.add_argument('--gpu_id', type=int, default=0)
120 | args = parser.parse_args()
121 |
122 | # img_size
123 | if args.source_data_name == "cifar100":
124 | args.img_size = 32
125 | if args.test_model == "base":
126 | args.test_model = "convnet_128_256_512_bn"
127 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet":
128 | args.img_size = 64
129 | if args.test_model == "base":
130 | args.test_model = "convnet_64_128_256_512_bn"
131 | elif args.source_data_name == "imagenette":
132 | args.img_size = 224
133 | if args.test_model == "base":
134 | args.test_model = "convnet_32_64_128_256_512_bn"
135 | else:
136 | raise NotImplementedError
137 | args.img_shape = (3, args.img_size, args.img_size)
138 |
139 | # pretrain hparams
140 | if args.method == "gaussian" or args.method == "random" or args.method == "kmeans" or args.method == "dsa" or args.method == "dm" or args.method == "mtt":
141 | args.pre_opt = "sgd"
142 | args.pre_epoch = 1000
143 | args.pre_iteration = None
144 | args.pre_batch_size = 256
145 | args.pre_lr = 0.01
146 | args.pre_wd = 5e-4
147 |
148 | elif args.method == "kip" or args.method == "frepo":
149 | #step_per_prototpyes = {10: 1000, 100: 2000, 200: 20000, 400: 5000, 500: 5000, 1000: 10000, 2000: 40000, 5000: 40000}
150 | args.pre_opt = "adam"
151 | args.pre_epoch = None
152 | if args.source_data_name == "cifar100" or args.source_data_name == "imagenet":
153 | args.pre_iteration = 10000 # 1000
154 | args.pre_batch_size = 500
155 | elif args.source_data_name == "tinyimagenet":
156 | args.pre_iteration = 40000 # 2000
157 | if args.test_model == "mobilenet":
158 | args.pre_batch_size = 128
159 | else:
160 | args.pre_batch_size = 500
161 | elif args.source_data_name == "imagenette":
162 | args.pre_iteration = 1000 # 10
163 | args.pre_batch_size = 10
164 | args.pre_lr = 0.0003
165 | args.pre_wd = 0.0
166 |
167 | elif args.method == "krr_st":
168 | args.pre_opt = "sgd"
169 | args.pre_epoch = 1000
170 | args.pre_batch_size = 256
171 | args.pre_lr = 0.1
172 | args.pre_wd = 1e-3
173 | else:
174 | raise NotImplementedError
175 |
176 | # zeroshot_kd hyperparams
177 | args.test_opt = "adam"
178 | args.test_epoch = 1000
179 | args.test_batch_size = 512
180 | if args.method == "gaussian":
181 | args.test_lr = 1e-3
182 | else:
183 | args.test_lr = 1e-4
184 | args.test_wd = 0.
185 |
186 | main(args)
187 |
--------------------------------------------------------------------------------
/test_scratch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from data.wrapper import get_loader
8 | from data.augmentation import NUM_CLASSES
9 | from algorithms.wrapper import get_algorithm
10 |
11 | def main(args):
12 | device = torch.device(f"cuda:{args.gpu_id}")
13 | torch.cuda.set_device(device)
14 |
15 | # seed
16 | if args.seed is None:
17 | args.seed = random.randint(0, 9999)
18 | random.seed(args.seed)
19 | np.random.seed(args.seed)
20 | torch.manual_seed(args.seed)
21 |
22 | print(args)
23 |
24 | # algo
25 | test_algo = get_algorithm("scratch")
26 |
27 | # target_data_name
28 | if args.target_data_name == "full":
29 | if args.source_data_name == "cifar100":
30 | data_name_list = ["cifar100", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"]
31 | elif args.source_data_name == "tinyimagenet":
32 | data_name_list = ["tinyimagenet", "cifar10", "aircraft", "cars", "cub2011", "dogs", "flowers"]
33 | elif args.source_data_name == "imagenet":
34 | data_name_list = ["cifar10", "cifar100", "aircraft", "cars", "cub2011", "dogs", "flowers"]
35 | elif args.source_data_name == "imagenette":
36 | data_name_list = ["imagenette", "aircraft", "cars", "cub2011", "dogs", "flowers"]
37 | else:
38 | raise NotImplementedError
39 | else:
40 | data_name_list = args.target_data_name.split("_")
41 |
42 | # train
43 | acc_dict = { data_name: [] for data_name in data_name_list }
44 | for _ in range(args.num_test):
45 | for data_name in data_name_list:
46 | args.num_classes = NUM_CLASSES[data_name]
47 | if data_name in ["tinyimagenet", "cifar100", "cifar10"]:
48 | args.test_iteration = 10000
49 | else:
50 | args.test_iteration = 5000
51 | dl_tr, dl_te, aug_tr, aug_te = get_loader(
52 | args.data_dir, data_name, args.test_batch_size, args.img_size, True)
53 | _, acc = test_algo.run(args, device, args.test_model, dl_tr, dl_te, aug_tr, aug_te)
54 | print(data_name, acc)
55 | acc_dict[data_name].append(acc)
56 |
57 | for data_name in data_name_list:
58 | print(f"{data_name}, mean: {np.mean(acc_dict[data_name])}, std: {np.std(acc_dict[data_name])}")
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser(description='Parameter Processing')
62 |
63 | # seed
64 | parser.add_argument('--seed', type=int, default=None)
65 |
66 | # data
67 | parser.add_argument('--source_data_name', type=str, default="cifar100")
68 | parser.add_argument('--target_data_name', type=str, default="full")
69 |
70 | # dir
71 | parser.add_argument('--data_dir', type=str, default="../evaluation_seanie/datasets")
72 | parser.add_argument('--synthetic_data_dir', type=str, default="./synthetic_data")
73 | parser.add_argument('--log_dir', type=str, default="./test_log")
74 |
75 | # hparams for model
76 | parser.add_argument('--test_model', type=str, default="base")
77 |
78 | # hparms for test
79 | parser.add_argument('--num_test', type=int, default=3)
80 |
81 | # gpus
82 | parser.add_argument('--gpu_id', type=int, default=0)
83 | args = parser.parse_args()
84 |
85 | # img_size
86 | if args.source_data_name == "cifar100":
87 | args.img_size = 32
88 | if args.test_model == "base":
89 | args.test_model = "convnet_128_256_512_bn"
90 | elif args.source_data_name == "tinyimagenet" or args.source_data_name == "imagenet":
91 | args.img_size = 64
92 | if args.test_model == "base":
93 | args.test_model = "convnet_64_128_256_512_bn"
94 | elif args.source_data_name == "imagenette":
95 | args.img_size = 224
96 | if args.test_model == "base":
97 | args.test_model = "convnet_32_64_128_256_512_bn"
98 | else:
99 | raise NotImplementedError
100 | args.img_shape = (3, args.img_size, args.img_size)
101 |
102 | # finetune hyperparams
103 | args.test_opt = "sgd"
104 | args.test_iteration = None
105 | if args.img_size == 224 and args.test_model == "resnet18":
106 | args.test_batch_size = 64
107 | else:
108 | args.test_batch_size = 256
109 | args.test_lr = 0.01
110 | args.test_wd = 5e-4
111 |
112 | main(args)
113 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from torchvision.models import resnet18
9 |
10 | from data.wrapper import get_loader
11 | from data.augmentation import NUM_CLASSES
12 | from algorithms.wrapper import get_algorithm
13 | from utils import InfIterator, Logger
14 | from model_pool import ModelPool
15 |
16 | def main(args):
17 | device = torch.device(f"cuda:{args.gpu_id}")
18 | torch.cuda.set_device(device)
19 |
20 | # seed
21 | if args.seed is None:
22 | args.seed = random.randint(0, 9999)
23 | random.seed(args.seed)
24 | np.random.seed(args.seed)
25 | torch.manual_seed(args.seed)
26 |
27 | # data
28 | args.img_shape = (3, args.img_size, args.img_size)
29 | dl, _, _, _ = get_loader(args.data_dir, args.data_name, args.outer_batch_size, args.img_size, False)
30 | iter_tr = InfIterator(dl)
31 | dl_tr, dl_te, aug_tr, aug_te = get_loader(args.data_dir, args.data_name, args.test_batch_size, args.img_size, True)
32 |
33 | # target model
34 | if args.data_name == "imagenette":
35 | target_model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
36 | target_model.fc = nn.Identity()
37 | target_model = target_model.to(device)
38 | else:
39 | target_model = resnet18()
40 | target_model.fc = nn.Identity()
41 | target_model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
42 | target_model.maxpool = nn.Identity()
43 | target_model = target_model.to(device)
44 | target_model.load_state_dict(torch.load(f"./teacher_ckpt/barlow_twins_resnet18_{args.data_name}.pt", map_location="cpu"))
45 |
46 | # get features
47 | target_model.eval()
48 | with torch.no_grad():
49 | x_syn, y_syn = torch.FloatTensor([]), torch.FloatTensor([])
50 | while x_syn.shape[0] < args.num_images:
51 | x, y = next(iter_tr)
52 | x_syn = torch.cat([x_syn, x], dim=0)
53 | x = x.to(device)
54 | x = aug_te(x)
55 | y = target_model(x).cpu()
56 | torch.cuda.empty_cache()
57 | y_syn = torch.cat([y_syn, y], dim=0)
58 |
59 | args.num_pretrain_classes = y_syn.shape[-1]
60 | args.num_classes = NUM_CLASSES[args.data_name]
61 |
62 | # x_syn, y_syn
63 | x_syn, y_syn = x_syn[:args.num_images].to(device), y_syn[:args.num_images].to(device)
64 | x_syn.requires_grad_(True); y_syn.requires_grad_(True)
65 |
66 | # outer opt
67 | if args.outer_opt == "sgd":
68 | outer_opt = torch.optim.SGD([x_syn, y_syn], lr=args.outer_lr, momentum=0.5, weight_decay=args.outer_wd)
69 | elif args.outer_opt == "adam":
70 | outer_opt = torch.optim.AdamW([x_syn, y_syn], lr=args.outer_lr, weight_decay=args.outer_wd)
71 | else:
72 | raise NotImplementedError
73 | outer_sch = torch.optim.lr_scheduler.LinearLR(
74 | outer_opt, start_factor=1.0, end_factor=1e-3, total_iters=args.outer_iteration)
75 |
76 | # model pool
77 | model_pool = ModelPool(args, device)
78 | model_pool.init(x_syn.detach(), y_syn.detach())
79 |
80 | # logger
81 | logger = Logger(
82 | save_dir=f"{args.save_dir}/{args.exp_name}",
83 | save_only_last=True,
84 | print_every=args.print_every,
85 | save_every=args.save_every,
86 | total_step=args.outer_iteration,
87 | print_to_stdout=True
88 | )
89 | logger.register_object_to_save(x_syn, "x_syn")
90 | logger.register_object_to_save(y_syn, "y_syn")
91 | logger.start()
92 |
93 | # algo
94 | train_algo = get_algorithm("distill")
95 | pretrain = get_algorithm("pretrain_krr_st")
96 | test_algo = get_algorithm("linear_eval")
97 |
98 | # outer loop
99 | for outer_step in range(1, args.outer_iteration+1):
100 |
101 | # meta train
102 | loss = train_algo.run(
103 | args, device, target_model, model_pool, outer_opt, iter_tr, aug_tr, x_syn, y_syn)
104 | logger.meter("meta_train", "mse loss", loss)
105 |
106 | # meta test
107 | if outer_step % args.eval_every == 0 or outer_step == args.outer_iteration:
108 | init_model = pretrain.run(args, device, args.test_model, x_syn, y_syn)
109 | loss, acc = test_algo.run(args, device, args.test_model, init_model, dl_tr, dl_te, aug_tr, aug_te)
110 | del init_model
111 | logger.meter(f"meta_test", "loss", loss)
112 | logger.meter(f"meta_test", "accuracy", acc)
113 |
114 | outer_sch.step()
115 | logger.step()
116 |
117 | logger.finish()
118 |
119 | if __name__ == '__main__':
120 | parser = argparse.ArgumentParser(description='Parameter Processing')
121 |
122 | # seed
123 | parser.add_argument('--seed', type=int, default=None)
124 |
125 | # data
126 | parser.add_argument('--target_model', type=str, default="resnet18")
127 | parser.add_argument('--data_name', type=str, default="cifar100")
128 | parser.add_argument('--num_workers', type=int, default=0)
129 |
130 | # dir
131 | parser.add_argument('--data_dir', type=str, default="./datasets")
132 | parser.add_argument('--save_dir', type=str, default="./results")
133 | parser.add_argument('--exp_name', type=str, default=None)
134 |
135 | # algorithm
136 | parser.add_argument('--num_models', type=int, default=10)
137 |
138 | # hparms for online
139 | parser.add_argument('--online_opt', type=str, default="sgd")
140 | parser.add_argument('--online_iteration', type=int, default=1000)
141 | parser.add_argument('--online_lr', type=float, default=0.1)
142 | parser.add_argument('--online_wd', type=float, default=1e-3)
143 |
144 | # hparms for pretrain
145 | parser.add_argument('--pre_opt', type=str, default="sgd")
146 | parser.add_argument('--pre_epoch', type=int, default=1000)
147 | parser.add_argument('--pre_batch_size', type=int, default=256)
148 | parser.add_argument('--pre_lr', type=float, default=0.1)
149 | parser.add_argument('--pre_wd', type=float, default=1e-3)
150 |
151 | # hparms for test
152 | parser.add_argument('--test_opt', type=str, default="sgd")
153 | parser.add_argument('--test_iteration', type=int, default=5000)
154 | parser.add_argument('--test_batch_size', type=float, default=512)
155 | parser.add_argument('--test_lr', type=float, default=0.2)
156 | parser.add_argument('--test_wd', type=float, default=0.0)
157 |
158 | # hparms for outer
159 | parser.add_argument('--outer_opt', type=str, default="adam")
160 | parser.add_argument('--outer_iteration', type=int, default=160000)
161 | parser.add_argument('--outer_batch_size', type=int, default=1024)
162 | parser.add_argument('--outer_lr', type=float, default=1e-3)
163 | parser.add_argument('--outer_wd', type=float, default=0.)
164 | parser.add_argument('--outer_grad_norm', type=float, default=0.0)
165 |
166 | # hparams for logger
167 | parser.add_argument('--print_every', type=int, default=100)
168 | parser.add_argument('--eval_every', type=int, default=16000)
169 | parser.add_argument('--save_every', type=int, default=2000)
170 |
171 | # gpus
172 | parser.add_argument('--gpu_id', type=int, default=0)
173 | args = parser.parse_args()
174 |
175 | if args.data_name == "cifar100":
176 | args.img_size = 32
177 | args.num_images = 1000
178 | args.train_model = "convnet_128_256_512_bn"
179 | args.test_model = "convnet_128_256_512_bn"
180 | elif args.data_name == "tinyimagenet":
181 | args.img_size = 64
182 | args.num_images = 2000
183 | args.train_model = "convnet_64_128_256_512_bn"
184 | args.test_model = "convnet_64_128_256_512_bn"
185 | elif args.data_name == "imagenet":
186 | args.img_size = 64
187 | args.num_images = 1000
188 | args.train_model = "convnet_64_128_256_512_bn"
189 | args.test_model = "convnet_64_128_256_512_bn"
190 | elif args.data_name == "imagenette":
191 | args.img_size = 224
192 | args.num_images = 10
193 | args.train_model = "convnet_32_64_128_256_512_bn"
194 | args.test_model = "convnet_32_64_128_256_512_bn"
195 | else:
196 | raise NotImplementedError
197 |
198 | main(args)
199 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | import torch
5 |
6 |
7 | class InfIterator:
8 | def __init__(self, iterable):
9 | self.iterable = iterable
10 | self.iterator = iter(self.iterable)
11 |
12 | def __next__(self):
13 | try:
14 | return next(self.iterator)
15 | except StopIteration:
16 | self.iterator = iter(self.iterable)
17 | return next(self.iterator)
18 |
19 | class Logger:
20 | def __init__(
21 | self,
22 | save_dir=None,
23 | save_only_last=True,
24 | print_every=100,
25 | save_every=100,
26 | total_step=0,
27 | print_to_stdout=True,
28 | ):
29 | if save_dir is not None:
30 | self.save_dir = save_dir
31 | os.makedirs(self.save_dir, exist_ok=True)
32 | else:
33 | self.save_dir = None
34 |
35 | self.print_every = print_every
36 | self.save_every = save_every
37 | self.save_only_last = save_only_last
38 | self.step_count = 0
39 | self.total_step = total_step
40 | self.print_to_stdout = print_to_stdout
41 |
42 | self.writer = None
43 | self.start_time = None
44 | self.groups = dict()
45 | self.models_to_save = dict()
46 | self.objects_to_save = dict()
47 |
48 | def register_model_to_save(self, model, name):
49 | assert name not in self.models_to_save.keys(), "Name is already registered."
50 |
51 | self.models_to_save[name] = model
52 |
53 | def register_object_to_save(self, object, name):
54 | assert name not in self.objects_to_save.keys(), "Name is already registered."
55 |
56 | self.objects_to_save[name] = object
57 |
58 | def step(self):
59 | self.step_count += 1
60 | if self.step_count % self.print_every == 0:
61 | if self.print_to_stdout:
62 | self.print_log(self.step_count, self.total_step, elapsed_time=datetime.now() - self.start_time)
63 |
64 | if self.step_count % self.save_every == 0:
65 | if self.save_only_last:
66 | self.save_models()
67 | self.save_objects()
68 | else:
69 | self.save_models(self.step_count)
70 | self.save_objects(self.step_count)
71 |
72 | def meter(self, group_name, log_name, value):
73 | if group_name not in self.groups.keys():
74 | self.groups[group_name] = dict()
75 |
76 | if log_name not in self.groups[group_name].keys():
77 | self.groups[group_name][log_name] = Accumulator()
78 |
79 | self.groups[group_name][log_name].update_state(value)
80 |
81 | def reset_state(self):
82 | for _, group in self.groups.items():
83 | for _, log in group.items():
84 | log.reset_state()
85 |
86 | def print_log(self, step, total_step, elapsed_time=None):
87 | print(f"[Step {step:5d}/{total_step}]", end=" ")
88 |
89 | for name, group in self.groups.items():
90 | print(f"({name})", end=" ")
91 | for log_name, log in group.items():
92 | res = log.result()
93 | if res is None:
94 | continue
95 |
96 | if "acc" in log_name.lower():
97 | print(f"{log_name} {res:.2f}", end=" | ")
98 | else:
99 | print(f"{log_name} {res:.4f}", end=" | ")
100 |
101 | if elapsed_time is not None:
102 | print(f"(Elapsed time) {elapsed_time}")
103 | else:
104 | print()
105 |
106 | def save_models(self, suffix=None):
107 | if self.save_dir is None:
108 | return
109 | for name, model in self.models_to_save.items():
110 | _name = name
111 | if suffix:
112 | _name += f"_{suffix}"
113 | torch.save(model.state_dict(), os.path.join(self.save_dir, f"{_name}.pt"))
114 |
115 | if self.print_to_stdout:
116 | print(f"{name} is saved to {self.save_dir}")
117 |
118 | def save_objects(self, suffix=None):
119 | if self.save_dir is None:
120 | return
121 |
122 | for name, obj in self.objects_to_save.items():
123 | _name = name
124 | if suffix:
125 | _name += f"_{suffix}"
126 | torch.save(obj, os.path.join(self.save_dir, f"{_name}.pt"))
127 |
128 | if self.print_to_stdout:
129 | print(f"{name} is saved to {self.save_dir}")
130 |
131 | def start(self):
132 | if self.print_to_stdout:
133 | print("Training starts!")
134 | self.start_time = datetime.now()
135 |
136 | def finish(self):
137 | if self.step_count % self.save_every != 0:
138 | if self.save_only_last:
139 | self.save_models()
140 | self.save_objects()
141 | else:
142 | self.save_models(self.step_count)
143 | self.save_objects(self.step_count)
144 |
145 | if self.print_to_stdout:
146 | print("Training is finished!")
147 |
148 | class Accumulator:
149 | def __init__(self):
150 | self.data = 0
151 | self.num_data = 0
152 |
153 | def reset_state(self):
154 | self.data = 0
155 | self.num_data = 0
156 |
157 | def update_state(self, tensor):
158 | with torch.no_grad():
159 | self.data += tensor
160 | self.num_data += 1
161 |
162 | def result(self):
163 | if self.num_data == 0:
164 | return None
165 | data = self.data.item() if hasattr(self.data, 'item') else self.data
166 | return float(data) / self.num_data
167 |
--------------------------------------------------------------------------------