├── .gitignore ├── LICENSE ├── README.md ├── cv ├── .gitignore ├── LICENSE ├── README.md ├── crd │ ├── __init__.py │ ├── criterion.py │ └── memory.py ├── dataset │ ├── cifar100.py │ ├── cifar_with_held.py │ ├── imagenet.py │ └── meta_cifar100.py ├── distiller_zoo │ ├── AB.py │ ├── AT.py │ ├── CC.py │ ├── FSP.py │ ├── FT.py │ ├── FitNet.py │ ├── KD.py │ ├── KDSVD.py │ ├── MSE.py │ ├── NST.py │ ├── PKT.py │ ├── RKD.py │ ├── SP.py │ ├── VID.py │ └── __init__.py ├── helper │ ├── __init__.py │ ├── loops.py │ ├── meta_loops.py │ ├── pretrain.py │ └── util.py ├── models │ ├── ShuffleNetv1.py │ ├── ShuffleNetv2.py │ ├── __init__.py │ ├── classifier.py │ ├── meta_resnet.py │ ├── meta_resnet_v2.py │ ├── meta_vgg.py │ ├── mobilenetv2.py │ ├── resnet.py │ ├── resnetv2.py │ ├── util.py │ ├── vgg.py │ └── wrn.py ├── scripts │ ├── fetch_pretrained_teachers.sh │ ├── run_cifar_distill.sh │ └── run_cifar_vanilla.sh ├── train_student.py ├── train_student_debug.py ├── train_student_meta.py └── train_teacher.py └── nlp ├── .gitignore ├── LICENSE ├── README.md ├── distillation_meta.py ├── download_glue_data.py ├── functional_forward_bert.py ├── mrpc_hyperparameters.json ├── requirements.txt ├── run_glue.py ├── run_glue_distillation_meta.py ├── run_hyperparameter_tuning.py └── utils_glue.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kevin Canwen Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaDistil 2 | Code for ACL 2022 paper ["BERT Learns to Teach: Knowledge Distillation with Meta Learning"](https://arxiv.org/abs/2106.04570). 3 | 4 | ## ⚠️ Read before use 5 | Since the release of this paper on arXiv, we have received a lot of requests for the code. Thus, we want to first release the code without cleaning up. We know implementing a second-order approach is non-trivial so we want to help you but please note that the current code may contain bugs, useless codes, incorrect settings etc. **Please use at your own risk.** We'll later verify the code and clean it up once we have the chance to do so. 6 | 7 | ## Acknowledgments 8 | The implementation of image classification is based on https://github.com/HobbitLong/RepDistiller 9 | 10 | The implementation of text classification is based on https://github.com/bzantium/pytorch-PKD-for-BERT-compression 11 | 12 | Shout out to the authors of these two repos. 13 | 14 | ## How to use the code 15 | To be added. For now, please see `/nlp/run_glue_distillation_meta.py` and `/cv/train_student_meta.py`. 16 | -------------------------------------------------------------------------------- /cv/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | output*/ 4 | ckpts/ 5 | *.pth 6 | *.t7 7 | *.png 8 | *.jpg 9 | tmp*.py 10 | 11 | *.pdf 12 | 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | -------------------------------------------------------------------------------- /cv/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Yonglong Tian 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /cv/README.md: -------------------------------------------------------------------------------- 1 | # RepDistiller 2 | 3 | This repo: 4 | 5 | **(1) covers the implementation of the following ICLR 2020 paper:** 6 | 7 | "Contrastive Representation Distillation" (CRD). [Paper](http://arxiv.org/abs/1910.10699), [Project Page](http://hobbitlong.github.io/CRD/). 8 | 9 |
10 | 11 |

12 | 13 | **(2) benchmarks 12 state-of-the-art knowledge distillation methods in PyTorch, including:** 14 | 15 | (KD) - Distilling the Knowledge in a Neural Network 16 | (FitNet) - Fitnets: hints for thin deep nets 17 | (AT) - Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks 18 | via Attention Transfer 19 | (SP) - Similarity-Preserving Knowledge Distillation 20 | (CC) - Correlation Congruence for Knowledge Distillation 21 | (VID) - Variational Information Distillation for Knowledge Transfer 22 | (RKD) - Relational Knowledge Distillation 23 | (PKT) - Probabilistic Knowledge Transfer for deep representation learning 24 | (AB) - Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons 25 | (FT) - Paraphrasing Complex Network: Network Compression via Factor Transfer 26 | (FSP) - A Gift from Knowledge Distillation: 27 | Fast Optimization, Network Minimization and Transfer Learning 28 | (NST) - Like what you like: knowledge distill via neuron selectivity transfer 29 | 30 | ## Installation 31 | 32 | This repo was tested with Ubuntu 16.04.5 LTS, Python 3.5, PyTorch 0.4.0, and CUDA 9.0. But it should be runnable with recent PyTorch versions >=0.4.0 33 | 34 | ## Running 35 | 36 | 1. Fetch the pretrained teacher models by: 37 | 38 | ``` 39 | sh scripts/fetch_pretrained_teachers.sh 40 | ``` 41 | which will download and save the models to `save/models` 42 | 43 | 2. Run distillation by following commands in `scripts/run_cifar_distill.sh`. An example of running Geoffrey's original Knowledge Distillation (KD) is given by: 44 | 45 | ``` 46 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 0.1 -a 0.9 -b 0 --trial 1 47 | ``` 48 | where the flags are explained as: 49 | - `--path_t`: specify the path of the teacher model 50 | - `--model_s`: specify the student model, see 'models/\_\_init\_\_.py' to check the available model types. 51 | - `--distill`: specify the distillation method 52 | - `-r`: the weight of the cross-entropy loss between logit and ground truth, default: `1` 53 | - `-a`: the weight of the KD loss, default: `None` 54 | - `-b`: the weight of other distillation losses, default: `None` 55 | - `--trial`: specify the experimental id to differentiate between multiple runs. 56 | 57 | Therefore, the command for running CRD is something like: 58 | ``` 59 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 0 -b 0.8 --trial 1 60 | ``` 61 | 62 | 3. Combining a distillation objective with KD is simply done by setting `-a` as a non-zero value, which results in the following example (combining CRD with KD) 63 | ``` 64 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 1 -b 0.8 --trial 1 65 | ``` 66 | 67 | 4. (optional) Train teacher networks from scratch. Example commands are in `scripts/run_cifar_vanilla.sh` 68 | 69 | Note: the default setting is for a single-GPU training. If you would like to play this repo with multiple GPUs, you might need to tune the learning rate, which empirically needs to be scaled up linearly with the batch size, see [this paper](https://arxiv.org/abs/1706.02677) 70 | 71 | ## Benchmark Results on CIFAR-100: 72 | 73 | Performance is measured by classification accuracy (%) 74 | 75 | 1. Teacher and student are of the **same** architectural type. 76 | 77 | | Teacher
Student | wrn-40-2
wrn-16-2 | wrn-40-2
wrn-40-1 | resnet56
resnet20 | resnet110
resnet20 | resnet110
resnet32 | resnet32x4
resnet8x4 | vgg13
vgg8 | 78 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:|:-----------:| 79 | | Teacher
Student | 75.61
73.26 | 75.61
71.98 | 72.34
69.06 | 74.31
69.06 | 74.31
71.14 | 79.42
72.50 | 74.64
70.36 | 80 | | KD | 74.92 | 73.54 | 70.66 | 70.67 | 73.08 | 73.33 | 72.98 | 81 | | FitNet | 73.58 | 72.24 | 69.21 | 68.99 | 71.06 | 73.50 | 71.02 | 82 | | AT | 74.08 | 72.77 | 70.55 | 70.22 | 72.31 | 73.44 | 71.43 | 83 | | SP | 73.83 | 72.43 | 69.67 | 70.04 | 72.69 | 72.94 | 72.68 | 84 | | CC | 73.56 | 72.21 | 69.63 | 69.48 | 71.48 | 72.97 | 70.71 | 85 | | VID | 74.11 | 73.30 | 70.38 | 70.16 | 72.61 | 73.09 | 71.23 | 86 | | RKD | 73.35 | 72.22 | 69.61 | 69.25 | 71.82 | 71.90 | 71.48 | 87 | | PKT | 74.54 | 73.45 | 70.34 | 70.25 | 72.61 | 73.64 | 72.88 | 88 | | AB | 72.50 | 72.38 | 69.47 | 69.53 | 70.98 | 73.17 | 70.94 | 89 | | FT | 73.25 | 71.59 | 69.84 | 70.22 | 72.37 | 72.86 | 70.58 | 90 | | FSP | 72.91 | 0.00 | 69.95 | 70.11 | 71.89 | 72.62 | 70.23 | 91 | | NST | 73.68 | 72.24 | 69.60 | 69.53 | 71.96 | 73.30 | 71.53 | 92 | | **CRD** | **75.48** | **74.14** | **71.16** | **71.46** | **73.48** | **75.51** | **73.94** | 93 | 94 | 2. Teacher and student are of **different** architectural type. 95 | 96 | | Teacher
Student | vgg13
MobileNetV2 | ResNet50
MobileNetV2 | ResNet50
vgg8 | resnet32x4
ShuffleNetV1 | resnet32x4
ShuffleNetV2 | wrn-40-2
ShuffleNetV1 | 97 | |:---------------:|:-----------------:|:--------------------:|:-------------:|:-----------------------:|:-----------------------:|:---------------------:| 98 | | Teacher
Student | 74.64
64.60 | 79.34
64.60 | 79.34
70.36 | 79.42
70.50 | 79.42
71.82 | 75.61
70.50 | 99 | | KD | 67.37 | 67.35 | 73.81 | 74.07 | 74.45 | 74.83 | 100 | | FitNet | 64.14 | 63.16 | 70.69 | 73.59 | 73.54 | 73.73 | 101 | | AT | 59.40 | 58.58 | 71.84 | 71.73 | 72.73 | 73.32 | 102 | | SP | 66.30 | 68.08 | 73.34 | 73.48 | 74.56 | 74.52 | 103 | | CC | 64.86 | 65.43 | 70.25 | 71.14 | 71.29 | 71.38 | 104 | | VID | 65.56 | 67.57 | 70.30 | 73.38 | 73.40 | 73.61 | 105 | | RKD | 64.52 | 64.43 | 71.50 | 72.28 | 73.21 | 72.21 | 106 | | PKT | 67.13 | 66.52 | 73.01 | 74.10 | 74.69 | 73.89 | 107 | | AB | 66.06 | 67.20 | 70.65 | 73.55 | 74.31 | 73.34 | 108 | | FT | 61.78 | 60.99 | 70.29 | 71.75 | 72.50 | 72.03 | 109 | | NST | 58.16 | 64.96 | 71.28 | 74.12 | 74.68 | 74.89 | 110 | | **CRD** | **69.73** | **69.11** | **74.30** | **75.11** | **75.65** | **76.05** | 111 | 112 | ## Citation 113 | 114 | If you find this repo useful for your research, please consider citing the paper 115 | 116 | ``` 117 | @inproceedings{tian2019crd, 118 | title={Contrastive Representation Distillation}, 119 | author={Yonglong Tian and Dilip Krishnan and Phillip Isola}, 120 | booktitle={International Conference on Learning Representations}, 121 | year={2020} 122 | } 123 | ``` 124 | For any questions, please contact Yonglong Tian (yonglong@mit.edu). 125 | 126 | ## Acknowledgement 127 | 128 | Thanks to Baoyun Peng for providing the code of CC and to Frederick Tung for verifying our reimplementation of SP. Thanks also go to authors of other papers who make their code publicly available. 129 | -------------------------------------------------------------------------------- /cv/crd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetRunner/MetaDistil/80e60c11de531b10d1f06ceb2b71c70665bb6aff/cv/crd/__init__.py -------------------------------------------------------------------------------- /cv/crd/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .memory import ContrastMemory 4 | 5 | eps = 1e-7 6 | 7 | 8 | class CRDLoss(nn.Module): 9 | """CRD Loss function 10 | includes two symmetric parts: 11 | (a) using teacher as anchor, choose positive and negatives over the student side 12 | (b) using student as anchor, choose positive and negatives over the teacher side 13 | 14 | Args: 15 | opt.s_dim: the dimension of student's feature 16 | opt.t_dim: the dimension of teacher's feature 17 | opt.feat_dim: the dimension of the projection space 18 | opt.nce_k: number of negatives paired with each positive 19 | opt.nce_t: the temperature 20 | opt.nce_m: the momentum for updating the memory buffer 21 | opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim 22 | """ 23 | def __init__(self, opt): 24 | super(CRDLoss, self).__init__() 25 | self.embed_s = Embed(opt.s_dim, opt.feat_dim) 26 | self.embed_t = Embed(opt.t_dim, opt.feat_dim) 27 | self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m) 28 | self.criterion_t = ContrastLoss(opt.n_data) 29 | self.criterion_s = ContrastLoss(opt.n_data) 30 | 31 | def forward(self, f_s, f_t, idx, contrast_idx=None): 32 | """ 33 | Args: 34 | f_s: the feature of student network, size [batch_size, s_dim] 35 | f_t: the feature of teacher network, size [batch_size, t_dim] 36 | idx: the indices of these positive samples in the dataset, size [batch_size] 37 | contrast_idx: the indices of negative samples, size [batch_size, nce_k] 38 | 39 | Returns: 40 | The contrastive loss 41 | """ 42 | f_s = self.embed_s(f_s) 43 | f_t = self.embed_t(f_t) 44 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) 45 | s_loss = self.criterion_s(out_s) 46 | t_loss = self.criterion_t(out_t) 47 | loss = s_loss + t_loss 48 | return loss 49 | 50 | 51 | class ContrastLoss(nn.Module): 52 | """ 53 | contrastive loss, corresponding to Eq (18) 54 | """ 55 | def __init__(self, n_data): 56 | super(ContrastLoss, self).__init__() 57 | self.n_data = n_data 58 | 59 | def forward(self, x): 60 | bsz = x.shape[0] 61 | m = x.size(1) - 1 62 | 63 | # noise distribution 64 | Pn = 1 / float(self.n_data) 65 | 66 | # loss for positive pair 67 | P_pos = x.select(1, 0) 68 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 69 | 70 | # loss for K negative pair 71 | P_neg = x.narrow(1, 1, m) 72 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 73 | 74 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 75 | 76 | return loss 77 | 78 | 79 | class Embed(nn.Module): 80 | """Embedding module""" 81 | def __init__(self, dim_in=1024, dim_out=128): 82 | super(Embed, self).__init__() 83 | self.linear = nn.Linear(dim_in, dim_out) 84 | self.l2norm = Normalize(2) 85 | 86 | def forward(self, x): 87 | x = x.view(x.shape[0], -1) 88 | x = self.linear(x) 89 | x = self.l2norm(x) 90 | return x 91 | 92 | 93 | class Normalize(nn.Module): 94 | """normalization layer""" 95 | def __init__(self, power=2): 96 | super(Normalize, self).__init__() 97 | self.power = power 98 | 99 | def forward(self, x): 100 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 101 | out = x.div(norm) 102 | return out 103 | -------------------------------------------------------------------------------- /cv/crd/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class ContrastMemory(nn.Module): 7 | """ 8 | memory buffer that supplies large amount of negative samples. 9 | """ 10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5): 11 | super(ContrastMemory, self).__init__() 12 | self.nLem = outputSize 13 | self.unigrams = torch.ones(self.nLem) 14 | self.multinomial = AliasMethod(self.unigrams) 15 | self.multinomial.cuda() 16 | self.K = K 17 | 18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 19 | stdv = 1. / math.sqrt(inputSize / 3) 20 | self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 21 | self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 22 | 23 | def forward(self, v1, v2, y, idx=None): 24 | K = int(self.params[0].item()) 25 | T = self.params[1].item() 26 | Z_v1 = self.params[2].item() 27 | Z_v2 = self.params[3].item() 28 | 29 | momentum = self.params[4].item() 30 | batchSize = v1.size(0) 31 | outputSize = self.memory_v1.size(0) 32 | inputSize = self.memory_v1.size(1) 33 | 34 | # original score computation 35 | if idx is None: 36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 37 | idx.select(1, 0).copy_(y.data) 38 | # sample 39 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach() 40 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize) 41 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1)) 42 | out_v2 = torch.exp(torch.div(out_v2, T)) 43 | # sample 44 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach() 45 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize) 46 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1)) 47 | out_v1 = torch.exp(torch.div(out_v1, T)) 48 | 49 | # set Z if haven't been set yet 50 | if Z_v1 < 0: 51 | self.params[2] = out_v1.mean() * outputSize 52 | Z_v1 = self.params[2].clone().detach().item() 53 | print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1)) 54 | if Z_v2 < 0: 55 | self.params[3] = out_v2.mean() * outputSize 56 | Z_v2 = self.params[3].clone().detach().item() 57 | print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2)) 58 | 59 | # compute out_v1, out_v2 60 | out_v1 = torch.div(out_v1, Z_v1).contiguous() 61 | out_v2 = torch.div(out_v2, Z_v2).contiguous() 62 | 63 | # update memory 64 | with torch.no_grad(): 65 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1)) 66 | l_pos.mul_(momentum) 67 | l_pos.add_(torch.mul(v1, 1 - momentum)) 68 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 69 | updated_v1 = l_pos.div(l_norm) 70 | self.memory_v1.index_copy_(0, y, updated_v1) 71 | 72 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1)) 73 | ab_pos.mul_(momentum) 74 | ab_pos.add_(torch.mul(v2, 1 - momentum)) 75 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 76 | updated_v2 = ab_pos.div(ab_norm) 77 | self.memory_v2.index_copy_(0, y, updated_v2) 78 | 79 | return out_v1, out_v2 80 | 81 | 82 | class AliasMethod(object): 83 | """ 84 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 85 | """ 86 | def __init__(self, probs): 87 | 88 | if probs.sum() > 1: 89 | probs.div_(probs.sum()) 90 | K = len(probs) 91 | self.prob = torch.zeros(K) 92 | self.alias = torch.LongTensor([0]*K) 93 | 94 | # Sort the data into the outcomes with probabilities 95 | # that are larger and smaller than 1/K. 96 | smaller = [] 97 | larger = [] 98 | for kk, prob in enumerate(probs): 99 | self.prob[kk] = K*prob 100 | if self.prob[kk] < 1.0: 101 | smaller.append(kk) 102 | else: 103 | larger.append(kk) 104 | 105 | # Loop though and create little binary mixtures that 106 | # appropriately allocate the larger outcomes over the 107 | # overall uniform mixture. 108 | while len(smaller) > 0 and len(larger) > 0: 109 | small = smaller.pop() 110 | large = larger.pop() 111 | 112 | self.alias[small] = large 113 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 114 | 115 | if self.prob[large] < 1.0: 116 | smaller.append(large) 117 | else: 118 | larger.append(large) 119 | 120 | for last_one in smaller+larger: 121 | self.prob[last_one] = 1 122 | 123 | def cuda(self): 124 | self.prob = self.prob.cuda() 125 | self.alias = self.alias.cuda() 126 | 127 | def draw(self, N): 128 | """ Draw N samples from multinomial """ 129 | K = self.alias.size(0) 130 | 131 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 132 | prob = self.prob.index_select(0, kk) 133 | alias = self.alias.index_select(0, kk) 134 | # b is whether a random number is greater than q 135 | b = torch.bernoulli(prob) 136 | oq = kk.mul(b.long()) 137 | oj = alias.mul((1-b).long()) 138 | 139 | return oq + oj -------------------------------------------------------------------------------- /cv/dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import socket 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | from PIL import Image 9 | 10 | """ 11 | mean = { 12 | 'cifar100': (0.5071, 0.4867, 0.4408), 13 | } 14 | 15 | std = { 16 | 'cifar100': (0.2675, 0.2565, 0.2761), 17 | } 18 | """ 19 | 20 | 21 | def get_data_folder(): 22 | """ 23 | return server-dependent path to store the data 24 | """ 25 | hostname = socket.gethostname() 26 | if hostname.startswith('visiongpu'): 27 | data_folder = '/data/vision/phillipi/rep-learn/datasets' 28 | elif hostname.startswith('yonglong-home'): 29 | data_folder = '/home/yonglong/Data/data' 30 | else: 31 | data_folder = './data/' 32 | 33 | if not os.path.isdir(data_folder): 34 | os.makedirs(data_folder) 35 | 36 | return data_folder 37 | 38 | 39 | class CIFAR100Instance(datasets.CIFAR100): 40 | """CIFAR100Instance Dataset. 41 | """ 42 | def __getitem__(self, index): 43 | if self.train: 44 | img, target = self.train_data[index], self.train_labels[index] 45 | else: 46 | img, target = self.test_data[index], self.test_labels[index] 47 | 48 | # doing this so that it is consistent with all other datasets 49 | # to return a PIL Image 50 | img = Image.fromarray(img) 51 | 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | return img, target, index 59 | 60 | 61 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False): 62 | """ 63 | cifar 100 64 | """ 65 | data_folder = get_data_folder() 66 | 67 | train_transform = transforms.Compose([ 68 | transforms.RandomCrop(32, padding=4), 69 | transforms.RandomHorizontalFlip(), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 72 | ]) 73 | test_transform = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 76 | ]) 77 | 78 | if is_instance: 79 | train_set = CIFAR100Instance(root=data_folder, 80 | download=True, 81 | train=True, 82 | transform=train_transform) 83 | n_data = len(train_set) 84 | else: 85 | train_set = datasets.CIFAR100(root=data_folder, 86 | download=True, 87 | train=True, 88 | transform=train_transform) 89 | train_loader = DataLoader(train_set, 90 | batch_size=batch_size, 91 | shuffle=True, 92 | num_workers=num_workers) 93 | 94 | test_set = datasets.CIFAR100(root=data_folder, 95 | download=True, 96 | train=False, 97 | transform=test_transform) 98 | test_loader = DataLoader(test_set, 99 | batch_size=int(batch_size/2), 100 | shuffle=False, 101 | num_workers=int(num_workers/2)) 102 | 103 | if is_instance: 104 | return train_loader, test_loader, n_data 105 | else: 106 | return train_loader, test_loader 107 | 108 | 109 | class CIFAR100InstanceSample(datasets.CIFAR100): 110 | """ 111 | CIFAR100Instance+Sample Dataset 112 | """ 113 | def __init__(self, root, train=True, 114 | transform=None, target_transform=None, 115 | download=False, k=4096, mode='exact', is_sample=True, percent=1.0): 116 | super().__init__(root=root, train=train, download=download, 117 | transform=transform, target_transform=target_transform) 118 | self.k = k 119 | self.mode = mode 120 | self.is_sample = is_sample 121 | 122 | num_classes = 100 123 | if self.train: 124 | num_samples = len(self.train_data) 125 | label = self.train_labels 126 | else: 127 | num_samples = len(self.test_data) 128 | label = self.test_labels 129 | 130 | self.cls_positive = [[] for i in range(num_classes)] 131 | for i in range(num_samples): 132 | self.cls_positive[label[i]].append(i) 133 | 134 | self.cls_negative = [[] for i in range(num_classes)] 135 | for i in range(num_classes): 136 | for j in range(num_classes): 137 | if j == i: 138 | continue 139 | self.cls_negative[i].extend(self.cls_positive[j]) 140 | 141 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 142 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 143 | 144 | if 0 < percent < 1: 145 | n = int(len(self.cls_negative[0]) * percent) 146 | self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n] 147 | for i in range(num_classes)] 148 | 149 | self.cls_positive = np.asarray(self.cls_positive) 150 | self.cls_negative = np.asarray(self.cls_negative) 151 | 152 | def __getitem__(self, index): 153 | if self.train: 154 | img, target = self.train_data[index], self.train_labels[index] 155 | else: 156 | img, target = self.test_data[index], self.test_labels[index] 157 | 158 | # doing this so that it is consistent with all other datasets 159 | # to return a PIL Image 160 | img = Image.fromarray(img) 161 | 162 | if self.transform is not None: 163 | img = self.transform(img) 164 | 165 | if self.target_transform is not None: 166 | target = self.target_transform(target) 167 | 168 | if not self.is_sample: 169 | # directly return 170 | return img, target, index 171 | else: 172 | # sample contrastive examples 173 | if self.mode == 'exact': 174 | pos_idx = index 175 | elif self.mode == 'relax': 176 | pos_idx = np.random.choice(self.cls_positive[target], 1) 177 | pos_idx = pos_idx[0] 178 | else: 179 | raise NotImplementedError(self.mode) 180 | replace = True if self.k > len(self.cls_negative[target]) else False 181 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 182 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 183 | return img, target, index, sample_idx 184 | 185 | 186 | def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact', 187 | is_sample=True, percent=1.0): 188 | """ 189 | cifar 100 190 | """ 191 | data_folder = get_data_folder() 192 | 193 | train_transform = transforms.Compose([ 194 | transforms.RandomCrop(32, padding=4), 195 | transforms.RandomHorizontalFlip(), 196 | transforms.ToTensor(), 197 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 198 | ]) 199 | test_transform = transforms.Compose([ 200 | transforms.ToTensor(), 201 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 202 | ]) 203 | 204 | train_set = CIFAR100InstanceSample(root=data_folder, 205 | download=True, 206 | train=True, 207 | transform=train_transform, 208 | k=k, 209 | mode=mode, 210 | is_sample=is_sample, 211 | percent=percent) 212 | n_data = len(train_set) 213 | train_loader = DataLoader(train_set, 214 | batch_size=batch_size, 215 | shuffle=True, 216 | num_workers=num_workers) 217 | 218 | test_set = datasets.CIFAR100(root=data_folder, 219 | download=True, 220 | train=False, 221 | transform=test_transform) 222 | test_loader = DataLoader(test_set, 223 | batch_size=int(batch_size/2), 224 | shuffle=False, 225 | num_workers=int(num_workers/2)) 226 | 227 | return train_loader, test_loader, n_data 228 | -------------------------------------------------------------------------------- /cv/dataset/cifar_with_held.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | 13 | from torchvision.datasets.vision import VisionDataset 14 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 15 | 16 | 17 | class CIFAR100WithHeld(VisionDataset): 18 | """`CIFAR10 `_ Dataset. 19 | 20 | Args: 21 | root (string): Root directory of dataset where directory 22 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 23 | train (bool, optional): If True, creates dataset from training set, otherwise 24 | creates from test set. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | download (bool, optional): If true, downloads the dataset from the internet and 30 | puts it in root directory. If dataset is already downloaded, it is not 31 | downloaded again. 32 | 33 | """ 34 | base_folder = 'cifar-100-python' 35 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 36 | filename = "cifar-100-python.tar.gz" 37 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 38 | train_list = [ 39 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 40 | ] 41 | 42 | test_list = [ 43 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 44 | ] 45 | meta = { 46 | 'filename': 'meta', 47 | 'key': 'fine_label_names', 48 | 'md5': '7973b15100ade9c7d40fb424638fde48', 49 | } 50 | 51 | def __init__(self, root, train=True, transform=None, target_transform=None, 52 | download=False, held=False, held_samples=0): 53 | 54 | super(CIFAR100WithHeld, self).__init__(root, transform=transform, 55 | target_transform=target_transform) 56 | 57 | self.train = train # training set or test set 58 | self.held = held 59 | self.held_samples = held_samples 60 | 61 | if download: 62 | self.download() 63 | 64 | if not self._check_integrity(): 65 | raise RuntimeError('Dataset not found or corrupted.' + 66 | ' You can use download=True to download it') 67 | 68 | if self.held: 69 | assert self.train, "Held set is a subset of train set" 70 | 71 | if self.train: 72 | downloaded_list = self.train_list 73 | else: 74 | downloaded_list = self.test_list 75 | 76 | self.data = [] 77 | self.targets = [] 78 | 79 | # now load the picked numpy arrays 80 | for file_name, checksum in downloaded_list: 81 | file_path = os.path.join(self.root, self.base_folder, file_name) 82 | with open(file_path, 'rb') as f: 83 | if sys.version_info[0] == 2: 84 | entry = pickle.load(f) 85 | else: 86 | entry = pickle.load(f, encoding='latin1') 87 | self.data.append(entry['data']) 88 | if 'labels' in entry: 89 | self.targets.extend(entry['labels']) 90 | else: 91 | self.targets.extend(entry['fine_labels']) 92 | 93 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 94 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 95 | 96 | if self.held: 97 | self.data = self.data[-self.held_samples:] 98 | self.targets = self.targets[-self.held_samples:] 99 | elif self.train: 100 | self.data = self.data[:-self.held_samples] 101 | self.targets = self.targets[:-self.held_samples] 102 | 103 | self._load_meta() 104 | 105 | def _load_meta(self): 106 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 107 | if not check_integrity(path, self.meta['md5']): 108 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 109 | ' You can use download=True to download it') 110 | with open(path, 'rb') as infile: 111 | if sys.version_info[0] == 2: 112 | data = pickle.load(infile) 113 | else: 114 | data = pickle.load(infile, encoding='latin1') 115 | self.classes = data[self.meta['key']] 116 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 117 | 118 | def __getitem__(self, index): 119 | """ 120 | Args: 121 | index (int): Index 122 | 123 | Returns: 124 | tuple: (image, target) where target is index of the target class. 125 | """ 126 | img, target = self.data[index], self.targets[index] 127 | 128 | # doing this so that it is consistent with all other datasets 129 | # to return a PIL Image 130 | img = Image.fromarray(img) 131 | 132 | if self.transform is not None: 133 | img = self.transform(img) 134 | 135 | if self.target_transform is not None: 136 | target = self.target_transform(target) 137 | 138 | return img, target 139 | 140 | def __len__(self): 141 | return len(self.data) 142 | 143 | def _check_integrity(self): 144 | root = self.root 145 | for fentry in (self.train_list + self.test_list): 146 | filename, md5 = fentry[0], fentry[1] 147 | fpath = os.path.join(root, self.base_folder, filename) 148 | if not check_integrity(fpath, md5): 149 | return False 150 | return True 151 | 152 | def download(self): 153 | if self._check_integrity(): 154 | print('Files already downloaded and verified') 155 | return 156 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 157 | 158 | def extra_repr(self): 159 | return "Split: {}".format("Train" if self.train is True else "Test") 160 | -------------------------------------------------------------------------------- /cv/dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | get data loaders 3 | """ 4 | from __future__ import print_function 5 | 6 | import os 7 | import socket 8 | import numpy as np 9 | from torch.utils.data import DataLoader 10 | from torchvision import datasets 11 | from torchvision import transforms 12 | 13 | 14 | def get_data_folder(): 15 | """ 16 | return server-dependent path to store the data 17 | """ 18 | hostname = socket.gethostname() 19 | if hostname.startswith('visiongpu'): 20 | data_folder = '/data/vision/phillipi/rep-learn/datasets/imagenet' 21 | elif hostname.startswith('yonglong-home'): 22 | data_folder = '/home/yonglong/Data/data/imagenet' 23 | else: 24 | data_folder = './data/imagenet' 25 | 26 | if not os.path.isdir(data_folder): 27 | os.makedirs(data_folder) 28 | 29 | return data_folder 30 | 31 | 32 | class ImageFolderInstance(datasets.ImageFolder): 33 | """: Folder datasets which returns the index of the image as well:: 34 | """ 35 | def __getitem__(self, index): 36 | """ 37 | Args: 38 | index (int): Index 39 | Returns: 40 | tuple: (image, target) where target is class_index of the target class. 41 | """ 42 | path, target = self.imgs[index] 43 | img = self.loader(path) 44 | if self.transform is not None: 45 | img = self.transform(img) 46 | if self.target_transform is not None: 47 | target = self.target_transform(target) 48 | 49 | return img, target, index 50 | 51 | 52 | class ImageFolderSample(datasets.ImageFolder): 53 | """: Folder datasets which returns (img, label, index, contrast_index): 54 | """ 55 | def __init__(self, root, transform=None, target_transform=None, 56 | is_sample=False, k=4096): 57 | super().__init__(root=root, transform=transform, target_transform=target_transform) 58 | 59 | self.k = k 60 | self.is_sample = is_sample 61 | 62 | print('stage1 finished!') 63 | 64 | if self.is_sample: 65 | num_classes = len(self.classes) 66 | num_samples = len(self.samples) 67 | label = np.zeros(num_samples, dtype=np.int32) 68 | for i in range(num_samples): 69 | path, target = self.imgs[i] 70 | label[i] = target 71 | 72 | self.cls_positive = [[] for i in range(num_classes)] 73 | for i in range(num_samples): 74 | self.cls_positive[label[i]].append(i) 75 | 76 | self.cls_negative = [[] for i in range(num_classes)] 77 | for i in range(num_classes): 78 | for j in range(num_classes): 79 | if j == i: 80 | continue 81 | self.cls_negative[i].extend(self.cls_positive[j]) 82 | 83 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] 84 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] 85 | 86 | print('dataset initialized!') 87 | 88 | def __getitem__(self, index): 89 | """ 90 | Args: 91 | index (int): Index 92 | Returns: 93 | tuple: (image, target) where target is class_index of the target class. 94 | """ 95 | path, target = self.imgs[index] 96 | img = self.loader(path) 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | if self.target_transform is not None: 100 | target = self.target_transform(target) 101 | 102 | if self.is_sample: 103 | # sample contrastive examples 104 | pos_idx = index 105 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) 106 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 107 | return img, target, index, sample_idx 108 | else: 109 | return img, target, index 110 | 111 | 112 | def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8): 113 | """get the test data loader""" 114 | 115 | if dataset == 'imagenet': 116 | data_folder = get_data_folder() 117 | else: 118 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 119 | 120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 121 | std=[0.229, 0.224, 0.225]) 122 | test_transform = transforms.Compose([ 123 | transforms.Resize(256), 124 | transforms.CenterCrop(224), 125 | transforms.ToTensor(), 126 | normalize, 127 | ]) 128 | 129 | test_folder = os.path.join(data_folder, 'val') 130 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 131 | test_loader = DataLoader(test_set, 132 | batch_size=batch_size, 133 | shuffle=False, 134 | num_workers=num_workers, 135 | pin_memory=True) 136 | 137 | return test_loader 138 | 139 | 140 | def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, is_sample=False, k=4096): 141 | """Data Loader for ImageNet""" 142 | 143 | if dataset == 'imagenet': 144 | data_folder = get_data_folder() 145 | else: 146 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 147 | 148 | # add data transform 149 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 150 | std=[0.229, 0.224, 0.225]) 151 | train_transform = transforms.Compose([ 152 | transforms.RandomResizedCrop(224), 153 | transforms.RandomHorizontalFlip(), 154 | transforms.ToTensor(), 155 | normalize, 156 | ]) 157 | test_transform = transforms.Compose([ 158 | transforms.Resize(256), 159 | transforms.CenterCrop(224), 160 | transforms.ToTensor(), 161 | normalize, 162 | ]) 163 | train_folder = os.path.join(data_folder, 'train') 164 | test_folder = os.path.join(data_folder, 'val') 165 | 166 | train_set = ImageFolderSample(train_folder, transform=train_transform, is_sample=is_sample, k=k) 167 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 168 | 169 | train_loader = DataLoader(train_set, 170 | batch_size=batch_size, 171 | shuffle=True, 172 | num_workers=num_workers, 173 | pin_memory=True) 174 | test_loader = DataLoader(test_set, 175 | batch_size=batch_size, 176 | shuffle=False, 177 | num_workers=num_workers, 178 | pin_memory=True) 179 | 180 | print('num_samples', len(train_set.samples)) 181 | print('num_class', len(train_set.classes)) 182 | 183 | return train_loader, test_loader, len(train_set), len(train_set.classes) 184 | 185 | 186 | def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, is_instance=False): 187 | """ 188 | Data Loader for imagenet 189 | """ 190 | if dataset == 'imagenet': 191 | data_folder = get_data_folder() 192 | else: 193 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 194 | 195 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 196 | std=[0.229, 0.224, 0.225]) 197 | train_transform = transforms.Compose([ 198 | transforms.RandomResizedCrop(224), 199 | transforms.RandomHorizontalFlip(), 200 | transforms.ToTensor(), 201 | normalize, 202 | ]) 203 | test_transform = transforms.Compose([ 204 | transforms.Resize(256), 205 | transforms.CenterCrop(224), 206 | transforms.ToTensor(), 207 | normalize, 208 | ]) 209 | 210 | train_folder = os.path.join(data_folder, 'train') 211 | test_folder = os.path.join(data_folder, 'val') 212 | 213 | if is_instance: 214 | train_set = ImageFolderInstance(train_folder, transform=train_transform) 215 | n_data = len(train_set) 216 | else: 217 | train_set = datasets.ImageFolder(train_folder, transform=train_transform) 218 | 219 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 220 | 221 | train_loader = DataLoader(train_set, 222 | batch_size=batch_size, 223 | shuffle=True, 224 | num_workers=num_workers, 225 | pin_memory=True) 226 | 227 | test_loader = DataLoader(test_set, 228 | batch_size=batch_size, 229 | shuffle=False, 230 | num_workers=num_workers//2, 231 | pin_memory=True) 232 | 233 | if is_instance: 234 | return train_loader, test_loader, n_data 235 | else: 236 | return train_loader, test_loader 237 | -------------------------------------------------------------------------------- /cv/dataset/meta_cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import socket 5 | import numpy as np 6 | from torch.utils.data import DataLoader, RandomSampler 7 | from torchvision import transforms 8 | from .cifar_with_held import CIFAR100WithHeld 9 | from PIL import Image 10 | 11 | """ 12 | mean = { 13 | 'cifar100': (0.5071, 0.4867, 0.4408), 14 | } 15 | 16 | std = { 17 | 'cifar100': (0.2675, 0.2565, 0.2761), 18 | } 19 | """ 20 | 21 | 22 | def get_data_folder(): 23 | """ 24 | return server-dependent path to store the data 25 | """ 26 | hostname = socket.gethostname() 27 | if hostname.startswith('visiongpu'): 28 | data_folder = '/data/vision/phillipi/rep-learn/datasets' 29 | elif hostname.startswith('yonglong-home'): 30 | data_folder = '/home/yonglong/Data/data' 31 | else: 32 | data_folder = './data/' 33 | 34 | if not os.path.isdir(data_folder): 35 | os.makedirs(data_folder) 36 | 37 | return data_folder 38 | 39 | 40 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, held_size=0, num_held_samples=0): 41 | """ 42 | cifar 100 43 | """ 44 | data_folder = get_data_folder() 45 | 46 | train_transform = transforms.Compose([ 47 | transforms.RandomCrop(32, padding=4), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 51 | ]) 52 | test_transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 55 | ]) 56 | 57 | train_set = CIFAR100WithHeld(root=data_folder, 58 | download=True, 59 | train=True, 60 | held=False, 61 | held_samples=held_size, 62 | transform=train_transform) 63 | train_loader = DataLoader(train_set, 64 | batch_size=batch_size, 65 | shuffle=True, 66 | num_workers=num_workers) 67 | 68 | held_set = CIFAR100WithHeld(root=data_folder, 69 | download=True, 70 | train=True, 71 | held=True, 72 | held_samples=held_size, 73 | transform=train_transform) 74 | 75 | if num_held_samples == 0: 76 | held_loader = DataLoader(held_set, 77 | batch_size=batch_size, 78 | shuffle=True, 79 | num_workers=num_workers) 80 | else: 81 | held_sampler = RandomSampler(held_set, 82 | replacement=True, 83 | num_samples=num_held_samples) 84 | held_loader = DataLoader(held_set, 85 | sampler=held_sampler, 86 | batch_size=batch_size) 87 | 88 | test_set = CIFAR100WithHeld(root=data_folder, 89 | download=True, 90 | train=False, 91 | transform=test_transform) 92 | test_loader = DataLoader(test_set, 93 | batch_size=int(batch_size/2), 94 | shuffle=False, 95 | num_workers=int(num_workers/2)) 96 | 97 | return train_loader, held_loader, test_loader 98 | -------------------------------------------------------------------------------- /cv/distiller_zoo/AB.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ABLoss(nn.Module): 8 | """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons 9 | code: https://github.com/bhheo/AB_distillation 10 | """ 11 | def __init__(self, feat_num, margin=1.0): 12 | super(ABLoss, self).__init__() 13 | self.w = [2**(i-feat_num+1) for i in range(feat_num)] 14 | self.margin = margin 15 | 16 | def forward(self, g_s, g_t): 17 | bsz = g_s[0].shape[0] 18 | losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)] 19 | losses = [w * l for w, l in zip(self.w, losses)] 20 | # loss = sum(losses) / bsz 21 | # loss = loss / 1000 * 3 22 | losses = [l / bsz for l in losses] 23 | losses = [l / 1000 * 3 for l in losses] 24 | return losses 25 | 26 | def criterion_alternative_l2(self, source, target): 27 | loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() + 28 | (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float()) 29 | return torch.abs(loss).sum() 30 | -------------------------------------------------------------------------------- /cv/distiller_zoo/AT.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Attention(nn.Module): 8 | """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks 9 | via Attention Transfer 10 | code: https://github.com/szagoruyko/attention-transfer""" 11 | def __init__(self, p=2): 12 | super(Attention, self).__init__() 13 | self.p = p 14 | 15 | def forward(self, g_s, g_t): 16 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 17 | 18 | def at_loss(self, f_s, f_t): 19 | s_H, t_H = f_s.shape[2], f_t.shape[2] 20 | if s_H > t_H: 21 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 22 | elif s_H < t_H: 23 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 24 | else: 25 | pass 26 | return (self.at(f_s) - self.at(f_t)).pow(2).mean() 27 | 28 | def at(self, f): 29 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) 30 | -------------------------------------------------------------------------------- /cv/distiller_zoo/CC.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Correlation(nn.Module): 8 | """Correlation Congruence for Knowledge Distillation, ICCV 2019. 9 | The authors nicely shared the code with me. I restructured their code to be 10 | compatible with my running framework. Credits go to the original author""" 11 | def __init__(self): 12 | super(Correlation, self).__init__() 13 | 14 | def forward(self, f_s, f_t): 15 | delta = torch.abs(f_s - f_t) 16 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 17 | return loss 18 | 19 | 20 | # class Correlation(nn.Module): 21 | # """Similarity-preserving loss. My origianl own reimplementation 22 | # based on the paper before emailing the original authors.""" 23 | # def __init__(self): 24 | # super(Correlation, self).__init__() 25 | # 26 | # def forward(self, f_s, f_t): 27 | # return self.similarity_loss(f_s, f_t) 28 | # # return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 29 | # 30 | # def similarity_loss(self, f_s, f_t): 31 | # bsz = f_s.shape[0] 32 | # f_s = f_s.view(bsz, -1) 33 | # f_t = f_t.view(bsz, -1) 34 | # 35 | # G_s = torch.mm(f_s, torch.t(f_s)) 36 | # G_s = G_s / G_s.norm(2) 37 | # G_t = torch.mm(f_t, torch.t(f_t)) 38 | # G_t = G_t / G_t.norm(2) 39 | # 40 | # G_diff = G_t - G_s 41 | # loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 42 | # return loss 43 | -------------------------------------------------------------------------------- /cv/distiller_zoo/FSP.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FSP(nn.Module): 9 | """A Gift from Knowledge Distillation: 10 | Fast Optimization, Network Minimization and Transfer Learning""" 11 | def __init__(self, s_shapes, t_shapes): 12 | super(FSP, self).__init__() 13 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 14 | s_c = [s[1] for s in s_shapes] 15 | t_c = [t[1] for t in t_shapes] 16 | if np.any(np.asarray(s_c) != np.asarray(t_c)): 17 | raise ValueError('num of channels not equal (error in FSP)') 18 | 19 | def forward(self, g_s, g_t): 20 | s_fsp = self.compute_fsp(g_s) 21 | t_fsp = self.compute_fsp(g_t) 22 | loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)] 23 | return loss_group 24 | 25 | @staticmethod 26 | def compute_loss(s, t): 27 | return (s - t).pow(2).mean() 28 | 29 | @staticmethod 30 | def compute_fsp(g): 31 | fsp_list = [] 32 | for i in range(len(g) - 1): 33 | bot, top = g[i], g[i + 1] 34 | b_H, t_H = bot.shape[2], top.shape[2] 35 | if b_H > t_H: 36 | bot = F.adaptive_avg_pool2d(bot, (t_H, t_H)) 37 | elif b_H < t_H: 38 | top = F.adaptive_avg_pool2d(top, (b_H, b_H)) 39 | else: 40 | pass 41 | bot = bot.unsqueeze(1) 42 | top = top.unsqueeze(2) 43 | bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1) 44 | top = top.view(top.shape[0], top.shape[1], top.shape[2], -1) 45 | 46 | fsp = (bot * top).mean(-1) 47 | fsp_list.append(fsp) 48 | return fsp_list 49 | -------------------------------------------------------------------------------- /cv/distiller_zoo/FT.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class FactorTransfer(nn.Module): 8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018""" 9 | def __init__(self, p1=2, p2=1): 10 | super(FactorTransfer, self).__init__() 11 | self.p1 = p1 12 | self.p2 = p2 13 | 14 | def forward(self, f_s, f_t): 15 | return self.factor_loss(f_s, f_t) 16 | 17 | def factor_loss(self, f_s, f_t): 18 | s_H, t_H = f_s.shape[2], f_t.shape[2] 19 | if s_H > t_H: 20 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 21 | elif s_H < t_H: 22 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 23 | else: 24 | pass 25 | if self.p2 == 1: 26 | return (self.factor(f_s) - self.factor(f_t)).abs().mean() 27 | else: 28 | return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean() 29 | 30 | def factor(self, f): 31 | return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1)) 32 | -------------------------------------------------------------------------------- /cv/distiller_zoo/FitNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class HintLoss(nn.Module): 7 | """Fitnets: hints for thin deep nets, ICLR 2015""" 8 | def __init__(self): 9 | super(HintLoss, self).__init__() 10 | self.crit = nn.MSELoss() 11 | 12 | def forward(self, f_s, f_t): 13 | loss = self.crit(f_s, f_t) 14 | return loss 15 | -------------------------------------------------------------------------------- /cv/distiller_zoo/KD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class DistillKL(nn.Module): 9 | """Distilling the Knowledge in a Neural Network""" 10 | def __init__(self, T): 11 | super(DistillKL, self).__init__() 12 | self.T = T 13 | 14 | def forward(self, y_s, y_t): 15 | p_s = F.log_softmax(y_s/self.T, dim=1) 16 | p_t = F.softmax(y_t/self.T, dim=1) 17 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] 18 | return loss 19 | 20 | 21 | class CustomDistillKL(nn.Module): 22 | """Distilling the Knowledge in a Neural Network""" 23 | def __init__(self, T): 24 | super(CustomDistillKL, self).__init__() 25 | self.T = T 26 | 27 | def forward(self, y_s, y_t): 28 | p_s = F.log_softmax(y_s/self.T, dim=1) 29 | p_t = F.softmax(y_t/self.T, dim=1) 30 | loss = (p_t * (p_t.log() - p_s)).sum(dim=1).mean(dim=0) * (self.T ** 2) 31 | return loss 32 | 33 | 34 | if __name__ == '__main__': 35 | kl_1 = DistillKL(3) 36 | kl_2 = CustomDistillKL(3) 37 | student = torch.tensor([[3., 7., 1.], [4., 8., 5.], [0.25, 0.76, 0.2], [0.42, 0.99, 0.8]]) 38 | teacher = torch.tensor([[4., 6., 5.], [2., 7., 6.], [0.28, 0.94, 0.1], [0.49, 0.75, 0.0]]) 39 | print(kl_1(student, teacher), kl_2(student, teacher)) 40 | assert kl_1(student, teacher) - kl_2(student, teacher) < 0.000001 41 | -------------------------------------------------------------------------------- /cv/distiller_zoo/KDSVD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class KDSVD(nn.Module): 9 | """ 10 | Self-supervised Knowledge Distillation using Singular Value Decomposition 11 | original Tensorflow code: https://github.com/sseung0703/SSKD_SVD 12 | """ 13 | def __init__(self, k=1): 14 | super(KDSVD, self).__init__() 15 | self.k = k 16 | 17 | def forward(self, g_s, g_t): 18 | v_sb = None 19 | v_tb = None 20 | losses = [] 21 | for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): 22 | 23 | u_t, s_t, v_t = self.svd(f_t, self.k) 24 | u_s, s_s, v_s = self.svd(f_s, self.k + 3) 25 | v_s, v_t = self.align_rsv(v_s, v_t) 26 | s_t = s_t.unsqueeze(1) 27 | v_t = v_t * s_t 28 | v_s = v_s * s_t 29 | 30 | if i > 0: 31 | s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) 32 | t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) 33 | 34 | l2loss = (s_rbf - t_rbf.detach()).pow(2) 35 | l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss)) 36 | losses.append(l2loss.sum()) 37 | 38 | v_tb = v_t 39 | v_sb = v_s 40 | 41 | bsz = g_s[0].shape[0] 42 | losses = [l / bsz for l in losses] 43 | return losses 44 | 45 | def svd(self, feat, n=1): 46 | size = feat.shape 47 | assert len(size) == 4 48 | 49 | x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1) 50 | u, s, v = torch.svd(x) 51 | 52 | u = self.removenan(u) 53 | s = self.removenan(s) 54 | v = self.removenan(v) 55 | 56 | if n > 0: 57 | u = F.normalize(u[:, :, :n], dim=1) 58 | s = F.normalize(s[:, :n], dim=1) 59 | v = F.normalize(v[:, :, :n], dim=1) 60 | 61 | return u, s, v 62 | 63 | @staticmethod 64 | def removenan(x): 65 | x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) 66 | return x 67 | 68 | @staticmethod 69 | def align_rsv(a, b): 70 | cosine = torch.matmul(a.transpose(-2, -1), b) 71 | max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) 72 | mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)), 73 | torch.sign(cosine), torch.zeros_like(cosine)) 74 | a = torch.matmul(a, mask) 75 | return a, b 76 | -------------------------------------------------------------------------------- /cv/distiller_zoo/MSE.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class MSEWithTemperature(nn.Module): 9 | """Distilling the Knowledge in a Neural Network""" 10 | def __init__(self, T): 11 | super(MSEWithTemperature, self).__init__() 12 | self.T = T 13 | 14 | def forward(self, y_s, y_t): 15 | p_s = y_s / self.T 16 | p_t = y_t / self.T 17 | loss = F.mse_loss(p_s, p_t) 18 | return loss 19 | -------------------------------------------------------------------------------- /cv/distiller_zoo/NST.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class NSTLoss(nn.Module): 8 | """like what you like: knowledge distill via neuron selectivity transfer""" 9 | def __init__(self): 10 | super(NSTLoss, self).__init__() 11 | pass 12 | 13 | def forward(self, g_s, g_t): 14 | return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 15 | 16 | def nst_loss(self, f_s, f_t): 17 | s_H, t_H = f_s.shape[2], f_t.shape[2] 18 | if s_H > t_H: 19 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 20 | elif s_H < t_H: 21 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 22 | else: 23 | pass 24 | 25 | f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) 26 | f_s = F.normalize(f_s, dim=2) 27 | f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) 28 | f_t = F.normalize(f_t, dim=2) 29 | 30 | # set full_loss as False to avoid unnecessary computation 31 | full_loss = True 32 | if full_loss: 33 | return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() 34 | - 2 * self.poly_kernel(f_s, f_t).mean()) 35 | else: 36 | return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() 37 | 38 | def poly_kernel(self, a, b): 39 | a = a.unsqueeze(1) 40 | b = b.unsqueeze(2) 41 | res = (a * b).sum(-1).pow(2) 42 | return res -------------------------------------------------------------------------------- /cv/distiller_zoo/PKT.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PKT(nn.Module): 8 | """Probabilistic Knowledge Transfer for deep representation learning 9 | Code from author: https://github.com/passalis/probabilistic_kt""" 10 | def __init__(self): 11 | super(PKT, self).__init__() 12 | 13 | def forward(self, f_s, f_t): 14 | return self.cosine_similarity_loss(f_s, f_t) 15 | 16 | @staticmethod 17 | def cosine_similarity_loss(output_net, target_net, eps=0.0000001): 18 | # Normalize each vector by its norm 19 | output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) 20 | output_net = output_net / (output_net_norm + eps) 21 | output_net[output_net != output_net] = 0 22 | 23 | target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) 24 | target_net = target_net / (target_net_norm + eps) 25 | target_net[target_net != target_net] = 0 26 | 27 | # Calculate the cosine similarity 28 | model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) 29 | target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) 30 | 31 | # Scale cosine similarity to 0..1 32 | model_similarity = (model_similarity + 1.0) / 2.0 33 | target_similarity = (target_similarity + 1.0) / 2.0 34 | 35 | # Transform them into probabilities 36 | model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) 37 | target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) 38 | 39 | # Calculate the KL-divergence 40 | loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) 41 | 42 | return loss 43 | -------------------------------------------------------------------------------- /cv/distiller_zoo/RKD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class RKDLoss(nn.Module): 9 | """Relational Knowledge Disitllation, CVPR2019""" 10 | def __init__(self, w_d=25, w_a=50): 11 | super(RKDLoss, self).__init__() 12 | self.w_d = w_d 13 | self.w_a = w_a 14 | 15 | def forward(self, f_s, f_t): 16 | student = f_s.view(f_s.shape[0], -1) 17 | teacher = f_t.view(f_t.shape[0], -1) 18 | 19 | # RKD distance loss 20 | with torch.no_grad(): 21 | t_d = self.pdist(teacher, squared=False) 22 | mean_td = t_d[t_d > 0].mean() 23 | t_d = t_d / mean_td 24 | 25 | d = self.pdist(student, squared=False) 26 | mean_d = d[d > 0].mean() 27 | d = d / mean_d 28 | 29 | loss_d = F.smooth_l1_loss(d, t_d) 30 | 31 | # RKD Angle loss 32 | with torch.no_grad(): 33 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 34 | norm_td = F.normalize(td, p=2, dim=2) 35 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 36 | 37 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 38 | norm_sd = F.normalize(sd, p=2, dim=2) 39 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 40 | 41 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 42 | 43 | loss = self.w_d * loss_d + self.w_a * loss_a 44 | 45 | return loss 46 | 47 | @staticmethod 48 | def pdist(e, squared=False, eps=1e-12): 49 | e_square = e.pow(2).sum(dim=1) 50 | prod = e @ e.t() 51 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 52 | 53 | if not squared: 54 | res = res.sqrt() 55 | 56 | res = res.clone() 57 | res[range(len(e)), range(len(e))] = 0 58 | return res 59 | -------------------------------------------------------------------------------- /cv/distiller_zoo/SP.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Similarity(nn.Module): 9 | """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" 10 | def __init__(self): 11 | super(Similarity, self).__init__() 12 | 13 | def forward(self, g_s, g_t): 14 | return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 15 | 16 | def similarity_loss(self, f_s, f_t): 17 | bsz = f_s.shape[0] 18 | f_s = f_s.view(bsz, -1) 19 | f_t = f_t.view(bsz, -1) 20 | 21 | G_s = torch.mm(f_s, torch.t(f_s)) 22 | # G_s = G_s / G_s.norm(2) 23 | G_s = torch.nn.functional.normalize(G_s) 24 | G_t = torch.mm(f_t, torch.t(f_t)) 25 | # G_t = G_t / G_t.norm(2) 26 | G_t = torch.nn.functional.normalize(G_t) 27 | 28 | G_diff = G_t - G_s 29 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 30 | return loss 31 | -------------------------------------------------------------------------------- /cv/distiller_zoo/VID.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class VIDLoss(nn.Module): 10 | """Variational Information Distillation for Knowledge Transfer (CVPR 2019), 11 | code from author: https://github.com/ssahn0215/variational-information-distillation""" 12 | def __init__(self, 13 | num_input_channels, 14 | num_mid_channel, 15 | num_target_channels, 16 | init_pred_var=5.0, 17 | eps=1e-5): 18 | super(VIDLoss, self).__init__() 19 | 20 | def conv1x1(in_channels, out_channels, stride=1): 21 | return nn.Conv2d( 22 | in_channels, out_channels, 23 | kernel_size=1, padding=0, 24 | bias=False, stride=stride) 25 | 26 | self.regressor = nn.Sequential( 27 | conv1x1(num_input_channels, num_mid_channel), 28 | nn.ReLU(), 29 | conv1x1(num_mid_channel, num_mid_channel), 30 | nn.ReLU(), 31 | conv1x1(num_mid_channel, num_target_channels), 32 | ) 33 | self.log_scale = torch.nn.Parameter( 34 | np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) 35 | ) 36 | self.eps = eps 37 | 38 | def forward(self, input, target): 39 | # pool for dimentsion match 40 | s_H, t_H = input.shape[2], target.shape[2] 41 | if s_H > t_H: 42 | input = F.adaptive_avg_pool2d(input, (t_H, t_H)) 43 | elif s_H < t_H: 44 | target = F.adaptive_avg_pool2d(target, (s_H, s_H)) 45 | else: 46 | pass 47 | pred_mean = self.regressor(input) 48 | pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps 49 | pred_var = pred_var.view(1, -1, 1, 1) 50 | neg_log_prob = 0.5*( 51 | (pred_mean-target)**2/pred_var+torch.log(pred_var) 52 | ) 53 | loss = torch.mean(neg_log_prob) 54 | return loss 55 | -------------------------------------------------------------------------------- /cv/distiller_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .AB import ABLoss 2 | from .AT import Attention 3 | from .CC import Correlation 4 | from .FitNet import HintLoss 5 | from .FSP import FSP 6 | from .FT import FactorTransfer 7 | from .KD import DistillKL 8 | from .KDSVD import KDSVD 9 | from .NST import NSTLoss 10 | from .PKT import PKT 11 | from .RKD import RKDLoss 12 | from .SP import Similarity 13 | from .VID import VIDLoss 14 | -------------------------------------------------------------------------------- /cv/helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetRunner/MetaDistil/80e60c11de531b10d1f06ceb2b71c70665bb6aff/cv/helper/__init__.py -------------------------------------------------------------------------------- /cv/helper/loops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import sys 4 | import time 5 | import torch 6 | 7 | from .util import AverageMeter, accuracy 8 | 9 | 10 | def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt): 11 | """vanilla training""" 12 | model.train() 13 | 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | losses = AverageMeter() 17 | top1 = AverageMeter() 18 | top5 = AverageMeter() 19 | 20 | end = time.time() 21 | for idx, (input, target) in enumerate(train_loader): 22 | data_time.update(time.time() - end) 23 | 24 | input = input.float() 25 | if torch.cuda.is_available(): 26 | input = input.cuda() 27 | target = target.cuda() 28 | 29 | # ===================forward===================== 30 | output = model(input) 31 | loss = criterion(output, target) 32 | 33 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 34 | losses.update(loss.item(), input.size(0)) 35 | top1.update(acc1[0], input.size(0)) 36 | top5.update(acc5[0], input.size(0)) 37 | 38 | # ===================backward===================== 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | # ===================meters===================== 44 | batch_time.update(time.time() - end) 45 | end = time.time() 46 | 47 | # tensorboard logger 48 | pass 49 | 50 | # print info 51 | if idx % opt.print_freq == 0: 52 | print('Epoch: [{0}][{1}/{2}]\t' 53 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 54 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 55 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 56 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 57 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 58 | epoch, idx, len(train_loader), batch_time=batch_time, 59 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 60 | sys.stdout.flush() 61 | 62 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 63 | .format(top1=top1, top5=top5)) 64 | 65 | return top1.avg, losses.avg 66 | 67 | 68 | def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt): 69 | """One epoch distillation""" 70 | # set modules as train() 71 | for module in module_list: 72 | module.train() 73 | # set teacher as eval() 74 | module_list[-1].eval() 75 | 76 | if opt.distill == 'abound': 77 | module_list[1].eval() 78 | elif opt.distill == 'factor': 79 | module_list[2].eval() 80 | 81 | criterion_cls = criterion_list[0] 82 | criterion_div = criterion_list[1] 83 | criterion_kd = criterion_list[2] 84 | 85 | model_s = module_list[0] 86 | model_t = module_list[-1] 87 | 88 | batch_time = AverageMeter() 89 | data_time = AverageMeter() 90 | losses = AverageMeter() 91 | top1 = AverageMeter() 92 | top5 = AverageMeter() 93 | 94 | end = time.time() 95 | for idx, data in enumerate(train_loader): 96 | if opt.distill in ['crd']: 97 | input, target, index, contrast_idx = data 98 | else: 99 | input, target, index = data 100 | data_time.update(time.time() - end) 101 | 102 | input = input.float() 103 | if torch.cuda.is_available(): 104 | input = input.cuda() 105 | target = target.cuda() 106 | index = index.cuda() 107 | if opt.distill in ['crd']: 108 | contrast_idx = contrast_idx.cuda() 109 | 110 | # ===================forward===================== 111 | preact = False 112 | if opt.distill in ['abound']: 113 | preact = True 114 | feat_s, logit_s = model_s(input, is_feat=True, preact=preact) 115 | with torch.no_grad(): 116 | feat_t, logit_t = model_t(input, is_feat=True, preact=preact) 117 | feat_t = [f.detach() for f in feat_t] 118 | 119 | # cls + kl div 120 | loss_cls = criterion_cls(logit_s, target) 121 | loss_div = criterion_div(logit_s, logit_t) 122 | 123 | # other kd beyond KL divergence 124 | if opt.distill == 'kd': 125 | loss_kd = 0 126 | elif opt.distill == 'hint': 127 | f_s = module_list[1](feat_s[opt.hint_layer]) 128 | f_t = feat_t[opt.hint_layer] 129 | loss_kd = criterion_kd(f_s, f_t) 130 | elif opt.distill == 'crd': 131 | f_s = feat_s[-1] 132 | f_t = feat_t[-1] 133 | loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) 134 | elif opt.distill == 'attention': 135 | g_s = feat_s[1:-1] 136 | g_t = feat_t[1:-1] 137 | loss_group = criterion_kd(g_s, g_t) 138 | loss_kd = sum(loss_group) 139 | elif opt.distill == 'nst': 140 | g_s = feat_s[1:-1] 141 | g_t = feat_t[1:-1] 142 | loss_group = criterion_kd(g_s, g_t) 143 | loss_kd = sum(loss_group) 144 | elif opt.distill == 'similarity': 145 | g_s = [feat_s[-2]] 146 | g_t = [feat_t[-2]] 147 | loss_group = criterion_kd(g_s, g_t) 148 | loss_kd = sum(loss_group) 149 | elif opt.distill == 'rkd': 150 | f_s = feat_s[-1] 151 | f_t = feat_t[-1] 152 | loss_kd = criterion_kd(f_s, f_t) 153 | elif opt.distill == 'pkt': 154 | f_s = feat_s[-1] 155 | f_t = feat_t[-1] 156 | loss_kd = criterion_kd(f_s, f_t) 157 | elif opt.distill == 'kdsvd': 158 | g_s = feat_s[1:-1] 159 | g_t = feat_t[1:-1] 160 | loss_group = criterion_kd(g_s, g_t) 161 | loss_kd = sum(loss_group) 162 | elif opt.distill == 'correlation': 163 | f_s = module_list[1](feat_s[-1]) 164 | f_t = module_list[2](feat_t[-1]) 165 | loss_kd = criterion_kd(f_s, f_t) 166 | elif opt.distill == 'vid': 167 | g_s = feat_s[1:-1] 168 | g_t = feat_t[1:-1] 169 | loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)] 170 | loss_kd = sum(loss_group) 171 | elif opt.distill == 'abound': 172 | # can also add loss to this stage 173 | loss_kd = 0 174 | elif opt.distill == 'fsp': 175 | # can also add loss to this stage 176 | loss_kd = 0 177 | elif opt.distill == 'factor': 178 | factor_s = module_list[1](feat_s[-2]) 179 | factor_t = module_list[2](feat_t[-2], is_factor=True) 180 | loss_kd = criterion_kd(factor_s, factor_t) 181 | else: 182 | raise NotImplementedError(opt.distill) 183 | 184 | loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd 185 | 186 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5)) 187 | losses.update(loss.item(), input.size(0)) 188 | top1.update(acc1[0], input.size(0)) 189 | top5.update(acc5[0], input.size(0)) 190 | 191 | # ===================backward===================== 192 | optimizer.zero_grad() 193 | loss.backward() 194 | optimizer.step() 195 | 196 | # ===================meters===================== 197 | batch_time.update(time.time() - end) 198 | end = time.time() 199 | 200 | # print info 201 | if idx % opt.print_freq == 0: 202 | print('Epoch: [{0}][{1}/{2}]\t' 203 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 204 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 205 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 206 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 207 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 208 | epoch, idx, len(train_loader), batch_time=batch_time, 209 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 210 | sys.stdout.flush() 211 | 212 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 213 | .format(top1=top1, top5=top5)) 214 | 215 | return top1.avg, losses.avg 216 | 217 | 218 | def validate(val_loader, model, criterion, opt): 219 | """validation""" 220 | batch_time = AverageMeter() 221 | losses = AverageMeter() 222 | top1 = AverageMeter() 223 | top5 = AverageMeter() 224 | 225 | # switch to evaluate mode 226 | model.eval() 227 | 228 | with torch.no_grad(): 229 | end = time.time() 230 | for idx, (input, target) in enumerate(val_loader): 231 | 232 | input = input.float() 233 | if torch.cuda.is_available(): 234 | input = input.cuda() 235 | target = target.cuda() 236 | 237 | # compute output 238 | output = model(input) 239 | loss = criterion(output, target) 240 | 241 | # measure accuracy and record loss 242 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 243 | losses.update(loss.item(), input.size(0)) 244 | top1.update(acc1[0], input.size(0)) 245 | top5.update(acc5[0], input.size(0)) 246 | 247 | # measure elapsed time 248 | batch_time.update(time.time() - end) 249 | end = time.time() 250 | 251 | if idx % opt.print_freq == 0: 252 | print('Test: [{0}/{1}]\t' 253 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 255 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 256 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 257 | idx, len(val_loader), batch_time=batch_time, loss=losses, 258 | top1=top1, top5=top5)) 259 | 260 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 261 | .format(top1=top1, top5=top5)) 262 | 263 | return top1.avg, top5.avg, losses.avg 264 | -------------------------------------------------------------------------------- /cv/helper/meta_loops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import sys 4 | import time 5 | from collections import OrderedDict 6 | 7 | import torch 8 | from copy import deepcopy as cp 9 | 10 | from .util import AverageMeter, accuracy 11 | 12 | 13 | def train_distill(epoch, train_loader, held_loader, module_list, criterion_list, s_optimizer, t_optimizer, opt): 14 | """One epoch distillation""" 15 | 16 | criterion_cls = criterion_list[0] 17 | criterion_kd = criterion_list[1] 18 | 19 | s_model = module_list[0] 20 | t_model = module_list[-1] 21 | 22 | batch_time = AverageMeter() 23 | 24 | assume_losses = AverageMeter() 25 | assume_top1 = AverageMeter() 26 | assume_top5 = AverageMeter() 27 | 28 | held_losses = AverageMeter() 29 | 30 | real_losses = AverageMeter() 31 | real_top1 = AverageMeter() 32 | real_top5 = AverageMeter() 33 | 34 | end = time.time() 35 | 36 | total_steps_one_epoch = len(train_loader) 37 | batches_buffer = [] 38 | round_counter = 0 39 | 40 | for d_idx, d_data in enumerate(train_loader): 41 | 42 | batches_buffer.append((d_idx, d_data)) 43 | 44 | if (d_idx + 1) % opt.num_meta_batches != 0 and (d_idx + 1) != total_steps_one_epoch: 45 | continue 46 | 47 | ######################################### 48 | # Step 1: Assume S' # 49 | ######################################### 50 | 51 | # Time machine! 52 | fast_weights = OrderedDict((name, param) for (name, param) in s_model.named_parameters()) 53 | s_model_backup_state_dict, s_optimizer_backup_state_dict = cp(s_model.state_dict()), cp(s_optimizer.state_dict()) 54 | 55 | s_model.train() 56 | t_model.eval() 57 | 58 | for idx, data in batches_buffer: 59 | 60 | input, target = data 61 | 62 | input = input.float() 63 | if torch.cuda.is_available(): 64 | input = input.cuda() 65 | target = target.cuda() 66 | 67 | logit_s = s_model(input, params=None if idx == 0 else fast_weights) 68 | logit_t = t_model(input) 69 | 70 | assume_loss_cls = criterion_cls(logit_s, target) 71 | assume_loss_kd = criterion_kd(logit_s, logit_t) 72 | 73 | assume_loss = opt.alpha * assume_loss_kd + (1 - opt.alpha) * assume_loss_cls 74 | 75 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5)) 76 | assume_losses.update(assume_loss.item(), input.size(0)) 77 | assume_top1.update(acc1[0], input.size(0)) 78 | assume_top5.update(acc5[0], input.size(0)) 79 | 80 | grads = torch.autograd.grad(assume_loss, s_model.parameters() if idx == 0 else fast_weights.values(), 81 | create_graph=True, retain_graph=True) 82 | 83 | fast_weights = OrderedDict( 84 | (name, param - opt.assume_s_step_size * grad) for ((name, param), grad) in 85 | zip(fast_weights.items(), grads)) 86 | 87 | ######################################### 88 | # Step 2: Train T with S' on HELD set # 89 | ######################################### 90 | 91 | s_prime_loss = None 92 | 93 | t_model.train() 94 | held_batch_num = 0 95 | 96 | for idx, data in enumerate(held_loader): 97 | input, target = data 98 | 99 | input = input.float() 100 | if torch.cuda.is_available(): 101 | input = input.cuda() 102 | target = target.cuda() 103 | 104 | logit_s_prime = s_model(input, params=fast_weights) 105 | s_prime_step_loss = criterion_cls(logit_s_prime, target) 106 | 107 | if s_prime_loss is None: 108 | s_prime_loss = s_prime_step_loss 109 | else: 110 | s_prime_loss += s_prime_step_loss 111 | 112 | held_batch_num += 1 113 | 114 | s_prime_loss /= held_batch_num 115 | t_grads = torch.autograd.grad(s_prime_loss, t_model.parameters()) 116 | 117 | for p, gr in zip(t_model.parameters(), t_grads): 118 | p.grad = gr 119 | 120 | held_losses.update(s_prime_loss.item(), 1) 121 | 122 | t_optimizer.step() 123 | 124 | # Manual zero_grad 125 | for p in t_model.parameters(): 126 | p.grad = None 127 | 128 | for p in s_model.parameters(): 129 | p.grad = None 130 | 131 | del t_grads 132 | del grads 133 | del fast_weights 134 | 135 | ######################################### 136 | # Step 3: Actually update S # 137 | ######################################### 138 | 139 | # We use the Time Machine! 140 | s_model.load_state_dict(s_model_backup_state_dict) 141 | s_optimizer.load_state_dict(s_optimizer_backup_state_dict) 142 | 143 | del s_model_backup_state_dict, s_optimizer_backup_state_dict 144 | 145 | s_model.train() 146 | t_model.eval() 147 | 148 | for idx, data in batches_buffer: 149 | 150 | input, target = data 151 | 152 | input = input.float() 153 | if torch.cuda.is_available(): 154 | input = input.cuda() 155 | target = target.cuda() 156 | 157 | logit_s = s_model(input) 158 | with torch.no_grad(): 159 | logit_t = t_model(input) 160 | 161 | real_loss_cls = criterion_cls(logit_s, target) 162 | real_loss_kd = criterion_kd(logit_s, logit_t) 163 | 164 | real_loss = opt.alpha * real_loss_kd + (1 - opt.alpha) * real_loss_cls 165 | 166 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5)) 167 | real_losses.update(real_loss.item(), input.size(0)) 168 | real_top1.update(acc1[0], input.size(0)) 169 | real_top5.update(acc5[0], input.size(0)) 170 | 171 | real_loss.backward() 172 | 173 | s_optimizer.step() 174 | s_optimizer.zero_grad() 175 | 176 | round_counter += 1 177 | batches_buffer = [] 178 | 179 | # ===================meters===================== 180 | batch_time.update(time.time() - end) 181 | end = time.time() 182 | 183 | # print info 184 | if round_counter % opt.print_freq == 0: 185 | print('Epoch: [{0}][{1}/{2}]\t' 186 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 187 | 'Real_Loss {loss.val:.4f} ({loss.avg:.4f})\t' 188 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 189 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 190 | epoch, idx, len(train_loader), batch_time=batch_time, 191 | loss=real_losses, top1=real_top1, top5=real_top5)) 192 | sys.stdout.flush() 193 | 194 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 195 | .format(top1=real_top1, top5=real_top5)) 196 | 197 | return real_top1.avg, real_losses.avg 198 | 199 | 200 | def validate(val_loader, model, criterion, opt): 201 | """validation""" 202 | batch_time = AverageMeter() 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | top5 = AverageMeter() 206 | 207 | # switch to evaluate mode 208 | model.eval() 209 | 210 | with torch.no_grad(): 211 | end = time.time() 212 | for idx, (input, target) in enumerate(val_loader): 213 | 214 | input = input.float() 215 | if torch.cuda.is_available(): 216 | input = input.cuda() 217 | target = target.cuda() 218 | 219 | # compute output 220 | output = model(input) 221 | loss = criterion(output, target) 222 | 223 | # measure accuracy and record loss 224 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 225 | losses.update(loss.item(), input.size(0)) 226 | top1.update(acc1[0], input.size(0)) 227 | top5.update(acc5[0], input.size(0)) 228 | 229 | # measure elapsed time 230 | batch_time.update(time.time() - end) 231 | end = time.time() 232 | 233 | if idx % opt.print_freq == 0: 234 | print('Test: [{0}/{1}]\t' 235 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 236 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 237 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 238 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 239 | idx, len(val_loader), batch_time=batch_time, loss=losses, 240 | top1=top1, top5=top5)) 241 | 242 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 243 | .format(top1=top1, top5=top5)) 244 | 245 | return top1.avg, top5.avg, losses.avg 246 | -------------------------------------------------------------------------------- /cv/helper/pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import time 4 | import sys 5 | import torch 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | from .util import AverageMeter 9 | 10 | 11 | def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt): 12 | model_t.eval() 13 | model_s.eval() 14 | init_modules.train() 15 | 16 | if torch.cuda.is_available(): 17 | model_s.cuda() 18 | model_t.cuda() 19 | init_modules.cuda() 20 | cudnn.benchmark = True 21 | 22 | if opt.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 23 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and \ 24 | opt.distill == 'factor': 25 | lr = 0.01 26 | else: 27 | lr = opt.learning_rate 28 | optimizer = optim.SGD(init_modules.parameters(), 29 | lr=lr, 30 | momentum=opt.momentum, 31 | weight_decay=opt.weight_decay) 32 | 33 | batch_time = AverageMeter() 34 | data_time = AverageMeter() 35 | losses = AverageMeter() 36 | for epoch in range(1, opt.init_epochs + 1): 37 | batch_time.reset() 38 | data_time.reset() 39 | losses.reset() 40 | end = time.time() 41 | for idx, data in enumerate(train_loader): 42 | if opt.distill in ['crd']: 43 | input, target, index, contrast_idx = data 44 | else: 45 | input, target, index = data 46 | data_time.update(time.time() - end) 47 | 48 | input = input.float() 49 | if torch.cuda.is_available(): 50 | input = input.cuda() 51 | target = target.cuda() 52 | index = index.cuda() 53 | if opt.distill in ['crd']: 54 | contrast_idx = contrast_idx.cuda() 55 | 56 | # ============= forward ============== 57 | preact = (opt.distill == 'abound') 58 | feat_s, _ = model_s(input, is_feat=True, preact=preact) 59 | with torch.no_grad(): 60 | feat_t, _ = model_t(input, is_feat=True, preact=preact) 61 | feat_t = [f.detach() for f in feat_t] 62 | 63 | if opt.distill == 'abound': 64 | g_s = init_modules[0](feat_s[1:-1]) 65 | g_t = feat_t[1:-1] 66 | loss_group = criterion(g_s, g_t) 67 | loss = sum(loss_group) 68 | elif opt.distill == 'factor': 69 | f_t = feat_t[-2] 70 | _, f_t_rec = init_modules[0](f_t) 71 | loss = criterion(f_t_rec, f_t) 72 | elif opt.distill == 'fsp': 73 | loss_group = criterion(feat_s[:-1], feat_t[:-1]) 74 | loss = sum(loss_group) 75 | else: 76 | raise NotImplemented('Not supported in init training: {}'.format(opt.distill)) 77 | 78 | losses.update(loss.item(), input.size(0)) 79 | 80 | # ===================backward===================== 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | batch_time.update(time.time() - end) 86 | end = time.time() 87 | 88 | # end of epoch 89 | logger.log_value('init_train_loss', losses.avg, epoch) 90 | print('Epoch: [{0}/{1}]\t' 91 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 92 | 'losses: {losses.val:.3f} ({losses.avg:.3f})'.format( 93 | epoch, opt.init_epochs, batch_time=batch_time, losses=losses)) 94 | sys.stdout.flush() 95 | -------------------------------------------------------------------------------- /cv/helper/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def adjust_learning_rate_new(epoch, optimizer, LUT): 8 | """ 9 | new learning rate schedule according to RotNet 10 | """ 11 | lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1]) 12 | for param_group in optimizer.param_groups: 13 | param_group['lr'] = lr 14 | 15 | 16 | def adjust_learning_rate(epoch, opt, optimizer): 17 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 18 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 19 | if steps > 0: 20 | new_lr = opt.lr * (opt.lr_decay_rate ** steps) 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = new_lr 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | def accuracy(output, target, topk=(1,)): 44 | """Computes the accuracy over the k top predictions for the specified values of k""" 45 | with torch.no_grad(): 46 | maxk = max(topk) 47 | batch_size = target.size(0) 48 | 49 | _, pred = output.topk(maxk, 1, True, True) 50 | pred = pred.t() 51 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 52 | 53 | res = [] 54 | for k in topk: 55 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 56 | res.append(correct_k.mul_(100.0 / batch_size)) 57 | return res 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | pass 63 | -------------------------------------------------------------------------------- /cv/models/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N,C,H,W = x.size() 17 | g = self.groups 18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 23 | super(Bottleneck, self).__init__() 24 | self.is_last = is_last 25 | self.stride = stride 26 | 27 | mid_planes = int(out_planes/4) 28 | g = 1 if in_planes == 24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res 48 | out = F.relu(preact) 49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 50 | if self.is_last: 51 | return out, preact 52 | else: 53 | return out 54 | 55 | 56 | class ShuffleNet(nn.Module): 57 | def __init__(self, cfg, num_classes=10): 58 | super(ShuffleNet, self).__init__() 59 | out_planes = cfg['out_planes'] 60 | num_blocks = cfg['num_blocks'] 61 | groups = cfg['groups'] 62 | 63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(24) 65 | self.in_planes = 24 66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 69 | self.linear = nn.Linear(out_planes[2], num_classes) 70 | 71 | def _make_layer(self, out_planes, num_blocks, groups): 72 | layers = [] 73 | for i in range(num_blocks): 74 | stride = 2 if i == 0 else 1 75 | cat_planes = self.in_planes if i == 0 else 0 76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 77 | stride=stride, 78 | groups=groups, 79 | is_last=(i == num_blocks - 1))) 80 | self.in_planes = out_planes 81 | return nn.Sequential(*layers) 82 | 83 | def get_feat_modules(self): 84 | feat_m = nn.ModuleList([]) 85 | feat_m.append(self.conv1) 86 | feat_m.append(self.bn1) 87 | feat_m.append(self.layer1) 88 | feat_m.append(self.layer2) 89 | feat_m.append(self.layer3) 90 | return feat_m 91 | 92 | def get_bn_before_relu(self): 93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 94 | 95 | def forward(self, x, is_feat=False, preact=False): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | f0 = out 98 | out, f1_pre = self.layer1(out) 99 | f1 = out 100 | out, f2_pre = self.layer2(out) 101 | f2 = out 102 | out, f3_pre = self.layer3(out) 103 | f3 = out 104 | out = F.avg_pool2d(out, 4) 105 | out = out.view(out.size(0), -1) 106 | f4 = out 107 | out = self.linear(out) 108 | 109 | if is_feat: 110 | if preact: 111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 112 | else: 113 | return [f0, f1, f2, f3, f4], out 114 | else: 115 | return out 116 | 117 | 118 | def ShuffleV1(**kwargs): 119 | cfg = { 120 | 'out_planes': [240, 480, 960], 121 | 'num_blocks': [4, 8, 4], 122 | 'groups': 3 123 | } 124 | return ShuffleNet(cfg, **kwargs) 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | x = torch.randn(2, 3, 32, 32) 130 | net = ShuffleV1(num_classes=100) 131 | import time 132 | a = time.time() 133 | feats, logit = net(x, is_feat=True, preact=True) 134 | b = time.time() 135 | print(b - a) 136 | for f in feats: 137 | print(f.shape, f.min().item()) 138 | print(logit.shape) 139 | -------------------------------------------------------------------------------- /cv/models/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups=2): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N, C, H, W = x.size() 17 | g = self.groups 18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 19 | 20 | 21 | class SplitBlock(nn.Module): 22 | def __init__(self, ratio): 23 | super(SplitBlock, self).__init__() 24 | self.ratio = ratio 25 | 26 | def forward(self, x): 27 | c = int(x.size(1) * self.ratio) 28 | return x[:, :c, :, :], x[:, c:, :, :] 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 33 | super(BasicBlock, self).__init__() 34 | self.is_last = is_last 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | preact = self.bn3(self.conv3(out)) 53 | out = F.relu(preact) 54 | # out = F.relu(self.bn3(self.conv3(out))) 55 | preact = torch.cat([x1, preact], 1) 56 | out = torch.cat([x1, out], 1) 57 | out = self.shuffle(out) 58 | if self.is_last: 59 | return out, preact 60 | else: 61 | return out 62 | 63 | 64 | class DownBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels): 66 | super(DownBlock, self).__init__() 67 | mid_channels = out_channels // 2 68 | # left 69 | self.conv1 = nn.Conv2d(in_channels, in_channels, 70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 71 | self.bn1 = nn.BatchNorm2d(in_channels) 72 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 73 | kernel_size=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(mid_channels) 75 | # right 76 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(mid_channels) 79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 81 | self.bn4 = nn.BatchNorm2d(mid_channels) 82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 83 | kernel_size=1, bias=False) 84 | self.bn5 = nn.BatchNorm2d(mid_channels) 85 | 86 | self.shuffle = ShuffleBlock() 87 | 88 | def forward(self, x): 89 | # left 90 | out1 = self.bn1(self.conv1(x)) 91 | out1 = F.relu(self.bn2(self.conv2(out1))) 92 | # right 93 | out2 = F.relu(self.bn3(self.conv3(x))) 94 | out2 = self.bn4(self.conv4(out2)) 95 | out2 = F.relu(self.bn5(self.conv5(out2))) 96 | # concat 97 | out = torch.cat([out1, out2], 1) 98 | out = self.shuffle(out) 99 | return out 100 | 101 | 102 | class ShuffleNetV2(nn.Module): 103 | def __init__(self, net_size, num_classes=10): 104 | super(ShuffleNetV2, self).__init__() 105 | out_channels = configs[net_size]['out_channels'] 106 | num_blocks = configs[net_size]['num_blocks'] 107 | 108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 109 | # stride=1, padding=1, bias=False) 110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(24) 112 | self.in_channels = 24 113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 117 | kernel_size=1, stride=1, padding=0, bias=False) 118 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 119 | self.linear = nn.Linear(out_channels[3], num_classes) 120 | 121 | def _make_layer(self, out_channels, num_blocks): 122 | layers = [DownBlock(self.in_channels, out_channels)] 123 | for i in range(num_blocks): 124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 125 | self.in_channels = out_channels 126 | return nn.Sequential(*layers) 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.bn1) 132 | feat_m.append(self.layer1) 133 | feat_m.append(self.layer2) 134 | feat_m.append(self.layer3) 135 | return feat_m 136 | 137 | def get_bn_before_relu(self): 138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 139 | 140 | def forward(self, x, is_feat=False, preact=False): 141 | out = F.relu(self.bn1(self.conv1(x))) 142 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out = F.relu(self.bn2(self.conv2(out))) 151 | out = F.avg_pool2d(out, 4) 152 | out = out.view(out.size(0), -1) 153 | f4 = out 154 | out = self.linear(out) 155 | if is_feat: 156 | if preact: 157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 158 | else: 159 | return [f0, f1, f2, f3, f4], out 160 | else: 161 | return out 162 | 163 | 164 | configs = { 165 | 0.2: { 166 | 'out_channels': (40, 80, 160, 512), 167 | 'num_blocks': (3, 3, 3) 168 | }, 169 | 170 | 0.3: { 171 | 'out_channels': (40, 80, 160, 512), 172 | 'num_blocks': (3, 7, 3) 173 | }, 174 | 175 | 0.5: { 176 | 'out_channels': (48, 96, 192, 1024), 177 | 'num_blocks': (3, 7, 3) 178 | }, 179 | 180 | 1: { 181 | 'out_channels': (116, 232, 464, 1024), 182 | 'num_blocks': (3, 7, 3) 183 | }, 184 | 1.5: { 185 | 'out_channels': (176, 352, 704, 1024), 186 | 'num_blocks': (3, 7, 3) 187 | }, 188 | 2: { 189 | 'out_channels': (224, 488, 976, 2048), 190 | 'num_blocks': (3, 7, 3) 191 | } 192 | } 193 | 194 | 195 | def ShuffleV2(**kwargs): 196 | model = ShuffleNetV2(net_size=1, **kwargs) 197 | return model 198 | 199 | 200 | if __name__ == '__main__': 201 | net = ShuffleV2(num_classes=100) 202 | x = torch.randn(3, 3, 32, 32) 203 | import time 204 | a = time.time() 205 | feats, logit = net(x, is_feat=True, preact=True) 206 | b = time.time() 207 | print(b - a) 208 | for f in feats: 209 | print(f.shape, f.min().item()) 210 | print(logit.shape) 211 | -------------------------------------------------------------------------------- /cv/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta_resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 2 | from .resnetv2 import ResNet50 3 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 4 | from .meta_vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn 5 | from .mobilenetv2 import mobile_half 6 | from .ShuffleNetv1 import ShuffleV1 7 | from .ShuffleNetv2 import ShuffleV2 8 | 9 | model_dict = { 10 | 'resnet8': resnet8, 11 | 'resnet14': resnet14, 12 | 'resnet20': resnet20, 13 | 'resnet32': resnet32, 14 | 'resnet44': resnet44, 15 | 'resnet56': resnet56, 16 | 'resnet110': resnet110, 17 | 'resnet8x4': resnet8x4, 18 | 'resnet32x4': resnet32x4, 19 | 'ResNet50': ResNet50, 20 | 'wrn_16_1': wrn_16_1, 21 | 'wrn_16_2': wrn_16_2, 22 | 'wrn_40_1': wrn_40_1, 23 | 'wrn_40_2': wrn_40_2, 24 | 'vgg8': vgg8_bn, 25 | 'vgg11': vgg11_bn, 26 | 'vgg13': vgg13_bn, 27 | 'vgg16': vgg16_bn, 28 | 'vgg19': vgg19_bn, 29 | 'MobileNetV2': mobile_half, 30 | 'ShuffleV1': ShuffleV1, 31 | 'ShuffleV2': ShuffleV2, 32 | } 33 | -------------------------------------------------------------------------------- /cv/models/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | ######################################### 7 | # ===== Classifiers ===== # 8 | ######################################### 9 | 10 | class LinearClassifier(nn.Module): 11 | 12 | def __init__(self, dim_in, n_label=10): 13 | super(LinearClassifier, self).__init__() 14 | 15 | self.net = nn.Linear(dim_in, n_label) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | class NonLinearClassifier(nn.Module): 22 | 23 | def __init__(self, dim_in, n_label=10, p=0.1): 24 | super(NonLinearClassifier, self).__init__() 25 | 26 | self.net = nn.Sequential( 27 | nn.Linear(dim_in, 200), 28 | nn.Dropout(p=p), 29 | nn.BatchNorm1d(200), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(200, n_label), 32 | ) 33 | 34 | def forward(self, x): 35 | return self.net(x) 36 | -------------------------------------------------------------------------------- /cv/models/meta_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchmeta.modules import * 13 | import math 14 | 15 | 16 | __all__ = ['resnet'] 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(MetaModule): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 29 | super(BasicBlock, self).__init__() 30 | self.is_last = is_last 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = MetaBatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = MetaBatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x, params=None): 40 | residual = x 41 | 42 | out = self.conv1(x, params=self.get_subdict(params, 'conv1')) 43 | out = self.bn1(out, params=self.get_subdict(params, 'bn1')) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out, params=self.get_subdict(params, 'conv2')) 47 | out = self.bn2(out, params=self.get_subdict(params, 'bn2')) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x, params=self.get_subdict(params, 'downsample')) 51 | 52 | out += residual 53 | preact = out 54 | out = F.relu(out) 55 | if self.is_last: 56 | return out, preact 57 | else: 58 | return out 59 | 60 | 61 | class Bottleneck(MetaModule): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 65 | super(Bottleneck, self).__init__() 66 | self.is_last = is_last 67 | self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = MetaBatchNorm2d(planes) 69 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = MetaBatchNorm2d(planes) 72 | self.conv3 = MetaConv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = MetaBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x, params=None): 79 | residual = x 80 | 81 | out = self.conv1(x, params=self.get_subdict(params, 'conv1')) 82 | out = self.bn1(out, params=self.get_subdict(params, 'bn1')) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out, params=self.get_subdict(params, 'conv2')) 86 | out = self.bn2(out, params=self.get_subdict(params, 'bn2')) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out, params=self.get_subdict(params, 'conv3')) 90 | out = self.bn3(out, params=self.get_subdict(params, 'bn3')) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x, params=self.get_subdict(params, 'downsample')) 94 | 95 | out += residual 96 | preact = out 97 | out = F.relu(out) 98 | if self.is_last: 99 | return out, preact 100 | else: 101 | return out 102 | 103 | 104 | class ResNet(MetaModule): 105 | 106 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 107 | super(ResNet, self).__init__() 108 | # Model type specifies number of layers for CIFAR-10 model 109 | if block_name.lower() == 'basicblock': 110 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 111 | n = (depth - 2) // 6 112 | block = BasicBlock 113 | elif block_name.lower() == 'bottleneck': 114 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 115 | n = (depth - 2) // 9 116 | block = Bottleneck 117 | else: 118 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 119 | 120 | self.inplanes = num_filters[0] 121 | self.conv1 = MetaConv2d(3, num_filters[0], kernel_size=3, padding=1, 122 | bias=False) 123 | self.bn1 = MetaBatchNorm2d(num_filters[0]) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.layer1 = self._make_layer(block, num_filters[1], n) 126 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 127 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 128 | self.avgpool = nn.AvgPool2d(8) 129 | self.fc = MetaLinear(num_filters[3] * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, MetaConv2d): 133 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 134 | elif isinstance(m, (MetaBatchNorm2d, nn.GroupNorm)): 135 | nn.init.constant_(m.weight, 1) 136 | nn.init.constant_(m.bias, 0) 137 | 138 | def _make_layer(self, block, planes, blocks, stride=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = MetaSequential( 142 | MetaConv2d(self.inplanes, planes * block.expansion, 143 | kernel_size=1, stride=stride, bias=False), 144 | MetaBatchNorm2d(planes * block.expansion), 145 | ) 146 | 147 | layers = list([]) 148 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) 149 | self.inplanes = planes * block.expansion 150 | for i in range(1, blocks): 151 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) 152 | 153 | return MetaSequential(*layers) 154 | 155 | def get_feat_modules(self): 156 | feat_m = nn.ModuleList([]) 157 | feat_m.append(self.conv1) 158 | feat_m.append(self.bn1) 159 | feat_m.append(self.relu) 160 | feat_m.append(self.layer1) 161 | feat_m.append(self.layer2) 162 | feat_m.append(self.layer3) 163 | return feat_m 164 | 165 | def get_bn_before_relu(self): 166 | if isinstance(self.layer1[0], Bottleneck): 167 | bn1 = self.layer1[-1].bn3 168 | bn2 = self.layer2[-1].bn3 169 | bn3 = self.layer3[-1].bn3 170 | elif isinstance(self.layer1[0], BasicBlock): 171 | bn1 = self.layer1[-1].bn2 172 | bn2 = self.layer2[-1].bn2 173 | bn3 = self.layer3[-1].bn2 174 | else: 175 | raise NotImplementedError('ResNet unknown block error !!!') 176 | 177 | return [bn1, bn2, bn3] 178 | 179 | def forward(self, x, is_feat=False, preact=False, params=None): 180 | x = self.conv1(x, params=self.get_subdict(params, 'conv1')) 181 | x = self.bn1(x, params=self.get_subdict(params, 'bn1')) 182 | x = self.relu(x) # 32x32 183 | f0 = x 184 | 185 | x, f1_pre = self.layer1(x, params=self.get_subdict(params, 'layer1')) # 32x32 186 | f1 = x 187 | x, f2_pre = self.layer2(x, params=self.get_subdict(params, 'layer2')) # 16x16 188 | f2 = x 189 | x, f3_pre = self.layer3(x, params=self.get_subdict(params, 'layer3')) # 8x8 190 | f3 = x 191 | 192 | x = self.avgpool(x) 193 | x = x.view(x.size(0), -1) 194 | f4 = x 195 | x = self.fc(x, params=self.get_subdict(params, 'fc')) 196 | 197 | if is_feat: 198 | if preact: 199 | return [f0, f1_pre, f2_pre, f3_pre, f4], x 200 | else: 201 | return [f0, f1, f2, f3, f4], x 202 | else: 203 | return x 204 | 205 | 206 | def resnet8(**kwargs): 207 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 208 | 209 | 210 | def resnet14(**kwargs): 211 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 212 | 213 | 214 | def resnet20(**kwargs): 215 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 216 | 217 | 218 | def resnet32(**kwargs): 219 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 220 | 221 | 222 | def resnet44(**kwargs): 223 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 224 | 225 | 226 | def resnet56(**kwargs): 227 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 228 | 229 | 230 | def resnet110(**kwargs): 231 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 232 | 233 | 234 | def resnet8x4(**kwargs): 235 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 236 | 237 | 238 | def resnet32x4(**kwargs): 239 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 240 | 241 | 242 | if __name__ == '__main__': 243 | import torch 244 | 245 | x = torch.randn(2, 3, 32, 32) 246 | net = resnet8x4(num_classes=20) 247 | feats, logit = net(x, is_feat=True, preact=True) 248 | 249 | for f in feats: 250 | print(f.shape, f.min().item()) 251 | print(logit.shape) 252 | 253 | for m in net.get_bn_before_relu(): 254 | if isinstance(m, MetaBatchNorm2d): 255 | print('pass') 256 | else: 257 | print('warning') 258 | -------------------------------------------------------------------------------- /cv/models/meta_vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | from torchmeta.modules import * 8 | 9 | 10 | __all__ = [ 11 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 12 | 'vgg19_bn', 'vgg19', 13 | ] 14 | 15 | 16 | model_urls = { 17 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 18 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 19 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 20 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 21 | } 22 | 23 | 24 | class VGG(MetaModule): 25 | 26 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 27 | super(VGG, self).__init__() 28 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 29 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 30 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 31 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 32 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 33 | 34 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 38 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 39 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 40 | 41 | self.classifier = MetaLinear(512, num_classes) 42 | self._initialize_weights() 43 | 44 | def get_feat_modules(self): 45 | feat_m = nn.ModuleList([]) 46 | feat_m.append(self.block0) 47 | feat_m.append(self.pool0) 48 | feat_m.append(self.block1) 49 | feat_m.append(self.pool1) 50 | feat_m.append(self.block2) 51 | feat_m.append(self.pool2) 52 | feat_m.append(self.block3) 53 | feat_m.append(self.pool3) 54 | feat_m.append(self.block4) 55 | feat_m.append(self.pool4) 56 | return feat_m 57 | 58 | def get_bn_before_relu(self): 59 | bn1 = self.block1[-1] 60 | bn2 = self.block2[-1] 61 | bn3 = self.block3[-1] 62 | bn4 = self.block4[-1] 63 | return [bn1, bn2, bn3, bn4] 64 | 65 | def forward(self, x, is_feat=False, preact=False, params=None): 66 | h = x.shape[2] 67 | x = F.relu(self.block0(x, params=self.get_subdict(params, 'block0'))) 68 | f0 = x 69 | x = self.pool0(x) 70 | x = self.block1(x, params=self.get_subdict(params, 'block1')) 71 | f1_pre = x 72 | x = F.relu(x) 73 | f1 = x 74 | x = self.pool1(x) 75 | x = self.block2(x, params=self.get_subdict(params, 'block2')) 76 | f2_pre = x 77 | x = F.relu(x) 78 | f2 = x 79 | x = self.pool2(x) 80 | x = self.block3(x, params=self.get_subdict(params, 'block3')) 81 | f3_pre = x 82 | x = F.relu(x) 83 | f3 = x 84 | if h == 64: 85 | x = self.pool3(x) 86 | x = self.block4(x) 87 | f4_pre = x 88 | x = F.relu(x) 89 | f4 = x 90 | x = self.pool4(x) 91 | x = x.view(x.size(0), -1) 92 | f5 = x 93 | x = self.classifier(x, params=self.get_subdict(params, 'classifier')) 94 | 95 | if is_feat: 96 | if preact: 97 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x 98 | else: 99 | return [f0, f1, f2, f3, f4, f5], x 100 | else: 101 | return x 102 | 103 | @staticmethod 104 | def _make_layers(cfg, batch_norm=False, in_channels=3): 105 | layers = [] 106 | for v in cfg: 107 | if v == 'M': 108 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 109 | else: 110 | conv2d = MetaConv2d(in_channels, v, kernel_size=3, padding=1) 111 | if batch_norm: 112 | layers += [conv2d, MetaBatchNorm2d(v), nn.ReLU(inplace=True)] 113 | else: 114 | layers += [conv2d, nn.ReLU(inplace=True)] 115 | in_channels = v 116 | layers = layers[:-1] 117 | return MetaSequential(*layers) 118 | 119 | def _initialize_weights(self): 120 | for m in self.modules(): 121 | if isinstance(m, MetaConv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, MetaBatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, MetaLinear): 130 | n = m.weight.size(1) 131 | m.weight.data.normal_(0, 0.01) 132 | m.bias.data.zero_() 133 | 134 | 135 | cfg = { 136 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 137 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 138 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 139 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 140 | 'S': [[64], [128], [256], [512], [512]], 141 | } 142 | 143 | 144 | def vgg8(**kwargs): 145 | """VGG 8-layer model (configuration "S") 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | """ 149 | model = VGG(cfg['S'], **kwargs) 150 | return model 151 | 152 | 153 | def vgg8_bn(**kwargs): 154 | """VGG 8-layer model (configuration "S") 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 159 | return model 160 | 161 | 162 | def vgg11(**kwargs): 163 | """VGG 11-layer model (configuration "A") 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = VGG(cfg['A'], **kwargs) 168 | return model 169 | 170 | 171 | def vgg11_bn(**kwargs): 172 | """VGG 11-layer model (configuration "A") with batch normalization""" 173 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 174 | return model 175 | 176 | 177 | def vgg13(**kwargs): 178 | """VGG 13-layer model (configuration "B") 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = VGG(cfg['B'], **kwargs) 183 | return model 184 | 185 | 186 | def vgg13_bn(**kwargs): 187 | """VGG 13-layer model (configuration "B") with batch normalization""" 188 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 189 | return model 190 | 191 | 192 | def vgg16(**kwargs): 193 | """VGG 16-layer model (configuration "D") 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = VGG(cfg['D'], **kwargs) 198 | return model 199 | 200 | 201 | def vgg16_bn(**kwargs): 202 | """VGG 16-layer model (configuration "D") with batch normalization""" 203 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 204 | return model 205 | 206 | 207 | def vgg19(**kwargs): 208 | """VGG 19-layer model (configuration "E") 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = VGG(cfg['E'], **kwargs) 213 | return model 214 | 215 | 216 | def vgg19_bn(**kwargs): 217 | """VGG 19-layer model (configuration 'E') with batch normalization""" 218 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 219 | return model 220 | 221 | 222 | if __name__ == '__main__': 223 | import torch 224 | 225 | x = torch.randn(2, 3, 32, 32) 226 | net = vgg19_bn(num_classes=100) 227 | feats, logit = net(x, is_feat=True, preact=True) 228 | 229 | for f in feats: 230 | print(f.shape, f.min().item()) 231 | print(logit.shape) 232 | 233 | for m in net.get_bn_before_relu(): 234 | if isinstance(m, MetaBatchNorm2d): 235 | print('pass') 236 | else: 237 | print('warning') 238 | -------------------------------------------------------------------------------- /cv/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['mobilenetv2_T_w', 'mobile_half'] 11 | 12 | BN = None 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 26 | nn.BatchNorm2d(oup), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio): 33 | super(InvertedResidual, self).__init__() 34 | self.blockname = None 35 | 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] 55 | 56 | def forward(self, x): 57 | t = x 58 | if self.use_res_connect: 59 | return t + self.conv(x) 60 | else: 61 | return self.conv(x) 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | """mobilenetV2""" 66 | def __init__(self, T, 67 | feature_dim, 68 | input_size=32, 69 | width_mult=1., 70 | remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | # building classifier 108 | self.classifier = nn.Sequential( 109 | # nn.Dropout(0.5), 110 | nn.Linear(self.last_channel, feature_dim), 111 | ) 112 | 113 | H = input_size // (32//2) 114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 115 | 116 | self._initialize_weights() 117 | print(T, width_mult) 118 | 119 | def get_bn_before_relu(self): 120 | bn1 = self.blocks[1][-1].conv[-1] 121 | bn2 = self.blocks[2][-1].conv[-1] 122 | bn3 = self.blocks[4][-1].conv[-1] 123 | bn4 = self.blocks[6][-1].conv[-1] 124 | return [bn1, bn2, bn3, bn4] 125 | 126 | def get_feat_modules(self): 127 | feat_m = nn.ModuleList([]) 128 | feat_m.append(self.conv1) 129 | feat_m.append(self.blocks) 130 | return feat_m 131 | 132 | def forward(self, x, is_feat=False, preact=False): 133 | 134 | out = self.conv1(x) 135 | f0 = out 136 | 137 | out = self.blocks[0](out) 138 | out = self.blocks[1](out) 139 | f1 = out 140 | out = self.blocks[2](out) 141 | f2 = out 142 | out = self.blocks[3](out) 143 | out = self.blocks[4](out) 144 | f3 = out 145 | out = self.blocks[5](out) 146 | out = self.blocks[6](out) 147 | f4 = out 148 | 149 | out = self.conv2(out) 150 | 151 | if not self.remove_avg: 152 | out = self.avgpool(out) 153 | out = out.view(out.size(0), -1) 154 | f5 = out 155 | out = self.classifier(out) 156 | 157 | if is_feat: 158 | return [f0, f1, f2, f3, f4, f5], out 159 | else: 160 | return out 161 | 162 | def _initialize_weights(self): 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 166 | m.weight.data.normal_(0, math.sqrt(2. / n)) 167 | if m.bias is not None: 168 | m.bias.data.zero_() 169 | elif isinstance(m, nn.BatchNorm2d): 170 | m.weight.data.fill_(1) 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.Linear): 173 | n = m.weight.size(1) 174 | m.weight.data.normal_(0, 0.01) 175 | m.bias.data.zero_() 176 | 177 | 178 | def mobilenetv2_T_w(T, W, feature_dim=100): 179 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 180 | return model 181 | 182 | 183 | def mobile_half(num_classes): 184 | return mobilenetv2_T_w(6, 0.5, num_classes) 185 | 186 | 187 | if __name__ == '__main__': 188 | x = torch.randn(2, 3, 32, 32) 189 | 190 | net = mobile_half(100) 191 | 192 | feats, logit = net(x, is_feat=True, preact=True) 193 | for f in feats: 194 | print(f.shape, f.min().item()) 195 | print(logit.shape) 196 | 197 | for m in net.get_bn_before_relu(): 198 | if isinstance(m, nn.BatchNorm2d): 199 | print('pass') 200 | else: 201 | print('warning') 202 | 203 | -------------------------------------------------------------------------------- /cv/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 28 | super(BasicBlock, self).__init__() 29 | self.is_last = is_last 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | preact = out 53 | out = F.relu(out) 54 | if self.is_last: 55 | return out, preact 56 | else: 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 64 | super(Bottleneck, self).__init__() 65 | self.is_last = is_last 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | preact = out 96 | out = F.relu(out) 97 | if self.is_last: 98 | return out, preact 99 | else: 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 106 | super(ResNet, self).__init__() 107 | # Model type specifies number of layers for CIFAR-10 model 108 | if block_name.lower() == 'basicblock': 109 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 110 | n = (depth - 2) // 6 111 | block = BasicBlock 112 | elif block_name.lower() == 'bottleneck': 113 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 114 | n = (depth - 2) // 9 115 | block = Bottleneck 116 | else: 117 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 118 | 119 | self.inplanes = num_filters[0] 120 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 121 | bias=False) 122 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.layer1 = self._make_layer(block, num_filters[1], n) 125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 127 | self.avgpool = nn.AvgPool2d(8) 128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = list([]) 147 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def get_feat_modules(self): 155 | feat_m = nn.ModuleList([]) 156 | feat_m.append(self.conv1) 157 | feat_m.append(self.bn1) 158 | feat_m.append(self.relu) 159 | feat_m.append(self.layer1) 160 | feat_m.append(self.layer2) 161 | feat_m.append(self.layer3) 162 | return feat_m 163 | 164 | def get_bn_before_relu(self): 165 | if isinstance(self.layer1[0], Bottleneck): 166 | bn1 = self.layer1[-1].bn3 167 | bn2 = self.layer2[-1].bn3 168 | bn3 = self.layer3[-1].bn3 169 | elif isinstance(self.layer1[0], BasicBlock): 170 | bn1 = self.layer1[-1].bn2 171 | bn2 = self.layer2[-1].bn2 172 | bn3 = self.layer3[-1].bn2 173 | else: 174 | raise NotImplementedError('ResNet unknown block error !!!') 175 | 176 | return [bn1, bn2, bn3] 177 | 178 | def forward(self, x, is_feat=False, preact=False): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) # 32x32 182 | f0 = x 183 | 184 | x, f1_pre = self.layer1(x) # 32x32 185 | f1 = x 186 | x, f2_pre = self.layer2(x) # 16x16 187 | f2 = x 188 | x, f3_pre = self.layer3(x) # 8x8 189 | f3 = x 190 | 191 | x = self.avgpool(x) 192 | x = x.view(x.size(0), -1) 193 | f4 = x 194 | x = self.fc(x) 195 | 196 | if is_feat: 197 | if preact: 198 | return [f0, f1_pre, f2_pre, f3_pre, f4], x 199 | else: 200 | return [f0, f1, f2, f3, f4], x 201 | else: 202 | return x 203 | 204 | 205 | def resnet8(**kwargs): 206 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 207 | 208 | 209 | def resnet14(**kwargs): 210 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 211 | 212 | 213 | def resnet20(**kwargs): 214 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 215 | 216 | 217 | def resnet32(**kwargs): 218 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 219 | 220 | 221 | def resnet44(**kwargs): 222 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 223 | 224 | 225 | def resnet56(**kwargs): 226 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 227 | 228 | 229 | def resnet110(**kwargs): 230 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 231 | 232 | 233 | def resnet8x4(**kwargs): 234 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 235 | 236 | 237 | def resnet32x4(**kwargs): 238 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 239 | 240 | 241 | if __name__ == '__main__': 242 | import torch 243 | 244 | x = torch.randn(2, 3, 32, 32) 245 | net = resnet8x4(num_classes=20) 246 | feats, logit = net(x, is_feat=True, preact=True) 247 | 248 | for f in feats: 249 | print(f.shape, f.min().item()) 250 | print(logit.shape) 251 | 252 | for m in net.get_bn_before_relu(): 253 | if isinstance(m, nn.BatchNorm2d): 254 | print('pass') 255 | else: 256 | print('warning') 257 | -------------------------------------------------------------------------------- /cv/models/resnetv2.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 87 | self.linear = nn.Linear(512 * block.expansion, num_classes) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | elif isinstance(m, BasicBlock): 104 | nn.init.constant_(m.bn2.weight, 0) 105 | 106 | def get_feat_modules(self): 107 | feat_m = nn.ModuleList([]) 108 | feat_m.append(self.conv1) 109 | feat_m.append(self.bn1) 110 | feat_m.append(self.layer1) 111 | feat_m.append(self.layer2) 112 | feat_m.append(self.layer3) 113 | feat_m.append(self.layer4) 114 | return feat_m 115 | 116 | def get_bn_before_relu(self): 117 | if isinstance(self.layer1[0], Bottleneck): 118 | bn1 = self.layer1[-1].bn3 119 | bn2 = self.layer2[-1].bn3 120 | bn3 = self.layer3[-1].bn3 121 | bn4 = self.layer4[-1].bn3 122 | elif isinstance(self.layer1[0], BasicBlock): 123 | bn1 = self.layer1[-1].bn2 124 | bn2 = self.layer2[-1].bn2 125 | bn3 = self.layer3[-1].bn2 126 | bn4 = self.layer4[-1].bn2 127 | else: 128 | raise NotImplementedError('ResNet unknown block error !!!') 129 | 130 | return [bn1, bn2, bn3, bn4] 131 | 132 | def _make_layer(self, block, planes, num_blocks, stride): 133 | strides = [stride] + [1] * (num_blocks - 1) 134 | layers = [] 135 | for i in range(num_blocks): 136 | stride = strides[i] 137 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 138 | self.in_planes = planes * block.expansion 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, is_feat=False, preact=False): 142 | out = F.relu(self.bn1(self.conv1(x))) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out, f4_pre = self.layer4(out) 151 | f4 = out 152 | out = self.avgpool(out) 153 | out = out.view(out.size(0), -1) 154 | f5 = out 155 | out = self.linear(out) 156 | if is_feat: 157 | if preact: 158 | return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] 159 | else: 160 | return [f0, f1, f2, f3, f4, f5], out 161 | else: 162 | return out 163 | 164 | 165 | def ResNet18(**kwargs): 166 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | 168 | 169 | def ResNet34(**kwargs): 170 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 171 | 172 | 173 | def ResNet50(**kwargs): 174 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 175 | 176 | 177 | def ResNet101(**kwargs): 178 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 179 | 180 | 181 | def ResNet152(**kwargs): 182 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 183 | 184 | 185 | if __name__ == '__main__': 186 | net = ResNet18(num_classes=100) 187 | x = torch.randn(2, 3, 32, 32) 188 | feats, logit = net(x, is_feat=True, preact=True) 189 | 190 | for f in feats: 191 | print(f.shape, f.min().item()) 192 | print(logit.shape) 193 | 194 | for m in net.get_bn_before_relu(): 195 | if isinstance(m, nn.BatchNorm2d): 196 | print('pass') 197 | else: 198 | print('warning') 199 | -------------------------------------------------------------------------------- /cv/models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Paraphraser(nn.Module): 8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer""" 9 | def __init__(self, t_shape, k=0.5, use_bn=False): 10 | super(Paraphraser, self).__init__() 11 | in_channel = t_shape[1] 12 | out_channel = int(t_shape[1] * k) 13 | self.encoder = nn.Sequential( 14 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 15 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 16 | nn.LeakyReLU(0.1, inplace=True), 17 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 18 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 19 | nn.LeakyReLU(0.1, inplace=True), 20 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 21 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 22 | nn.LeakyReLU(0.1, inplace=True), 23 | ) 24 | self.decoder = nn.Sequential( 25 | nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1), 26 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 27 | nn.LeakyReLU(0.1, inplace=True), 28 | nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1), 29 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 30 | nn.LeakyReLU(0.1, inplace=True), 31 | nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1), 32 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 33 | nn.LeakyReLU(0.1, inplace=True), 34 | ) 35 | 36 | def forward(self, f_s, is_factor=False): 37 | factor = self.encoder(f_s) 38 | if is_factor: 39 | return factor 40 | rec = self.decoder(factor) 41 | return factor, rec 42 | 43 | 44 | class Translator(nn.Module): 45 | def __init__(self, s_shape, t_shape, k=0.5, use_bn=True): 46 | super(Translator, self).__init__() 47 | in_channel = s_shape[1] 48 | out_channel = int(t_shape[1] * k) 49 | self.encoder = nn.Sequential( 50 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 51 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 52 | nn.LeakyReLU(0.1, inplace=True), 53 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 54 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 55 | nn.LeakyReLU(0.1, inplace=True), 56 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 57 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | ) 60 | 61 | def forward(self, f_s): 62 | return self.encoder(f_s) 63 | 64 | 65 | class Connector(nn.Module): 66 | """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons""" 67 | def __init__(self, s_shapes, t_shapes): 68 | super(Connector, self).__init__() 69 | self.s_shapes = s_shapes 70 | self.t_shapes = t_shapes 71 | 72 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 73 | 74 | @staticmethod 75 | def _make_conenctors(s_shapes, t_shapes): 76 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 77 | connectors = [] 78 | for s, t in zip(s_shapes, t_shapes): 79 | if s[1] == t[1] and s[2] == t[2]: 80 | connectors.append(nn.Sequential()) 81 | else: 82 | connectors.append(ConvReg(s, t, use_relu=False)) 83 | return connectors 84 | 85 | def forward(self, g_s): 86 | out = [] 87 | for i in range(len(g_s)): 88 | out.append(self.connectors[i](g_s[i])) 89 | 90 | return out 91 | 92 | 93 | class ConnectorV2(nn.Module): 94 | """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)""" 95 | def __init__(self, s_shapes, t_shapes): 96 | super(ConnectorV2, self).__init__() 97 | self.s_shapes = s_shapes 98 | self.t_shapes = t_shapes 99 | 100 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 101 | 102 | def _make_conenctors(self, s_shapes, t_shapes): 103 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 104 | t_channels = [t[1] for t in t_shapes] 105 | s_channels = [s[1] for s in s_shapes] 106 | connectors = nn.ModuleList([self._build_feature_connector(t, s) 107 | for t, s in zip(t_channels, s_channels)]) 108 | return connectors 109 | 110 | @staticmethod 111 | def _build_feature_connector(t_channel, s_channel): 112 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 113 | nn.BatchNorm2d(t_channel)] 114 | for m in C: 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | return nn.Sequential(*C) 122 | 123 | def forward(self, g_s): 124 | out = [] 125 | for i in range(len(g_s)): 126 | out.append(self.connectors[i](g_s[i])) 127 | 128 | return out 129 | 130 | 131 | class ConvReg(nn.Module): 132 | """Convolutional regression for FitNet""" 133 | def __init__(self, s_shape, t_shape, use_relu=True): 134 | super(ConvReg, self).__init__() 135 | self.use_relu = use_relu 136 | s_N, s_C, s_H, s_W = s_shape 137 | t_N, t_C, t_H, t_W = t_shape 138 | if s_H == 2 * t_H: 139 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 140 | elif s_H * 2 == t_H: 141 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 142 | elif s_H >= t_H: 143 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 144 | else: 145 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) 146 | self.bn = nn.BatchNorm2d(t_C) 147 | self.relu = nn.ReLU(inplace=True) 148 | 149 | def forward(self, x): 150 | x = self.conv(x) 151 | if self.use_relu: 152 | return self.relu(self.bn(x)) 153 | else: 154 | return self.bn(x) 155 | 156 | 157 | class Regress(nn.Module): 158 | """Simple Linear Regression for hints""" 159 | def __init__(self, dim_in=1024, dim_out=1024): 160 | super(Regress, self).__init__() 161 | self.linear = nn.Linear(dim_in, dim_out) 162 | self.relu = nn.ReLU(inplace=True) 163 | 164 | def forward(self, x): 165 | x = x.view(x.shape[0], -1) 166 | x = self.linear(x) 167 | x = self.relu(x) 168 | return x 169 | 170 | 171 | class Embed(nn.Module): 172 | """Embedding module""" 173 | def __init__(self, dim_in=1024, dim_out=128): 174 | super(Embed, self).__init__() 175 | self.linear = nn.Linear(dim_in, dim_out) 176 | self.l2norm = Normalize(2) 177 | 178 | def forward(self, x): 179 | x = x.view(x.shape[0], -1) 180 | x = self.linear(x) 181 | x = self.l2norm(x) 182 | return x 183 | 184 | 185 | class LinearEmbed(nn.Module): 186 | """Linear Embedding""" 187 | def __init__(self, dim_in=1024, dim_out=128): 188 | super(LinearEmbed, self).__init__() 189 | self.linear = nn.Linear(dim_in, dim_out) 190 | 191 | def forward(self, x): 192 | x = x.view(x.shape[0], -1) 193 | x = self.linear(x) 194 | return x 195 | 196 | 197 | class MLPEmbed(nn.Module): 198 | """non-linear embed by MLP""" 199 | def __init__(self, dim_in=1024, dim_out=128): 200 | super(MLPEmbed, self).__init__() 201 | self.linear1 = nn.Linear(dim_in, 2 * dim_out) 202 | self.relu = nn.ReLU(inplace=True) 203 | self.linear2 = nn.Linear(2 * dim_out, dim_out) 204 | self.l2norm = Normalize(2) 205 | 206 | def forward(self, x): 207 | x = x.view(x.shape[0], -1) 208 | x = self.relu(self.linear1(x)) 209 | x = self.l2norm(self.linear2(x)) 210 | return x 211 | 212 | 213 | class Normalize(nn.Module): 214 | """normalization layer""" 215 | def __init__(self, power=2): 216 | super(Normalize, self).__init__() 217 | self.power = power 218 | 219 | def forward(self, x): 220 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 221 | out = x.div(norm) 222 | return out 223 | 224 | 225 | class Flatten(nn.Module): 226 | """flatten module""" 227 | def __init__(self): 228 | super(Flatten, self).__init__() 229 | 230 | def forward(self, feat): 231 | return feat.view(feat.size(0), -1) 232 | 233 | 234 | class PoolEmbed(nn.Module): 235 | """pool and embed""" 236 | def __init__(self, layer=0, dim_out=128, pool_type='avg'): 237 | super().__init__() 238 | if layer == 0: 239 | pool_size = 8 240 | nChannels = 16 241 | elif layer == 1: 242 | pool_size = 8 243 | nChannels = 16 244 | elif layer == 2: 245 | pool_size = 6 246 | nChannels = 32 247 | elif layer == 3: 248 | pool_size = 4 249 | nChannels = 64 250 | elif layer == 4: 251 | pool_size = 1 252 | nChannels = 64 253 | else: 254 | raise NotImplementedError('layer not supported: {}'.format(layer)) 255 | 256 | self.embed = nn.Sequential() 257 | if layer <= 3: 258 | if pool_type == 'max': 259 | self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) 260 | elif pool_type == 'avg': 261 | self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) 262 | 263 | self.embed.add_module('Flatten', Flatten()) 264 | self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out)) 265 | self.embed.add_module('Normalize', Normalize(2)) 266 | 267 | def forward(self, x): 268 | return self.embed(x) 269 | 270 | 271 | if __name__ == '__main__': 272 | import torch 273 | 274 | g_s = [ 275 | torch.randn(2, 16, 16, 16), 276 | torch.randn(2, 32, 8, 8), 277 | torch.randn(2, 64, 4, 4), 278 | ] 279 | g_t = [ 280 | torch.randn(2, 32, 16, 16), 281 | torch.randn(2, 64, 8, 8), 282 | torch.randn(2, 128, 4, 4), 283 | ] 284 | s_shapes = [s.shape for s in g_s] 285 | t_shapes = [t.shape for t in g_t] 286 | 287 | net = ConnectorV2(s_shapes, t_shapes) 288 | out = net(g_s) 289 | for f in out: 290 | print(f.shape) 291 | -------------------------------------------------------------------------------- /cv/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 28 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 29 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 30 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 31 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 32 | 33 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 38 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | 40 | self.classifier = nn.Linear(512, num_classes) 41 | self._initialize_weights() 42 | 43 | def get_feat_modules(self): 44 | feat_m = nn.ModuleList([]) 45 | feat_m.append(self.block0) 46 | feat_m.append(self.pool0) 47 | feat_m.append(self.block1) 48 | feat_m.append(self.pool1) 49 | feat_m.append(self.block2) 50 | feat_m.append(self.pool2) 51 | feat_m.append(self.block3) 52 | feat_m.append(self.pool3) 53 | feat_m.append(self.block4) 54 | feat_m.append(self.pool4) 55 | return feat_m 56 | 57 | def get_bn_before_relu(self): 58 | bn1 = self.block1[-1] 59 | bn2 = self.block2[-1] 60 | bn3 = self.block3[-1] 61 | bn4 = self.block4[-1] 62 | return [bn1, bn2, bn3, bn4] 63 | 64 | def forward(self, x, is_feat=False, preact=False): 65 | h = x.shape[2] 66 | x = F.relu(self.block0(x)) 67 | f0 = x 68 | x = self.pool0(x) 69 | x = self.block1(x) 70 | f1_pre = x 71 | x = F.relu(x) 72 | f1 = x 73 | x = self.pool1(x) 74 | x = self.block2(x) 75 | f2_pre = x 76 | x = F.relu(x) 77 | f2 = x 78 | x = self.pool2(x) 79 | x = self.block3(x) 80 | f3_pre = x 81 | x = F.relu(x) 82 | f3 = x 83 | if h == 64: 84 | x = self.pool3(x) 85 | x = self.block4(x) 86 | f4_pre = x 87 | x = F.relu(x) 88 | f4 = x 89 | x = self.pool4(x) 90 | x = x.view(x.size(0), -1) 91 | f5 = x 92 | x = self.classifier(x) 93 | 94 | if is_feat: 95 | if preact: 96 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x 97 | else: 98 | return [f0, f1, f2, f3, f4, f5], x 99 | else: 100 | return x 101 | 102 | @staticmethod 103 | def _make_layers(cfg, batch_norm=False, in_channels=3): 104 | layers = [] 105 | for v in cfg: 106 | if v == 'M': 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 110 | if batch_norm: 111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 112 | else: 113 | layers += [conv2d, nn.ReLU(inplace=True)] 114 | in_channels = v 115 | layers = layers[:-1] 116 | return nn.Sequential(*layers) 117 | 118 | def _initialize_weights(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | if m.bias is not None: 124 | m.bias.data.zero_() 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.Linear): 129 | n = m.weight.size(1) 130 | m.weight.data.normal_(0, 0.01) 131 | m.bias.data.zero_() 132 | 133 | 134 | cfg = { 135 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 136 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 137 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 138 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 139 | 'S': [[64], [128], [256], [512], [512]], 140 | } 141 | 142 | 143 | def vgg8(**kwargs): 144 | """VGG 8-layer model (configuration "S") 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = VGG(cfg['S'], **kwargs) 149 | return model 150 | 151 | 152 | def vgg8_bn(**kwargs): 153 | """VGG 8-layer model (configuration "S") 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 158 | return model 159 | 160 | 161 | def vgg11(**kwargs): 162 | """VGG 11-layer model (configuration "A") 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = VGG(cfg['A'], **kwargs) 167 | return model 168 | 169 | 170 | def vgg11_bn(**kwargs): 171 | """VGG 11-layer model (configuration "A") with batch normalization""" 172 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 173 | return model 174 | 175 | 176 | def vgg13(**kwargs): 177 | """VGG 13-layer model (configuration "B") 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = VGG(cfg['B'], **kwargs) 182 | return model 183 | 184 | 185 | def vgg13_bn(**kwargs): 186 | """VGG 13-layer model (configuration "B") with batch normalization""" 187 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 188 | return model 189 | 190 | 191 | def vgg16(**kwargs): 192 | """VGG 16-layer model (configuration "D") 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = VGG(cfg['D'], **kwargs) 197 | return model 198 | 199 | 200 | def vgg16_bn(**kwargs): 201 | """VGG 16-layer model (configuration "D") with batch normalization""" 202 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 203 | return model 204 | 205 | 206 | def vgg19(**kwargs): 207 | """VGG 19-layer model (configuration "E") 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = VGG(cfg['E'], **kwargs) 212 | return model 213 | 214 | 215 | def vgg19_bn(**kwargs): 216 | """VGG 19-layer model (configuration 'E') with batch normalization""" 217 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 218 | return model 219 | 220 | 221 | if __name__ == '__main__': 222 | import torch 223 | 224 | x = torch.randn(2, 3, 32, 32) 225 | net = vgg19_bn(num_classes=100) 226 | feats, logit = net(x, is_feat=True, preact=True) 227 | 228 | for f in feats: 229 | print(f.shape, f.min().item()) 230 | print(logit.shape) 231 | 232 | for m in net.get_bn_before_relu(): 233 | if isinstance(m, nn.BatchNorm2d): 234 | print('pass') 235 | else: 236 | print('warning') 237 | -------------------------------------------------------------------------------- /cv/models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Original Author: Wei Yang 8 | """ 9 | 10 | __all__ = ['wrn'] 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 15 | super(BasicBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 23 | padding=1, bias=False) 24 | self.droprate = dropRate 25 | self.equalInOut = (in_planes == out_planes) 26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 27 | padding=0, bias=False) or None 28 | 29 | def forward(self, x): 30 | if not self.equalInOut: 31 | x = self.relu1(self.bn1(x)) 32 | else: 33 | out = self.relu1(self.bn1(x)) 34 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, training=self.training) 37 | out = self.conv2(out) 38 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 60 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def get_feat_modules(self): 89 | feat_m = nn.ModuleList([]) 90 | feat_m.append(self.conv1) 91 | feat_m.append(self.block1) 92 | feat_m.append(self.block2) 93 | feat_m.append(self.block3) 94 | return feat_m 95 | 96 | def get_bn_before_relu(self): 97 | bn1 = self.block2.layer[0].bn1 98 | bn2 = self.block3.layer[0].bn1 99 | bn3 = self.bn1 100 | 101 | return [bn1, bn2, bn3] 102 | 103 | def forward(self, x, is_feat=False, preact=False): 104 | out = self.conv1(x) 105 | f0 = out 106 | out = self.block1(out) 107 | f1 = out 108 | out = self.block2(out) 109 | f2 = out 110 | out = self.block3(out) 111 | f3 = out 112 | out = self.relu(self.bn1(out)) 113 | out = F.avg_pool2d(out, 8) 114 | out = out.view(-1, self.nChannels) 115 | f4 = out 116 | out = self.fc(out) 117 | if is_feat: 118 | if preact: 119 | f1 = self.block2.layer[0].bn1(f1) 120 | f2 = self.block3.layer[0].bn1(f2) 121 | f3 = self.bn1(f3) 122 | return [f0, f1, f2, f3, f4], out 123 | else: 124 | return out 125 | 126 | 127 | def wrn(**kwargs): 128 | """ 129 | Constructs a Wide Residual Networks. 130 | """ 131 | model = WideResNet(**kwargs) 132 | return model 133 | 134 | 135 | def wrn_40_2(**kwargs): 136 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 137 | return model 138 | 139 | 140 | def wrn_40_1(**kwargs): 141 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 142 | return model 143 | 144 | 145 | def wrn_16_2(**kwargs): 146 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 147 | return model 148 | 149 | 150 | def wrn_16_1(**kwargs): 151 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 152 | return model 153 | 154 | 155 | if __name__ == '__main__': 156 | import torch 157 | 158 | x = torch.randn(2, 3, 32, 32) 159 | net = wrn_40_2(num_classes=100) 160 | feats, logit = net(x, is_feat=True, preact=True) 161 | 162 | for f in feats: 163 | print(f.shape, f.min().item()) 164 | print(logit.shape) 165 | 166 | for m in net.get_bn_before_relu(): 167 | if isinstance(m, nn.BatchNorm2d): 168 | print('pass') 169 | else: 170 | print('warning') 171 | -------------------------------------------------------------------------------- /cv/scripts/fetch_pretrained_teachers.sh: -------------------------------------------------------------------------------- 1 | # fetch pre-trained teacher models 2 | 3 | mkdir -p save/models/ 4 | 5 | cd save/models 6 | 7 | mkdir -p wrn_40_2_vanilla 8 | wget http://shape2prog.csail.mit.edu/repo/wrn_40_2_vanilla/ckpt_epoch_240.pth 9 | mv ckpt_epoch_240.pth wrn_40_2_vanilla/ 10 | 11 | mkdir -p resnet56_vanilla 12 | wget http://shape2prog.csail.mit.edu/repo/resnet56_vanilla/ckpt_epoch_240.pth 13 | mv ckpt_epoch_240.pth resnet56_vanilla/ 14 | 15 | mkdir -p resnet110_vanilla 16 | wget http://shape2prog.csail.mit.edu/repo/resnet110_vanilla/ckpt_epoch_240.pth 17 | mv ckpt_epoch_240.pth resnet110_vanilla/ 18 | 19 | mkdir -p resnet32x4_vanilla 20 | wget http://shape2prog.csail.mit.edu/repo/resnet32x4_vanilla/ckpt_epoch_240.pth 21 | mv ckpt_epoch_240.pth resnet32x4_vanilla/ 22 | 23 | mkdir -p vgg13_vanilla 24 | wget http://shape2prog.csail.mit.edu/repo/vgg13_vanilla/ckpt_epoch_240.pth 25 | mv ckpt_epoch_240.pth vgg13_vanilla/ 26 | 27 | mkdir -p ResNet50_vanilla 28 | wget http://shape2prog.csail.mit.edu/repo/ResNet50_vanilla/ckpt_epoch_240.pth 29 | mv ckpt_epoch_240.pth ResNet50_vanilla/ 30 | 31 | cd ../.. -------------------------------------------------------------------------------- /cv/scripts/run_cifar_distill.sh: -------------------------------------------------------------------------------- 1 | # sample scripts for running the distillation code 2 | # use resnet32x4 and resnet8x4 as an example 3 | 4 | # kd 5 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 0.1 -a 0.9 -b 0 --trial 1 6 | # FitNet 7 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill hint --model_s resnet8x4 -a 0 -b 100 --trial 1 8 | # AT 9 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill attention --model_s resnet8x4 -a 0 -b 1000 --trial 1 10 | # SP 11 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill similarity --model_s resnet8x4 -a 0 -b 3000 --trial 1 12 | # CC 13 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill correlation --model_s resnet8x4 -a 0 -b 0.02 --trial 1 14 | # VID 15 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill vid --model_s resnet8x4 -a 0 -b 1 --trial 1 16 | # RKD 17 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill rkd --model_s resnet8x4 -a 0 -b 1 --trial 1 18 | # PKT 19 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill pkt --model_s resnet8x4 -a 0 -b 30000 --trial 1 20 | # AB 21 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill abound --model_s resnet8x4 -a 0 -b 1 --trial 1 22 | # FT 23 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill factor --model_s resnet8x4 -a 0 -b 200 --trial 1 24 | # FSP 25 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill fsp --model_s resnet8x4 -a 0 -b 50 --trial 1 26 | # NST 27 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill nst --model_s resnet8x4 -a 0 -b 50 --trial 1 28 | # CRD 29 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 0 -b 0.8 --trial 1 30 | 31 | # CRD+KD 32 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 1 -b 0.8 --trial 1 -------------------------------------------------------------------------------- /cv/scripts/run_cifar_vanilla.sh: -------------------------------------------------------------------------------- 1 | # sample scripts for training vanilla teacher models 2 | 3 | python train_teacher.py --model wrn_40_2 4 | 5 | python train_teacher.py --model resnet56 6 | 7 | python train_teacher.py --model resnet110 8 | 9 | python train_teacher.py --model resnet32x4 10 | 11 | python train_teacher.py --model vgg13 12 | 13 | python train_teacher.py --model ResNet50 14 | -------------------------------------------------------------------------------- /cv/train_student_meta.py: -------------------------------------------------------------------------------- 1 | """ 2 | the general training framework 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import os 8 | import argparse 9 | import socket 10 | import time 11 | 12 | import tensorboard_logger as tb_logger 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | 18 | 19 | from models import model_dict 20 | from models.util import Embed, ConvReg, LinearEmbed 21 | from models.util import Connector, Translator, Paraphraser 22 | 23 | from dataset.meta_cifar100 import get_cifar100_dataloaders 24 | 25 | from helper.util import adjust_learning_rate 26 | 27 | from distiller_zoo.KD import CustomDistillKL 28 | from distiller_zoo.MSE import MSEWithTemperature 29 | from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss 30 | from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss 31 | from crd.criterion import CRDLoss 32 | 33 | from helper.meta_loops import train_distill as train, validate 34 | from helper.pretrain import init 35 | 36 | 37 | def parse_option(): 38 | 39 | hostname = socket.gethostname() 40 | 41 | parser = argparse.ArgumentParser('argument for training') 42 | 43 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 44 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 45 | parser.add_argument('--save_freq', type=int, default=40, help='save frequency') 46 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 47 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 48 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs') 49 | parser.add_argument('--init_epochs', type=int, default=30, help='init training for two-stage methods') 50 | 51 | # optimization 52 | parser.add_argument('--lr', type=float, default=0.05, help='learning rate') 53 | parser.add_argument('--teacher_lr', type=float, default=0.05, help='teacher learning rate') 54 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list') 55 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 56 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 57 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 58 | parser.add_argument('--loss_type', type=str, choices=['mse', 'kl']) 59 | 60 | # held set 61 | parser.add_argument('--held_size', type=int, help="the size of held set") 62 | parser.add_argument('--num_held_samples', type=int, help="num of held samples used for one teacher update") 63 | parser.add_argument('--num_meta_batches', type=int, default=1, help="num of meta batches used for one teacher update") 64 | parser.add_argument('--assume_s_step_size', type=float, default=0.05, help="assume student grad update lr") 65 | 66 | # dataset 67 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset') 68 | 69 | # model 70 | parser.add_argument('--model_s', type=str, default='resnet8', 71 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 72 | 'resnet8x4', 'resnet32x4', 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19']) 73 | # TODO: Add more support 74 | # 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2', 'ResNet50', 'MobileNetV2', 'ShuffleV1', 'ShuffleV2']) 75 | parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot') 76 | 77 | # distillation 78 | parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation') 79 | parser.add_argument('--trial', type=str, default='1', help='trial id') 80 | 81 | parser.add_argument('-a', '--alpha', type=float, default=None, help='weight balance for KD') 82 | 83 | opt = parser.parse_args() 84 | 85 | # set different learning rate from these 4 models 86 | if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']: 87 | opt.lr = 0.01 88 | 89 | # set the path according to the environment 90 | if hostname.startswith('visiongpu'): 91 | opt.model_path = '/path/to/my/student_model' 92 | opt.tb_path = '/path/to/my/student_tensorboards' 93 | else: 94 | opt.model_path = './save/student_model' 95 | opt.tb_path = './save/student_tensorboards' 96 | 97 | iterations = opt.lr_decay_epochs.split(',') 98 | opt.lr_decay_epochs = list([]) 99 | for it in iterations: 100 | opt.lr_decay_epochs.append(int(it)) 101 | 102 | opt.model_t = get_teacher_name(opt.path_t) 103 | 104 | opt.model_name = 'S:{}_T:{}_{}_{}_a:{}_{}'.format(opt.model_s, opt.model_t, opt.dataset, 'mlkd', 105 | opt.alpha, opt.trial) 106 | 107 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 108 | if not os.path.isdir(opt.tb_folder): 109 | os.makedirs(opt.tb_folder) 110 | 111 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 112 | if not os.path.isdir(opt.save_folder): 113 | os.makedirs(opt.save_folder) 114 | 115 | return opt 116 | 117 | 118 | def get_teacher_name(model_path): 119 | """parse teacher name""" 120 | segments = model_path.split('/')[-2].split('_') 121 | if segments[0] != 'wrn': 122 | return segments[0] 123 | else: 124 | return segments[0] + '_' + segments[1] + '_' + segments[2] 125 | 126 | 127 | def load_teacher(model_path, n_cls): 128 | print('==> loading teacher model') 129 | model_t = get_teacher_name(model_path) 130 | model = model_dict[model_t](num_classes=n_cls) 131 | if torch.cuda.is_available(): 132 | model.load_state_dict(torch.load(model_path)['model']) 133 | else: 134 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) 135 | print('==> done') 136 | return model 137 | 138 | 139 | def main(): 140 | best_acc = 0 141 | 142 | opt = parse_option() 143 | 144 | # tensorboard logger 145 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 146 | 147 | # dataloader 148 | if opt.dataset == 'cifar100': 149 | train_loader, held_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, 150 | num_workers=opt.num_workers, 151 | held_size=opt.held_size, 152 | num_held_samples=opt.num_held_samples, 153 | ) 154 | n_cls = 100 155 | else: 156 | raise NotImplementedError(opt.dataset) 157 | 158 | # model 159 | model_t = load_teacher(opt.path_t, n_cls) 160 | model_s = model_dict[opt.model_s](num_classes=n_cls) 161 | 162 | data = torch.randn(2, 3, 32, 32) 163 | model_t.eval() 164 | model_s.eval() 165 | feat_t, _ = model_t(data, is_feat=True) 166 | feat_s, _ = model_s(data, is_feat=True) 167 | 168 | module_list = nn.ModuleList([]) 169 | module_list.append(model_s) 170 | trainable_list = nn.ModuleList([]) 171 | trainable_list.append(model_s) 172 | 173 | criterion_cls = nn.CrossEntropyLoss() 174 | 175 | if opt.loss_type == 'mse': 176 | criterion_kd = MSEWithTemperature(T=opt.kd_T) 177 | elif opt.loss_type == 'kl': 178 | criterion_kd = CustomDistillKL(T=opt.kd_T) 179 | else: 180 | raise NotImplementedError() 181 | 182 | criterion_list = nn.ModuleList([]) 183 | criterion_list.append(criterion_cls) # classification loss 184 | criterion_list.append(criterion_kd) # other knowledge distillation loss 185 | 186 | # optimizer 187 | s_optimizer = optim.SGD(model_s.parameters(), 188 | lr=opt.lr, 189 | momentum=opt.momentum, 190 | weight_decay=opt.weight_decay) 191 | t_optimizer = optim.SGD(model_t.parameters(), 192 | lr=opt.teacher_lr, 193 | momentum=opt.momentum, 194 | weight_decay=opt.weight_decay) 195 | 196 | # append teacher after optimizer to avoid weight_decay 197 | module_list.append(model_t) 198 | 199 | if torch.cuda.is_available(): 200 | module_list.cuda() 201 | criterion_list.cuda() 202 | cudnn.benchmark = True 203 | 204 | # validate teacher accuracy 205 | teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) 206 | print('teacher accuracy: ', teacher_acc) 207 | 208 | #Fixed for first 150 epochs 209 | for param_group in t_optimizer.param_groups: 210 | param_group['lr'] = 0 211 | # routine 212 | for epoch in range(1, opt.epochs + 1): 213 | 214 | if epoch == 150: 215 | for param_group in t_optimizer.param_groups: 216 | param_group['lr'] = opt.teacher_lr 217 | opt.assume_s_step_size *= opt.lr_decay_rate 218 | if epoch == 180: 219 | for param_group in t_optimizer.param_groups: 220 | param_group['lr'] *= opt.lr_decay_rate 221 | opt.assume_s_step_size *= opt.lr_decay_rate 222 | if epoch == 210: 223 | for param_group in t_optimizer.param_groups: 224 | param_group['lr'] *= opt.lr_decay_rate 225 | opt.assume_s_step_size *= opt.lr_decay_rate 226 | adjust_learning_rate(epoch, opt, s_optimizer) 227 | # adjust_learning_rate(epoch, opt, t_optimizer) 228 | 229 | print("==> training...") 230 | 231 | time1 = time.time() 232 | train_acc, train_loss = train(epoch, train_loader, held_loader, module_list, criterion_list, s_optimizer, t_optimizer, opt) 233 | time2 = time.time() 234 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 235 | 236 | logger.log_value('train_acc', train_acc, epoch) 237 | logger.log_value('train_loss', train_loss, epoch) 238 | 239 | test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) 240 | 241 | logger.log_value('test_acc', test_acc, epoch) 242 | logger.log_value('test_loss', test_loss, epoch) 243 | logger.log_value('test_acc_top5', test_acc_top5, epoch) 244 | 245 | # save the best model 246 | if test_acc > best_acc: 247 | best_acc = test_acc 248 | state = { 249 | 'epoch': epoch, 250 | 'model': model_s.state_dict(), 251 | 'best_acc': best_acc, 252 | } 253 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) 254 | print('saving the best model!') 255 | torch.save(state, save_file) 256 | 257 | # regular saving 258 | if epoch % opt.save_freq == 0: 259 | print('==> Saving...') 260 | state = { 261 | 'epoch': epoch, 262 | 'model': model_s.state_dict(), 263 | 'accuracy': test_acc, 264 | } 265 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 266 | torch.save(state, save_file) 267 | 268 | # This best accuracy is only for printing purpose. 269 | # The results reported in the paper/README is from the last epoch. 270 | print('best accuracy:', best_acc) 271 | 272 | # save model 273 | state = { 274 | 'opt': opt, 275 | 'model': model_s.state_dict(), 276 | } 277 | save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) 278 | torch.save(state, save_file) 279 | 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /cv/train_teacher.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import socket 6 | import time 7 | 8 | import tensorboard_logger as tb_logger 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn as nn 12 | import torch.backends.cudnn as cudnn 13 | 14 | from models import model_dict 15 | 16 | from dataset.cifar100 import get_cifar100_dataloaders 17 | 18 | from helper.util import adjust_learning_rate, accuracy, AverageMeter 19 | from helper.loops import train_vanilla as train, validate 20 | 21 | 22 | def parse_option(): 23 | 24 | hostname = socket.gethostname() 25 | 26 | parser = argparse.ArgumentParser('argument for training') 27 | 28 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 29 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 30 | parser.add_argument('--save_freq', type=int, default=40, help='save frequency') 31 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 32 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 33 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs') 34 | 35 | # optimization 36 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 37 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list') 38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 39 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 40 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 41 | 42 | # dataset 43 | parser.add_argument('--model', type=str, default='resnet110', 44 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 45 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2', 46 | 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 47 | 'MobileNetV2', 'ShuffleV1', 'ShuffleV2', ]) 48 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset') 49 | 50 | parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id') 51 | 52 | opt = parser.parse_args() 53 | 54 | # set different learning rate from these 4 models 55 | if opt.model in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']: 56 | opt.learning_rate = 0.01 57 | 58 | # set the path according to the environment 59 | if hostname.startswith('visiongpu'): 60 | opt.model_path = '/path/to/my/model' 61 | opt.tb_path = '/path/to/my/tensorboard' 62 | else: 63 | opt.model_path = './save/models' 64 | opt.tb_path = './save/tensorboard' 65 | 66 | iterations = opt.lr_decay_epochs.split(',') 67 | opt.lr_decay_epochs = list([]) 68 | for it in iterations: 69 | opt.lr_decay_epochs.append(int(it)) 70 | 71 | opt.model_name = '{}_{}_lr_{}_decay_{}_trial_{}'.format(opt.model, opt.dataset, opt.learning_rate, 72 | opt.weight_decay, opt.trial) 73 | 74 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 75 | if not os.path.isdir(opt.tb_folder): 76 | os.makedirs(opt.tb_folder) 77 | 78 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 79 | if not os.path.isdir(opt.save_folder): 80 | os.makedirs(opt.save_folder) 81 | 82 | return opt 83 | 84 | 85 | def main(): 86 | best_acc = 0 87 | 88 | opt = parse_option() 89 | 90 | # dataloader 91 | if opt.dataset == 'cifar100': 92 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers) 93 | n_cls = 100 94 | else: 95 | raise NotImplementedError(opt.dataset) 96 | 97 | # model 98 | model = model_dict[opt.model](num_classes=n_cls) 99 | 100 | # optimizer 101 | optimizer = optim.SGD(model.parameters(), 102 | lr=opt.learning_rate, 103 | momentum=opt.momentum, 104 | weight_decay=opt.weight_decay) 105 | 106 | criterion = nn.CrossEntropyLoss() 107 | 108 | if torch.cuda.is_available(): 109 | model = model.cuda() 110 | criterion = criterion.cuda() 111 | cudnn.benchmark = True 112 | 113 | # tensorboard 114 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 115 | 116 | # routine 117 | for epoch in range(1, opt.epochs + 1): 118 | 119 | adjust_learning_rate(epoch, opt, optimizer) 120 | print("==> training...") 121 | 122 | time1 = time.time() 123 | train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) 124 | time2 = time.time() 125 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 126 | 127 | logger.log_value('train_acc', train_acc, epoch) 128 | logger.log_value('train_loss', train_loss, epoch) 129 | 130 | test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt) 131 | 132 | logger.log_value('test_acc', test_acc, epoch) 133 | logger.log_value('test_acc_top5', test_acc_top5, epoch) 134 | logger.log_value('test_loss', test_loss, epoch) 135 | 136 | # save the best model 137 | if test_acc > best_acc: 138 | best_acc = test_acc 139 | state = { 140 | 'epoch': epoch, 141 | 'model': model.state_dict(), 142 | 'best_acc': best_acc, 143 | 'optimizer': optimizer.state_dict(), 144 | } 145 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model)) 146 | print('saving the best model!') 147 | torch.save(state, save_file) 148 | 149 | # regular saving 150 | if epoch % opt.save_freq == 0: 151 | print('==> Saving...') 152 | state = { 153 | 'epoch': epoch, 154 | 'model': model.state_dict(), 155 | 'accuracy': test_acc, 156 | 'optimizer': optimizer.state_dict(), 157 | } 158 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 159 | torch.save(state, save_file) 160 | 161 | # This best accuracy is only for printing purpose. 162 | # The results reported in the paper/README is from the last epoch. 163 | print('best accuracy:', best_acc) 164 | 165 | # save model 166 | state = { 167 | 'opt': opt, 168 | 'model': model.state_dict(), 169 | 'optimizer': optimizer.state_dict(), 170 | } 171 | save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model)) 172 | torch.save(state, save_file) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /nlp/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /nlp/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Canwen Xu and Wangchunshu Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nlp/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-PKD-for-BERT-Compression 2 | 3 | Pytorch implementation of the distillation method described in the following paper: [**Patient Knowledge Distillation for BERT Model Compression**](https://arxiv.org/abs/1908.09355). This repository heavily refers to [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) by huggingface. 4 | 5 | ## Steps to run the code 6 | ### 1. download glue_data 7 | ``` 8 | $ python download_glue_data.py 9 | ``` 10 | 11 | ### 2. Fine-tune teacher BERT model 12 | By running following code, save fine-tuned model. 13 | ``` 14 | python run_glue.py \ 15 | --model_type bert \ 16 | --model_name_or_path bert-base-uncased \ 17 | --task_name $TASK_NAME \ 18 | --do_train \ 19 | --do_eval \ 20 | --do_lower_case \ 21 | --data_dir $GLUE_DIR/$TASK_NAME \ 22 | --max_seq_length 128 \ 23 | --per_gpu_eval_batch_size=8 \ 24 | --per_gpu_train_batch_size=8 \ 25 | --learning_rate 2e-5 \ 26 | --num_train_epochs 3.0 \ 27 | --output_dir /tmp/$TASK_NAME/ 28 | ``` 29 | 30 | ### 3. distill student model with teacher BERT 31 | $TEACHER_MODEL is your fine-tuned model folder. 32 | ``` 33 | python run_glue_distillation.py \ 34 | --model_type bert \ 35 | --teacher_model $TEACHER_MODEL \ 36 | --student_model bert-base-uncased \ 37 | --task_name $TASK_NAME \ 38 | --num_hidden_layers 6 \ 39 | --alpha 0.5 \ 40 | --beta 100.0 \ 41 | --do_train \ 42 | --do_eval \ 43 | --do_lower_case \ 44 | --data_dir $GLUE_DIR/$TASK_NAME \ 45 | --max_seq_length 128 \ 46 | --per_gpu_eval_batch_size=8 \ 47 | --per_gpu_train_batch_size=8 \ 48 | --learning_rate 2e-5 \ 49 | --num_train_epochs 4.0 \ 50 | --output_dir /tmp/$TASK_NAME/ 51 | ``` 52 | 53 | ## Experimental Results on dev set 54 | model | num_layers | SST-2 | MRPC-f1/acc | QQP-f1/acc | MNLI-m/mm | QNLI | RTE 55 | -- | -- | -- | -- | -- | -- | -- | -- 56 | base | 12 | 0.9232 | 0.89/0.8358 | 0.8818/0.9121 | 0.8432/0.8479 | 0.916 | 0.6751 57 | finetuned | 6 | 0.9002 | 0.8741/0.8186 | 0.8672/0.901 | 0.8051/0.8033 | 0.8662 | 0.6101 58 | distill | 6 | 0.9071 | 0.8885/0.8382 | 0.8704/0.9016 | 0.8153/0.821 | 0.8642 | 0.6318 59 | -------------------------------------------------------------------------------- /nlp/distillation_meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from operator import itemgetter 5 | from functional_forward_bert import functional_bert_for_classification 6 | from transformers import PreTrainedModel 7 | 8 | 9 | class MetaPatientDistillation(nn.Module): 10 | def __init__(self, t_config, s_config): 11 | super(MetaPatientDistillation, self).__init__() 12 | self.t_config = t_config 13 | self.s_config = s_config 14 | 15 | def forward(self, t_model, s_model, order, input_ids, token_type_ids, attention_mask, labels, args, teacher_grad): 16 | if teacher_grad: 17 | t_outputs = t_model(input_ids=input_ids, 18 | token_type_ids=token_type_ids, 19 | attention_mask=attention_mask) 20 | else: 21 | with torch.no_grad(): 22 | t_outputs = t_model(input_ids=input_ids, 23 | token_type_ids=token_type_ids, 24 | attention_mask=attention_mask) 25 | 26 | if isinstance(s_model, PreTrainedModel): 27 | s_outputs = s_model(input_ids=input_ids, 28 | token_type_ids=token_type_ids, 29 | attention_mask=attention_mask, 30 | labels=labels) 31 | else: 32 | s_outputs = functional_bert_for_classification( 33 | s_model, 34 | self.s_config, 35 | input_ids=input_ids, 36 | token_type_ids=token_type_ids, 37 | attention_mask=attention_mask, 38 | labels=labels 39 | ) 40 | 41 | t_logits, t_features = t_outputs[0], t_outputs[-1] 42 | train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1] 43 | 44 | if args.logits_mse: 45 | soft_loss = F.mse_loss(t_logits, s_logits) 46 | else: 47 | T = args.temperature 48 | soft_targets = F.softmax(t_logits / T, dim=-1) 49 | 50 | probs = F.softmax(s_logits / T, dim=-1) 51 | soft_loss = F.mse_loss(soft_targets, probs) * T * T 52 | 53 | if args.beta == 0: # if beta=0, we don't even compute pkd_loss to save some time 54 | pkd_loss = torch.zeros_like(soft_loss) 55 | else: 56 | t_features = torch.cat(t_features[1:-1], dim=0).view(self.t_config.num_hidden_layers - 1, 57 | -1, 58 | args.max_seq_length, 59 | self.t_config.hidden_size)[:, :, 0] 60 | 61 | s_features = torch.cat(s_features[1:-1], dim=0).view(self.s_config.num_hidden_layers - 1, 62 | -1, 63 | args.max_seq_length, 64 | self.s_config.hidden_size)[:, :, 0] 65 | 66 | t_features = itemgetter(order)(t_features) 67 | t_features = t_features / t_features.norm(dim=-1).unsqueeze(-1) 68 | s_features = s_features / s_features.norm(dim=-1).unsqueeze(-1) 69 | pkd_loss = F.mse_loss(s_features, t_features, reduction="mean") 70 | 71 | return train_loss, soft_loss, pkd_loss 72 | 73 | def s_prime_forward(self, s_prime, input_ids, token_type_ids, attention_mask, labels, args): 74 | 75 | s_outputs = functional_bert_for_classification( 76 | s_prime, 77 | self.s_config, 78 | input_ids=input_ids, 79 | token_type_ids=token_type_ids, 80 | attention_mask=attention_mask, 81 | labels=labels, 82 | is_train=False 83 | ) 84 | 85 | train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1] 86 | 87 | return train_loss 88 | -------------------------------------------------------------------------------- /nlp/download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | Note: for legal reasons, we are unable to host MRPC. 3 | You can either use the version hosted by the SentEval team, which is already tokenized, 4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 6 | You should then rename and place specific files in a folder (see below for an example). 7 | mkdir MRPC 8 | cabextract MSRParaphraseCorpus.msi -d MRPC 9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 11 | rm MRPC/_* 12 | rm MSRParaphraseCorpus.msi 13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = { 27 | "CoLA": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 28 | "SST": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 29 | "MRPC": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 30 | "QQP": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 31 | "STS": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 32 | "MNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 33 | "SNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 34 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 35 | "RTE": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 36 | "WNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 37 | "diagnostic": 'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 38 | 39 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 40 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 41 | 42 | 43 | def download_and_extract(task, data_dir): 44 | print("Downloading and extracting %s..." % task) 45 | data_file = "%s.zip" % task 46 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 47 | with zipfile.ZipFile(data_file) as zip_ref: 48 | zip_ref.extractall(data_dir) 49 | os.remove(data_file) 50 | print("\tCompleted!") 51 | 52 | 53 | def format_mrpc(data_dir, path_to_data): 54 | print("Processing MRPC...") 55 | mrpc_dir = os.path.join(data_dir, "MRPC") 56 | if not os.path.isdir(mrpc_dir): 57 | os.mkdir(mrpc_dir) 58 | if path_to_data: 59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 61 | else: 62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 70 | 71 | dev_ids = [] 72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 73 | for row in ids_fh: 74 | dev_ids.append(row.strip().split('\t')) 75 | 76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 79 | header = data_fh.readline() 80 | train_fh.write(header) 81 | dev_fh.write(header) 82 | for row in data_fh: 83 | label, id1, id2, s1, s2 = row.strip().split('\t') 84 | if [id1, id2] in dev_ids: 85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 86 | else: 87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 88 | 89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 91 | header = data_fh.readline() 92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 93 | for idx, row in enumerate(data_fh): 94 | label, id1, id2, s1, s2 = row.strip().split('\t') 95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 96 | print("\tCompleted!") 97 | 98 | 99 | def download_diagnostic(data_dir): 100 | print("Downloading and extracting diagnostic...") 101 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 102 | os.mkdir(os.path.join(data_dir, "diagnostic")) 103 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 104 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 105 | print("\tCompleted!") 106 | return 107 | 108 | 109 | def get_tasks(task_names): 110 | task_names = task_names.split(',') 111 | if "all" in task_names: 112 | tasks = TASKS 113 | else: 114 | tasks = [] 115 | for task_name in task_names: 116 | assert task_name in TASKS, "Task %s not found!" % task_name 117 | tasks.append(task_name) 118 | return tasks 119 | 120 | 121 | def main(arguments): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 124 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 125 | type=str, default='all') 126 | parser.add_argument('--path_to_mrpc', 127 | help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 128 | type=str, default='') 129 | args = parser.parse_args(arguments) 130 | 131 | if not os.path.isdir(args.data_dir): 132 | os.mkdir(args.data_dir) 133 | tasks = get_tasks(args.tasks) 134 | 135 | for task in tasks: 136 | if task == 'MRPC': 137 | format_mrpc(args.data_dir, args.path_to_mrpc) 138 | elif task == 'diagnostic': 139 | download_diagnostic(args.data_dir) 140 | else: 141 | download_and_extract(task, args.data_dir) 142 | 143 | 144 | if __name__ == '__main__': 145 | sys.exit(main(sys.argv[1:])) 146 | -------------------------------------------------------------------------------- /nlp/mrpc_hyperparameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": [42, 2022, 12], 3 | "max_seq_length": [128], 4 | "per_gpu_train_batch_size": [4,8], 5 | "per_gpu_eval_batch_size": [8], 6 | "gradient_accumulation_steps": [1], 7 | "num_train_epochs": [4,6,8], 8 | "learning_rate": [2e-5, 1e-5], 9 | "evaluate_during_training": ["True"], 10 | "logging_steps": [50], 11 | "save_steps": [50], 12 | "model_type": ["bert"], 13 | "model_name_or_path": ["bert-base-uncased"], 14 | "task_name": ["MRPC"], 15 | "data_dir": ["glue_data/MRPC"], 16 | "warmup_steps": [100,150] 17 | } -------------------------------------------------------------------------------- /nlp/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | -------------------------------------------------------------------------------- /nlp/run_hyperparameter_tuning.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | # import sys 5 | # sys.path.append("/".join(os.getcwd().split("/")[:-1]) + "/") 6 | 7 | def convertTextToParams(input_file): 8 | param_dict_seq = [] 9 | with open(input_file, "r") as f: 10 | header = None 11 | for cur_line in f: 12 | if header is None: 13 | header = list(cur_line.split()) 14 | else: 15 | cur_row = list(cur_line.split()) 16 | cur_dict = {} 17 | for i in range(len(header)): 18 | cur_dict[header[i]] = str(cur_row[i]) 19 | param_dict_seq.append(cur_dict) 20 | 21 | return param_dict_seq 22 | 23 | 24 | param_dict_seq_global = [] 25 | def computeAllParamCombinations(key_list, cur_key_idx, param_dict, cur_param_seq_dict): 26 | global param_dict_seq_global 27 | if cur_key_idx >= len(key_list): 28 | param_dict_seq_global.append(cur_param_seq_dict) 29 | return 30 | 31 | cur_key = key_list[cur_key_idx] 32 | for cur_val in param_dict[cur_key]: 33 | cur_param_seq_dict[cur_key] = str(cur_val) 34 | computeAllParamCombinations(key_list, cur_key_idx + 1, param_dict, cur_param_seq_dict.copy()) 35 | 36 | def convertJsonToParams(input_file): 37 | global param_dict_seq_global 38 | with open(input_file, "r") as f: 39 | param_dict = json.load(f) 40 | 41 | param_keys = list(param_dict.keys()) 42 | computeAllParamCombinations(param_keys, 0, param_dict, {}) 43 | return param_dict_seq_global 44 | 45 | def convertDictToCmdArgs(input_dict): 46 | out_string = "" 47 | for key, val in input_dict.items(): 48 | out_string += "--" + str(key) + " " + str(val) + " " 49 | return out_string 50 | 51 | def createFolderNameFromParamDict(input_dict): 52 | out_string = "" 53 | for key, val in input_dict.items(): 54 | key_split = "".join([x[0] for x in key.lower().split("_")]) 55 | val = val.replace("/", "_") 56 | out_string += key_split + "_" + str(val).lower() + "_" 57 | 58 | return out_string[:-1] 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--finetune_file', type=str, required=True, 63 | help='Finetuning file') 64 | parser.add_argument('--param_file', type=str, required=True, 65 | help='Hyperparameter file. Can be in ".json" or ".tsv" format') 66 | parser.add_argument('--root_output_dir', type=str, default="temp", 67 | help='Root output directory') 68 | args = parser.parse_known_args()[0] 69 | 70 | if args.param_file.endswith(".json"): 71 | param_dict_seq = convertJsonToParams(args.param_file) 72 | else: 73 | param_dict_seq = convertTextToParams(args.param_file) 74 | 75 | total_config_counts = len(param_dict_seq) 76 | cur_config_count = 0 77 | for cur_param_seq in param_dict_seq: 78 | cur_config_count += 1 79 | print("Running configuration {} out of {}:".format(cur_config_count, total_config_counts)) 80 | print(cur_param_seq, "\n") 81 | 82 | cur_output_folder = os.path.join(args.root_output_dir, createFolderNameFromParamDict(cur_param_seq)) 83 | 84 | # Finetuning: 85 | finetune_out_dir = os.path.join(cur_output_folder, 'finetuning') 86 | 87 | # Create a folder if output_dir doesn't exists: (needed for storing the logs file) 88 | if not os.path.exists(finetune_out_dir): 89 | os.makedirs(finetune_out_dir) 90 | 91 | finetune_log_file = os.path.join(finetune_out_dir, 'logs.txt') 92 | 93 | finetune_cmd = "CUDA_VISIBLE_DEVICES=0 python3 " + args.finetune_file + " " + convertDictToCmdArgs(cur_param_seq)\ 94 | + "--do_train --do_eval --do_lower_case " + " --output_dir " + finetune_out_dir + " > " + finetune_log_file 95 | 96 | print(finetune_cmd, "\n") 97 | os.system(finetune_cmd) 98 | --------------------------------------------------------------------------------