├── .gitignore
├── LICENSE
├── README.md
├── cv
├── .gitignore
├── LICENSE
├── README.md
├── crd
│ ├── __init__.py
│ ├── criterion.py
│ └── memory.py
├── dataset
│ ├── cifar100.py
│ ├── cifar_with_held.py
│ ├── imagenet.py
│ └── meta_cifar100.py
├── distiller_zoo
│ ├── AB.py
│ ├── AT.py
│ ├── CC.py
│ ├── FSP.py
│ ├── FT.py
│ ├── FitNet.py
│ ├── KD.py
│ ├── KDSVD.py
│ ├── MSE.py
│ ├── NST.py
│ ├── PKT.py
│ ├── RKD.py
│ ├── SP.py
│ ├── VID.py
│ └── __init__.py
├── helper
│ ├── __init__.py
│ ├── loops.py
│ ├── meta_loops.py
│ ├── pretrain.py
│ └── util.py
├── models
│ ├── ShuffleNetv1.py
│ ├── ShuffleNetv2.py
│ ├── __init__.py
│ ├── classifier.py
│ ├── meta_resnet.py
│ ├── meta_resnet_v2.py
│ ├── meta_vgg.py
│ ├── mobilenetv2.py
│ ├── resnet.py
│ ├── resnetv2.py
│ ├── util.py
│ ├── vgg.py
│ └── wrn.py
├── scripts
│ ├── fetch_pretrained_teachers.sh
│ ├── run_cifar_distill.sh
│ └── run_cifar_vanilla.sh
├── train_student.py
├── train_student_debug.py
├── train_student_meta.py
└── train_teacher.py
└── nlp
├── .gitignore
├── LICENSE
├── README.md
├── distillation_meta.py
├── download_glue_data.py
├── functional_forward_bert.py
├── mrpc_hyperparameters.json
├── requirements.txt
├── run_glue.py
├── run_glue_distillation_meta.py
├── run_hyperparameter_tuning.py
└── utils_glue.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kevin Canwen Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MetaDistil
2 | Code for ACL 2022 paper ["BERT Learns to Teach: Knowledge Distillation with Meta Learning"](https://arxiv.org/abs/2106.04570).
3 |
4 | ## ⚠️ Read before use
5 | Since the release of this paper on arXiv, we have received a lot of requests for the code. Thus, we want to first release the code without cleaning up. We know implementing a second-order approach is non-trivial so we want to help you but please note that the current code may contain bugs, useless codes, incorrect settings etc. **Please use at your own risk.** We'll later verify the code and clean it up once we have the chance to do so.
6 |
7 | ## Acknowledgments
8 | The implementation of image classification is based on https://github.com/HobbitLong/RepDistiller
9 |
10 | The implementation of text classification is based on https://github.com/bzantium/pytorch-PKD-for-BERT-compression
11 |
12 | Shout out to the authors of these two repos.
13 |
14 | ## How to use the code
15 | To be added. For now, please see `/nlp/run_glue_distillation_meta.py` and `/cv/train_student_meta.py`.
16 |
--------------------------------------------------------------------------------
/cv/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | output*/
4 | ckpts/
5 | *.pth
6 | *.t7
7 | *.png
8 | *.jpg
9 | tmp*.py
10 |
11 | *.pdf
12 |
13 |
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 | *$py.class
18 |
19 | # C extensions
20 | *.so
21 |
22 | # Distribution / packaging
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 |
--------------------------------------------------------------------------------
/cv/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, Yonglong Tian
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/cv/README.md:
--------------------------------------------------------------------------------
1 | # RepDistiller
2 |
3 | This repo:
4 |
5 | **(1) covers the implementation of the following ICLR 2020 paper:**
6 |
7 | "Contrastive Representation Distillation" (CRD). [Paper](http://arxiv.org/abs/1910.10699), [Project Page](http://hobbitlong.github.io/CRD/).
8 |
9 |

10 |
11 |
12 |
13 | **(2) benchmarks 12 state-of-the-art knowledge distillation methods in PyTorch, including:**
14 |
15 | (KD) - Distilling the Knowledge in a Neural Network
16 | (FitNet) - Fitnets: hints for thin deep nets
17 | (AT) - Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
18 | via Attention Transfer
19 | (SP) - Similarity-Preserving Knowledge Distillation
20 | (CC) - Correlation Congruence for Knowledge Distillation
21 | (VID) - Variational Information Distillation for Knowledge Transfer
22 | (RKD) - Relational Knowledge Distillation
23 | (PKT) - Probabilistic Knowledge Transfer for deep representation learning
24 | (AB) - Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
25 | (FT) - Paraphrasing Complex Network: Network Compression via Factor Transfer
26 | (FSP) - A Gift from Knowledge Distillation:
27 | Fast Optimization, Network Minimization and Transfer Learning
28 | (NST) - Like what you like: knowledge distill via neuron selectivity transfer
29 |
30 | ## Installation
31 |
32 | This repo was tested with Ubuntu 16.04.5 LTS, Python 3.5, PyTorch 0.4.0, and CUDA 9.0. But it should be runnable with recent PyTorch versions >=0.4.0
33 |
34 | ## Running
35 |
36 | 1. Fetch the pretrained teacher models by:
37 |
38 | ```
39 | sh scripts/fetch_pretrained_teachers.sh
40 | ```
41 | which will download and save the models to `save/models`
42 |
43 | 2. Run distillation by following commands in `scripts/run_cifar_distill.sh`. An example of running Geoffrey's original Knowledge Distillation (KD) is given by:
44 |
45 | ```
46 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 0.1 -a 0.9 -b 0 --trial 1
47 | ```
48 | where the flags are explained as:
49 | - `--path_t`: specify the path of the teacher model
50 | - `--model_s`: specify the student model, see 'models/\_\_init\_\_.py' to check the available model types.
51 | - `--distill`: specify the distillation method
52 | - `-r`: the weight of the cross-entropy loss between logit and ground truth, default: `1`
53 | - `-a`: the weight of the KD loss, default: `None`
54 | - `-b`: the weight of other distillation losses, default: `None`
55 | - `--trial`: specify the experimental id to differentiate between multiple runs.
56 |
57 | Therefore, the command for running CRD is something like:
58 | ```
59 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 0 -b 0.8 --trial 1
60 | ```
61 |
62 | 3. Combining a distillation objective with KD is simply done by setting `-a` as a non-zero value, which results in the following example (combining CRD with KD)
63 | ```
64 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 1 -b 0.8 --trial 1
65 | ```
66 |
67 | 4. (optional) Train teacher networks from scratch. Example commands are in `scripts/run_cifar_vanilla.sh`
68 |
69 | Note: the default setting is for a single-GPU training. If you would like to play this repo with multiple GPUs, you might need to tune the learning rate, which empirically needs to be scaled up linearly with the batch size, see [this paper](https://arxiv.org/abs/1706.02677)
70 |
71 | ## Benchmark Results on CIFAR-100:
72 |
73 | Performance is measured by classification accuracy (%)
74 |
75 | 1. Teacher and student are of the **same** architectural type.
76 |
77 | | Teacher
Student | wrn-40-2
wrn-16-2 | wrn-40-2
wrn-40-1 | resnet56
resnet20 | resnet110
resnet20 | resnet110
resnet32 | resnet32x4
resnet8x4 | vgg13
vgg8 |
78 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:|:-----------:|
79 | | Teacher
Student | 75.61
73.26 | 75.61
71.98 | 72.34
69.06 | 74.31
69.06 | 74.31
71.14 | 79.42
72.50 | 74.64
70.36 |
80 | | KD | 74.92 | 73.54 | 70.66 | 70.67 | 73.08 | 73.33 | 72.98 |
81 | | FitNet | 73.58 | 72.24 | 69.21 | 68.99 | 71.06 | 73.50 | 71.02 |
82 | | AT | 74.08 | 72.77 | 70.55 | 70.22 | 72.31 | 73.44 | 71.43 |
83 | | SP | 73.83 | 72.43 | 69.67 | 70.04 | 72.69 | 72.94 | 72.68 |
84 | | CC | 73.56 | 72.21 | 69.63 | 69.48 | 71.48 | 72.97 | 70.71 |
85 | | VID | 74.11 | 73.30 | 70.38 | 70.16 | 72.61 | 73.09 | 71.23 |
86 | | RKD | 73.35 | 72.22 | 69.61 | 69.25 | 71.82 | 71.90 | 71.48 |
87 | | PKT | 74.54 | 73.45 | 70.34 | 70.25 | 72.61 | 73.64 | 72.88 |
88 | | AB | 72.50 | 72.38 | 69.47 | 69.53 | 70.98 | 73.17 | 70.94 |
89 | | FT | 73.25 | 71.59 | 69.84 | 70.22 | 72.37 | 72.86 | 70.58 |
90 | | FSP | 72.91 | 0.00 | 69.95 | 70.11 | 71.89 | 72.62 | 70.23 |
91 | | NST | 73.68 | 72.24 | 69.60 | 69.53 | 71.96 | 73.30 | 71.53 |
92 | | **CRD** | **75.48** | **74.14** | **71.16** | **71.46** | **73.48** | **75.51** | **73.94** |
93 |
94 | 2. Teacher and student are of **different** architectural type.
95 |
96 | | Teacher
Student | vgg13
MobileNetV2 | ResNet50
MobileNetV2 | ResNet50
vgg8 | resnet32x4
ShuffleNetV1 | resnet32x4
ShuffleNetV2 | wrn-40-2
ShuffleNetV1 |
97 | |:---------------:|:-----------------:|:--------------------:|:-------------:|:-----------------------:|:-----------------------:|:---------------------:|
98 | | Teacher
Student | 74.64
64.60 | 79.34
64.60 | 79.34
70.36 | 79.42
70.50 | 79.42
71.82 | 75.61
70.50 |
99 | | KD | 67.37 | 67.35 | 73.81 | 74.07 | 74.45 | 74.83 |
100 | | FitNet | 64.14 | 63.16 | 70.69 | 73.59 | 73.54 | 73.73 |
101 | | AT | 59.40 | 58.58 | 71.84 | 71.73 | 72.73 | 73.32 |
102 | | SP | 66.30 | 68.08 | 73.34 | 73.48 | 74.56 | 74.52 |
103 | | CC | 64.86 | 65.43 | 70.25 | 71.14 | 71.29 | 71.38 |
104 | | VID | 65.56 | 67.57 | 70.30 | 73.38 | 73.40 | 73.61 |
105 | | RKD | 64.52 | 64.43 | 71.50 | 72.28 | 73.21 | 72.21 |
106 | | PKT | 67.13 | 66.52 | 73.01 | 74.10 | 74.69 | 73.89 |
107 | | AB | 66.06 | 67.20 | 70.65 | 73.55 | 74.31 | 73.34 |
108 | | FT | 61.78 | 60.99 | 70.29 | 71.75 | 72.50 | 72.03 |
109 | | NST | 58.16 | 64.96 | 71.28 | 74.12 | 74.68 | 74.89 |
110 | | **CRD** | **69.73** | **69.11** | **74.30** | **75.11** | **75.65** | **76.05** |
111 |
112 | ## Citation
113 |
114 | If you find this repo useful for your research, please consider citing the paper
115 |
116 | ```
117 | @inproceedings{tian2019crd,
118 | title={Contrastive Representation Distillation},
119 | author={Yonglong Tian and Dilip Krishnan and Phillip Isola},
120 | booktitle={International Conference on Learning Representations},
121 | year={2020}
122 | }
123 | ```
124 | For any questions, please contact Yonglong Tian (yonglong@mit.edu).
125 |
126 | ## Acknowledgement
127 |
128 | Thanks to Baoyun Peng for providing the code of CC and to Frederick Tung for verifying our reimplementation of SP. Thanks also go to authors of other papers who make their code publicly available.
129 |
--------------------------------------------------------------------------------
/cv/crd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JetRunner/MetaDistil/80e60c11de531b10d1f06ceb2b71c70665bb6aff/cv/crd/__init__.py
--------------------------------------------------------------------------------
/cv/crd/criterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .memory import ContrastMemory
4 |
5 | eps = 1e-7
6 |
7 |
8 | class CRDLoss(nn.Module):
9 | """CRD Loss function
10 | includes two symmetric parts:
11 | (a) using teacher as anchor, choose positive and negatives over the student side
12 | (b) using student as anchor, choose positive and negatives over the teacher side
13 |
14 | Args:
15 | opt.s_dim: the dimension of student's feature
16 | opt.t_dim: the dimension of teacher's feature
17 | opt.feat_dim: the dimension of the projection space
18 | opt.nce_k: number of negatives paired with each positive
19 | opt.nce_t: the temperature
20 | opt.nce_m: the momentum for updating the memory buffer
21 | opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
22 | """
23 | def __init__(self, opt):
24 | super(CRDLoss, self).__init__()
25 | self.embed_s = Embed(opt.s_dim, opt.feat_dim)
26 | self.embed_t = Embed(opt.t_dim, opt.feat_dim)
27 | self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
28 | self.criterion_t = ContrastLoss(opt.n_data)
29 | self.criterion_s = ContrastLoss(opt.n_data)
30 |
31 | def forward(self, f_s, f_t, idx, contrast_idx=None):
32 | """
33 | Args:
34 | f_s: the feature of student network, size [batch_size, s_dim]
35 | f_t: the feature of teacher network, size [batch_size, t_dim]
36 | idx: the indices of these positive samples in the dataset, size [batch_size]
37 | contrast_idx: the indices of negative samples, size [batch_size, nce_k]
38 |
39 | Returns:
40 | The contrastive loss
41 | """
42 | f_s = self.embed_s(f_s)
43 | f_t = self.embed_t(f_t)
44 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
45 | s_loss = self.criterion_s(out_s)
46 | t_loss = self.criterion_t(out_t)
47 | loss = s_loss + t_loss
48 | return loss
49 |
50 |
51 | class ContrastLoss(nn.Module):
52 | """
53 | contrastive loss, corresponding to Eq (18)
54 | """
55 | def __init__(self, n_data):
56 | super(ContrastLoss, self).__init__()
57 | self.n_data = n_data
58 |
59 | def forward(self, x):
60 | bsz = x.shape[0]
61 | m = x.size(1) - 1
62 |
63 | # noise distribution
64 | Pn = 1 / float(self.n_data)
65 |
66 | # loss for positive pair
67 | P_pos = x.select(1, 0)
68 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
69 |
70 | # loss for K negative pair
71 | P_neg = x.narrow(1, 1, m)
72 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
73 |
74 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
75 |
76 | return loss
77 |
78 |
79 | class Embed(nn.Module):
80 | """Embedding module"""
81 | def __init__(self, dim_in=1024, dim_out=128):
82 | super(Embed, self).__init__()
83 | self.linear = nn.Linear(dim_in, dim_out)
84 | self.l2norm = Normalize(2)
85 |
86 | def forward(self, x):
87 | x = x.view(x.shape[0], -1)
88 | x = self.linear(x)
89 | x = self.l2norm(x)
90 | return x
91 |
92 |
93 | class Normalize(nn.Module):
94 | """normalization layer"""
95 | def __init__(self, power=2):
96 | super(Normalize, self).__init__()
97 | self.power = power
98 |
99 | def forward(self, x):
100 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
101 | out = x.div(norm)
102 | return out
103 |
--------------------------------------------------------------------------------
/cv/crd/memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 |
6 | class ContrastMemory(nn.Module):
7 | """
8 | memory buffer that supplies large amount of negative samples.
9 | """
10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5):
11 | super(ContrastMemory, self).__init__()
12 | self.nLem = outputSize
13 | self.unigrams = torch.ones(self.nLem)
14 | self.multinomial = AliasMethod(self.unigrams)
15 | self.multinomial.cuda()
16 | self.K = K
17 |
18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum]))
19 | stdv = 1. / math.sqrt(inputSize / 3)
20 | self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
21 | self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
22 |
23 | def forward(self, v1, v2, y, idx=None):
24 | K = int(self.params[0].item())
25 | T = self.params[1].item()
26 | Z_v1 = self.params[2].item()
27 | Z_v2 = self.params[3].item()
28 |
29 | momentum = self.params[4].item()
30 | batchSize = v1.size(0)
31 | outputSize = self.memory_v1.size(0)
32 | inputSize = self.memory_v1.size(1)
33 |
34 | # original score computation
35 | if idx is None:
36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
37 | idx.select(1, 0).copy_(y.data)
38 | # sample
39 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach()
40 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize)
41 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1))
42 | out_v2 = torch.exp(torch.div(out_v2, T))
43 | # sample
44 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach()
45 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize)
46 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1))
47 | out_v1 = torch.exp(torch.div(out_v1, T))
48 |
49 | # set Z if haven't been set yet
50 | if Z_v1 < 0:
51 | self.params[2] = out_v1.mean() * outputSize
52 | Z_v1 = self.params[2].clone().detach().item()
53 | print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1))
54 | if Z_v2 < 0:
55 | self.params[3] = out_v2.mean() * outputSize
56 | Z_v2 = self.params[3].clone().detach().item()
57 | print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2))
58 |
59 | # compute out_v1, out_v2
60 | out_v1 = torch.div(out_v1, Z_v1).contiguous()
61 | out_v2 = torch.div(out_v2, Z_v2).contiguous()
62 |
63 | # update memory
64 | with torch.no_grad():
65 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1))
66 | l_pos.mul_(momentum)
67 | l_pos.add_(torch.mul(v1, 1 - momentum))
68 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
69 | updated_v1 = l_pos.div(l_norm)
70 | self.memory_v1.index_copy_(0, y, updated_v1)
71 |
72 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1))
73 | ab_pos.mul_(momentum)
74 | ab_pos.add_(torch.mul(v2, 1 - momentum))
75 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
76 | updated_v2 = ab_pos.div(ab_norm)
77 | self.memory_v2.index_copy_(0, y, updated_v2)
78 |
79 | return out_v1, out_v2
80 |
81 |
82 | class AliasMethod(object):
83 | """
84 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
85 | """
86 | def __init__(self, probs):
87 |
88 | if probs.sum() > 1:
89 | probs.div_(probs.sum())
90 | K = len(probs)
91 | self.prob = torch.zeros(K)
92 | self.alias = torch.LongTensor([0]*K)
93 |
94 | # Sort the data into the outcomes with probabilities
95 | # that are larger and smaller than 1/K.
96 | smaller = []
97 | larger = []
98 | for kk, prob in enumerate(probs):
99 | self.prob[kk] = K*prob
100 | if self.prob[kk] < 1.0:
101 | smaller.append(kk)
102 | else:
103 | larger.append(kk)
104 |
105 | # Loop though and create little binary mixtures that
106 | # appropriately allocate the larger outcomes over the
107 | # overall uniform mixture.
108 | while len(smaller) > 0 and len(larger) > 0:
109 | small = smaller.pop()
110 | large = larger.pop()
111 |
112 | self.alias[small] = large
113 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
114 |
115 | if self.prob[large] < 1.0:
116 | smaller.append(large)
117 | else:
118 | larger.append(large)
119 |
120 | for last_one in smaller+larger:
121 | self.prob[last_one] = 1
122 |
123 | def cuda(self):
124 | self.prob = self.prob.cuda()
125 | self.alias = self.alias.cuda()
126 |
127 | def draw(self, N):
128 | """ Draw N samples from multinomial """
129 | K = self.alias.size(0)
130 |
131 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
132 | prob = self.prob.index_select(0, kk)
133 | alias = self.alias.index_select(0, kk)
134 | # b is whether a random number is greater than q
135 | b = torch.bernoulli(prob)
136 | oq = kk.mul(b.long())
137 | oj = alias.mul((1-b).long())
138 |
139 | return oq + oj
--------------------------------------------------------------------------------
/cv/dataset/cifar100.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import socket
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets, transforms
8 | from PIL import Image
9 |
10 | """
11 | mean = {
12 | 'cifar100': (0.5071, 0.4867, 0.4408),
13 | }
14 |
15 | std = {
16 | 'cifar100': (0.2675, 0.2565, 0.2761),
17 | }
18 | """
19 |
20 |
21 | def get_data_folder():
22 | """
23 | return server-dependent path to store the data
24 | """
25 | hostname = socket.gethostname()
26 | if hostname.startswith('visiongpu'):
27 | data_folder = '/data/vision/phillipi/rep-learn/datasets'
28 | elif hostname.startswith('yonglong-home'):
29 | data_folder = '/home/yonglong/Data/data'
30 | else:
31 | data_folder = './data/'
32 |
33 | if not os.path.isdir(data_folder):
34 | os.makedirs(data_folder)
35 |
36 | return data_folder
37 |
38 |
39 | class CIFAR100Instance(datasets.CIFAR100):
40 | """CIFAR100Instance Dataset.
41 | """
42 | def __getitem__(self, index):
43 | if self.train:
44 | img, target = self.train_data[index], self.train_labels[index]
45 | else:
46 | img, target = self.test_data[index], self.test_labels[index]
47 |
48 | # doing this so that it is consistent with all other datasets
49 | # to return a PIL Image
50 | img = Image.fromarray(img)
51 |
52 | if self.transform is not None:
53 | img = self.transform(img)
54 |
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 |
58 | return img, target, index
59 |
60 |
61 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False):
62 | """
63 | cifar 100
64 | """
65 | data_folder = get_data_folder()
66 |
67 | train_transform = transforms.Compose([
68 | transforms.RandomCrop(32, padding=4),
69 | transforms.RandomHorizontalFlip(),
70 | transforms.ToTensor(),
71 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
72 | ])
73 | test_transform = transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
76 | ])
77 |
78 | if is_instance:
79 | train_set = CIFAR100Instance(root=data_folder,
80 | download=True,
81 | train=True,
82 | transform=train_transform)
83 | n_data = len(train_set)
84 | else:
85 | train_set = datasets.CIFAR100(root=data_folder,
86 | download=True,
87 | train=True,
88 | transform=train_transform)
89 | train_loader = DataLoader(train_set,
90 | batch_size=batch_size,
91 | shuffle=True,
92 | num_workers=num_workers)
93 |
94 | test_set = datasets.CIFAR100(root=data_folder,
95 | download=True,
96 | train=False,
97 | transform=test_transform)
98 | test_loader = DataLoader(test_set,
99 | batch_size=int(batch_size/2),
100 | shuffle=False,
101 | num_workers=int(num_workers/2))
102 |
103 | if is_instance:
104 | return train_loader, test_loader, n_data
105 | else:
106 | return train_loader, test_loader
107 |
108 |
109 | class CIFAR100InstanceSample(datasets.CIFAR100):
110 | """
111 | CIFAR100Instance+Sample Dataset
112 | """
113 | def __init__(self, root, train=True,
114 | transform=None, target_transform=None,
115 | download=False, k=4096, mode='exact', is_sample=True, percent=1.0):
116 | super().__init__(root=root, train=train, download=download,
117 | transform=transform, target_transform=target_transform)
118 | self.k = k
119 | self.mode = mode
120 | self.is_sample = is_sample
121 |
122 | num_classes = 100
123 | if self.train:
124 | num_samples = len(self.train_data)
125 | label = self.train_labels
126 | else:
127 | num_samples = len(self.test_data)
128 | label = self.test_labels
129 |
130 | self.cls_positive = [[] for i in range(num_classes)]
131 | for i in range(num_samples):
132 | self.cls_positive[label[i]].append(i)
133 |
134 | self.cls_negative = [[] for i in range(num_classes)]
135 | for i in range(num_classes):
136 | for j in range(num_classes):
137 | if j == i:
138 | continue
139 | self.cls_negative[i].extend(self.cls_positive[j])
140 |
141 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
142 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
143 |
144 | if 0 < percent < 1:
145 | n = int(len(self.cls_negative[0]) * percent)
146 | self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n]
147 | for i in range(num_classes)]
148 |
149 | self.cls_positive = np.asarray(self.cls_positive)
150 | self.cls_negative = np.asarray(self.cls_negative)
151 |
152 | def __getitem__(self, index):
153 | if self.train:
154 | img, target = self.train_data[index], self.train_labels[index]
155 | else:
156 | img, target = self.test_data[index], self.test_labels[index]
157 |
158 | # doing this so that it is consistent with all other datasets
159 | # to return a PIL Image
160 | img = Image.fromarray(img)
161 |
162 | if self.transform is not None:
163 | img = self.transform(img)
164 |
165 | if self.target_transform is not None:
166 | target = self.target_transform(target)
167 |
168 | if not self.is_sample:
169 | # directly return
170 | return img, target, index
171 | else:
172 | # sample contrastive examples
173 | if self.mode == 'exact':
174 | pos_idx = index
175 | elif self.mode == 'relax':
176 | pos_idx = np.random.choice(self.cls_positive[target], 1)
177 | pos_idx = pos_idx[0]
178 | else:
179 | raise NotImplementedError(self.mode)
180 | replace = True if self.k > len(self.cls_negative[target]) else False
181 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
182 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
183 | return img, target, index, sample_idx
184 |
185 |
186 | def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact',
187 | is_sample=True, percent=1.0):
188 | """
189 | cifar 100
190 | """
191 | data_folder = get_data_folder()
192 |
193 | train_transform = transforms.Compose([
194 | transforms.RandomCrop(32, padding=4),
195 | transforms.RandomHorizontalFlip(),
196 | transforms.ToTensor(),
197 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
198 | ])
199 | test_transform = transforms.Compose([
200 | transforms.ToTensor(),
201 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
202 | ])
203 |
204 | train_set = CIFAR100InstanceSample(root=data_folder,
205 | download=True,
206 | train=True,
207 | transform=train_transform,
208 | k=k,
209 | mode=mode,
210 | is_sample=is_sample,
211 | percent=percent)
212 | n_data = len(train_set)
213 | train_loader = DataLoader(train_set,
214 | batch_size=batch_size,
215 | shuffle=True,
216 | num_workers=num_workers)
217 |
218 | test_set = datasets.CIFAR100(root=data_folder,
219 | download=True,
220 | train=False,
221 | transform=test_transform)
222 | test_loader = DataLoader(test_set,
223 | batch_size=int(batch_size/2),
224 | shuffle=False,
225 | num_workers=int(num_workers/2))
226 |
227 | return train_loader, test_loader, n_data
228 |
--------------------------------------------------------------------------------
/cv/dataset/cifar_with_held.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | import sys
7 |
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 |
13 | from torchvision.datasets.vision import VisionDataset
14 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive
15 |
16 |
17 | class CIFAR100WithHeld(VisionDataset):
18 | """`CIFAR10 `_ Dataset.
19 |
20 | Args:
21 | root (string): Root directory of dataset where directory
22 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
23 | train (bool, optional): If True, creates dataset from training set, otherwise
24 | creates from test set.
25 | transform (callable, optional): A function/transform that takes in an PIL image
26 | and returns a transformed version. E.g, ``transforms.RandomCrop``
27 | target_transform (callable, optional): A function/transform that takes in the
28 | target and transforms it.
29 | download (bool, optional): If true, downloads the dataset from the internet and
30 | puts it in root directory. If dataset is already downloaded, it is not
31 | downloaded again.
32 |
33 | """
34 | base_folder = 'cifar-100-python'
35 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
36 | filename = "cifar-100-python.tar.gz"
37 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
38 | train_list = [
39 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
40 | ]
41 |
42 | test_list = [
43 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
44 | ]
45 | meta = {
46 | 'filename': 'meta',
47 | 'key': 'fine_label_names',
48 | 'md5': '7973b15100ade9c7d40fb424638fde48',
49 | }
50 |
51 | def __init__(self, root, train=True, transform=None, target_transform=None,
52 | download=False, held=False, held_samples=0):
53 |
54 | super(CIFAR100WithHeld, self).__init__(root, transform=transform,
55 | target_transform=target_transform)
56 |
57 | self.train = train # training set or test set
58 | self.held = held
59 | self.held_samples = held_samples
60 |
61 | if download:
62 | self.download()
63 |
64 | if not self._check_integrity():
65 | raise RuntimeError('Dataset not found or corrupted.' +
66 | ' You can use download=True to download it')
67 |
68 | if self.held:
69 | assert self.train, "Held set is a subset of train set"
70 |
71 | if self.train:
72 | downloaded_list = self.train_list
73 | else:
74 | downloaded_list = self.test_list
75 |
76 | self.data = []
77 | self.targets = []
78 |
79 | # now load the picked numpy arrays
80 | for file_name, checksum in downloaded_list:
81 | file_path = os.path.join(self.root, self.base_folder, file_name)
82 | with open(file_path, 'rb') as f:
83 | if sys.version_info[0] == 2:
84 | entry = pickle.load(f)
85 | else:
86 | entry = pickle.load(f, encoding='latin1')
87 | self.data.append(entry['data'])
88 | if 'labels' in entry:
89 | self.targets.extend(entry['labels'])
90 | else:
91 | self.targets.extend(entry['fine_labels'])
92 |
93 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
94 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
95 |
96 | if self.held:
97 | self.data = self.data[-self.held_samples:]
98 | self.targets = self.targets[-self.held_samples:]
99 | elif self.train:
100 | self.data = self.data[:-self.held_samples]
101 | self.targets = self.targets[:-self.held_samples]
102 |
103 | self._load_meta()
104 |
105 | def _load_meta(self):
106 | path = os.path.join(self.root, self.base_folder, self.meta['filename'])
107 | if not check_integrity(path, self.meta['md5']):
108 | raise RuntimeError('Dataset metadata file not found or corrupted.' +
109 | ' You can use download=True to download it')
110 | with open(path, 'rb') as infile:
111 | if sys.version_info[0] == 2:
112 | data = pickle.load(infile)
113 | else:
114 | data = pickle.load(infile, encoding='latin1')
115 | self.classes = data[self.meta['key']]
116 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
117 |
118 | def __getitem__(self, index):
119 | """
120 | Args:
121 | index (int): Index
122 |
123 | Returns:
124 | tuple: (image, target) where target is index of the target class.
125 | """
126 | img, target = self.data[index], self.targets[index]
127 |
128 | # doing this so that it is consistent with all other datasets
129 | # to return a PIL Image
130 | img = Image.fromarray(img)
131 |
132 | if self.transform is not None:
133 | img = self.transform(img)
134 |
135 | if self.target_transform is not None:
136 | target = self.target_transform(target)
137 |
138 | return img, target
139 |
140 | def __len__(self):
141 | return len(self.data)
142 |
143 | def _check_integrity(self):
144 | root = self.root
145 | for fentry in (self.train_list + self.test_list):
146 | filename, md5 = fentry[0], fentry[1]
147 | fpath = os.path.join(root, self.base_folder, filename)
148 | if not check_integrity(fpath, md5):
149 | return False
150 | return True
151 |
152 | def download(self):
153 | if self._check_integrity():
154 | print('Files already downloaded and verified')
155 | return
156 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
157 |
158 | def extra_repr(self):
159 | return "Split: {}".format("Train" if self.train is True else "Test")
160 |
--------------------------------------------------------------------------------
/cv/dataset/imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | get data loaders
3 | """
4 | from __future__ import print_function
5 |
6 | import os
7 | import socket
8 | import numpy as np
9 | from torch.utils.data import DataLoader
10 | from torchvision import datasets
11 | from torchvision import transforms
12 |
13 |
14 | def get_data_folder():
15 | """
16 | return server-dependent path to store the data
17 | """
18 | hostname = socket.gethostname()
19 | if hostname.startswith('visiongpu'):
20 | data_folder = '/data/vision/phillipi/rep-learn/datasets/imagenet'
21 | elif hostname.startswith('yonglong-home'):
22 | data_folder = '/home/yonglong/Data/data/imagenet'
23 | else:
24 | data_folder = './data/imagenet'
25 |
26 | if not os.path.isdir(data_folder):
27 | os.makedirs(data_folder)
28 |
29 | return data_folder
30 |
31 |
32 | class ImageFolderInstance(datasets.ImageFolder):
33 | """: Folder datasets which returns the index of the image as well::
34 | """
35 | def __getitem__(self, index):
36 | """
37 | Args:
38 | index (int): Index
39 | Returns:
40 | tuple: (image, target) where target is class_index of the target class.
41 | """
42 | path, target = self.imgs[index]
43 | img = self.loader(path)
44 | if self.transform is not None:
45 | img = self.transform(img)
46 | if self.target_transform is not None:
47 | target = self.target_transform(target)
48 |
49 | return img, target, index
50 |
51 |
52 | class ImageFolderSample(datasets.ImageFolder):
53 | """: Folder datasets which returns (img, label, index, contrast_index):
54 | """
55 | def __init__(self, root, transform=None, target_transform=None,
56 | is_sample=False, k=4096):
57 | super().__init__(root=root, transform=transform, target_transform=target_transform)
58 |
59 | self.k = k
60 | self.is_sample = is_sample
61 |
62 | print('stage1 finished!')
63 |
64 | if self.is_sample:
65 | num_classes = len(self.classes)
66 | num_samples = len(self.samples)
67 | label = np.zeros(num_samples, dtype=np.int32)
68 | for i in range(num_samples):
69 | path, target = self.imgs[i]
70 | label[i] = target
71 |
72 | self.cls_positive = [[] for i in range(num_classes)]
73 | for i in range(num_samples):
74 | self.cls_positive[label[i]].append(i)
75 |
76 | self.cls_negative = [[] for i in range(num_classes)]
77 | for i in range(num_classes):
78 | for j in range(num_classes):
79 | if j == i:
80 | continue
81 | self.cls_negative[i].extend(self.cls_positive[j])
82 |
83 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
84 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
85 |
86 | print('dataset initialized!')
87 |
88 | def __getitem__(self, index):
89 | """
90 | Args:
91 | index (int): Index
92 | Returns:
93 | tuple: (image, target) where target is class_index of the target class.
94 | """
95 | path, target = self.imgs[index]
96 | img = self.loader(path)
97 | if self.transform is not None:
98 | img = self.transform(img)
99 | if self.target_transform is not None:
100 | target = self.target_transform(target)
101 |
102 | if self.is_sample:
103 | # sample contrastive examples
104 | pos_idx = index
105 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
106 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
107 | return img, target, index, sample_idx
108 | else:
109 | return img, target, index
110 |
111 |
112 | def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8):
113 | """get the test data loader"""
114 |
115 | if dataset == 'imagenet':
116 | data_folder = get_data_folder()
117 | else:
118 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
119 |
120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
121 | std=[0.229, 0.224, 0.225])
122 | test_transform = transforms.Compose([
123 | transforms.Resize(256),
124 | transforms.CenterCrop(224),
125 | transforms.ToTensor(),
126 | normalize,
127 | ])
128 |
129 | test_folder = os.path.join(data_folder, 'val')
130 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
131 | test_loader = DataLoader(test_set,
132 | batch_size=batch_size,
133 | shuffle=False,
134 | num_workers=num_workers,
135 | pin_memory=True)
136 |
137 | return test_loader
138 |
139 |
140 | def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, is_sample=False, k=4096):
141 | """Data Loader for ImageNet"""
142 |
143 | if dataset == 'imagenet':
144 | data_folder = get_data_folder()
145 | else:
146 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
147 |
148 | # add data transform
149 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
150 | std=[0.229, 0.224, 0.225])
151 | train_transform = transforms.Compose([
152 | transforms.RandomResizedCrop(224),
153 | transforms.RandomHorizontalFlip(),
154 | transforms.ToTensor(),
155 | normalize,
156 | ])
157 | test_transform = transforms.Compose([
158 | transforms.Resize(256),
159 | transforms.CenterCrop(224),
160 | transforms.ToTensor(),
161 | normalize,
162 | ])
163 | train_folder = os.path.join(data_folder, 'train')
164 | test_folder = os.path.join(data_folder, 'val')
165 |
166 | train_set = ImageFolderSample(train_folder, transform=train_transform, is_sample=is_sample, k=k)
167 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
168 |
169 | train_loader = DataLoader(train_set,
170 | batch_size=batch_size,
171 | shuffle=True,
172 | num_workers=num_workers,
173 | pin_memory=True)
174 | test_loader = DataLoader(test_set,
175 | batch_size=batch_size,
176 | shuffle=False,
177 | num_workers=num_workers,
178 | pin_memory=True)
179 |
180 | print('num_samples', len(train_set.samples))
181 | print('num_class', len(train_set.classes))
182 |
183 | return train_loader, test_loader, len(train_set), len(train_set.classes)
184 |
185 |
186 | def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, is_instance=False):
187 | """
188 | Data Loader for imagenet
189 | """
190 | if dataset == 'imagenet':
191 | data_folder = get_data_folder()
192 | else:
193 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
194 |
195 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
196 | std=[0.229, 0.224, 0.225])
197 | train_transform = transforms.Compose([
198 | transforms.RandomResizedCrop(224),
199 | transforms.RandomHorizontalFlip(),
200 | transforms.ToTensor(),
201 | normalize,
202 | ])
203 | test_transform = transforms.Compose([
204 | transforms.Resize(256),
205 | transforms.CenterCrop(224),
206 | transforms.ToTensor(),
207 | normalize,
208 | ])
209 |
210 | train_folder = os.path.join(data_folder, 'train')
211 | test_folder = os.path.join(data_folder, 'val')
212 |
213 | if is_instance:
214 | train_set = ImageFolderInstance(train_folder, transform=train_transform)
215 | n_data = len(train_set)
216 | else:
217 | train_set = datasets.ImageFolder(train_folder, transform=train_transform)
218 |
219 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
220 |
221 | train_loader = DataLoader(train_set,
222 | batch_size=batch_size,
223 | shuffle=True,
224 | num_workers=num_workers,
225 | pin_memory=True)
226 |
227 | test_loader = DataLoader(test_set,
228 | batch_size=batch_size,
229 | shuffle=False,
230 | num_workers=num_workers//2,
231 | pin_memory=True)
232 |
233 | if is_instance:
234 | return train_loader, test_loader, n_data
235 | else:
236 | return train_loader, test_loader
237 |
--------------------------------------------------------------------------------
/cv/dataset/meta_cifar100.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import socket
5 | import numpy as np
6 | from torch.utils.data import DataLoader, RandomSampler
7 | from torchvision import transforms
8 | from .cifar_with_held import CIFAR100WithHeld
9 | from PIL import Image
10 |
11 | """
12 | mean = {
13 | 'cifar100': (0.5071, 0.4867, 0.4408),
14 | }
15 |
16 | std = {
17 | 'cifar100': (0.2675, 0.2565, 0.2761),
18 | }
19 | """
20 |
21 |
22 | def get_data_folder():
23 | """
24 | return server-dependent path to store the data
25 | """
26 | hostname = socket.gethostname()
27 | if hostname.startswith('visiongpu'):
28 | data_folder = '/data/vision/phillipi/rep-learn/datasets'
29 | elif hostname.startswith('yonglong-home'):
30 | data_folder = '/home/yonglong/Data/data'
31 | else:
32 | data_folder = './data/'
33 |
34 | if not os.path.isdir(data_folder):
35 | os.makedirs(data_folder)
36 |
37 | return data_folder
38 |
39 |
40 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, held_size=0, num_held_samples=0):
41 | """
42 | cifar 100
43 | """
44 | data_folder = get_data_folder()
45 |
46 | train_transform = transforms.Compose([
47 | transforms.RandomCrop(32, padding=4),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
51 | ])
52 | test_transform = transforms.Compose([
53 | transforms.ToTensor(),
54 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
55 | ])
56 |
57 | train_set = CIFAR100WithHeld(root=data_folder,
58 | download=True,
59 | train=True,
60 | held=False,
61 | held_samples=held_size,
62 | transform=train_transform)
63 | train_loader = DataLoader(train_set,
64 | batch_size=batch_size,
65 | shuffle=True,
66 | num_workers=num_workers)
67 |
68 | held_set = CIFAR100WithHeld(root=data_folder,
69 | download=True,
70 | train=True,
71 | held=True,
72 | held_samples=held_size,
73 | transform=train_transform)
74 |
75 | if num_held_samples == 0:
76 | held_loader = DataLoader(held_set,
77 | batch_size=batch_size,
78 | shuffle=True,
79 | num_workers=num_workers)
80 | else:
81 | held_sampler = RandomSampler(held_set,
82 | replacement=True,
83 | num_samples=num_held_samples)
84 | held_loader = DataLoader(held_set,
85 | sampler=held_sampler,
86 | batch_size=batch_size)
87 |
88 | test_set = CIFAR100WithHeld(root=data_folder,
89 | download=True,
90 | train=False,
91 | transform=test_transform)
92 | test_loader = DataLoader(test_set,
93 | batch_size=int(batch_size/2),
94 | shuffle=False,
95 | num_workers=int(num_workers/2))
96 |
97 | return train_loader, held_loader, test_loader
98 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/AB.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ABLoss(nn.Module):
8 | """Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
9 | code: https://github.com/bhheo/AB_distillation
10 | """
11 | def __init__(self, feat_num, margin=1.0):
12 | super(ABLoss, self).__init__()
13 | self.w = [2**(i-feat_num+1) for i in range(feat_num)]
14 | self.margin = margin
15 |
16 | def forward(self, g_s, g_t):
17 | bsz = g_s[0].shape[0]
18 | losses = [self.criterion_alternative_l2(s, t) for s, t in zip(g_s, g_t)]
19 | losses = [w * l for w, l in zip(self.w, losses)]
20 | # loss = sum(losses) / bsz
21 | # loss = loss / 1000 * 3
22 | losses = [l / bsz for l in losses]
23 | losses = [l / 1000 * 3 for l in losses]
24 | return losses
25 |
26 | def criterion_alternative_l2(self, source, target):
27 | loss = ((source + self.margin) ** 2 * ((source > -self.margin) & (target <= 0)).float() +
28 | (source - self.margin) ** 2 * ((source <= self.margin) & (target > 0)).float())
29 | return torch.abs(loss).sum()
30 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/AT.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Attention(nn.Module):
8 | """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
9 | via Attention Transfer
10 | code: https://github.com/szagoruyko/attention-transfer"""
11 | def __init__(self, p=2):
12 | super(Attention, self).__init__()
13 | self.p = p
14 |
15 | def forward(self, g_s, g_t):
16 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
17 |
18 | def at_loss(self, f_s, f_t):
19 | s_H, t_H = f_s.shape[2], f_t.shape[2]
20 | if s_H > t_H:
21 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
22 | elif s_H < t_H:
23 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
24 | else:
25 | pass
26 | return (self.at(f_s) - self.at(f_t)).pow(2).mean()
27 |
28 | def at(self, f):
29 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
30 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/CC.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Correlation(nn.Module):
8 | """Correlation Congruence for Knowledge Distillation, ICCV 2019.
9 | The authors nicely shared the code with me. I restructured their code to be
10 | compatible with my running framework. Credits go to the original author"""
11 | def __init__(self):
12 | super(Correlation, self).__init__()
13 |
14 | def forward(self, f_s, f_t):
15 | delta = torch.abs(f_s - f_t)
16 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
17 | return loss
18 |
19 |
20 | # class Correlation(nn.Module):
21 | # """Similarity-preserving loss. My origianl own reimplementation
22 | # based on the paper before emailing the original authors."""
23 | # def __init__(self):
24 | # super(Correlation, self).__init__()
25 | #
26 | # def forward(self, f_s, f_t):
27 | # return self.similarity_loss(f_s, f_t)
28 | # # return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
29 | #
30 | # def similarity_loss(self, f_s, f_t):
31 | # bsz = f_s.shape[0]
32 | # f_s = f_s.view(bsz, -1)
33 | # f_t = f_t.view(bsz, -1)
34 | #
35 | # G_s = torch.mm(f_s, torch.t(f_s))
36 | # G_s = G_s / G_s.norm(2)
37 | # G_t = torch.mm(f_t, torch.t(f_t))
38 | # G_t = G_t / G_t.norm(2)
39 | #
40 | # G_diff = G_t - G_s
41 | # loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
42 | # return loss
43 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/FSP.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FSP(nn.Module):
9 | """A Gift from Knowledge Distillation:
10 | Fast Optimization, Network Minimization and Transfer Learning"""
11 | def __init__(self, s_shapes, t_shapes):
12 | super(FSP, self).__init__()
13 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
14 | s_c = [s[1] for s in s_shapes]
15 | t_c = [t[1] for t in t_shapes]
16 | if np.any(np.asarray(s_c) != np.asarray(t_c)):
17 | raise ValueError('num of channels not equal (error in FSP)')
18 |
19 | def forward(self, g_s, g_t):
20 | s_fsp = self.compute_fsp(g_s)
21 | t_fsp = self.compute_fsp(g_t)
22 | loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
23 | return loss_group
24 |
25 | @staticmethod
26 | def compute_loss(s, t):
27 | return (s - t).pow(2).mean()
28 |
29 | @staticmethod
30 | def compute_fsp(g):
31 | fsp_list = []
32 | for i in range(len(g) - 1):
33 | bot, top = g[i], g[i + 1]
34 | b_H, t_H = bot.shape[2], top.shape[2]
35 | if b_H > t_H:
36 | bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
37 | elif b_H < t_H:
38 | top = F.adaptive_avg_pool2d(top, (b_H, b_H))
39 | else:
40 | pass
41 | bot = bot.unsqueeze(1)
42 | top = top.unsqueeze(2)
43 | bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
44 | top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)
45 |
46 | fsp = (bot * top).mean(-1)
47 | fsp_list.append(fsp)
48 | return fsp_list
49 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/FT.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class FactorTransfer(nn.Module):
8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer, NeurIPS 2018"""
9 | def __init__(self, p1=2, p2=1):
10 | super(FactorTransfer, self).__init__()
11 | self.p1 = p1
12 | self.p2 = p2
13 |
14 | def forward(self, f_s, f_t):
15 | return self.factor_loss(f_s, f_t)
16 |
17 | def factor_loss(self, f_s, f_t):
18 | s_H, t_H = f_s.shape[2], f_t.shape[2]
19 | if s_H > t_H:
20 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
21 | elif s_H < t_H:
22 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
23 | else:
24 | pass
25 | if self.p2 == 1:
26 | return (self.factor(f_s) - self.factor(f_t)).abs().mean()
27 | else:
28 | return (self.factor(f_s) - self.factor(f_t)).pow(self.p2).mean()
29 |
30 | def factor(self, f):
31 | return F.normalize(f.pow(self.p1).mean(1).view(f.size(0), -1))
32 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/FitNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class HintLoss(nn.Module):
7 | """Fitnets: hints for thin deep nets, ICLR 2015"""
8 | def __init__(self):
9 | super(HintLoss, self).__init__()
10 | self.crit = nn.MSELoss()
11 |
12 | def forward(self, f_s, f_t):
13 | loss = self.crit(f_s, f_t)
14 | return loss
15 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/KD.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 |
8 | class DistillKL(nn.Module):
9 | """Distilling the Knowledge in a Neural Network"""
10 | def __init__(self, T):
11 | super(DistillKL, self).__init__()
12 | self.T = T
13 |
14 | def forward(self, y_s, y_t):
15 | p_s = F.log_softmax(y_s/self.T, dim=1)
16 | p_t = F.softmax(y_t/self.T, dim=1)
17 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
18 | return loss
19 |
20 |
21 | class CustomDistillKL(nn.Module):
22 | """Distilling the Knowledge in a Neural Network"""
23 | def __init__(self, T):
24 | super(CustomDistillKL, self).__init__()
25 | self.T = T
26 |
27 | def forward(self, y_s, y_t):
28 | p_s = F.log_softmax(y_s/self.T, dim=1)
29 | p_t = F.softmax(y_t/self.T, dim=1)
30 | loss = (p_t * (p_t.log() - p_s)).sum(dim=1).mean(dim=0) * (self.T ** 2)
31 | return loss
32 |
33 |
34 | if __name__ == '__main__':
35 | kl_1 = DistillKL(3)
36 | kl_2 = CustomDistillKL(3)
37 | student = torch.tensor([[3., 7., 1.], [4., 8., 5.], [0.25, 0.76, 0.2], [0.42, 0.99, 0.8]])
38 | teacher = torch.tensor([[4., 6., 5.], [2., 7., 6.], [0.28, 0.94, 0.1], [0.49, 0.75, 0.0]])
39 | print(kl_1(student, teacher), kl_2(student, teacher))
40 | assert kl_1(student, teacher) - kl_2(student, teacher) < 0.000001
41 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/KDSVD.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class KDSVD(nn.Module):
9 | """
10 | Self-supervised Knowledge Distillation using Singular Value Decomposition
11 | original Tensorflow code: https://github.com/sseung0703/SSKD_SVD
12 | """
13 | def __init__(self, k=1):
14 | super(KDSVD, self).__init__()
15 | self.k = k
16 |
17 | def forward(self, g_s, g_t):
18 | v_sb = None
19 | v_tb = None
20 | losses = []
21 | for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t):
22 |
23 | u_t, s_t, v_t = self.svd(f_t, self.k)
24 | u_s, s_s, v_s = self.svd(f_s, self.k + 3)
25 | v_s, v_t = self.align_rsv(v_s, v_t)
26 | s_t = s_t.unsqueeze(1)
27 | v_t = v_t * s_t
28 | v_s = v_s * s_t
29 |
30 | if i > 0:
31 | s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8)
32 | t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8)
33 |
34 | l2loss = (s_rbf - t_rbf.detach()).pow(2)
35 | l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss))
36 | losses.append(l2loss.sum())
37 |
38 | v_tb = v_t
39 | v_sb = v_s
40 |
41 | bsz = g_s[0].shape[0]
42 | losses = [l / bsz for l in losses]
43 | return losses
44 |
45 | def svd(self, feat, n=1):
46 | size = feat.shape
47 | assert len(size) == 4
48 |
49 | x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1)
50 | u, s, v = torch.svd(x)
51 |
52 | u = self.removenan(u)
53 | s = self.removenan(s)
54 | v = self.removenan(v)
55 |
56 | if n > 0:
57 | u = F.normalize(u[:, :, :n], dim=1)
58 | s = F.normalize(s[:, :n], dim=1)
59 | v = F.normalize(v[:, :, :n], dim=1)
60 |
61 | return u, s, v
62 |
63 | @staticmethod
64 | def removenan(x):
65 | x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
66 | return x
67 |
68 | @staticmethod
69 | def align_rsv(a, b):
70 | cosine = torch.matmul(a.transpose(-2, -1), b)
71 | max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True)
72 | mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)),
73 | torch.sign(cosine), torch.zeros_like(cosine))
74 | a = torch.matmul(a, mask)
75 | return a, b
76 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/MSE.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 |
8 | class MSEWithTemperature(nn.Module):
9 | """Distilling the Knowledge in a Neural Network"""
10 | def __init__(self, T):
11 | super(MSEWithTemperature, self).__init__()
12 | self.T = T
13 |
14 | def forward(self, y_s, y_t):
15 | p_s = y_s / self.T
16 | p_t = y_t / self.T
17 | loss = F.mse_loss(p_s, p_t)
18 | return loss
19 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/NST.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class NSTLoss(nn.Module):
8 | """like what you like: knowledge distill via neuron selectivity transfer"""
9 | def __init__(self):
10 | super(NSTLoss, self).__init__()
11 | pass
12 |
13 | def forward(self, g_s, g_t):
14 | return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
15 |
16 | def nst_loss(self, f_s, f_t):
17 | s_H, t_H = f_s.shape[2], f_t.shape[2]
18 | if s_H > t_H:
19 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
20 | elif s_H < t_H:
21 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
22 | else:
23 | pass
24 |
25 | f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
26 | f_s = F.normalize(f_s, dim=2)
27 | f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
28 | f_t = F.normalize(f_t, dim=2)
29 |
30 | # set full_loss as False to avoid unnecessary computation
31 | full_loss = True
32 | if full_loss:
33 | return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
34 | - 2 * self.poly_kernel(f_s, f_t).mean())
35 | else:
36 | return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()
37 |
38 | def poly_kernel(self, a, b):
39 | a = a.unsqueeze(1)
40 | b = b.unsqueeze(2)
41 | res = (a * b).sum(-1).pow(2)
42 | return res
--------------------------------------------------------------------------------
/cv/distiller_zoo/PKT.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PKT(nn.Module):
8 | """Probabilistic Knowledge Transfer for deep representation learning
9 | Code from author: https://github.com/passalis/probabilistic_kt"""
10 | def __init__(self):
11 | super(PKT, self).__init__()
12 |
13 | def forward(self, f_s, f_t):
14 | return self.cosine_similarity_loss(f_s, f_t)
15 |
16 | @staticmethod
17 | def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
18 | # Normalize each vector by its norm
19 | output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
20 | output_net = output_net / (output_net_norm + eps)
21 | output_net[output_net != output_net] = 0
22 |
23 | target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
24 | target_net = target_net / (target_net_norm + eps)
25 | target_net[target_net != target_net] = 0
26 |
27 | # Calculate the cosine similarity
28 | model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
29 | target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
30 |
31 | # Scale cosine similarity to 0..1
32 | model_similarity = (model_similarity + 1.0) / 2.0
33 | target_similarity = (target_similarity + 1.0) / 2.0
34 |
35 | # Transform them into probabilities
36 | model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
37 | target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
38 |
39 | # Calculate the KL-divergence
40 | loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
41 |
42 | return loss
43 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/RKD.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class RKDLoss(nn.Module):
9 | """Relational Knowledge Disitllation, CVPR2019"""
10 | def __init__(self, w_d=25, w_a=50):
11 | super(RKDLoss, self).__init__()
12 | self.w_d = w_d
13 | self.w_a = w_a
14 |
15 | def forward(self, f_s, f_t):
16 | student = f_s.view(f_s.shape[0], -1)
17 | teacher = f_t.view(f_t.shape[0], -1)
18 |
19 | # RKD distance loss
20 | with torch.no_grad():
21 | t_d = self.pdist(teacher, squared=False)
22 | mean_td = t_d[t_d > 0].mean()
23 | t_d = t_d / mean_td
24 |
25 | d = self.pdist(student, squared=False)
26 | mean_d = d[d > 0].mean()
27 | d = d / mean_d
28 |
29 | loss_d = F.smooth_l1_loss(d, t_d)
30 |
31 | # RKD Angle loss
32 | with torch.no_grad():
33 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
34 | norm_td = F.normalize(td, p=2, dim=2)
35 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
36 |
37 | sd = (student.unsqueeze(0) - student.unsqueeze(1))
38 | norm_sd = F.normalize(sd, p=2, dim=2)
39 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
40 |
41 | loss_a = F.smooth_l1_loss(s_angle, t_angle)
42 |
43 | loss = self.w_d * loss_d + self.w_a * loss_a
44 |
45 | return loss
46 |
47 | @staticmethod
48 | def pdist(e, squared=False, eps=1e-12):
49 | e_square = e.pow(2).sum(dim=1)
50 | prod = e @ e.t()
51 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
52 |
53 | if not squared:
54 | res = res.sqrt()
55 |
56 | res = res.clone()
57 | res[range(len(e)), range(len(e))] = 0
58 | return res
59 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/SP.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Similarity(nn.Module):
9 | """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
10 | def __init__(self):
11 | super(Similarity, self).__init__()
12 |
13 | def forward(self, g_s, g_t):
14 | return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
15 |
16 | def similarity_loss(self, f_s, f_t):
17 | bsz = f_s.shape[0]
18 | f_s = f_s.view(bsz, -1)
19 | f_t = f_t.view(bsz, -1)
20 |
21 | G_s = torch.mm(f_s, torch.t(f_s))
22 | # G_s = G_s / G_s.norm(2)
23 | G_s = torch.nn.functional.normalize(G_s)
24 | G_t = torch.mm(f_t, torch.t(f_t))
25 | # G_t = G_t / G_t.norm(2)
26 | G_t = torch.nn.functional.normalize(G_t)
27 |
28 | G_diff = G_t - G_s
29 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
30 | return loss
31 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/VID.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | class VIDLoss(nn.Module):
10 | """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
11 | code from author: https://github.com/ssahn0215/variational-information-distillation"""
12 | def __init__(self,
13 | num_input_channels,
14 | num_mid_channel,
15 | num_target_channels,
16 | init_pred_var=5.0,
17 | eps=1e-5):
18 | super(VIDLoss, self).__init__()
19 |
20 | def conv1x1(in_channels, out_channels, stride=1):
21 | return nn.Conv2d(
22 | in_channels, out_channels,
23 | kernel_size=1, padding=0,
24 | bias=False, stride=stride)
25 |
26 | self.regressor = nn.Sequential(
27 | conv1x1(num_input_channels, num_mid_channel),
28 | nn.ReLU(),
29 | conv1x1(num_mid_channel, num_mid_channel),
30 | nn.ReLU(),
31 | conv1x1(num_mid_channel, num_target_channels),
32 | )
33 | self.log_scale = torch.nn.Parameter(
34 | np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
35 | )
36 | self.eps = eps
37 |
38 | def forward(self, input, target):
39 | # pool for dimentsion match
40 | s_H, t_H = input.shape[2], target.shape[2]
41 | if s_H > t_H:
42 | input = F.adaptive_avg_pool2d(input, (t_H, t_H))
43 | elif s_H < t_H:
44 | target = F.adaptive_avg_pool2d(target, (s_H, s_H))
45 | else:
46 | pass
47 | pred_mean = self.regressor(input)
48 | pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
49 | pred_var = pred_var.view(1, -1, 1, 1)
50 | neg_log_prob = 0.5*(
51 | (pred_mean-target)**2/pred_var+torch.log(pred_var)
52 | )
53 | loss = torch.mean(neg_log_prob)
54 | return loss
55 |
--------------------------------------------------------------------------------
/cv/distiller_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .AB import ABLoss
2 | from .AT import Attention
3 | from .CC import Correlation
4 | from .FitNet import HintLoss
5 | from .FSP import FSP
6 | from .FT import FactorTransfer
7 | from .KD import DistillKL
8 | from .KDSVD import KDSVD
9 | from .NST import NSTLoss
10 | from .PKT import PKT
11 | from .RKD import RKDLoss
12 | from .SP import Similarity
13 | from .VID import VIDLoss
14 |
--------------------------------------------------------------------------------
/cv/helper/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JetRunner/MetaDistil/80e60c11de531b10d1f06ceb2b71c70665bb6aff/cv/helper/__init__.py
--------------------------------------------------------------------------------
/cv/helper/loops.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import sys
4 | import time
5 | import torch
6 |
7 | from .util import AverageMeter, accuracy
8 |
9 |
10 | def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
11 | """vanilla training"""
12 | model.train()
13 |
14 | batch_time = AverageMeter()
15 | data_time = AverageMeter()
16 | losses = AverageMeter()
17 | top1 = AverageMeter()
18 | top5 = AverageMeter()
19 |
20 | end = time.time()
21 | for idx, (input, target) in enumerate(train_loader):
22 | data_time.update(time.time() - end)
23 |
24 | input = input.float()
25 | if torch.cuda.is_available():
26 | input = input.cuda()
27 | target = target.cuda()
28 |
29 | # ===================forward=====================
30 | output = model(input)
31 | loss = criterion(output, target)
32 |
33 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
34 | losses.update(loss.item(), input.size(0))
35 | top1.update(acc1[0], input.size(0))
36 | top5.update(acc5[0], input.size(0))
37 |
38 | # ===================backward=====================
39 | optimizer.zero_grad()
40 | loss.backward()
41 | optimizer.step()
42 |
43 | # ===================meters=====================
44 | batch_time.update(time.time() - end)
45 | end = time.time()
46 |
47 | # tensorboard logger
48 | pass
49 |
50 | # print info
51 | if idx % opt.print_freq == 0:
52 | print('Epoch: [{0}][{1}/{2}]\t'
53 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
54 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
55 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
56 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
57 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
58 | epoch, idx, len(train_loader), batch_time=batch_time,
59 | data_time=data_time, loss=losses, top1=top1, top5=top5))
60 | sys.stdout.flush()
61 |
62 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
63 | .format(top1=top1, top5=top5))
64 |
65 | return top1.avg, losses.avg
66 |
67 |
68 | def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt):
69 | """One epoch distillation"""
70 | # set modules as train()
71 | for module in module_list:
72 | module.train()
73 | # set teacher as eval()
74 | module_list[-1].eval()
75 |
76 | if opt.distill == 'abound':
77 | module_list[1].eval()
78 | elif opt.distill == 'factor':
79 | module_list[2].eval()
80 |
81 | criterion_cls = criterion_list[0]
82 | criterion_div = criterion_list[1]
83 | criterion_kd = criterion_list[2]
84 |
85 | model_s = module_list[0]
86 | model_t = module_list[-1]
87 |
88 | batch_time = AverageMeter()
89 | data_time = AverageMeter()
90 | losses = AverageMeter()
91 | top1 = AverageMeter()
92 | top5 = AverageMeter()
93 |
94 | end = time.time()
95 | for idx, data in enumerate(train_loader):
96 | if opt.distill in ['crd']:
97 | input, target, index, contrast_idx = data
98 | else:
99 | input, target, index = data
100 | data_time.update(time.time() - end)
101 |
102 | input = input.float()
103 | if torch.cuda.is_available():
104 | input = input.cuda()
105 | target = target.cuda()
106 | index = index.cuda()
107 | if opt.distill in ['crd']:
108 | contrast_idx = contrast_idx.cuda()
109 |
110 | # ===================forward=====================
111 | preact = False
112 | if opt.distill in ['abound']:
113 | preact = True
114 | feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
115 | with torch.no_grad():
116 | feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
117 | feat_t = [f.detach() for f in feat_t]
118 |
119 | # cls + kl div
120 | loss_cls = criterion_cls(logit_s, target)
121 | loss_div = criterion_div(logit_s, logit_t)
122 |
123 | # other kd beyond KL divergence
124 | if opt.distill == 'kd':
125 | loss_kd = 0
126 | elif opt.distill == 'hint':
127 | f_s = module_list[1](feat_s[opt.hint_layer])
128 | f_t = feat_t[opt.hint_layer]
129 | loss_kd = criterion_kd(f_s, f_t)
130 | elif opt.distill == 'crd':
131 | f_s = feat_s[-1]
132 | f_t = feat_t[-1]
133 | loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
134 | elif opt.distill == 'attention':
135 | g_s = feat_s[1:-1]
136 | g_t = feat_t[1:-1]
137 | loss_group = criterion_kd(g_s, g_t)
138 | loss_kd = sum(loss_group)
139 | elif opt.distill == 'nst':
140 | g_s = feat_s[1:-1]
141 | g_t = feat_t[1:-1]
142 | loss_group = criterion_kd(g_s, g_t)
143 | loss_kd = sum(loss_group)
144 | elif opt.distill == 'similarity':
145 | g_s = [feat_s[-2]]
146 | g_t = [feat_t[-2]]
147 | loss_group = criterion_kd(g_s, g_t)
148 | loss_kd = sum(loss_group)
149 | elif opt.distill == 'rkd':
150 | f_s = feat_s[-1]
151 | f_t = feat_t[-1]
152 | loss_kd = criterion_kd(f_s, f_t)
153 | elif opt.distill == 'pkt':
154 | f_s = feat_s[-1]
155 | f_t = feat_t[-1]
156 | loss_kd = criterion_kd(f_s, f_t)
157 | elif opt.distill == 'kdsvd':
158 | g_s = feat_s[1:-1]
159 | g_t = feat_t[1:-1]
160 | loss_group = criterion_kd(g_s, g_t)
161 | loss_kd = sum(loss_group)
162 | elif opt.distill == 'correlation':
163 | f_s = module_list[1](feat_s[-1])
164 | f_t = module_list[2](feat_t[-1])
165 | loss_kd = criterion_kd(f_s, f_t)
166 | elif opt.distill == 'vid':
167 | g_s = feat_s[1:-1]
168 | g_t = feat_t[1:-1]
169 | loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
170 | loss_kd = sum(loss_group)
171 | elif opt.distill == 'abound':
172 | # can also add loss to this stage
173 | loss_kd = 0
174 | elif opt.distill == 'fsp':
175 | # can also add loss to this stage
176 | loss_kd = 0
177 | elif opt.distill == 'factor':
178 | factor_s = module_list[1](feat_s[-2])
179 | factor_t = module_list[2](feat_t[-2], is_factor=True)
180 | loss_kd = criterion_kd(factor_s, factor_t)
181 | else:
182 | raise NotImplementedError(opt.distill)
183 |
184 | loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
185 |
186 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
187 | losses.update(loss.item(), input.size(0))
188 | top1.update(acc1[0], input.size(0))
189 | top5.update(acc5[0], input.size(0))
190 |
191 | # ===================backward=====================
192 | optimizer.zero_grad()
193 | loss.backward()
194 | optimizer.step()
195 |
196 | # ===================meters=====================
197 | batch_time.update(time.time() - end)
198 | end = time.time()
199 |
200 | # print info
201 | if idx % opt.print_freq == 0:
202 | print('Epoch: [{0}][{1}/{2}]\t'
203 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
204 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
205 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
206 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
207 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
208 | epoch, idx, len(train_loader), batch_time=batch_time,
209 | data_time=data_time, loss=losses, top1=top1, top5=top5))
210 | sys.stdout.flush()
211 |
212 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
213 | .format(top1=top1, top5=top5))
214 |
215 | return top1.avg, losses.avg
216 |
217 |
218 | def validate(val_loader, model, criterion, opt):
219 | """validation"""
220 | batch_time = AverageMeter()
221 | losses = AverageMeter()
222 | top1 = AverageMeter()
223 | top5 = AverageMeter()
224 |
225 | # switch to evaluate mode
226 | model.eval()
227 |
228 | with torch.no_grad():
229 | end = time.time()
230 | for idx, (input, target) in enumerate(val_loader):
231 |
232 | input = input.float()
233 | if torch.cuda.is_available():
234 | input = input.cuda()
235 | target = target.cuda()
236 |
237 | # compute output
238 | output = model(input)
239 | loss = criterion(output, target)
240 |
241 | # measure accuracy and record loss
242 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
243 | losses.update(loss.item(), input.size(0))
244 | top1.update(acc1[0], input.size(0))
245 | top5.update(acc5[0], input.size(0))
246 |
247 | # measure elapsed time
248 | batch_time.update(time.time() - end)
249 | end = time.time()
250 |
251 | if idx % opt.print_freq == 0:
252 | print('Test: [{0}/{1}]\t'
253 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
255 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
256 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
257 | idx, len(val_loader), batch_time=batch_time, loss=losses,
258 | top1=top1, top5=top5))
259 |
260 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
261 | .format(top1=top1, top5=top5))
262 |
263 | return top1.avg, top5.avg, losses.avg
264 |
--------------------------------------------------------------------------------
/cv/helper/meta_loops.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import sys
4 | import time
5 | from collections import OrderedDict
6 |
7 | import torch
8 | from copy import deepcopy as cp
9 |
10 | from .util import AverageMeter, accuracy
11 |
12 |
13 | def train_distill(epoch, train_loader, held_loader, module_list, criterion_list, s_optimizer, t_optimizer, opt):
14 | """One epoch distillation"""
15 |
16 | criterion_cls = criterion_list[0]
17 | criterion_kd = criterion_list[1]
18 |
19 | s_model = module_list[0]
20 | t_model = module_list[-1]
21 |
22 | batch_time = AverageMeter()
23 |
24 | assume_losses = AverageMeter()
25 | assume_top1 = AverageMeter()
26 | assume_top5 = AverageMeter()
27 |
28 | held_losses = AverageMeter()
29 |
30 | real_losses = AverageMeter()
31 | real_top1 = AverageMeter()
32 | real_top5 = AverageMeter()
33 |
34 | end = time.time()
35 |
36 | total_steps_one_epoch = len(train_loader)
37 | batches_buffer = []
38 | round_counter = 0
39 |
40 | for d_idx, d_data in enumerate(train_loader):
41 |
42 | batches_buffer.append((d_idx, d_data))
43 |
44 | if (d_idx + 1) % opt.num_meta_batches != 0 and (d_idx + 1) != total_steps_one_epoch:
45 | continue
46 |
47 | #########################################
48 | # Step 1: Assume S' #
49 | #########################################
50 |
51 | # Time machine!
52 | fast_weights = OrderedDict((name, param) for (name, param) in s_model.named_parameters())
53 | s_model_backup_state_dict, s_optimizer_backup_state_dict = cp(s_model.state_dict()), cp(s_optimizer.state_dict())
54 |
55 | s_model.train()
56 | t_model.eval()
57 |
58 | for idx, data in batches_buffer:
59 |
60 | input, target = data
61 |
62 | input = input.float()
63 | if torch.cuda.is_available():
64 | input = input.cuda()
65 | target = target.cuda()
66 |
67 | logit_s = s_model(input, params=None if idx == 0 else fast_weights)
68 | logit_t = t_model(input)
69 |
70 | assume_loss_cls = criterion_cls(logit_s, target)
71 | assume_loss_kd = criterion_kd(logit_s, logit_t)
72 |
73 | assume_loss = opt.alpha * assume_loss_kd + (1 - opt.alpha) * assume_loss_cls
74 |
75 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
76 | assume_losses.update(assume_loss.item(), input.size(0))
77 | assume_top1.update(acc1[0], input.size(0))
78 | assume_top5.update(acc5[0], input.size(0))
79 |
80 | grads = torch.autograd.grad(assume_loss, s_model.parameters() if idx == 0 else fast_weights.values(),
81 | create_graph=True, retain_graph=True)
82 |
83 | fast_weights = OrderedDict(
84 | (name, param - opt.assume_s_step_size * grad) for ((name, param), grad) in
85 | zip(fast_weights.items(), grads))
86 |
87 | #########################################
88 | # Step 2: Train T with S' on HELD set #
89 | #########################################
90 |
91 | s_prime_loss = None
92 |
93 | t_model.train()
94 | held_batch_num = 0
95 |
96 | for idx, data in enumerate(held_loader):
97 | input, target = data
98 |
99 | input = input.float()
100 | if torch.cuda.is_available():
101 | input = input.cuda()
102 | target = target.cuda()
103 |
104 | logit_s_prime = s_model(input, params=fast_weights)
105 | s_prime_step_loss = criterion_cls(logit_s_prime, target)
106 |
107 | if s_prime_loss is None:
108 | s_prime_loss = s_prime_step_loss
109 | else:
110 | s_prime_loss += s_prime_step_loss
111 |
112 | held_batch_num += 1
113 |
114 | s_prime_loss /= held_batch_num
115 | t_grads = torch.autograd.grad(s_prime_loss, t_model.parameters())
116 |
117 | for p, gr in zip(t_model.parameters(), t_grads):
118 | p.grad = gr
119 |
120 | held_losses.update(s_prime_loss.item(), 1)
121 |
122 | t_optimizer.step()
123 |
124 | # Manual zero_grad
125 | for p in t_model.parameters():
126 | p.grad = None
127 |
128 | for p in s_model.parameters():
129 | p.grad = None
130 |
131 | del t_grads
132 | del grads
133 | del fast_weights
134 |
135 | #########################################
136 | # Step 3: Actually update S #
137 | #########################################
138 |
139 | # We use the Time Machine!
140 | s_model.load_state_dict(s_model_backup_state_dict)
141 | s_optimizer.load_state_dict(s_optimizer_backup_state_dict)
142 |
143 | del s_model_backup_state_dict, s_optimizer_backup_state_dict
144 |
145 | s_model.train()
146 | t_model.eval()
147 |
148 | for idx, data in batches_buffer:
149 |
150 | input, target = data
151 |
152 | input = input.float()
153 | if torch.cuda.is_available():
154 | input = input.cuda()
155 | target = target.cuda()
156 |
157 | logit_s = s_model(input)
158 | with torch.no_grad():
159 | logit_t = t_model(input)
160 |
161 | real_loss_cls = criterion_cls(logit_s, target)
162 | real_loss_kd = criterion_kd(logit_s, logit_t)
163 |
164 | real_loss = opt.alpha * real_loss_kd + (1 - opt.alpha) * real_loss_cls
165 |
166 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
167 | real_losses.update(real_loss.item(), input.size(0))
168 | real_top1.update(acc1[0], input.size(0))
169 | real_top5.update(acc5[0], input.size(0))
170 |
171 | real_loss.backward()
172 |
173 | s_optimizer.step()
174 | s_optimizer.zero_grad()
175 |
176 | round_counter += 1
177 | batches_buffer = []
178 |
179 | # ===================meters=====================
180 | batch_time.update(time.time() - end)
181 | end = time.time()
182 |
183 | # print info
184 | if round_counter % opt.print_freq == 0:
185 | print('Epoch: [{0}][{1}/{2}]\t'
186 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
187 | 'Real_Loss {loss.val:.4f} ({loss.avg:.4f})\t'
188 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
189 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
190 | epoch, idx, len(train_loader), batch_time=batch_time,
191 | loss=real_losses, top1=real_top1, top5=real_top5))
192 | sys.stdout.flush()
193 |
194 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
195 | .format(top1=real_top1, top5=real_top5))
196 |
197 | return real_top1.avg, real_losses.avg
198 |
199 |
200 | def validate(val_loader, model, criterion, opt):
201 | """validation"""
202 | batch_time = AverageMeter()
203 | losses = AverageMeter()
204 | top1 = AverageMeter()
205 | top5 = AverageMeter()
206 |
207 | # switch to evaluate mode
208 | model.eval()
209 |
210 | with torch.no_grad():
211 | end = time.time()
212 | for idx, (input, target) in enumerate(val_loader):
213 |
214 | input = input.float()
215 | if torch.cuda.is_available():
216 | input = input.cuda()
217 | target = target.cuda()
218 |
219 | # compute output
220 | output = model(input)
221 | loss = criterion(output, target)
222 |
223 | # measure accuracy and record loss
224 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
225 | losses.update(loss.item(), input.size(0))
226 | top1.update(acc1[0], input.size(0))
227 | top5.update(acc5[0], input.size(0))
228 |
229 | # measure elapsed time
230 | batch_time.update(time.time() - end)
231 | end = time.time()
232 |
233 | if idx % opt.print_freq == 0:
234 | print('Test: [{0}/{1}]\t'
235 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
236 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
237 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
238 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
239 | idx, len(val_loader), batch_time=batch_time, loss=losses,
240 | top1=top1, top5=top5))
241 |
242 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
243 | .format(top1=top1, top5=top5))
244 |
245 | return top1.avg, top5.avg, losses.avg
246 |
--------------------------------------------------------------------------------
/cv/helper/pretrain.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import time
4 | import sys
5 | import torch
6 | import torch.optim as optim
7 | import torch.backends.cudnn as cudnn
8 | from .util import AverageMeter
9 |
10 |
11 | def init(model_s, model_t, init_modules, criterion, train_loader, logger, opt):
12 | model_t.eval()
13 | model_s.eval()
14 | init_modules.train()
15 |
16 | if torch.cuda.is_available():
17 | model_s.cuda()
18 | model_t.cuda()
19 | init_modules.cuda()
20 | cudnn.benchmark = True
21 |
22 | if opt.model_s in ['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
23 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2'] and \
24 | opt.distill == 'factor':
25 | lr = 0.01
26 | else:
27 | lr = opt.learning_rate
28 | optimizer = optim.SGD(init_modules.parameters(),
29 | lr=lr,
30 | momentum=opt.momentum,
31 | weight_decay=opt.weight_decay)
32 |
33 | batch_time = AverageMeter()
34 | data_time = AverageMeter()
35 | losses = AverageMeter()
36 | for epoch in range(1, opt.init_epochs + 1):
37 | batch_time.reset()
38 | data_time.reset()
39 | losses.reset()
40 | end = time.time()
41 | for idx, data in enumerate(train_loader):
42 | if opt.distill in ['crd']:
43 | input, target, index, contrast_idx = data
44 | else:
45 | input, target, index = data
46 | data_time.update(time.time() - end)
47 |
48 | input = input.float()
49 | if torch.cuda.is_available():
50 | input = input.cuda()
51 | target = target.cuda()
52 | index = index.cuda()
53 | if opt.distill in ['crd']:
54 | contrast_idx = contrast_idx.cuda()
55 |
56 | # ============= forward ==============
57 | preact = (opt.distill == 'abound')
58 | feat_s, _ = model_s(input, is_feat=True, preact=preact)
59 | with torch.no_grad():
60 | feat_t, _ = model_t(input, is_feat=True, preact=preact)
61 | feat_t = [f.detach() for f in feat_t]
62 |
63 | if opt.distill == 'abound':
64 | g_s = init_modules[0](feat_s[1:-1])
65 | g_t = feat_t[1:-1]
66 | loss_group = criterion(g_s, g_t)
67 | loss = sum(loss_group)
68 | elif opt.distill == 'factor':
69 | f_t = feat_t[-2]
70 | _, f_t_rec = init_modules[0](f_t)
71 | loss = criterion(f_t_rec, f_t)
72 | elif opt.distill == 'fsp':
73 | loss_group = criterion(feat_s[:-1], feat_t[:-1])
74 | loss = sum(loss_group)
75 | else:
76 | raise NotImplemented('Not supported in init training: {}'.format(opt.distill))
77 |
78 | losses.update(loss.item(), input.size(0))
79 |
80 | # ===================backward=====================
81 | optimizer.zero_grad()
82 | loss.backward()
83 | optimizer.step()
84 |
85 | batch_time.update(time.time() - end)
86 | end = time.time()
87 |
88 | # end of epoch
89 | logger.log_value('init_train_loss', losses.avg, epoch)
90 | print('Epoch: [{0}/{1}]\t'
91 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
92 | 'losses: {losses.val:.3f} ({losses.avg:.3f})'.format(
93 | epoch, opt.init_epochs, batch_time=batch_time, losses=losses))
94 | sys.stdout.flush()
95 |
--------------------------------------------------------------------------------
/cv/helper/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def adjust_learning_rate_new(epoch, optimizer, LUT):
8 | """
9 | new learning rate schedule according to RotNet
10 | """
11 | lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1])
12 | for param_group in optimizer.param_groups:
13 | param_group['lr'] = lr
14 |
15 |
16 | def adjust_learning_rate(epoch, opt, optimizer):
17 | """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
18 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
19 | if steps > 0:
20 | new_lr = opt.lr * (opt.lr_decay_rate ** steps)
21 | for param_group in optimizer.param_groups:
22 | param_group['lr'] = new_lr
23 |
24 |
25 | class AverageMeter(object):
26 | """Computes and stores the average and current value"""
27 | def __init__(self):
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = 0
32 | self.avg = 0
33 | self.sum = 0
34 | self.count = 0
35 |
36 | def update(self, val, n=1):
37 | self.val = val
38 | self.sum += val * n
39 | self.count += n
40 | self.avg = self.sum / self.count
41 |
42 |
43 | def accuracy(output, target, topk=(1,)):
44 | """Computes the accuracy over the k top predictions for the specified values of k"""
45 | with torch.no_grad():
46 | maxk = max(topk)
47 | batch_size = target.size(0)
48 |
49 | _, pred = output.topk(maxk, 1, True, True)
50 | pred = pred.t()
51 | correct = pred.eq(target.view(1, -1).expand_as(pred))
52 |
53 | res = []
54 | for k in topk:
55 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
56 | res.append(correct_k.mul_(100.0 / batch_size))
57 | return res
58 |
59 |
60 | if __name__ == '__main__':
61 |
62 | pass
63 |
--------------------------------------------------------------------------------
/cv/models/ShuffleNetv1.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N,C,H,W = x.size()
17 | g = self.groups
18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False):
23 | super(Bottleneck, self).__init__()
24 | self.is_last = is_last
25 | self.stride = stride
26 |
27 | mid_planes = int(out_planes/4)
28 | g = 1 if in_planes == 24 else groups
29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
30 | self.bn1 = nn.BatchNorm2d(mid_planes)
31 | self.shuffle1 = ShuffleBlock(groups=g)
32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
33 | self.bn2 = nn.BatchNorm2d(mid_planes)
34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
35 | self.bn3 = nn.BatchNorm2d(out_planes)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride == 2:
39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.shuffle1(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | res = self.shortcut(x)
47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res
48 | out = F.relu(preact)
49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res)
50 | if self.is_last:
51 | return out, preact
52 | else:
53 | return out
54 |
55 |
56 | class ShuffleNet(nn.Module):
57 | def __init__(self, cfg, num_classes=10):
58 | super(ShuffleNet, self).__init__()
59 | out_planes = cfg['out_planes']
60 | num_blocks = cfg['num_blocks']
61 | groups = cfg['groups']
62 |
63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(24)
65 | self.in_planes = 24
66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
69 | self.linear = nn.Linear(out_planes[2], num_classes)
70 |
71 | def _make_layer(self, out_planes, num_blocks, groups):
72 | layers = []
73 | for i in range(num_blocks):
74 | stride = 2 if i == 0 else 1
75 | cat_planes = self.in_planes if i == 0 else 0
76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes,
77 | stride=stride,
78 | groups=groups,
79 | is_last=(i == num_blocks - 1)))
80 | self.in_planes = out_planes
81 | return nn.Sequential(*layers)
82 |
83 | def get_feat_modules(self):
84 | feat_m = nn.ModuleList([])
85 | feat_m.append(self.conv1)
86 | feat_m.append(self.bn1)
87 | feat_m.append(self.layer1)
88 | feat_m.append(self.layer2)
89 | feat_m.append(self.layer3)
90 | return feat_m
91 |
92 | def get_bn_before_relu(self):
93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher')
94 |
95 | def forward(self, x, is_feat=False, preact=False):
96 | out = F.relu(self.bn1(self.conv1(x)))
97 | f0 = out
98 | out, f1_pre = self.layer1(out)
99 | f1 = out
100 | out, f2_pre = self.layer2(out)
101 | f2 = out
102 | out, f3_pre = self.layer3(out)
103 | f3 = out
104 | out = F.avg_pool2d(out, 4)
105 | out = out.view(out.size(0), -1)
106 | f4 = out
107 | out = self.linear(out)
108 |
109 | if is_feat:
110 | if preact:
111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
112 | else:
113 | return [f0, f1, f2, f3, f4], out
114 | else:
115 | return out
116 |
117 |
118 | def ShuffleV1(**kwargs):
119 | cfg = {
120 | 'out_planes': [240, 480, 960],
121 | 'num_blocks': [4, 8, 4],
122 | 'groups': 3
123 | }
124 | return ShuffleNet(cfg, **kwargs)
125 |
126 |
127 | if __name__ == '__main__':
128 |
129 | x = torch.randn(2, 3, 32, 32)
130 | net = ShuffleV1(num_classes=100)
131 | import time
132 | a = time.time()
133 | feats, logit = net(x, is_feat=True, preact=True)
134 | b = time.time()
135 | print(b - a)
136 | for f in feats:
137 | print(f.shape, f.min().item())
138 | print(logit.shape)
139 |
--------------------------------------------------------------------------------
/cv/models/ShuffleNetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups=2):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N, C, H, W = x.size()
17 | g = self.groups
18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
19 |
20 |
21 | class SplitBlock(nn.Module):
22 | def __init__(self, ratio):
23 | super(SplitBlock, self).__init__()
24 | self.ratio = ratio
25 |
26 | def forward(self, x):
27 | c = int(x.size(1) * self.ratio)
28 | return x[:, :c, :, :], x[:, c:, :, :]
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False):
33 | super(BasicBlock, self).__init__()
34 | self.is_last = is_last
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | preact = self.bn3(self.conv3(out))
53 | out = F.relu(preact)
54 | # out = F.relu(self.bn3(self.conv3(out)))
55 | preact = torch.cat([x1, preact], 1)
56 | out = torch.cat([x1, out], 1)
57 | out = self.shuffle(out)
58 | if self.is_last:
59 | return out, preact
60 | else:
61 | return out
62 |
63 |
64 | class DownBlock(nn.Module):
65 | def __init__(self, in_channels, out_channels):
66 | super(DownBlock, self).__init__()
67 | mid_channels = out_channels // 2
68 | # left
69 | self.conv1 = nn.Conv2d(in_channels, in_channels,
70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
71 | self.bn1 = nn.BatchNorm2d(in_channels)
72 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
73 | kernel_size=1, bias=False)
74 | self.bn2 = nn.BatchNorm2d(mid_channels)
75 | # right
76 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(mid_channels)
79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
81 | self.bn4 = nn.BatchNorm2d(mid_channels)
82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
83 | kernel_size=1, bias=False)
84 | self.bn5 = nn.BatchNorm2d(mid_channels)
85 |
86 | self.shuffle = ShuffleBlock()
87 |
88 | def forward(self, x):
89 | # left
90 | out1 = self.bn1(self.conv1(x))
91 | out1 = F.relu(self.bn2(self.conv2(out1)))
92 | # right
93 | out2 = F.relu(self.bn3(self.conv3(x)))
94 | out2 = self.bn4(self.conv4(out2))
95 | out2 = F.relu(self.bn5(self.conv5(out2)))
96 | # concat
97 | out = torch.cat([out1, out2], 1)
98 | out = self.shuffle(out)
99 | return out
100 |
101 |
102 | class ShuffleNetV2(nn.Module):
103 | def __init__(self, net_size, num_classes=10):
104 | super(ShuffleNetV2, self).__init__()
105 | out_channels = configs[net_size]['out_channels']
106 | num_blocks = configs[net_size]['num_blocks']
107 |
108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
109 | # stride=1, padding=1, bias=False)
110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(24)
112 | self.in_channels = 24
113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
117 | kernel_size=1, stride=1, padding=0, bias=False)
118 | self.bn2 = nn.BatchNorm2d(out_channels[3])
119 | self.linear = nn.Linear(out_channels[3], num_classes)
120 |
121 | def _make_layer(self, out_channels, num_blocks):
122 | layers = [DownBlock(self.in_channels, out_channels)]
123 | for i in range(num_blocks):
124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1)))
125 | self.in_channels = out_channels
126 | return nn.Sequential(*layers)
127 |
128 | def get_feat_modules(self):
129 | feat_m = nn.ModuleList([])
130 | feat_m.append(self.conv1)
131 | feat_m.append(self.bn1)
132 | feat_m.append(self.layer1)
133 | feat_m.append(self.layer2)
134 | feat_m.append(self.layer3)
135 | return feat_m
136 |
137 | def get_bn_before_relu(self):
138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher')
139 |
140 | def forward(self, x, is_feat=False, preact=False):
141 | out = F.relu(self.bn1(self.conv1(x)))
142 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
143 | f0 = out
144 | out, f1_pre = self.layer1(out)
145 | f1 = out
146 | out, f2_pre = self.layer2(out)
147 | f2 = out
148 | out, f3_pre = self.layer3(out)
149 | f3 = out
150 | out = F.relu(self.bn2(self.conv2(out)))
151 | out = F.avg_pool2d(out, 4)
152 | out = out.view(out.size(0), -1)
153 | f4 = out
154 | out = self.linear(out)
155 | if is_feat:
156 | if preact:
157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
158 | else:
159 | return [f0, f1, f2, f3, f4], out
160 | else:
161 | return out
162 |
163 |
164 | configs = {
165 | 0.2: {
166 | 'out_channels': (40, 80, 160, 512),
167 | 'num_blocks': (3, 3, 3)
168 | },
169 |
170 | 0.3: {
171 | 'out_channels': (40, 80, 160, 512),
172 | 'num_blocks': (3, 7, 3)
173 | },
174 |
175 | 0.5: {
176 | 'out_channels': (48, 96, 192, 1024),
177 | 'num_blocks': (3, 7, 3)
178 | },
179 |
180 | 1: {
181 | 'out_channels': (116, 232, 464, 1024),
182 | 'num_blocks': (3, 7, 3)
183 | },
184 | 1.5: {
185 | 'out_channels': (176, 352, 704, 1024),
186 | 'num_blocks': (3, 7, 3)
187 | },
188 | 2: {
189 | 'out_channels': (224, 488, 976, 2048),
190 | 'num_blocks': (3, 7, 3)
191 | }
192 | }
193 |
194 |
195 | def ShuffleV2(**kwargs):
196 | model = ShuffleNetV2(net_size=1, **kwargs)
197 | return model
198 |
199 |
200 | if __name__ == '__main__':
201 | net = ShuffleV2(num_classes=100)
202 | x = torch.randn(3, 3, 32, 32)
203 | import time
204 | a = time.time()
205 | feats, logit = net(x, is_feat=True, preact=True)
206 | b = time.time()
207 | print(b - a)
208 | for f in feats:
209 | print(f.shape, f.min().item())
210 | print(logit.shape)
211 |
--------------------------------------------------------------------------------
/cv/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .meta_resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4
2 | from .resnetv2 import ResNet50
3 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
4 | from .meta_vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
5 | from .mobilenetv2 import mobile_half
6 | from .ShuffleNetv1 import ShuffleV1
7 | from .ShuffleNetv2 import ShuffleV2
8 |
9 | model_dict = {
10 | 'resnet8': resnet8,
11 | 'resnet14': resnet14,
12 | 'resnet20': resnet20,
13 | 'resnet32': resnet32,
14 | 'resnet44': resnet44,
15 | 'resnet56': resnet56,
16 | 'resnet110': resnet110,
17 | 'resnet8x4': resnet8x4,
18 | 'resnet32x4': resnet32x4,
19 | 'ResNet50': ResNet50,
20 | 'wrn_16_1': wrn_16_1,
21 | 'wrn_16_2': wrn_16_2,
22 | 'wrn_40_1': wrn_40_1,
23 | 'wrn_40_2': wrn_40_2,
24 | 'vgg8': vgg8_bn,
25 | 'vgg11': vgg11_bn,
26 | 'vgg13': vgg13_bn,
27 | 'vgg16': vgg16_bn,
28 | 'vgg19': vgg19_bn,
29 | 'MobileNetV2': mobile_half,
30 | 'ShuffleV1': ShuffleV1,
31 | 'ShuffleV2': ShuffleV2,
32 | }
33 |
--------------------------------------------------------------------------------
/cv/models/classifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 |
5 |
6 | #########################################
7 | # ===== Classifiers ===== #
8 | #########################################
9 |
10 | class LinearClassifier(nn.Module):
11 |
12 | def __init__(self, dim_in, n_label=10):
13 | super(LinearClassifier, self).__init__()
14 |
15 | self.net = nn.Linear(dim_in, n_label)
16 |
17 | def forward(self, x):
18 | return self.net(x)
19 |
20 |
21 | class NonLinearClassifier(nn.Module):
22 |
23 | def __init__(self, dim_in, n_label=10, p=0.1):
24 | super(NonLinearClassifier, self).__init__()
25 |
26 | self.net = nn.Sequential(
27 | nn.Linear(dim_in, 200),
28 | nn.Dropout(p=p),
29 | nn.BatchNorm1d(200),
30 | nn.ReLU(inplace=True),
31 | nn.Linear(200, n_label),
32 | )
33 |
34 | def forward(self, x):
35 | return self.net(x)
36 |
--------------------------------------------------------------------------------
/cv/models/meta_resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchmeta.modules import *
13 | import math
14 |
15 |
16 | __all__ = ['resnet']
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(MetaModule):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
29 | super(BasicBlock, self).__init__()
30 | self.is_last = is_last
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = MetaBatchNorm2d(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = MetaBatchNorm2d(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x, params=None):
40 | residual = x
41 |
42 | out = self.conv1(x, params=self.get_subdict(params, 'conv1'))
43 | out = self.bn1(out, params=self.get_subdict(params, 'bn1'))
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out, params=self.get_subdict(params, 'conv2'))
47 | out = self.bn2(out, params=self.get_subdict(params, 'bn2'))
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x, params=self.get_subdict(params, 'downsample'))
51 |
52 | out += residual
53 | preact = out
54 | out = F.relu(out)
55 | if self.is_last:
56 | return out, preact
57 | else:
58 | return out
59 |
60 |
61 | class Bottleneck(MetaModule):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
65 | super(Bottleneck, self).__init__()
66 | self.is_last = is_last
67 | self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = MetaBatchNorm2d(planes)
69 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn2 = MetaBatchNorm2d(planes)
72 | self.conv3 = MetaConv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.bn3 = MetaBatchNorm2d(planes * 4)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x, params=None):
79 | residual = x
80 |
81 | out = self.conv1(x, params=self.get_subdict(params, 'conv1'))
82 | out = self.bn1(out, params=self.get_subdict(params, 'bn1'))
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out, params=self.get_subdict(params, 'conv2'))
86 | out = self.bn2(out, params=self.get_subdict(params, 'bn2'))
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out, params=self.get_subdict(params, 'conv3'))
90 | out = self.bn3(out, params=self.get_subdict(params, 'bn3'))
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x, params=self.get_subdict(params, 'downsample'))
94 |
95 | out += residual
96 | preact = out
97 | out = F.relu(out)
98 | if self.is_last:
99 | return out, preact
100 | else:
101 | return out
102 |
103 |
104 | class ResNet(MetaModule):
105 |
106 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
107 | super(ResNet, self).__init__()
108 | # Model type specifies number of layers for CIFAR-10 model
109 | if block_name.lower() == 'basicblock':
110 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
111 | n = (depth - 2) // 6
112 | block = BasicBlock
113 | elif block_name.lower() == 'bottleneck':
114 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
115 | n = (depth - 2) // 9
116 | block = Bottleneck
117 | else:
118 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
119 |
120 | self.inplanes = num_filters[0]
121 | self.conv1 = MetaConv2d(3, num_filters[0], kernel_size=3, padding=1,
122 | bias=False)
123 | self.bn1 = MetaBatchNorm2d(num_filters[0])
124 | self.relu = nn.ReLU(inplace=True)
125 | self.layer1 = self._make_layer(block, num_filters[1], n)
126 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
127 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
128 | self.avgpool = nn.AvgPool2d(8)
129 | self.fc = MetaLinear(num_filters[3] * block.expansion, num_classes)
130 |
131 | for m in self.modules():
132 | if isinstance(m, MetaConv2d):
133 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
134 | elif isinstance(m, (MetaBatchNorm2d, nn.GroupNorm)):
135 | nn.init.constant_(m.weight, 1)
136 | nn.init.constant_(m.bias, 0)
137 |
138 | def _make_layer(self, block, planes, blocks, stride=1):
139 | downsample = None
140 | if stride != 1 or self.inplanes != planes * block.expansion:
141 | downsample = MetaSequential(
142 | MetaConv2d(self.inplanes, planes * block.expansion,
143 | kernel_size=1, stride=stride, bias=False),
144 | MetaBatchNorm2d(planes * block.expansion),
145 | )
146 |
147 | layers = list([])
148 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1)))
149 | self.inplanes = planes * block.expansion
150 | for i in range(1, blocks):
151 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1)))
152 |
153 | return MetaSequential(*layers)
154 |
155 | def get_feat_modules(self):
156 | feat_m = nn.ModuleList([])
157 | feat_m.append(self.conv1)
158 | feat_m.append(self.bn1)
159 | feat_m.append(self.relu)
160 | feat_m.append(self.layer1)
161 | feat_m.append(self.layer2)
162 | feat_m.append(self.layer3)
163 | return feat_m
164 |
165 | def get_bn_before_relu(self):
166 | if isinstance(self.layer1[0], Bottleneck):
167 | bn1 = self.layer1[-1].bn3
168 | bn2 = self.layer2[-1].bn3
169 | bn3 = self.layer3[-1].bn3
170 | elif isinstance(self.layer1[0], BasicBlock):
171 | bn1 = self.layer1[-1].bn2
172 | bn2 = self.layer2[-1].bn2
173 | bn3 = self.layer3[-1].bn2
174 | else:
175 | raise NotImplementedError('ResNet unknown block error !!!')
176 |
177 | return [bn1, bn2, bn3]
178 |
179 | def forward(self, x, is_feat=False, preact=False, params=None):
180 | x = self.conv1(x, params=self.get_subdict(params, 'conv1'))
181 | x = self.bn1(x, params=self.get_subdict(params, 'bn1'))
182 | x = self.relu(x) # 32x32
183 | f0 = x
184 |
185 | x, f1_pre = self.layer1(x, params=self.get_subdict(params, 'layer1')) # 32x32
186 | f1 = x
187 | x, f2_pre = self.layer2(x, params=self.get_subdict(params, 'layer2')) # 16x16
188 | f2 = x
189 | x, f3_pre = self.layer3(x, params=self.get_subdict(params, 'layer3')) # 8x8
190 | f3 = x
191 |
192 | x = self.avgpool(x)
193 | x = x.view(x.size(0), -1)
194 | f4 = x
195 | x = self.fc(x, params=self.get_subdict(params, 'fc'))
196 |
197 | if is_feat:
198 | if preact:
199 | return [f0, f1_pre, f2_pre, f3_pre, f4], x
200 | else:
201 | return [f0, f1, f2, f3, f4], x
202 | else:
203 | return x
204 |
205 |
206 | def resnet8(**kwargs):
207 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs)
208 |
209 |
210 | def resnet14(**kwargs):
211 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs)
212 |
213 |
214 | def resnet20(**kwargs):
215 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs)
216 |
217 |
218 | def resnet32(**kwargs):
219 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs)
220 |
221 |
222 | def resnet44(**kwargs):
223 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs)
224 |
225 |
226 | def resnet56(**kwargs):
227 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs)
228 |
229 |
230 | def resnet110(**kwargs):
231 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs)
232 |
233 |
234 | def resnet8x4(**kwargs):
235 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs)
236 |
237 |
238 | def resnet32x4(**kwargs):
239 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs)
240 |
241 |
242 | if __name__ == '__main__':
243 | import torch
244 |
245 | x = torch.randn(2, 3, 32, 32)
246 | net = resnet8x4(num_classes=20)
247 | feats, logit = net(x, is_feat=True, preact=True)
248 |
249 | for f in feats:
250 | print(f.shape, f.min().item())
251 | print(logit.shape)
252 |
253 | for m in net.get_bn_before_relu():
254 | if isinstance(m, MetaBatchNorm2d):
255 | print('pass')
256 | else:
257 | print('warning')
258 |
--------------------------------------------------------------------------------
/cv/models/meta_vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 | from torchmeta.modules import *
8 |
9 |
10 | __all__ = [
11 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
12 | 'vgg19_bn', 'vgg19',
13 | ]
14 |
15 |
16 | model_urls = {
17 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
18 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
19 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
20 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
21 | }
22 |
23 |
24 | class VGG(MetaModule):
25 |
26 | def __init__(self, cfg, batch_norm=False, num_classes=1000):
27 | super(VGG, self).__init__()
28 | self.block0 = self._make_layers(cfg[0], batch_norm, 3)
29 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
30 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
31 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
32 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])
33 |
34 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
35 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
36 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
37 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
38 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1))
39 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
40 |
41 | self.classifier = MetaLinear(512, num_classes)
42 | self._initialize_weights()
43 |
44 | def get_feat_modules(self):
45 | feat_m = nn.ModuleList([])
46 | feat_m.append(self.block0)
47 | feat_m.append(self.pool0)
48 | feat_m.append(self.block1)
49 | feat_m.append(self.pool1)
50 | feat_m.append(self.block2)
51 | feat_m.append(self.pool2)
52 | feat_m.append(self.block3)
53 | feat_m.append(self.pool3)
54 | feat_m.append(self.block4)
55 | feat_m.append(self.pool4)
56 | return feat_m
57 |
58 | def get_bn_before_relu(self):
59 | bn1 = self.block1[-1]
60 | bn2 = self.block2[-1]
61 | bn3 = self.block3[-1]
62 | bn4 = self.block4[-1]
63 | return [bn1, bn2, bn3, bn4]
64 |
65 | def forward(self, x, is_feat=False, preact=False, params=None):
66 | h = x.shape[2]
67 | x = F.relu(self.block0(x, params=self.get_subdict(params, 'block0')))
68 | f0 = x
69 | x = self.pool0(x)
70 | x = self.block1(x, params=self.get_subdict(params, 'block1'))
71 | f1_pre = x
72 | x = F.relu(x)
73 | f1 = x
74 | x = self.pool1(x)
75 | x = self.block2(x, params=self.get_subdict(params, 'block2'))
76 | f2_pre = x
77 | x = F.relu(x)
78 | f2 = x
79 | x = self.pool2(x)
80 | x = self.block3(x, params=self.get_subdict(params, 'block3'))
81 | f3_pre = x
82 | x = F.relu(x)
83 | f3 = x
84 | if h == 64:
85 | x = self.pool3(x)
86 | x = self.block4(x)
87 | f4_pre = x
88 | x = F.relu(x)
89 | f4 = x
90 | x = self.pool4(x)
91 | x = x.view(x.size(0), -1)
92 | f5 = x
93 | x = self.classifier(x, params=self.get_subdict(params, 'classifier'))
94 |
95 | if is_feat:
96 | if preact:
97 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x
98 | else:
99 | return [f0, f1, f2, f3, f4, f5], x
100 | else:
101 | return x
102 |
103 | @staticmethod
104 | def _make_layers(cfg, batch_norm=False, in_channels=3):
105 | layers = []
106 | for v in cfg:
107 | if v == 'M':
108 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
109 | else:
110 | conv2d = MetaConv2d(in_channels, v, kernel_size=3, padding=1)
111 | if batch_norm:
112 | layers += [conv2d, MetaBatchNorm2d(v), nn.ReLU(inplace=True)]
113 | else:
114 | layers += [conv2d, nn.ReLU(inplace=True)]
115 | in_channels = v
116 | layers = layers[:-1]
117 | return MetaSequential(*layers)
118 |
119 | def _initialize_weights(self):
120 | for m in self.modules():
121 | if isinstance(m, MetaConv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2. / n))
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | elif isinstance(m, MetaBatchNorm2d):
127 | m.weight.data.fill_(1)
128 | m.bias.data.zero_()
129 | elif isinstance(m, MetaLinear):
130 | n = m.weight.size(1)
131 | m.weight.data.normal_(0, 0.01)
132 | m.bias.data.zero_()
133 |
134 |
135 | cfg = {
136 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
137 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
138 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
139 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
140 | 'S': [[64], [128], [256], [512], [512]],
141 | }
142 |
143 |
144 | def vgg8(**kwargs):
145 | """VGG 8-layer model (configuration "S")
146 | Args:
147 | pretrained (bool): If True, returns a model pre-trained on ImageNet
148 | """
149 | model = VGG(cfg['S'], **kwargs)
150 | return model
151 |
152 |
153 | def vgg8_bn(**kwargs):
154 | """VGG 8-layer model (configuration "S")
155 | Args:
156 | pretrained (bool): If True, returns a model pre-trained on ImageNet
157 | """
158 | model = VGG(cfg['S'], batch_norm=True, **kwargs)
159 | return model
160 |
161 |
162 | def vgg11(**kwargs):
163 | """VGG 11-layer model (configuration "A")
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = VGG(cfg['A'], **kwargs)
168 | return model
169 |
170 |
171 | def vgg11_bn(**kwargs):
172 | """VGG 11-layer model (configuration "A") with batch normalization"""
173 | model = VGG(cfg['A'], batch_norm=True, **kwargs)
174 | return model
175 |
176 |
177 | def vgg13(**kwargs):
178 | """VGG 13-layer model (configuration "B")
179 | Args:
180 | pretrained (bool): If True, returns a model pre-trained on ImageNet
181 | """
182 | model = VGG(cfg['B'], **kwargs)
183 | return model
184 |
185 |
186 | def vgg13_bn(**kwargs):
187 | """VGG 13-layer model (configuration "B") with batch normalization"""
188 | model = VGG(cfg['B'], batch_norm=True, **kwargs)
189 | return model
190 |
191 |
192 | def vgg16(**kwargs):
193 | """VGG 16-layer model (configuration "D")
194 | Args:
195 | pretrained (bool): If True, returns a model pre-trained on ImageNet
196 | """
197 | model = VGG(cfg['D'], **kwargs)
198 | return model
199 |
200 |
201 | def vgg16_bn(**kwargs):
202 | """VGG 16-layer model (configuration "D") with batch normalization"""
203 | model = VGG(cfg['D'], batch_norm=True, **kwargs)
204 | return model
205 |
206 |
207 | def vgg19(**kwargs):
208 | """VGG 19-layer model (configuration "E")
209 | Args:
210 | pretrained (bool): If True, returns a model pre-trained on ImageNet
211 | """
212 | model = VGG(cfg['E'], **kwargs)
213 | return model
214 |
215 |
216 | def vgg19_bn(**kwargs):
217 | """VGG 19-layer model (configuration 'E') with batch normalization"""
218 | model = VGG(cfg['E'], batch_norm=True, **kwargs)
219 | return model
220 |
221 |
222 | if __name__ == '__main__':
223 | import torch
224 |
225 | x = torch.randn(2, 3, 32, 32)
226 | net = vgg19_bn(num_classes=100)
227 | feats, logit = net(x, is_feat=True, preact=True)
228 |
229 | for f in feats:
230 | print(f.shape, f.min().item())
231 | print(logit.shape)
232 |
233 | for m in net.get_bn_before_relu():
234 | if isinstance(m, MetaBatchNorm2d):
235 | print('pass')
236 | else:
237 | print('warning')
238 |
--------------------------------------------------------------------------------
/cv/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """
2 | MobileNetV2 implementation used in
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 |
10 | __all__ = ['mobilenetv2_T_w', 'mobile_half']
11 |
12 | BN = None
13 |
14 |
15 | def conv_bn(inp, oup, stride):
16 | return nn.Sequential(
17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
18 | nn.BatchNorm2d(oup),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 |
23 | def conv_1x1_bn(inp, oup):
24 | return nn.Sequential(
25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
26 | nn.BatchNorm2d(oup),
27 | nn.ReLU(inplace=True)
28 | )
29 |
30 |
31 | class InvertedResidual(nn.Module):
32 | def __init__(self, inp, oup, stride, expand_ratio):
33 | super(InvertedResidual, self).__init__()
34 | self.blockname = None
35 |
36 | self.stride = stride
37 | assert stride in [1, 2]
38 |
39 | self.use_res_connect = self.stride == 1 and inp == oup
40 |
41 | self.conv = nn.Sequential(
42 | # pw
43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(inp * expand_ratio),
45 | nn.ReLU(inplace=True),
46 | # dw
47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
48 | nn.BatchNorm2d(inp * expand_ratio),
49 | nn.ReLU(inplace=True),
50 | # pw-linear
51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
52 | nn.BatchNorm2d(oup),
53 | )
54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7']
55 |
56 | def forward(self, x):
57 | t = x
58 | if self.use_res_connect:
59 | return t + self.conv(x)
60 | else:
61 | return self.conv(x)
62 |
63 |
64 | class MobileNetV2(nn.Module):
65 | """mobilenetV2"""
66 | def __init__(self, T,
67 | feature_dim,
68 | input_size=32,
69 | width_mult=1.,
70 | remove_avg=False):
71 | super(MobileNetV2, self).__init__()
72 | self.remove_avg = remove_avg
73 |
74 | # setting of inverted residual blocks
75 | self.interverted_residual_setting = [
76 | # t, c, n, s
77 | [1, 16, 1, 1],
78 | [T, 24, 2, 1],
79 | [T, 32, 3, 2],
80 | [T, 64, 4, 2],
81 | [T, 96, 3, 1],
82 | [T, 160, 3, 2],
83 | [T, 320, 1, 1],
84 | ]
85 |
86 | # building first layer
87 | assert input_size % 32 == 0
88 | input_channel = int(32 * width_mult)
89 | self.conv1 = conv_bn(3, input_channel, 2)
90 |
91 | # building inverted residual blocks
92 | self.blocks = nn.ModuleList([])
93 | for t, c, n, s in self.interverted_residual_setting:
94 | output_channel = int(c * width_mult)
95 | layers = []
96 | strides = [s] + [1] * (n - 1)
97 | for stride in strides:
98 | layers.append(
99 | InvertedResidual(input_channel, output_channel, stride, t)
100 | )
101 | input_channel = output_channel
102 | self.blocks.append(nn.Sequential(*layers))
103 |
104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel)
106 |
107 | # building classifier
108 | self.classifier = nn.Sequential(
109 | # nn.Dropout(0.5),
110 | nn.Linear(self.last_channel, feature_dim),
111 | )
112 |
113 | H = input_size // (32//2)
114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True)
115 |
116 | self._initialize_weights()
117 | print(T, width_mult)
118 |
119 | def get_bn_before_relu(self):
120 | bn1 = self.blocks[1][-1].conv[-1]
121 | bn2 = self.blocks[2][-1].conv[-1]
122 | bn3 = self.blocks[4][-1].conv[-1]
123 | bn4 = self.blocks[6][-1].conv[-1]
124 | return [bn1, bn2, bn3, bn4]
125 |
126 | def get_feat_modules(self):
127 | feat_m = nn.ModuleList([])
128 | feat_m.append(self.conv1)
129 | feat_m.append(self.blocks)
130 | return feat_m
131 |
132 | def forward(self, x, is_feat=False, preact=False):
133 |
134 | out = self.conv1(x)
135 | f0 = out
136 |
137 | out = self.blocks[0](out)
138 | out = self.blocks[1](out)
139 | f1 = out
140 | out = self.blocks[2](out)
141 | f2 = out
142 | out = self.blocks[3](out)
143 | out = self.blocks[4](out)
144 | f3 = out
145 | out = self.blocks[5](out)
146 | out = self.blocks[6](out)
147 | f4 = out
148 |
149 | out = self.conv2(out)
150 |
151 | if not self.remove_avg:
152 | out = self.avgpool(out)
153 | out = out.view(out.size(0), -1)
154 | f5 = out
155 | out = self.classifier(out)
156 |
157 | if is_feat:
158 | return [f0, f1, f2, f3, f4, f5], out
159 | else:
160 | return out
161 |
162 | def _initialize_weights(self):
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
166 | m.weight.data.normal_(0, math.sqrt(2. / n))
167 | if m.bias is not None:
168 | m.bias.data.zero_()
169 | elif isinstance(m, nn.BatchNorm2d):
170 | m.weight.data.fill_(1)
171 | m.bias.data.zero_()
172 | elif isinstance(m, nn.Linear):
173 | n = m.weight.size(1)
174 | m.weight.data.normal_(0, 0.01)
175 | m.bias.data.zero_()
176 |
177 |
178 | def mobilenetv2_T_w(T, W, feature_dim=100):
179 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W)
180 | return model
181 |
182 |
183 | def mobile_half(num_classes):
184 | return mobilenetv2_T_w(6, 0.5, num_classes)
185 |
186 |
187 | if __name__ == '__main__':
188 | x = torch.randn(2, 3, 32, 32)
189 |
190 | net = mobile_half(100)
191 |
192 | feats, logit = net(x, is_feat=True, preact=True)
193 | for f in feats:
194 | print(f.shape, f.min().item())
195 | print(logit.shape)
196 |
197 | for m in net.get_bn_before_relu():
198 | if isinstance(m, nn.BatchNorm2d):
199 | print('pass')
200 | else:
201 | print('warning')
202 |
203 |
--------------------------------------------------------------------------------
/cv/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 |
14 |
15 | __all__ = ['resnet']
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
28 | super(BasicBlock, self).__init__()
29 | self.is_last = is_last
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | preact = out
53 | out = F.relu(out)
54 | if self.is_last:
55 | return out, preact
56 | else:
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
64 | super(Bottleneck, self).__init__()
65 | self.is_last = is_last
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(planes * 4)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | preact = out
96 | out = F.relu(out)
97 | if self.is_last:
98 | return out, preact
99 | else:
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
106 | super(ResNet, self).__init__()
107 | # Model type specifies number of layers for CIFAR-10 model
108 | if block_name.lower() == 'basicblock':
109 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
110 | n = (depth - 2) // 6
111 | block = BasicBlock
112 | elif block_name.lower() == 'bottleneck':
113 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
114 | n = (depth - 2) // 9
115 | block = Bottleneck
116 | else:
117 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
118 |
119 | self.inplanes = num_filters[0]
120 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1,
121 | bias=False)
122 | self.bn1 = nn.BatchNorm2d(num_filters[0])
123 | self.relu = nn.ReLU(inplace=True)
124 | self.layer1 = self._make_layer(block, num_filters[1], n)
125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
127 | self.avgpool = nn.AvgPool2d(8)
128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(self.inplanes, planes * block.expansion,
142 | kernel_size=1, stride=stride, bias=False),
143 | nn.BatchNorm2d(planes * block.expansion),
144 | )
145 |
146 | layers = list([])
147 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1)))
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1)))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def get_feat_modules(self):
155 | feat_m = nn.ModuleList([])
156 | feat_m.append(self.conv1)
157 | feat_m.append(self.bn1)
158 | feat_m.append(self.relu)
159 | feat_m.append(self.layer1)
160 | feat_m.append(self.layer2)
161 | feat_m.append(self.layer3)
162 | return feat_m
163 |
164 | def get_bn_before_relu(self):
165 | if isinstance(self.layer1[0], Bottleneck):
166 | bn1 = self.layer1[-1].bn3
167 | bn2 = self.layer2[-1].bn3
168 | bn3 = self.layer3[-1].bn3
169 | elif isinstance(self.layer1[0], BasicBlock):
170 | bn1 = self.layer1[-1].bn2
171 | bn2 = self.layer2[-1].bn2
172 | bn3 = self.layer3[-1].bn2
173 | else:
174 | raise NotImplementedError('ResNet unknown block error !!!')
175 |
176 | return [bn1, bn2, bn3]
177 |
178 | def forward(self, x, is_feat=False, preact=False):
179 | x = self.conv1(x)
180 | x = self.bn1(x)
181 | x = self.relu(x) # 32x32
182 | f0 = x
183 |
184 | x, f1_pre = self.layer1(x) # 32x32
185 | f1 = x
186 | x, f2_pre = self.layer2(x) # 16x16
187 | f2 = x
188 | x, f3_pre = self.layer3(x) # 8x8
189 | f3 = x
190 |
191 | x = self.avgpool(x)
192 | x = x.view(x.size(0), -1)
193 | f4 = x
194 | x = self.fc(x)
195 |
196 | if is_feat:
197 | if preact:
198 | return [f0, f1_pre, f2_pre, f3_pre, f4], x
199 | else:
200 | return [f0, f1, f2, f3, f4], x
201 | else:
202 | return x
203 |
204 |
205 | def resnet8(**kwargs):
206 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs)
207 |
208 |
209 | def resnet14(**kwargs):
210 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs)
211 |
212 |
213 | def resnet20(**kwargs):
214 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs)
215 |
216 |
217 | def resnet32(**kwargs):
218 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs)
219 |
220 |
221 | def resnet44(**kwargs):
222 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs)
223 |
224 |
225 | def resnet56(**kwargs):
226 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs)
227 |
228 |
229 | def resnet110(**kwargs):
230 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs)
231 |
232 |
233 | def resnet8x4(**kwargs):
234 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs)
235 |
236 |
237 | def resnet32x4(**kwargs):
238 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs)
239 |
240 |
241 | if __name__ == '__main__':
242 | import torch
243 |
244 | x = torch.randn(2, 3, 32, 32)
245 | net = resnet8x4(num_classes=20)
246 | feats, logit = net(x, is_feat=True, preact=True)
247 |
248 | for f in feats:
249 | print(f.shape, f.min().item())
250 | print(logit.shape)
251 |
252 | for m in net.get_bn_before_relu():
253 | if isinstance(m, nn.BatchNorm2d):
254 | print('pass')
255 | else:
256 | print('warning')
257 |
--------------------------------------------------------------------------------
/cv/models/resnetv2.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1, is_last=False):
16 | super(BasicBlock, self).__init__()
17 | self.is_last = is_last
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion * planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(self.expansion * planes)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = self.bn2(self.conv2(out))
33 | out += self.shortcut(x)
34 | preact = out
35 | out = F.relu(out)
36 | if self.is_last:
37 | return out, preact
38 | else:
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1, is_last=False):
46 | super(Bottleneck, self).__init__()
47 | self.is_last = is_last
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
54 |
55 | self.shortcut = nn.Sequential()
56 | if stride != 1 or in_planes != self.expansion * planes:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(self.expansion * planes)
60 | )
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(self.conv1(x)))
64 | out = F.relu(self.bn2(self.conv2(out)))
65 | out = self.bn3(self.conv3(out))
66 | out += self.shortcut(x)
67 | preact = out
68 | out = F.relu(out)
69 | if self.is_last:
70 | return out, preact
71 | else:
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False):
77 | super(ResNet, self).__init__()
78 | self.in_planes = 64
79 |
80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(64)
82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
87 | self.linear = nn.Linear(512 * block.expansion, num_classes)
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
93 | nn.init.constant_(m.weight, 1)
94 | nn.init.constant_(m.bias, 0)
95 |
96 | # Zero-initialize the last BN in each residual branch,
97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
99 | if zero_init_residual:
100 | for m in self.modules():
101 | if isinstance(m, Bottleneck):
102 | nn.init.constant_(m.bn3.weight, 0)
103 | elif isinstance(m, BasicBlock):
104 | nn.init.constant_(m.bn2.weight, 0)
105 |
106 | def get_feat_modules(self):
107 | feat_m = nn.ModuleList([])
108 | feat_m.append(self.conv1)
109 | feat_m.append(self.bn1)
110 | feat_m.append(self.layer1)
111 | feat_m.append(self.layer2)
112 | feat_m.append(self.layer3)
113 | feat_m.append(self.layer4)
114 | return feat_m
115 |
116 | def get_bn_before_relu(self):
117 | if isinstance(self.layer1[0], Bottleneck):
118 | bn1 = self.layer1[-1].bn3
119 | bn2 = self.layer2[-1].bn3
120 | bn3 = self.layer3[-1].bn3
121 | bn4 = self.layer4[-1].bn3
122 | elif isinstance(self.layer1[0], BasicBlock):
123 | bn1 = self.layer1[-1].bn2
124 | bn2 = self.layer2[-1].bn2
125 | bn3 = self.layer3[-1].bn2
126 | bn4 = self.layer4[-1].bn2
127 | else:
128 | raise NotImplementedError('ResNet unknown block error !!!')
129 |
130 | return [bn1, bn2, bn3, bn4]
131 |
132 | def _make_layer(self, block, planes, num_blocks, stride):
133 | strides = [stride] + [1] * (num_blocks - 1)
134 | layers = []
135 | for i in range(num_blocks):
136 | stride = strides[i]
137 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1))
138 | self.in_planes = planes * block.expansion
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x, is_feat=False, preact=False):
142 | out = F.relu(self.bn1(self.conv1(x)))
143 | f0 = out
144 | out, f1_pre = self.layer1(out)
145 | f1 = out
146 | out, f2_pre = self.layer2(out)
147 | f2 = out
148 | out, f3_pre = self.layer3(out)
149 | f3 = out
150 | out, f4_pre = self.layer4(out)
151 | f4 = out
152 | out = self.avgpool(out)
153 | out = out.view(out.size(0), -1)
154 | f5 = out
155 | out = self.linear(out)
156 | if is_feat:
157 | if preact:
158 | return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out]
159 | else:
160 | return [f0, f1, f2, f3, f4, f5], out
161 | else:
162 | return out
163 |
164 |
165 | def ResNet18(**kwargs):
166 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
167 |
168 |
169 | def ResNet34(**kwargs):
170 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
171 |
172 |
173 | def ResNet50(**kwargs):
174 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
175 |
176 |
177 | def ResNet101(**kwargs):
178 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
179 |
180 |
181 | def ResNet152(**kwargs):
182 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
183 |
184 |
185 | if __name__ == '__main__':
186 | net = ResNet18(num_classes=100)
187 | x = torch.randn(2, 3, 32, 32)
188 | feats, logit = net(x, is_feat=True, preact=True)
189 |
190 | for f in feats:
191 | print(f.shape, f.min().item())
192 | print(logit.shape)
193 |
194 | for m in net.get_bn_before_relu():
195 | if isinstance(m, nn.BatchNorm2d):
196 | print('pass')
197 | else:
198 | print('warning')
199 |
--------------------------------------------------------------------------------
/cv/models/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class Paraphraser(nn.Module):
8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer"""
9 | def __init__(self, t_shape, k=0.5, use_bn=False):
10 | super(Paraphraser, self).__init__()
11 | in_channel = t_shape[1]
12 | out_channel = int(t_shape[1] * k)
13 | self.encoder = nn.Sequential(
14 | nn.Conv2d(in_channel, in_channel, 3, 1, 1),
15 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
16 | nn.LeakyReLU(0.1, inplace=True),
17 | nn.Conv2d(in_channel, out_channel, 3, 1, 1),
18 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
19 | nn.LeakyReLU(0.1, inplace=True),
20 | nn.Conv2d(out_channel, out_channel, 3, 1, 1),
21 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
22 | nn.LeakyReLU(0.1, inplace=True),
23 | )
24 | self.decoder = nn.Sequential(
25 | nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1),
26 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
27 | nn.LeakyReLU(0.1, inplace=True),
28 | nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1),
29 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
30 | nn.LeakyReLU(0.1, inplace=True),
31 | nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1),
32 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
33 | nn.LeakyReLU(0.1, inplace=True),
34 | )
35 |
36 | def forward(self, f_s, is_factor=False):
37 | factor = self.encoder(f_s)
38 | if is_factor:
39 | return factor
40 | rec = self.decoder(factor)
41 | return factor, rec
42 |
43 |
44 | class Translator(nn.Module):
45 | def __init__(self, s_shape, t_shape, k=0.5, use_bn=True):
46 | super(Translator, self).__init__()
47 | in_channel = s_shape[1]
48 | out_channel = int(t_shape[1] * k)
49 | self.encoder = nn.Sequential(
50 | nn.Conv2d(in_channel, in_channel, 3, 1, 1),
51 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
52 | nn.LeakyReLU(0.1, inplace=True),
53 | nn.Conv2d(in_channel, out_channel, 3, 1, 1),
54 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
55 | nn.LeakyReLU(0.1, inplace=True),
56 | nn.Conv2d(out_channel, out_channel, 3, 1, 1),
57 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
58 | nn.LeakyReLU(0.1, inplace=True),
59 | )
60 |
61 | def forward(self, f_s):
62 | return self.encoder(f_s)
63 |
64 |
65 | class Connector(nn.Module):
66 | """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons"""
67 | def __init__(self, s_shapes, t_shapes):
68 | super(Connector, self).__init__()
69 | self.s_shapes = s_shapes
70 | self.t_shapes = t_shapes
71 |
72 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes))
73 |
74 | @staticmethod
75 | def _make_conenctors(s_shapes, t_shapes):
76 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
77 | connectors = []
78 | for s, t in zip(s_shapes, t_shapes):
79 | if s[1] == t[1] and s[2] == t[2]:
80 | connectors.append(nn.Sequential())
81 | else:
82 | connectors.append(ConvReg(s, t, use_relu=False))
83 | return connectors
84 |
85 | def forward(self, g_s):
86 | out = []
87 | for i in range(len(g_s)):
88 | out.append(self.connectors[i](g_s[i]))
89 |
90 | return out
91 |
92 |
93 | class ConnectorV2(nn.Module):
94 | """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)"""
95 | def __init__(self, s_shapes, t_shapes):
96 | super(ConnectorV2, self).__init__()
97 | self.s_shapes = s_shapes
98 | self.t_shapes = t_shapes
99 |
100 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes))
101 |
102 | def _make_conenctors(self, s_shapes, t_shapes):
103 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
104 | t_channels = [t[1] for t in t_shapes]
105 | s_channels = [s[1] for s in s_shapes]
106 | connectors = nn.ModuleList([self._build_feature_connector(t, s)
107 | for t, s in zip(t_channels, s_channels)])
108 | return connectors
109 |
110 | @staticmethod
111 | def _build_feature_connector(t_channel, s_channel):
112 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
113 | nn.BatchNorm2d(t_channel)]
114 | for m in C:
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 | return nn.Sequential(*C)
122 |
123 | def forward(self, g_s):
124 | out = []
125 | for i in range(len(g_s)):
126 | out.append(self.connectors[i](g_s[i]))
127 |
128 | return out
129 |
130 |
131 | class ConvReg(nn.Module):
132 | """Convolutional regression for FitNet"""
133 | def __init__(self, s_shape, t_shape, use_relu=True):
134 | super(ConvReg, self).__init__()
135 | self.use_relu = use_relu
136 | s_N, s_C, s_H, s_W = s_shape
137 | t_N, t_C, t_H, t_W = t_shape
138 | if s_H == 2 * t_H:
139 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
140 | elif s_H * 2 == t_H:
141 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
142 | elif s_H >= t_H:
143 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
144 | else:
145 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))
146 | self.bn = nn.BatchNorm2d(t_C)
147 | self.relu = nn.ReLU(inplace=True)
148 |
149 | def forward(self, x):
150 | x = self.conv(x)
151 | if self.use_relu:
152 | return self.relu(self.bn(x))
153 | else:
154 | return self.bn(x)
155 |
156 |
157 | class Regress(nn.Module):
158 | """Simple Linear Regression for hints"""
159 | def __init__(self, dim_in=1024, dim_out=1024):
160 | super(Regress, self).__init__()
161 | self.linear = nn.Linear(dim_in, dim_out)
162 | self.relu = nn.ReLU(inplace=True)
163 |
164 | def forward(self, x):
165 | x = x.view(x.shape[0], -1)
166 | x = self.linear(x)
167 | x = self.relu(x)
168 | return x
169 |
170 |
171 | class Embed(nn.Module):
172 | """Embedding module"""
173 | def __init__(self, dim_in=1024, dim_out=128):
174 | super(Embed, self).__init__()
175 | self.linear = nn.Linear(dim_in, dim_out)
176 | self.l2norm = Normalize(2)
177 |
178 | def forward(self, x):
179 | x = x.view(x.shape[0], -1)
180 | x = self.linear(x)
181 | x = self.l2norm(x)
182 | return x
183 |
184 |
185 | class LinearEmbed(nn.Module):
186 | """Linear Embedding"""
187 | def __init__(self, dim_in=1024, dim_out=128):
188 | super(LinearEmbed, self).__init__()
189 | self.linear = nn.Linear(dim_in, dim_out)
190 |
191 | def forward(self, x):
192 | x = x.view(x.shape[0], -1)
193 | x = self.linear(x)
194 | return x
195 |
196 |
197 | class MLPEmbed(nn.Module):
198 | """non-linear embed by MLP"""
199 | def __init__(self, dim_in=1024, dim_out=128):
200 | super(MLPEmbed, self).__init__()
201 | self.linear1 = nn.Linear(dim_in, 2 * dim_out)
202 | self.relu = nn.ReLU(inplace=True)
203 | self.linear2 = nn.Linear(2 * dim_out, dim_out)
204 | self.l2norm = Normalize(2)
205 |
206 | def forward(self, x):
207 | x = x.view(x.shape[0], -1)
208 | x = self.relu(self.linear1(x))
209 | x = self.l2norm(self.linear2(x))
210 | return x
211 |
212 |
213 | class Normalize(nn.Module):
214 | """normalization layer"""
215 | def __init__(self, power=2):
216 | super(Normalize, self).__init__()
217 | self.power = power
218 |
219 | def forward(self, x):
220 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
221 | out = x.div(norm)
222 | return out
223 |
224 |
225 | class Flatten(nn.Module):
226 | """flatten module"""
227 | def __init__(self):
228 | super(Flatten, self).__init__()
229 |
230 | def forward(self, feat):
231 | return feat.view(feat.size(0), -1)
232 |
233 |
234 | class PoolEmbed(nn.Module):
235 | """pool and embed"""
236 | def __init__(self, layer=0, dim_out=128, pool_type='avg'):
237 | super().__init__()
238 | if layer == 0:
239 | pool_size = 8
240 | nChannels = 16
241 | elif layer == 1:
242 | pool_size = 8
243 | nChannels = 16
244 | elif layer == 2:
245 | pool_size = 6
246 | nChannels = 32
247 | elif layer == 3:
248 | pool_size = 4
249 | nChannels = 64
250 | elif layer == 4:
251 | pool_size = 1
252 | nChannels = 64
253 | else:
254 | raise NotImplementedError('layer not supported: {}'.format(layer))
255 |
256 | self.embed = nn.Sequential()
257 | if layer <= 3:
258 | if pool_type == 'max':
259 | self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size)))
260 | elif pool_type == 'avg':
261 | self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size)))
262 |
263 | self.embed.add_module('Flatten', Flatten())
264 | self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out))
265 | self.embed.add_module('Normalize', Normalize(2))
266 |
267 | def forward(self, x):
268 | return self.embed(x)
269 |
270 |
271 | if __name__ == '__main__':
272 | import torch
273 |
274 | g_s = [
275 | torch.randn(2, 16, 16, 16),
276 | torch.randn(2, 32, 8, 8),
277 | torch.randn(2, 64, 4, 4),
278 | ]
279 | g_t = [
280 | torch.randn(2, 32, 16, 16),
281 | torch.randn(2, 64, 8, 8),
282 | torch.randn(2, 128, 4, 4),
283 | ]
284 | s_shapes = [s.shape for s in g_s]
285 | t_shapes = [t.shape for t in g_t]
286 |
287 | net = ConnectorV2(s_shapes, t_shapes)
288 | out = net(g_s)
289 | for f in out:
290 | print(f.shape)
291 |
--------------------------------------------------------------------------------
/cv/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, cfg, batch_norm=False, num_classes=1000):
26 | super(VGG, self).__init__()
27 | self.block0 = self._make_layers(cfg[0], batch_norm, 3)
28 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
29 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
30 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
31 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])
32 |
33 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
34 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
35 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
36 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
37 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1))
38 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
39 |
40 | self.classifier = nn.Linear(512, num_classes)
41 | self._initialize_weights()
42 |
43 | def get_feat_modules(self):
44 | feat_m = nn.ModuleList([])
45 | feat_m.append(self.block0)
46 | feat_m.append(self.pool0)
47 | feat_m.append(self.block1)
48 | feat_m.append(self.pool1)
49 | feat_m.append(self.block2)
50 | feat_m.append(self.pool2)
51 | feat_m.append(self.block3)
52 | feat_m.append(self.pool3)
53 | feat_m.append(self.block4)
54 | feat_m.append(self.pool4)
55 | return feat_m
56 |
57 | def get_bn_before_relu(self):
58 | bn1 = self.block1[-1]
59 | bn2 = self.block2[-1]
60 | bn3 = self.block3[-1]
61 | bn4 = self.block4[-1]
62 | return [bn1, bn2, bn3, bn4]
63 |
64 | def forward(self, x, is_feat=False, preact=False):
65 | h = x.shape[2]
66 | x = F.relu(self.block0(x))
67 | f0 = x
68 | x = self.pool0(x)
69 | x = self.block1(x)
70 | f1_pre = x
71 | x = F.relu(x)
72 | f1 = x
73 | x = self.pool1(x)
74 | x = self.block2(x)
75 | f2_pre = x
76 | x = F.relu(x)
77 | f2 = x
78 | x = self.pool2(x)
79 | x = self.block3(x)
80 | f3_pre = x
81 | x = F.relu(x)
82 | f3 = x
83 | if h == 64:
84 | x = self.pool3(x)
85 | x = self.block4(x)
86 | f4_pre = x
87 | x = F.relu(x)
88 | f4 = x
89 | x = self.pool4(x)
90 | x = x.view(x.size(0), -1)
91 | f5 = x
92 | x = self.classifier(x)
93 |
94 | if is_feat:
95 | if preact:
96 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x
97 | else:
98 | return [f0, f1, f2, f3, f4, f5], x
99 | else:
100 | return x
101 |
102 | @staticmethod
103 | def _make_layers(cfg, batch_norm=False, in_channels=3):
104 | layers = []
105 | for v in cfg:
106 | if v == 'M':
107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
108 | else:
109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
110 | if batch_norm:
111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
112 | else:
113 | layers += [conv2d, nn.ReLU(inplace=True)]
114 | in_channels = v
115 | layers = layers[:-1]
116 | return nn.Sequential(*layers)
117 |
118 | def _initialize_weights(self):
119 | for m in self.modules():
120 | if isinstance(m, nn.Conv2d):
121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
122 | m.weight.data.normal_(0, math.sqrt(2. / n))
123 | if m.bias is not None:
124 | m.bias.data.zero_()
125 | elif isinstance(m, nn.BatchNorm2d):
126 | m.weight.data.fill_(1)
127 | m.bias.data.zero_()
128 | elif isinstance(m, nn.Linear):
129 | n = m.weight.size(1)
130 | m.weight.data.normal_(0, 0.01)
131 | m.bias.data.zero_()
132 |
133 |
134 | cfg = {
135 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
136 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
137 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
138 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
139 | 'S': [[64], [128], [256], [512], [512]],
140 | }
141 |
142 |
143 | def vgg8(**kwargs):
144 | """VGG 8-layer model (configuration "S")
145 | Args:
146 | pretrained (bool): If True, returns a model pre-trained on ImageNet
147 | """
148 | model = VGG(cfg['S'], **kwargs)
149 | return model
150 |
151 |
152 | def vgg8_bn(**kwargs):
153 | """VGG 8-layer model (configuration "S")
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | """
157 | model = VGG(cfg['S'], batch_norm=True, **kwargs)
158 | return model
159 |
160 |
161 | def vgg11(**kwargs):
162 | """VGG 11-layer model (configuration "A")
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | """
166 | model = VGG(cfg['A'], **kwargs)
167 | return model
168 |
169 |
170 | def vgg11_bn(**kwargs):
171 | """VGG 11-layer model (configuration "A") with batch normalization"""
172 | model = VGG(cfg['A'], batch_norm=True, **kwargs)
173 | return model
174 |
175 |
176 | def vgg13(**kwargs):
177 | """VGG 13-layer model (configuration "B")
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = VGG(cfg['B'], **kwargs)
182 | return model
183 |
184 |
185 | def vgg13_bn(**kwargs):
186 | """VGG 13-layer model (configuration "B") with batch normalization"""
187 | model = VGG(cfg['B'], batch_norm=True, **kwargs)
188 | return model
189 |
190 |
191 | def vgg16(**kwargs):
192 | """VGG 16-layer model (configuration "D")
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = VGG(cfg['D'], **kwargs)
197 | return model
198 |
199 |
200 | def vgg16_bn(**kwargs):
201 | """VGG 16-layer model (configuration "D") with batch normalization"""
202 | model = VGG(cfg['D'], batch_norm=True, **kwargs)
203 | return model
204 |
205 |
206 | def vgg19(**kwargs):
207 | """VGG 19-layer model (configuration "E")
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = VGG(cfg['E'], **kwargs)
212 | return model
213 |
214 |
215 | def vgg19_bn(**kwargs):
216 | """VGG 19-layer model (configuration 'E') with batch normalization"""
217 | model = VGG(cfg['E'], batch_norm=True, **kwargs)
218 | return model
219 |
220 |
221 | if __name__ == '__main__':
222 | import torch
223 |
224 | x = torch.randn(2, 3, 32, 32)
225 | net = vgg19_bn(num_classes=100)
226 | feats, logit = net(x, is_feat=True, preact=True)
227 |
228 | for f in feats:
229 | print(f.shape, f.min().item())
230 | print(logit.shape)
231 |
232 | for m in net.get_bn_before_relu():
233 | if isinstance(m, nn.BatchNorm2d):
234 | print('pass')
235 | else:
236 | print('warning')
237 |
--------------------------------------------------------------------------------
/cv/models/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | """
7 | Original Author: Wei Yang
8 | """
9 |
10 | __all__ = ['wrn']
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
15 | super(BasicBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.relu1 = nn.ReLU(inplace=True)
18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(out_planes)
21 | self.relu2 = nn.ReLU(inplace=True)
22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
23 | padding=1, bias=False)
24 | self.droprate = dropRate
25 | self.equalInOut = (in_planes == out_planes)
26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
27 | padding=0, bias=False) or None
28 |
29 | def forward(self, x):
30 | if not self.equalInOut:
31 | x = self.relu1(self.bn1(x))
32 | else:
33 | out = self.relu1(self.bn1(x))
34 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, training=self.training)
37 | out = self.conv2(out)
38 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
39 |
40 |
41 | class NetworkBlock(nn.Module):
42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
43 | super(NetworkBlock, self).__init__()
44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
45 |
46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
47 | layers = []
48 | for i in range(nb_layers):
49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
50 | return nn.Sequential(*layers)
51 |
52 | def forward(self, x):
53 | return self.layer(x)
54 |
55 |
56 | class WideResNet(nn.Module):
57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
58 | super(WideResNet, self).__init__()
59 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
60 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
61 | n = (depth - 4) // 6
62 | block = BasicBlock
63 | # 1st conv before any network block
64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
65 | padding=1, bias=False)
66 | # 1st block
67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
68 | # 2nd block
69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
70 | # 3rd block
71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
72 | # global average pooling and classifier
73 | self.bn1 = nn.BatchNorm2d(nChannels[3])
74 | self.relu = nn.ReLU(inplace=True)
75 | self.fc = nn.Linear(nChannels[3], num_classes)
76 | self.nChannels = nChannels[3]
77 |
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
81 | m.weight.data.normal_(0, math.sqrt(2. / n))
82 | elif isinstance(m, nn.BatchNorm2d):
83 | m.weight.data.fill_(1)
84 | m.bias.data.zero_()
85 | elif isinstance(m, nn.Linear):
86 | m.bias.data.zero_()
87 |
88 | def get_feat_modules(self):
89 | feat_m = nn.ModuleList([])
90 | feat_m.append(self.conv1)
91 | feat_m.append(self.block1)
92 | feat_m.append(self.block2)
93 | feat_m.append(self.block3)
94 | return feat_m
95 |
96 | def get_bn_before_relu(self):
97 | bn1 = self.block2.layer[0].bn1
98 | bn2 = self.block3.layer[0].bn1
99 | bn3 = self.bn1
100 |
101 | return [bn1, bn2, bn3]
102 |
103 | def forward(self, x, is_feat=False, preact=False):
104 | out = self.conv1(x)
105 | f0 = out
106 | out = self.block1(out)
107 | f1 = out
108 | out = self.block2(out)
109 | f2 = out
110 | out = self.block3(out)
111 | f3 = out
112 | out = self.relu(self.bn1(out))
113 | out = F.avg_pool2d(out, 8)
114 | out = out.view(-1, self.nChannels)
115 | f4 = out
116 | out = self.fc(out)
117 | if is_feat:
118 | if preact:
119 | f1 = self.block2.layer[0].bn1(f1)
120 | f2 = self.block3.layer[0].bn1(f2)
121 | f3 = self.bn1(f3)
122 | return [f0, f1, f2, f3, f4], out
123 | else:
124 | return out
125 |
126 |
127 | def wrn(**kwargs):
128 | """
129 | Constructs a Wide Residual Networks.
130 | """
131 | model = WideResNet(**kwargs)
132 | return model
133 |
134 |
135 | def wrn_40_2(**kwargs):
136 | model = WideResNet(depth=40, widen_factor=2, **kwargs)
137 | return model
138 |
139 |
140 | def wrn_40_1(**kwargs):
141 | model = WideResNet(depth=40, widen_factor=1, **kwargs)
142 | return model
143 |
144 |
145 | def wrn_16_2(**kwargs):
146 | model = WideResNet(depth=16, widen_factor=2, **kwargs)
147 | return model
148 |
149 |
150 | def wrn_16_1(**kwargs):
151 | model = WideResNet(depth=16, widen_factor=1, **kwargs)
152 | return model
153 |
154 |
155 | if __name__ == '__main__':
156 | import torch
157 |
158 | x = torch.randn(2, 3, 32, 32)
159 | net = wrn_40_2(num_classes=100)
160 | feats, logit = net(x, is_feat=True, preact=True)
161 |
162 | for f in feats:
163 | print(f.shape, f.min().item())
164 | print(logit.shape)
165 |
166 | for m in net.get_bn_before_relu():
167 | if isinstance(m, nn.BatchNorm2d):
168 | print('pass')
169 | else:
170 | print('warning')
171 |
--------------------------------------------------------------------------------
/cv/scripts/fetch_pretrained_teachers.sh:
--------------------------------------------------------------------------------
1 | # fetch pre-trained teacher models
2 |
3 | mkdir -p save/models/
4 |
5 | cd save/models
6 |
7 | mkdir -p wrn_40_2_vanilla
8 | wget http://shape2prog.csail.mit.edu/repo/wrn_40_2_vanilla/ckpt_epoch_240.pth
9 | mv ckpt_epoch_240.pth wrn_40_2_vanilla/
10 |
11 | mkdir -p resnet56_vanilla
12 | wget http://shape2prog.csail.mit.edu/repo/resnet56_vanilla/ckpt_epoch_240.pth
13 | mv ckpt_epoch_240.pth resnet56_vanilla/
14 |
15 | mkdir -p resnet110_vanilla
16 | wget http://shape2prog.csail.mit.edu/repo/resnet110_vanilla/ckpt_epoch_240.pth
17 | mv ckpt_epoch_240.pth resnet110_vanilla/
18 |
19 | mkdir -p resnet32x4_vanilla
20 | wget http://shape2prog.csail.mit.edu/repo/resnet32x4_vanilla/ckpt_epoch_240.pth
21 | mv ckpt_epoch_240.pth resnet32x4_vanilla/
22 |
23 | mkdir -p vgg13_vanilla
24 | wget http://shape2prog.csail.mit.edu/repo/vgg13_vanilla/ckpt_epoch_240.pth
25 | mv ckpt_epoch_240.pth vgg13_vanilla/
26 |
27 | mkdir -p ResNet50_vanilla
28 | wget http://shape2prog.csail.mit.edu/repo/ResNet50_vanilla/ckpt_epoch_240.pth
29 | mv ckpt_epoch_240.pth ResNet50_vanilla/
30 |
31 | cd ../..
--------------------------------------------------------------------------------
/cv/scripts/run_cifar_distill.sh:
--------------------------------------------------------------------------------
1 | # sample scripts for running the distillation code
2 | # use resnet32x4 and resnet8x4 as an example
3 |
4 | # kd
5 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill kd --model_s resnet8x4 -r 0.1 -a 0.9 -b 0 --trial 1
6 | # FitNet
7 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill hint --model_s resnet8x4 -a 0 -b 100 --trial 1
8 | # AT
9 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill attention --model_s resnet8x4 -a 0 -b 1000 --trial 1
10 | # SP
11 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill similarity --model_s resnet8x4 -a 0 -b 3000 --trial 1
12 | # CC
13 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill correlation --model_s resnet8x4 -a 0 -b 0.02 --trial 1
14 | # VID
15 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill vid --model_s resnet8x4 -a 0 -b 1 --trial 1
16 | # RKD
17 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill rkd --model_s resnet8x4 -a 0 -b 1 --trial 1
18 | # PKT
19 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill pkt --model_s resnet8x4 -a 0 -b 30000 --trial 1
20 | # AB
21 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill abound --model_s resnet8x4 -a 0 -b 1 --trial 1
22 | # FT
23 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill factor --model_s resnet8x4 -a 0 -b 200 --trial 1
24 | # FSP
25 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill fsp --model_s resnet8x4 -a 0 -b 50 --trial 1
26 | # NST
27 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill nst --model_s resnet8x4 -a 0 -b 50 --trial 1
28 | # CRD
29 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 0 -b 0.8 --trial 1
30 |
31 | # CRD+KD
32 | python train_student.py --path_t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill crd --model_s resnet8x4 -a 1 -b 0.8 --trial 1
--------------------------------------------------------------------------------
/cv/scripts/run_cifar_vanilla.sh:
--------------------------------------------------------------------------------
1 | # sample scripts for training vanilla teacher models
2 |
3 | python train_teacher.py --model wrn_40_2
4 |
5 | python train_teacher.py --model resnet56
6 |
7 | python train_teacher.py --model resnet110
8 |
9 | python train_teacher.py --model resnet32x4
10 |
11 | python train_teacher.py --model vgg13
12 |
13 | python train_teacher.py --model ResNet50
14 |
--------------------------------------------------------------------------------
/cv/train_student_meta.py:
--------------------------------------------------------------------------------
1 | """
2 | the general training framework
3 | """
4 |
5 | from __future__ import print_function
6 |
7 | import os
8 | import argparse
9 | import socket
10 | import time
11 |
12 | import tensorboard_logger as tb_logger
13 | import torch
14 | import torch.optim as optim
15 | import torch.nn as nn
16 | import torch.backends.cudnn as cudnn
17 |
18 |
19 | from models import model_dict
20 | from models.util import Embed, ConvReg, LinearEmbed
21 | from models.util import Connector, Translator, Paraphraser
22 |
23 | from dataset.meta_cifar100 import get_cifar100_dataloaders
24 |
25 | from helper.util import adjust_learning_rate
26 |
27 | from distiller_zoo.KD import CustomDistillKL
28 | from distiller_zoo.MSE import MSEWithTemperature
29 | from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
30 | from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
31 | from crd.criterion import CRDLoss
32 |
33 | from helper.meta_loops import train_distill as train, validate
34 | from helper.pretrain import init
35 |
36 |
37 | def parse_option():
38 |
39 | hostname = socket.gethostname()
40 |
41 | parser = argparse.ArgumentParser('argument for training')
42 |
43 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
44 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
45 | parser.add_argument('--save_freq', type=int, default=40, help='save frequency')
46 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
47 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
48 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
49 | parser.add_argument('--init_epochs', type=int, default=30, help='init training for two-stage methods')
50 |
51 | # optimization
52 | parser.add_argument('--lr', type=float, default=0.05, help='learning rate')
53 | parser.add_argument('--teacher_lr', type=float, default=0.05, help='teacher learning rate')
54 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
55 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
56 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
57 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
58 | parser.add_argument('--loss_type', type=str, choices=['mse', 'kl'])
59 |
60 | # held set
61 | parser.add_argument('--held_size', type=int, help="the size of held set")
62 | parser.add_argument('--num_held_samples', type=int, help="num of held samples used for one teacher update")
63 | parser.add_argument('--num_meta_batches', type=int, default=1, help="num of meta batches used for one teacher update")
64 | parser.add_argument('--assume_s_step_size', type=float, default=0.05, help="assume student grad update lr")
65 |
66 | # dataset
67 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset')
68 |
69 | # model
70 | parser.add_argument('--model_s', type=str, default='resnet8',
71 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
72 | 'resnet8x4', 'resnet32x4', 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19'])
73 | # TODO: Add more support
74 | # 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2', 'ResNet50', 'MobileNetV2', 'ShuffleV1', 'ShuffleV2'])
75 | parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot')
76 |
77 | # distillation
78 | parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation')
79 | parser.add_argument('--trial', type=str, default='1', help='trial id')
80 |
81 | parser.add_argument('-a', '--alpha', type=float, default=None, help='weight balance for KD')
82 |
83 | opt = parser.parse_args()
84 |
85 | # set different learning rate from these 4 models
86 | if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
87 | opt.lr = 0.01
88 |
89 | # set the path according to the environment
90 | if hostname.startswith('visiongpu'):
91 | opt.model_path = '/path/to/my/student_model'
92 | opt.tb_path = '/path/to/my/student_tensorboards'
93 | else:
94 | opt.model_path = './save/student_model'
95 | opt.tb_path = './save/student_tensorboards'
96 |
97 | iterations = opt.lr_decay_epochs.split(',')
98 | opt.lr_decay_epochs = list([])
99 | for it in iterations:
100 | opt.lr_decay_epochs.append(int(it))
101 |
102 | opt.model_t = get_teacher_name(opt.path_t)
103 |
104 | opt.model_name = 'S:{}_T:{}_{}_{}_a:{}_{}'.format(opt.model_s, opt.model_t, opt.dataset, 'mlkd',
105 | opt.alpha, opt.trial)
106 |
107 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
108 | if not os.path.isdir(opt.tb_folder):
109 | os.makedirs(opt.tb_folder)
110 |
111 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
112 | if not os.path.isdir(opt.save_folder):
113 | os.makedirs(opt.save_folder)
114 |
115 | return opt
116 |
117 |
118 | def get_teacher_name(model_path):
119 | """parse teacher name"""
120 | segments = model_path.split('/')[-2].split('_')
121 | if segments[0] != 'wrn':
122 | return segments[0]
123 | else:
124 | return segments[0] + '_' + segments[1] + '_' + segments[2]
125 |
126 |
127 | def load_teacher(model_path, n_cls):
128 | print('==> loading teacher model')
129 | model_t = get_teacher_name(model_path)
130 | model = model_dict[model_t](num_classes=n_cls)
131 | if torch.cuda.is_available():
132 | model.load_state_dict(torch.load(model_path)['model'])
133 | else:
134 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
135 | print('==> done')
136 | return model
137 |
138 |
139 | def main():
140 | best_acc = 0
141 |
142 | opt = parse_option()
143 |
144 | # tensorboard logger
145 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
146 |
147 | # dataloader
148 | if opt.dataset == 'cifar100':
149 | train_loader, held_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size,
150 | num_workers=opt.num_workers,
151 | held_size=opt.held_size,
152 | num_held_samples=opt.num_held_samples,
153 | )
154 | n_cls = 100
155 | else:
156 | raise NotImplementedError(opt.dataset)
157 |
158 | # model
159 | model_t = load_teacher(opt.path_t, n_cls)
160 | model_s = model_dict[opt.model_s](num_classes=n_cls)
161 |
162 | data = torch.randn(2, 3, 32, 32)
163 | model_t.eval()
164 | model_s.eval()
165 | feat_t, _ = model_t(data, is_feat=True)
166 | feat_s, _ = model_s(data, is_feat=True)
167 |
168 | module_list = nn.ModuleList([])
169 | module_list.append(model_s)
170 | trainable_list = nn.ModuleList([])
171 | trainable_list.append(model_s)
172 |
173 | criterion_cls = nn.CrossEntropyLoss()
174 |
175 | if opt.loss_type == 'mse':
176 | criterion_kd = MSEWithTemperature(T=opt.kd_T)
177 | elif opt.loss_type == 'kl':
178 | criterion_kd = CustomDistillKL(T=opt.kd_T)
179 | else:
180 | raise NotImplementedError()
181 |
182 | criterion_list = nn.ModuleList([])
183 | criterion_list.append(criterion_cls) # classification loss
184 | criterion_list.append(criterion_kd) # other knowledge distillation loss
185 |
186 | # optimizer
187 | s_optimizer = optim.SGD(model_s.parameters(),
188 | lr=opt.lr,
189 | momentum=opt.momentum,
190 | weight_decay=opt.weight_decay)
191 | t_optimizer = optim.SGD(model_t.parameters(),
192 | lr=opt.teacher_lr,
193 | momentum=opt.momentum,
194 | weight_decay=opt.weight_decay)
195 |
196 | # append teacher after optimizer to avoid weight_decay
197 | module_list.append(model_t)
198 |
199 | if torch.cuda.is_available():
200 | module_list.cuda()
201 | criterion_list.cuda()
202 | cudnn.benchmark = True
203 |
204 | # validate teacher accuracy
205 | teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
206 | print('teacher accuracy: ', teacher_acc)
207 |
208 | #Fixed for first 150 epochs
209 | for param_group in t_optimizer.param_groups:
210 | param_group['lr'] = 0
211 | # routine
212 | for epoch in range(1, opt.epochs + 1):
213 |
214 | if epoch == 150:
215 | for param_group in t_optimizer.param_groups:
216 | param_group['lr'] = opt.teacher_lr
217 | opt.assume_s_step_size *= opt.lr_decay_rate
218 | if epoch == 180:
219 | for param_group in t_optimizer.param_groups:
220 | param_group['lr'] *= opt.lr_decay_rate
221 | opt.assume_s_step_size *= opt.lr_decay_rate
222 | if epoch == 210:
223 | for param_group in t_optimizer.param_groups:
224 | param_group['lr'] *= opt.lr_decay_rate
225 | opt.assume_s_step_size *= opt.lr_decay_rate
226 | adjust_learning_rate(epoch, opt, s_optimizer)
227 | # adjust_learning_rate(epoch, opt, t_optimizer)
228 |
229 | print("==> training...")
230 |
231 | time1 = time.time()
232 | train_acc, train_loss = train(epoch, train_loader, held_loader, module_list, criterion_list, s_optimizer, t_optimizer, opt)
233 | time2 = time.time()
234 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
235 |
236 | logger.log_value('train_acc', train_acc, epoch)
237 | logger.log_value('train_loss', train_loss, epoch)
238 |
239 | test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt)
240 |
241 | logger.log_value('test_acc', test_acc, epoch)
242 | logger.log_value('test_loss', test_loss, epoch)
243 | logger.log_value('test_acc_top5', test_acc_top5, epoch)
244 |
245 | # save the best model
246 | if test_acc > best_acc:
247 | best_acc = test_acc
248 | state = {
249 | 'epoch': epoch,
250 | 'model': model_s.state_dict(),
251 | 'best_acc': best_acc,
252 | }
253 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s))
254 | print('saving the best model!')
255 | torch.save(state, save_file)
256 |
257 | # regular saving
258 | if epoch % opt.save_freq == 0:
259 | print('==> Saving...')
260 | state = {
261 | 'epoch': epoch,
262 | 'model': model_s.state_dict(),
263 | 'accuracy': test_acc,
264 | }
265 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
266 | torch.save(state, save_file)
267 |
268 | # This best accuracy is only for printing purpose.
269 | # The results reported in the paper/README is from the last epoch.
270 | print('best accuracy:', best_acc)
271 |
272 | # save model
273 | state = {
274 | 'opt': opt,
275 | 'model': model_s.state_dict(),
276 | }
277 | save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s))
278 | torch.save(state, save_file)
279 |
280 |
281 | if __name__ == '__main__':
282 | main()
283 |
--------------------------------------------------------------------------------
/cv/train_teacher.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import socket
6 | import time
7 |
8 | import tensorboard_logger as tb_logger
9 | import torch
10 | import torch.optim as optim
11 | import torch.nn as nn
12 | import torch.backends.cudnn as cudnn
13 |
14 | from models import model_dict
15 |
16 | from dataset.cifar100 import get_cifar100_dataloaders
17 |
18 | from helper.util import adjust_learning_rate, accuracy, AverageMeter
19 | from helper.loops import train_vanilla as train, validate
20 |
21 |
22 | def parse_option():
23 |
24 | hostname = socket.gethostname()
25 |
26 | parser = argparse.ArgumentParser('argument for training')
27 |
28 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
29 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
30 | parser.add_argument('--save_freq', type=int, default=40, help='save frequency')
31 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
32 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
33 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
34 |
35 | # optimization
36 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
37 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
39 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
40 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
41 |
42 | # dataset
43 | parser.add_argument('--model', type=str, default='resnet110',
44 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
45 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
46 | 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19',
47 | 'MobileNetV2', 'ShuffleV1', 'ShuffleV2', ])
48 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset')
49 |
50 | parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id')
51 |
52 | opt = parser.parse_args()
53 |
54 | # set different learning rate from these 4 models
55 | if opt.model in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
56 | opt.learning_rate = 0.01
57 |
58 | # set the path according to the environment
59 | if hostname.startswith('visiongpu'):
60 | opt.model_path = '/path/to/my/model'
61 | opt.tb_path = '/path/to/my/tensorboard'
62 | else:
63 | opt.model_path = './save/models'
64 | opt.tb_path = './save/tensorboard'
65 |
66 | iterations = opt.lr_decay_epochs.split(',')
67 | opt.lr_decay_epochs = list([])
68 | for it in iterations:
69 | opt.lr_decay_epochs.append(int(it))
70 |
71 | opt.model_name = '{}_{}_lr_{}_decay_{}_trial_{}'.format(opt.model, opt.dataset, opt.learning_rate,
72 | opt.weight_decay, opt.trial)
73 |
74 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
75 | if not os.path.isdir(opt.tb_folder):
76 | os.makedirs(opt.tb_folder)
77 |
78 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
79 | if not os.path.isdir(opt.save_folder):
80 | os.makedirs(opt.save_folder)
81 |
82 | return opt
83 |
84 |
85 | def main():
86 | best_acc = 0
87 |
88 | opt = parse_option()
89 |
90 | # dataloader
91 | if opt.dataset == 'cifar100':
92 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
93 | n_cls = 100
94 | else:
95 | raise NotImplementedError(opt.dataset)
96 |
97 | # model
98 | model = model_dict[opt.model](num_classes=n_cls)
99 |
100 | # optimizer
101 | optimizer = optim.SGD(model.parameters(),
102 | lr=opt.learning_rate,
103 | momentum=opt.momentum,
104 | weight_decay=opt.weight_decay)
105 |
106 | criterion = nn.CrossEntropyLoss()
107 |
108 | if torch.cuda.is_available():
109 | model = model.cuda()
110 | criterion = criterion.cuda()
111 | cudnn.benchmark = True
112 |
113 | # tensorboard
114 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
115 |
116 | # routine
117 | for epoch in range(1, opt.epochs + 1):
118 |
119 | adjust_learning_rate(epoch, opt, optimizer)
120 | print("==> training...")
121 |
122 | time1 = time.time()
123 | train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
124 | time2 = time.time()
125 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
126 |
127 | logger.log_value('train_acc', train_acc, epoch)
128 | logger.log_value('train_loss', train_loss, epoch)
129 |
130 | test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt)
131 |
132 | logger.log_value('test_acc', test_acc, epoch)
133 | logger.log_value('test_acc_top5', test_acc_top5, epoch)
134 | logger.log_value('test_loss', test_loss, epoch)
135 |
136 | # save the best model
137 | if test_acc > best_acc:
138 | best_acc = test_acc
139 | state = {
140 | 'epoch': epoch,
141 | 'model': model.state_dict(),
142 | 'best_acc': best_acc,
143 | 'optimizer': optimizer.state_dict(),
144 | }
145 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model))
146 | print('saving the best model!')
147 | torch.save(state, save_file)
148 |
149 | # regular saving
150 | if epoch % opt.save_freq == 0:
151 | print('==> Saving...')
152 | state = {
153 | 'epoch': epoch,
154 | 'model': model.state_dict(),
155 | 'accuracy': test_acc,
156 | 'optimizer': optimizer.state_dict(),
157 | }
158 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
159 | torch.save(state, save_file)
160 |
161 | # This best accuracy is only for printing purpose.
162 | # The results reported in the paper/README is from the last epoch.
163 | print('best accuracy:', best_acc)
164 |
165 | # save model
166 | state = {
167 | 'opt': opt,
168 | 'model': model.state_dict(),
169 | 'optimizer': optimizer.state_dict(),
170 | }
171 | save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
172 | torch.save(state, save_file)
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/nlp/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | .idea
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/nlp/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Canwen Xu and Wangchunshu Zhou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/nlp/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-PKD-for-BERT-Compression
2 |
3 | Pytorch implementation of the distillation method described in the following paper: [**Patient Knowledge Distillation for BERT Model Compression**](https://arxiv.org/abs/1908.09355). This repository heavily refers to [Pytorch-Transformers](https://github.com/huggingface/pytorch-transformers) by huggingface.
4 |
5 | ## Steps to run the code
6 | ### 1. download glue_data
7 | ```
8 | $ python download_glue_data.py
9 | ```
10 |
11 | ### 2. Fine-tune teacher BERT model
12 | By running following code, save fine-tuned model.
13 | ```
14 | python run_glue.py \
15 | --model_type bert \
16 | --model_name_or_path bert-base-uncased \
17 | --task_name $TASK_NAME \
18 | --do_train \
19 | --do_eval \
20 | --do_lower_case \
21 | --data_dir $GLUE_DIR/$TASK_NAME \
22 | --max_seq_length 128 \
23 | --per_gpu_eval_batch_size=8 \
24 | --per_gpu_train_batch_size=8 \
25 | --learning_rate 2e-5 \
26 | --num_train_epochs 3.0 \
27 | --output_dir /tmp/$TASK_NAME/
28 | ```
29 |
30 | ### 3. distill student model with teacher BERT
31 | $TEACHER_MODEL is your fine-tuned model folder.
32 | ```
33 | python run_glue_distillation.py \
34 | --model_type bert \
35 | --teacher_model $TEACHER_MODEL \
36 | --student_model bert-base-uncased \
37 | --task_name $TASK_NAME \
38 | --num_hidden_layers 6 \
39 | --alpha 0.5 \
40 | --beta 100.0 \
41 | --do_train \
42 | --do_eval \
43 | --do_lower_case \
44 | --data_dir $GLUE_DIR/$TASK_NAME \
45 | --max_seq_length 128 \
46 | --per_gpu_eval_batch_size=8 \
47 | --per_gpu_train_batch_size=8 \
48 | --learning_rate 2e-5 \
49 | --num_train_epochs 4.0 \
50 | --output_dir /tmp/$TASK_NAME/
51 | ```
52 |
53 | ## Experimental Results on dev set
54 | model | num_layers | SST-2 | MRPC-f1/acc | QQP-f1/acc | MNLI-m/mm | QNLI | RTE
55 | -- | -- | -- | -- | -- | -- | -- | --
56 | base | 12 | 0.9232 | 0.89/0.8358 | 0.8818/0.9121 | 0.8432/0.8479 | 0.916 | 0.6751
57 | finetuned | 6 | 0.9002 | 0.8741/0.8186 | 0.8672/0.901 | 0.8051/0.8033 | 0.8662 | 0.6101
58 | distill | 6 | 0.9071 | 0.8885/0.8382 | 0.8704/0.9016 | 0.8153/0.821 | 0.8642 | 0.6318
59 |
--------------------------------------------------------------------------------
/nlp/distillation_meta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from operator import itemgetter
5 | from functional_forward_bert import functional_bert_for_classification
6 | from transformers import PreTrainedModel
7 |
8 |
9 | class MetaPatientDistillation(nn.Module):
10 | def __init__(self, t_config, s_config):
11 | super(MetaPatientDistillation, self).__init__()
12 | self.t_config = t_config
13 | self.s_config = s_config
14 |
15 | def forward(self, t_model, s_model, order, input_ids, token_type_ids, attention_mask, labels, args, teacher_grad):
16 | if teacher_grad:
17 | t_outputs = t_model(input_ids=input_ids,
18 | token_type_ids=token_type_ids,
19 | attention_mask=attention_mask)
20 | else:
21 | with torch.no_grad():
22 | t_outputs = t_model(input_ids=input_ids,
23 | token_type_ids=token_type_ids,
24 | attention_mask=attention_mask)
25 |
26 | if isinstance(s_model, PreTrainedModel):
27 | s_outputs = s_model(input_ids=input_ids,
28 | token_type_ids=token_type_ids,
29 | attention_mask=attention_mask,
30 | labels=labels)
31 | else:
32 | s_outputs = functional_bert_for_classification(
33 | s_model,
34 | self.s_config,
35 | input_ids=input_ids,
36 | token_type_ids=token_type_ids,
37 | attention_mask=attention_mask,
38 | labels=labels
39 | )
40 |
41 | t_logits, t_features = t_outputs[0], t_outputs[-1]
42 | train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1]
43 |
44 | if args.logits_mse:
45 | soft_loss = F.mse_loss(t_logits, s_logits)
46 | else:
47 | T = args.temperature
48 | soft_targets = F.softmax(t_logits / T, dim=-1)
49 |
50 | probs = F.softmax(s_logits / T, dim=-1)
51 | soft_loss = F.mse_loss(soft_targets, probs) * T * T
52 |
53 | if args.beta == 0: # if beta=0, we don't even compute pkd_loss to save some time
54 | pkd_loss = torch.zeros_like(soft_loss)
55 | else:
56 | t_features = torch.cat(t_features[1:-1], dim=0).view(self.t_config.num_hidden_layers - 1,
57 | -1,
58 | args.max_seq_length,
59 | self.t_config.hidden_size)[:, :, 0]
60 |
61 | s_features = torch.cat(s_features[1:-1], dim=0).view(self.s_config.num_hidden_layers - 1,
62 | -1,
63 | args.max_seq_length,
64 | self.s_config.hidden_size)[:, :, 0]
65 |
66 | t_features = itemgetter(order)(t_features)
67 | t_features = t_features / t_features.norm(dim=-1).unsqueeze(-1)
68 | s_features = s_features / s_features.norm(dim=-1).unsqueeze(-1)
69 | pkd_loss = F.mse_loss(s_features, t_features, reduction="mean")
70 |
71 | return train_loss, soft_loss, pkd_loss
72 |
73 | def s_prime_forward(self, s_prime, input_ids, token_type_ids, attention_mask, labels, args):
74 |
75 | s_outputs = functional_bert_for_classification(
76 | s_prime,
77 | self.s_config,
78 | input_ids=input_ids,
79 | token_type_ids=token_type_ids,
80 | attention_mask=attention_mask,
81 | labels=labels,
82 | is_train=False
83 | )
84 |
85 | train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1]
86 |
87 | return train_loss
88 |
--------------------------------------------------------------------------------
/nlp/download_glue_data.py:
--------------------------------------------------------------------------------
1 | ''' Script for downloading all GLUE data.
2 | Note: for legal reasons, we are unable to host MRPC.
3 | You can either use the version hosted by the SentEval team, which is already tokenized,
4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
6 | You should then rename and place specific files in a folder (see below for an example).
7 | mkdir MRPC
8 | cabextract MSRParaphraseCorpus.msi -d MRPC
9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
11 | rm MRPC/_*
12 | rm MSRParaphraseCorpus.msi
13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
15 | '''
16 |
17 | import os
18 | import sys
19 | import shutil
20 | import argparse
21 | import tempfile
22 | import urllib.request
23 | import zipfile
24 |
25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
26 | TASK2PATH = {
27 | "CoLA": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
28 | "SST": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
29 | "MRPC": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
30 | "QQP": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
31 | "STS": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
32 | "MNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
33 | "SNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
34 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
35 | "RTE": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
36 | "WNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
37 | "diagnostic": 'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
38 |
39 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
40 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
41 |
42 |
43 | def download_and_extract(task, data_dir):
44 | print("Downloading and extracting %s..." % task)
45 | data_file = "%s.zip" % task
46 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
47 | with zipfile.ZipFile(data_file) as zip_ref:
48 | zip_ref.extractall(data_dir)
49 | os.remove(data_file)
50 | print("\tCompleted!")
51 |
52 |
53 | def format_mrpc(data_dir, path_to_data):
54 | print("Processing MRPC...")
55 | mrpc_dir = os.path.join(data_dir, "MRPC")
56 | if not os.path.isdir(mrpc_dir):
57 | os.mkdir(mrpc_dir)
58 | if path_to_data:
59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
61 | else:
62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
70 |
71 | dev_ids = []
72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
73 | for row in ids_fh:
74 | dev_ids.append(row.strip().split('\t'))
75 |
76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \
77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
79 | header = data_fh.readline()
80 | train_fh.write(header)
81 | dev_fh.write(header)
82 | for row in data_fh:
83 | label, id1, id2, s1, s2 = row.strip().split('\t')
84 | if [id1, id2] in dev_ids:
85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
86 | else:
87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
88 |
89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \
90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
91 | header = data_fh.readline()
92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
93 | for idx, row in enumerate(data_fh):
94 | label, id1, id2, s1, s2 = row.strip().split('\t')
95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
96 | print("\tCompleted!")
97 |
98 |
99 | def download_diagnostic(data_dir):
100 | print("Downloading and extracting diagnostic...")
101 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
102 | os.mkdir(os.path.join(data_dir, "diagnostic"))
103 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
104 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
105 | print("\tCompleted!")
106 | return
107 |
108 |
109 | def get_tasks(task_names):
110 | task_names = task_names.split(',')
111 | if "all" in task_names:
112 | tasks = TASKS
113 | else:
114 | tasks = []
115 | for task_name in task_names:
116 | assert task_name in TASKS, "Task %s not found!" % task_name
117 | tasks.append(task_name)
118 | return tasks
119 |
120 |
121 | def main(arguments):
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
124 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
125 | type=str, default='all')
126 | parser.add_argument('--path_to_mrpc',
127 | help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
128 | type=str, default='')
129 | args = parser.parse_args(arguments)
130 |
131 | if not os.path.isdir(args.data_dir):
132 | os.mkdir(args.data_dir)
133 | tasks = get_tasks(args.tasks)
134 |
135 | for task in tasks:
136 | if task == 'MRPC':
137 | format_mrpc(args.data_dir, args.path_to_mrpc)
138 | elif task == 'diagnostic':
139 | download_diagnostic(args.data_dir)
140 | else:
141 | download_and_extract(task, args.data_dir)
142 |
143 |
144 | if __name__ == '__main__':
145 | sys.exit(main(sys.argv[1:]))
146 |
--------------------------------------------------------------------------------
/nlp/mrpc_hyperparameters.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": [42, 2022, 12],
3 | "max_seq_length": [128],
4 | "per_gpu_train_batch_size": [4,8],
5 | "per_gpu_eval_batch_size": [8],
6 | "gradient_accumulation_steps": [1],
7 | "num_train_epochs": [4,6,8],
8 | "learning_rate": [2e-5, 1e-5],
9 | "evaluate_during_training": ["True"],
10 | "logging_steps": [50],
11 | "save_steps": [50],
12 | "model_type": ["bert"],
13 | "model_name_or_path": ["bert-base-uncased"],
14 | "task_name": ["MRPC"],
15 | "data_dir": ["glue_data/MRPC"],
16 | "warmup_steps": [100,150]
17 | }
--------------------------------------------------------------------------------
/nlp/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 |
--------------------------------------------------------------------------------
/nlp/run_hyperparameter_tuning.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 | # import sys
5 | # sys.path.append("/".join(os.getcwd().split("/")[:-1]) + "/")
6 |
7 | def convertTextToParams(input_file):
8 | param_dict_seq = []
9 | with open(input_file, "r") as f:
10 | header = None
11 | for cur_line in f:
12 | if header is None:
13 | header = list(cur_line.split())
14 | else:
15 | cur_row = list(cur_line.split())
16 | cur_dict = {}
17 | for i in range(len(header)):
18 | cur_dict[header[i]] = str(cur_row[i])
19 | param_dict_seq.append(cur_dict)
20 |
21 | return param_dict_seq
22 |
23 |
24 | param_dict_seq_global = []
25 | def computeAllParamCombinations(key_list, cur_key_idx, param_dict, cur_param_seq_dict):
26 | global param_dict_seq_global
27 | if cur_key_idx >= len(key_list):
28 | param_dict_seq_global.append(cur_param_seq_dict)
29 | return
30 |
31 | cur_key = key_list[cur_key_idx]
32 | for cur_val in param_dict[cur_key]:
33 | cur_param_seq_dict[cur_key] = str(cur_val)
34 | computeAllParamCombinations(key_list, cur_key_idx + 1, param_dict, cur_param_seq_dict.copy())
35 |
36 | def convertJsonToParams(input_file):
37 | global param_dict_seq_global
38 | with open(input_file, "r") as f:
39 | param_dict = json.load(f)
40 |
41 | param_keys = list(param_dict.keys())
42 | computeAllParamCombinations(param_keys, 0, param_dict, {})
43 | return param_dict_seq_global
44 |
45 | def convertDictToCmdArgs(input_dict):
46 | out_string = ""
47 | for key, val in input_dict.items():
48 | out_string += "--" + str(key) + " " + str(val) + " "
49 | return out_string
50 |
51 | def createFolderNameFromParamDict(input_dict):
52 | out_string = ""
53 | for key, val in input_dict.items():
54 | key_split = "".join([x[0] for x in key.lower().split("_")])
55 | val = val.replace("/", "_")
56 | out_string += key_split + "_" + str(val).lower() + "_"
57 |
58 | return out_string[:-1]
59 |
60 | if __name__ == "__main__":
61 | parser = argparse.ArgumentParser()
62 | parser.add_argument('--finetune_file', type=str, required=True,
63 | help='Finetuning file')
64 | parser.add_argument('--param_file', type=str, required=True,
65 | help='Hyperparameter file. Can be in ".json" or ".tsv" format')
66 | parser.add_argument('--root_output_dir', type=str, default="temp",
67 | help='Root output directory')
68 | args = parser.parse_known_args()[0]
69 |
70 | if args.param_file.endswith(".json"):
71 | param_dict_seq = convertJsonToParams(args.param_file)
72 | else:
73 | param_dict_seq = convertTextToParams(args.param_file)
74 |
75 | total_config_counts = len(param_dict_seq)
76 | cur_config_count = 0
77 | for cur_param_seq in param_dict_seq:
78 | cur_config_count += 1
79 | print("Running configuration {} out of {}:".format(cur_config_count, total_config_counts))
80 | print(cur_param_seq, "\n")
81 |
82 | cur_output_folder = os.path.join(args.root_output_dir, createFolderNameFromParamDict(cur_param_seq))
83 |
84 | # Finetuning:
85 | finetune_out_dir = os.path.join(cur_output_folder, 'finetuning')
86 |
87 | # Create a folder if output_dir doesn't exists: (needed for storing the logs file)
88 | if not os.path.exists(finetune_out_dir):
89 | os.makedirs(finetune_out_dir)
90 |
91 | finetune_log_file = os.path.join(finetune_out_dir, 'logs.txt')
92 |
93 | finetune_cmd = "CUDA_VISIBLE_DEVICES=0 python3 " + args.finetune_file + " " + convertDictToCmdArgs(cur_param_seq)\
94 | + "--do_train --do_eval --do_lower_case " + " --output_dir " + finetune_out_dir + " > " + finetune_log_file
95 |
96 | print(finetune_cmd, "\n")
97 | os.system(finetune_cmd)
98 |
--------------------------------------------------------------------------------