├── .gitignore
├── .idea
├── Knowledge-Distillation-PyTorch.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
└── classification
├── LICENSE
├── README.md
├── TRAINING.md
├── cifar.py
├── imagenet.py
├── images
├── KD_total_loss.png
├── Softmax_output.png
├── act_attention.png
├── at_losses.png
├── at_scheme.png
├── attention.png
├── cross_entropy_loss.png
├── fitnet_scheme.png
├── fitnet_stage1.png
├── fitnet_stage2.png
├── fsp_loss.png
├── fsp_matrix.png
├── fsp_stage1.png
├── fsp_stage2.png
├── mmd_loss.png
├── nst_gram.png
├── nst_kernels.png
├── nst_loss.png
└── nst_total_loss.png
├── loss
├── distillation.py
└── losses.py
├── models
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── cifar
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── alexnet.cpython-36.pyc
│ │ ├── densenet.cpython-36.pyc
│ │ ├── preresnet.cpython-36.pyc
│ │ ├── resnet.cpython-36.pyc
│ │ ├── resnet_vision.cpython-36.pyc
│ │ ├── resnext.cpython-36.pyc
│ │ ├── vgg.cpython-36.pyc
│ │ └── wrn.cpython-36.pyc
│ ├── alexnet.py
│ ├── densenet.py
│ ├── preresnet.py
│ ├── resnet.py
│ ├── resnet_vision.py
│ ├── resnet_yw.py
│ ├── resnext.py
│ ├── vgg.py
│ └── wrn.py
└── imagenet
│ ├── __init__.py
│ └── resnext.py
├── scripts
├── mimic.sh
├── train.sh
└── train_local.sh
├── source.sh
├── tensorboardX
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── crc32c.cpython-36.pyc
│ ├── embedding.cpython-36.pyc
│ ├── event_file_writer.cpython-36.pyc
│ ├── graph.cpython-36.pyc
│ ├── graph_onnx.cpython-36.pyc
│ ├── record_writer.cpython-36.pyc
│ ├── summary.cpython-36.pyc
│ ├── writer.cpython-36.pyc
│ └── x2num.cpython-36.pyc
├── crc32c.py
├── crc32c.pyc
├── embedding.py
├── embedding.pyc
├── event_file_writer.py
├── event_file_writer.pyc
├── graph.py
├── graph.pyc
├── graph_onnx.py
├── graph_onnx.pyc
├── record_writer.py
├── record_writer.pyc
├── src
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── attr_value_pb2.cpython-36.pyc
│ │ ├── event_pb2.cpython-36.pyc
│ │ ├── graph_pb2.cpython-36.pyc
│ │ ├── node_def_pb2.cpython-36.pyc
│ │ ├── plugin_pr_curve_pb2.cpython-36.pyc
│ │ ├── resource_handle_pb2.cpython-36.pyc
│ │ ├── summary_pb2.cpython-36.pyc
│ │ ├── tensor_pb2.cpython-36.pyc
│ │ ├── tensor_shape_pb2.cpython-36.pyc
│ │ ├── types_pb2.cpython-36.pyc
│ │ └── versions_pb2.cpython-36.pyc
│ ├── attr_value.proto
│ ├── attr_value_pb2.py
│ ├── attr_value_pb2.pyc
│ ├── event.proto
│ ├── event_pb2.py
│ ├── event_pb2.pyc
│ ├── graph.proto
│ ├── graph_pb2.py
│ ├── graph_pb2.pyc
│ ├── node_def.proto
│ ├── node_def_pb2.py
│ ├── node_def_pb2.pyc
│ ├── plugin_pr_curve.proto
│ ├── plugin_pr_curve_pb2.py
│ ├── plugin_pr_curve_pb2.pyc
│ ├── resource_handle.proto
│ ├── resource_handle_pb2.py
│ ├── resource_handle_pb2.pyc
│ ├── summary.proto
│ ├── summary_pb2.py
│ ├── summary_pb2.pyc
│ ├── tensor.proto
│ ├── tensor_pb2.py
│ ├── tensor_pb2.pyc
│ ├── tensor_shape.proto
│ ├── tensor_shape_pb2.py
│ ├── tensor_shape_pb2.pyc
│ ├── types.proto
│ ├── types_pb2.py
│ ├── types_pb2.pyc
│ ├── versions.proto
│ ├── versions_pb2.py
│ └── versions_pb2.pyc
├── summary.py
├── summary.pyc
├── writer.py
├── writer.pyc
├── x2num.py
└── x2num.pyc
├── utility.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── eval.cpython-36.pyc
├── logger.cpython-36.pyc
├── misc.cpython-36.pyc
└── visualize.cpython-36.pyc
├── eval.py
├── images
├── cifar.png
└── imagenet.png
├── logger.py
├── misc.py
├── progress
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── bar.cpython-36.pyc
│ └── helpers.cpython-36.pyc
├── bar.py
├── counter.py
├── helpers.py
└── spinner.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | .idea/
4 |
5 | classification/.idea
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # celery beat schedule file
98 | celerybeat-schedule
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
--------------------------------------------------------------------------------
/.idea/Knowledge-Distillation-PyTorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 wangjiong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge-Distillation-PyTorch
2 | Knowledge Distillation Algorithms implemented with PyTorch \
3 | Trying to complete various tasks...
4 | ## Directories
5 | - classification\
6 | Classification on CIFAR-10/100 and ImageNet with PyTorch. \
7 | Based on repository [bearpaw/pytorch-classification](https://github.com/bearpaw/pytorch-classification)
8 |
--------------------------------------------------------------------------------
/classification/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Wei Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/classification/TRAINING.md:
--------------------------------------------------------------------------------
1 |
2 | ## CIFAR-10
3 |
4 | #### AlexNet
5 | ```
6 | python cifar.py -a alexnet --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/alexnet
7 | ```
8 |
9 |
10 | #### VGG19 (BN)
11 | ```
12 | python cifar.py -a vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/vgg19_bn
13 | ```
14 |
15 | #### ResNet-110
16 | ```
17 | python cifar.py -a resnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-110
18 | ```
19 |
20 | #### ResNet-1202
21 | ```
22 | python cifar.py -a resnet --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-1202
23 | ```
24 |
25 | #### PreResNet-110
26 | ```
27 | python cifar.py -a preresnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/preresnet-110
28 | ```
29 |
30 | #### ResNeXt-29, 8x64d
31 | ```
32 | python cifar.py -a resnext --depth 29 --cardinality 8 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-8x64d
33 | ```
34 | #### ResNeXt-29, 16x64d
35 | ```
36 | python cifar.py -a resnext --depth 29 --cardinality 16 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-16x64d
37 | ```
38 |
39 | #### WRN-28-10-drop
40 | ```
41 | python cifar.py -a wrn --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar10/WRN-28-10-drop
42 | ```
43 |
44 | #### DenseNet-BC (L=100, k=12)
45 | **Note**:
46 | * DenseNet use weight decay value `1e-4`. Larger weight decay (`5e-4`) if harmful for the accuracy (95.46 vs. 94.05)
47 | * Official batch size is 64. But there is no big difference using batchsize 64 or 128 (95.46 vs 95.11).
48 |
49 | ```
50 | python cifar.py -a densenet --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-100-12
51 | ```
52 |
53 | #### DenseNet-BC (L=190, k=40)
54 | ```
55 | python cifar.py -a densenet --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-L190-k40
56 | ```
57 |
58 | ## CIFAR-100
59 |
60 | #### AlexNet
61 | ```
62 | python cifar.py -a alexnet --dataset cifar100 --checkpoint checkpoints/cifar100/alexnet --epochs 164 --schedule 81 122 --gamma 0.1
63 | ```
64 |
65 | #### VGG19 (BN)
66 | ```
67 | python cifar.py -a vgg19_bn --dataset cifar100 --checkpoint checkpoints/cifar100/vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1
68 | ```
69 |
70 | #### ResNet-110
71 | ```
72 | python cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110
73 | ```
74 |
75 | #### ResNet-1202
76 | ```
77 | python cifar.py -a resnet --dataset cifar100 --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-1202
78 | ```
79 |
80 | #### PreResNet-110
81 | ```
82 | python cifar.py -a preresnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/preresnet-110
83 | ```
84 |
85 | #### ResNeXt-29, 8x64d
86 | ```
87 | python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 8 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-8x64d --schedule 150 225 --wd 5e-4 --gamma 0.1
88 | ```
89 | #### ResNeXt-29, 16x64d
90 | ```
91 | python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 16 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-16x64d --schedule 150 225 --wd 5e-4 --gamma 0.1
92 | ```
93 |
94 | #### WRN-28-10-drop
95 | ```
96 | python cifar.py -a wrn --dataset cifar100 --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar100/WRN-28-10-drop
97 | ```
98 |
99 | #### DenseNet-BC (L=100, k=12)
100 | ```
101 | python cifar.py -a densenet --dataset cifar100 --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-100-12
102 | ```
103 |
104 | #### DenseNet-BC (L=190, k=40)
105 | ```
106 | python cifar.py -a densenet --dataset cifar100 --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-L190-k40
107 | ```
108 |
109 | ## ImageNet
110 | ### ResNet-18
111 | ```
112 | python imagenet.py -a resnet18 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet18
113 | ```
114 |
115 | ### ResNeXt-50 (32x4d)
116 | *(Originally trained on 8xGPUs)*
117 | ```
118 | python imagenet.py -a resnext50 --base-width 4 --cardinality 32 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnext50-32x4d
119 | ```
--------------------------------------------------------------------------------
/classification/images/KD_total_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/KD_total_loss.png
--------------------------------------------------------------------------------
/classification/images/Softmax_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/Softmax_output.png
--------------------------------------------------------------------------------
/classification/images/act_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/act_attention.png
--------------------------------------------------------------------------------
/classification/images/at_losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/at_losses.png
--------------------------------------------------------------------------------
/classification/images/at_scheme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/at_scheme.png
--------------------------------------------------------------------------------
/classification/images/attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/attention.png
--------------------------------------------------------------------------------
/classification/images/cross_entropy_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/cross_entropy_loss.png
--------------------------------------------------------------------------------
/classification/images/fitnet_scheme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_scheme.png
--------------------------------------------------------------------------------
/classification/images/fitnet_stage1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_stage1.png
--------------------------------------------------------------------------------
/classification/images/fitnet_stage2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_stage2.png
--------------------------------------------------------------------------------
/classification/images/fsp_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_loss.png
--------------------------------------------------------------------------------
/classification/images/fsp_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_matrix.png
--------------------------------------------------------------------------------
/classification/images/fsp_stage1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_stage1.png
--------------------------------------------------------------------------------
/classification/images/fsp_stage2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_stage2.png
--------------------------------------------------------------------------------
/classification/images/mmd_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/mmd_loss.png
--------------------------------------------------------------------------------
/classification/images/nst_gram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_gram.png
--------------------------------------------------------------------------------
/classification/images/nst_kernels.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_kernels.png
--------------------------------------------------------------------------------
/classification/images/nst_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_loss.png
--------------------------------------------------------------------------------
/classification/images/nst_total_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_total_loss.png
--------------------------------------------------------------------------------
/classification/loss/distillation.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # from utility import vis_feat
4 | # from loss import adversarial
5 | # from loss import discriminator
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | # logger = logging.getLogger('SR')
12 |
13 |
14 | def input2feature(x, normalize=-1, pow=1):
15 | return x
16 |
17 |
18 | def input2attention(x, normalize=-1, pow=2):
19 | """
20 | :param x:
21 | :param normalize: axis to normalize along with
22 | :param pow: order of power
23 | :return:
24 | """
25 | # follows Attention Transfer
26 | if normalize < 0:
27 | return x.pow(pow).mean(1).view(x.size(0), -1)
28 | else:
29 | return F.normalize(x.pow(pow).mean(1).view(x.size(0), -1))
30 |
31 |
32 | def gram(x, norm=-1):
33 | """
34 | :param x:
35 | :param norm: normalization axis
36 | :return:
37 | """
38 | n, c, h, w = x.size()
39 | x = x.view(n, c, -1)
40 | if norm > 0:
41 | x = F.normalize(x, dim=norm)
42 | return x.transpose(1, 2).bmm(x)
43 |
44 |
45 | def input2gram(x, normalize=2, power=1):
46 | return gram(x.pow(power), norm=normalize)
47 |
48 |
49 | def input2similarity(x, normalize=1, power=1):
50 | return gram(x.pow(power), norm=normalize)
51 |
52 |
53 | def L1Loss(x, y):
54 | return torch.abs(x - y).mean()
55 |
56 |
57 | def L2Loss(x, y):
58 | return (x - y).pow(2).mean()
59 |
60 |
61 | class Distillation(nn.Module):
62 | def __init__(self, supervision='attention', function='L2', normalize=False):
63 | super(Distillation, self).__init__()
64 | if supervision == 'feature':
65 | self.process = input2feature
66 | elif supervision == 'attention':
67 | self.process = input2attention
68 | elif supervision == 'gram':
69 | self.process = input2gram
70 | elif supervision =='similarity':
71 | self.process = input2similarity
72 | else:
73 | logger.info('Supervision type [{}] is not implemented'.format(args.distill_supervision))
74 |
75 | self.norm = normalize
76 |
77 | if function == 'L1':
78 | self.function = nn.L1Loss()
79 | # self.function = L1Loss
80 | elif function == 'L2':
81 | # self.function = nn.MSELoss()
82 | self.function = L2Loss
83 | else:
84 | logger.info('Choose L1, L2 loss, rather than {}'.format(function))
85 |
86 | def forward(self, student, teacher, assistant=None, writer=None, batch=None):
87 | """
88 | :param student: dict of student feature to distillate
89 | :param teacher: dict of teacher feature to distillate
90 | :param assistant: dict of assistant feature to distillate
91 | :param label: label in tensorboard
92 | :param writer: tensorboard writer
93 | :return:
94 | """
95 | losses = list()
96 | # k: feature position; fs: student feature
97 | for k, fs in student.items():
98 | if fs is not None:
99 | ft = teacher[k] if teacher is not None else None
100 | fa = assistant[k] if assistant is not None else None
101 | if ft is None:
102 | # logger.info('teacher feature {} is None'.format(k))
103 | continue
104 | # assistant provides feature residual
105 | fs = fs + fa if fa is not None else fs
106 | # logger.info('{} vs {}'.format(fs.shape, ft.shape))
107 | # map input features to supervision
108 | fs = self.process(fs, normalize=self.norm) if fs is not None else None
109 | ft = self.process(ft, normalize=self.norm) if ft is not None else None
110 | # fa = self.process(fa) if fa else None
111 | loss = self.function(fs, ft)
112 | losses.append(loss)
113 | if writer and batch % 10 == 0:
114 | name = 'Mimic Loss feat{}'.format(k)
115 | writer.add_scalar('Distill_loss_batch/{}'.format(name), loss, batch + 1)
116 | else:
117 | # logger.info('{}th feature of student: {}'.format(k, fs))
118 | continue
119 | return torch.sum(torch.stack(losses, dim=0)) if len(losses) > 0 else torch.tensor(0).float().cuda()
120 |
--------------------------------------------------------------------------------
/classification/loss/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
7 | # Total loss = kd_loss * alpha + CELoss * (1 - alpha)
8 | def kd_loss(s, t, args=None):
9 | """
10 | Knowledge Distillation Loss
11 | :param s: student outputs
12 | :param t: teacher outputs
13 | :param args:
14 | :return:
15 | """
16 | T = args.temperature if args and hasattr(args, 'temperature') else 4
17 | kd_loss = nn.KLDivLoss()(F.log_softmax(s / T, dim=1), F.softmax(t / T, dim=1)) * (T * T)
18 | return kd_loss
19 |
20 |
21 | # attention transfer loss
22 | def attention(x):
23 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
24 |
25 |
26 | def at_loss(student, teacher, args=None):
27 | """
28 | Attention Transfer Loss
29 | :param student: feature dictionary
30 | :param teacher: feature dictionary
31 | :param args:
32 | :return:
33 | """
34 | losses = []
35 | assert len(args.mimic_position) == len(args.mimic_lambda)
36 | for i, p in enumerate(args.mimic_position):
37 | s = student[p]
38 | t = teacher[p]
39 | n, c, h, w = s.shape
40 | _n, _c, _h, _w = t.shape
41 | assert h == _h and w == _w
42 | loss = (attention(s) - attention(t)).pow(2).mean()
43 | losses.append(args.mimic_lambda[i] * loss)
44 | return losses
45 |
46 |
47 | def fm_loss(student, teacher, args=None):
48 | """
49 | Feature Mimic Loss
50 | :param student: feature dictionary
51 | :param teacher: feature dictionary
52 | :param args:
53 | :return:
54 | """
55 | losses = []
56 | assert len(args.mimic_position) == len(args.mimic_lambda)
57 | for i, p in enumerate(args.mimic_position):
58 | s = student[p]
59 | t = teacher[p]
60 | loss = (s - t).pow(2).mean()
61 | losses.append(args.mimic_lambda[i] * loss)
62 | return losses
63 |
64 |
65 | def nst_loss(student, teacher, args=None):
66 | """
67 | Neural Selectivity Transfer Loss (paper)
68 | :param student: feature dictionary
69 | :param teacher: feature dictionary
70 | :param args:
71 | :return:
72 | """
73 | losses = []
74 | assert len(args.mimic_position) == len(args.mimic_lambda)
75 | for i, p in enumerate(args.mimic_position):
76 | s = student[p]
77 | t = teacher[p]
78 | s = F.normalize(s.view(s.shape[0], s.shape[1], -1), dim=1) # N, C, H * W
79 | gram_s = s.transpose(1, 2).bmm(s) # H * W, H * W
80 | assert gram_s.shape[1] == gram_s.shape[2], print("gram_student's shape: {}".format(gram_s.shape))
81 |
82 | t = F.normalize(t.view(t.shape[0], t.shape[1], -1), dim=1)
83 | gram_t = t.transpose(1, 2).bmm(t)
84 | assert gram_t.shape[1] == gram_t.shape[2], print("gram_teacher's shape: {}".format(gram_t.shape))
85 | loss = (gram_s - gram_t).pow(2).mean()
86 | losses.append(args.mimic_lambda[i] * loss)
87 | return losses
88 |
89 |
90 | def mmd_loss(student, teacher, args=None):
91 | """
92 | Maximum Mean Discrepancy Loss (NST Project)
93 | :param student: feature dictionary
94 | :param teacher: feature dictionary
95 | :param args:
96 | :return:
97 | """
98 | losses = []
99 | assert len(args.mimic_position) == len(args.mimic_lambda)
100 | for i, p in enumerate(args.mimic_position):
101 | s = student[p]
102 | t = teacher[p]
103 | s = F.normalize(s.view(s.shape[0], s.shape[1], -1), dim=1) # N, C, H * W
104 | mmd_s = s.bmm(s.transpose(2, 1)) # N, C, C
105 | mmd_s_mean = mmd_s.pow(2).mean()
106 | t = F.normalize(t.view(t.shape[0], t.shape[1], -1), dim=1)
107 | mmd_st = s.bmm(t.transpose(2, 1))
108 | mmd_st_mean = mmd_st.pow(2).mean()
109 | loss = mmd_s_mean - 2 * mmd_st_mean
110 | losses.append(args.mimic_lambda[i] * loss)
111 | return losses
112 |
113 |
114 | def similarity_loss(student, teacher, args=None):
115 | """
116 | Similarity Transfer Loss
117 | :param student: feature dictionary
118 | :param teacher: feature dictionary
119 | :param args:
120 | :return:
121 | """
122 | return
123 |
124 |
125 | def fsp_loss(student, teacher, args=None):
126 | """
127 | Flow of Solving Problem Loss
128 | :param student: feature dictionary
129 | :param teacher: feature dictionary
130 | :param args:
131 | :return:
132 | """
133 | losses = []
134 | assert len(args.mimic_position) == 2 * len(args.mimic_lambda)
135 | for i in range(len(args.mimic_theta)):
136 | s1 = student[args.mimic_position[2 * i]]
137 | s2 = student[args.mimic_position[2 * i + 1]]
138 | t1 = teacher[args.mimic_position[2 * i]]
139 | t2 = teacher[args.mimic_position[2 * i + 1]]
140 | s1 = s1.view(s1.shape[0], s1.shape[1], -1) # N, C1, H * W
141 | s2 = s2.view(s2.shape[0], s2.shape[1], -1) # N, C2, H * W
142 | fsp_s = s1.bmm(s2.transpose(1, 2)) / s1.shape[2] # N, C1, C2
143 |
144 | t1 = t1.view(t1.shape[0], t1.shape[1], -1)
145 | t2 = t2.view(t2.shape[0], t2.shape[1], -1)
146 | fsp_t = t1.bmm(t2.transpose(1, 2)) / t1.shape[2] # N, C1, C2
147 | loss = (fsp_s - fsp_t).pow(2).mean()
148 | losses.append(args.mimic_lambda[i] * loss)
149 | return losses
150 |
--------------------------------------------------------------------------------
/classification/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/__init__.py
--------------------------------------------------------------------------------
/classification/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100
4 | architectures:
5 |
6 | - `AlexNet`_
7 | - `VGG`_
8 | - `ResNet`_
9 | - `SqueezeNet`_
10 | - `DenseNet`_
11 |
12 | You can construct a model with random weights by calling its constructor:
13 |
14 | .. code:: python
15 |
16 | import torchvision.models as models
17 | resnet18 = models.resnet18()
18 | alexnet = models.alexnet()
19 | squeezenet = models.squeezenet1_0()
20 | densenet = models.densenet_161()
21 |
22 | We provide pre-trained models for the ResNet variants and AlexNet, using the
23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
24 | ``pretrained=True``:
25 |
26 | .. code:: python
27 |
28 | import torchvision.models as models
29 | resnet18 = models.resnet18(pretrained=True)
30 | alexnet = models.alexnet(pretrained=True)
31 |
32 | ImageNet 1-crop error rates (224x224)
33 |
34 | ======================== ============= =============
35 | Network Top-1 error Top-5 error
36 | ======================== ============= =============
37 | ResNet-18 30.24 10.92
38 | ResNet-34 26.70 8.58
39 | ResNet-50 23.85 7.13
40 | ResNet-101 22.63 6.44
41 | ResNet-152 21.69 5.94
42 | Inception v3 22.55 6.44
43 | AlexNet 43.45 20.91
44 | VGG-11 30.98 11.37
45 | VGG-13 30.07 10.75
46 | VGG-16 28.41 9.62
47 | VGG-19 27.62 9.12
48 | SqueezeNet 1.0 41.90 19.58
49 | SqueezeNet 1.1 41.81 19.38
50 | Densenet-121 25.35 7.83
51 | Densenet-169 24.00 7.00
52 | Densenet-201 22.80 6.43
53 | Densenet-161 22.35 6.20
54 | ======================== ============= =============
55 |
56 |
57 | .. _AlexNet: https://arxiv.org/abs/1404.5997
58 | .. _VGG: https://arxiv.org/abs/1409.1556
59 | .. _ResNet: https://arxiv.org/abs/1512.03385
60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360
61 | .. _DenseNet: https://arxiv.org/abs/1608.06993
62 | """
63 |
64 | from .alexnet import *
65 | from .vgg import *
66 | from .resnet import *
67 | from .resnet_vision import ResNet as resnet_vision
68 | from .resnext import *
69 | from .wrn import *
70 | from .preresnet import *
71 | from .densenet import *
72 |
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/alexnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/alexnet.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/densenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/densenet.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/preresnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/preresnet.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/resnet_vision.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnet_vision.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/resnext.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnext.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/vgg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/vgg.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/__pycache__/wrn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/wrn.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/models/cifar/alexnet.py:
--------------------------------------------------------------------------------
1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted.
2 | Without BN, the start learning rate should be 0.01
3 | (c) YANG, Wei
4 | '''
5 | import torch.nn as nn
6 |
7 |
8 | __all__ = ['alexnet']
9 |
10 |
11 | class AlexNet(nn.Module):
12 |
13 | def __init__(self, num_classes=10):
14 | super(AlexNet, self).__init__()
15 | self.features = nn.Sequential(
16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(kernel_size=2, stride=2),
19 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=2, stride=2),
22 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.MaxPool2d(kernel_size=2, stride=2),
29 | )
30 | self.classifier = nn.Linear(256, num_classes)
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = x.view(x.size(0), -1)
35 | x = self.classifier(x)
36 | return x
37 |
38 |
39 | def alexnet(**kwargs):
40 | r"""AlexNet model architecture from the
41 | `"One weird trick..." `_ paper.
42 | """
43 | model = AlexNet(**kwargs)
44 | return model
45 |
--------------------------------------------------------------------------------
/classification/models/cifar/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | __all__ = ['densenet']
8 |
9 |
10 | from torch.autograd import Variable
11 |
12 |
13 | class Bottleneck(nn.Module):
14 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
15 | super(Bottleneck, self).__init__()
16 | planes = expansion * growthRate
17 | self.bn1 = nn.BatchNorm2d(inplanes)
18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
21 | padding=1, bias=False)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.dropRate = dropRate
24 |
25 | def forward(self, x):
26 | out = self.bn1(x)
27 | out = self.relu(out)
28 | out = self.conv1(out)
29 | out = self.bn2(out)
30 | out = self.relu(out)
31 | out = self.conv2(out)
32 | if self.dropRate > 0:
33 | out = F.dropout(out, p=self.dropRate, training=self.training)
34 |
35 | out = torch.cat((x, out), 1)
36 |
37 | return out
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
42 | super(BasicBlock, self).__init__()
43 | planes = expansion * growthRate
44 | self.bn1 = nn.BatchNorm2d(inplanes)
45 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
46 | padding=1, bias=False)
47 | self.relu = nn.ReLU(inplace=True)
48 | self.dropRate = dropRate
49 |
50 | def forward(self, x):
51 | out = self.bn1(x)
52 | out = self.relu(out)
53 | out = self.conv1(out)
54 | if self.dropRate > 0:
55 | out = F.dropout(out, p=self.dropRate, training=self.training)
56 |
57 | out = torch.cat((x, out), 1)
58 |
59 | return out
60 |
61 |
62 | class Transition(nn.Module):
63 | def __init__(self, inplanes, outplanes):
64 | super(Transition, self).__init__()
65 | self.bn1 = nn.BatchNorm2d(inplanes)
66 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
67 | bias=False)
68 | self.relu = nn.ReLU(inplace=True)
69 |
70 | def forward(self, x):
71 | out = self.bn1(x)
72 | out = self.relu(out)
73 | out = self.conv1(out)
74 | out = F.avg_pool2d(out, 2)
75 | return out
76 |
77 |
78 | class DenseNet(nn.Module):
79 |
80 | def __init__(self, depth=22, block=Bottleneck, dropRate=0,
81 | num_classes=10, growthRate=12, compressionRate=2):
82 | super(DenseNet, self).__init__()
83 |
84 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
85 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
86 |
87 | self.growthRate = growthRate
88 | self.dropRate = dropRate
89 |
90 | # self.inplanes is a global variable used across multiple
91 | # helper functions
92 | self.inplanes = growthRate * 2
93 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
94 | bias=False)
95 | self.dense1 = self._make_denseblock(block, n)
96 | self.trans1 = self._make_transition(compressionRate)
97 | self.dense2 = self._make_denseblock(block, n)
98 | self.trans2 = self._make_transition(compressionRate)
99 | self.dense3 = self._make_denseblock(block, n)
100 | self.bn = nn.BatchNorm2d(self.inplanes)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.avgpool = nn.AvgPool2d(8)
103 | self.fc = nn.Linear(self.inplanes, num_classes)
104 |
105 | # Weight initialization
106 | for m in self.modules():
107 | if isinstance(m, nn.Conv2d):
108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109 | m.weight.data.normal_(0, math.sqrt(2. / n))
110 | elif isinstance(m, nn.BatchNorm2d):
111 | m.weight.data.fill_(1)
112 | m.bias.data.zero_()
113 |
114 | def _make_denseblock(self, block, blocks):
115 | layers = []
116 | for i in range(blocks):
117 | # Currently we fix the expansion ratio as the default value
118 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
119 | self.inplanes += self.growthRate
120 |
121 | return nn.Sequential(*layers)
122 |
123 | def _make_transition(self, compressionRate):
124 | inplanes = self.inplanes
125 | outplanes = int(math.floor(self.inplanes // compressionRate))
126 | self.inplanes = outplanes
127 | return Transition(inplanes, outplanes)
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 |
132 | x = self.trans1(self.dense1(x))
133 | x = self.trans2(self.dense2(x))
134 | x = self.dense3(x)
135 | x = self.bn(x)
136 | x = self.relu(x)
137 |
138 | x = self.avgpool(x)
139 | x = x.view(x.size(0), -1)
140 | x = self.fc(x)
141 |
142 | return x
143 |
144 |
145 | def densenet(**kwargs):
146 | """
147 | Constructs a ResNet model.
148 | """
149 | return DenseNet(**kwargs)
--------------------------------------------------------------------------------
/classification/models/cifar/preresnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['preresnet']
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | """
19 | 3x3 convolution with padding
20 | :param in_planes:
21 | :param out_planes:
22 | :param stride:
23 | :return:
24 | """
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.bn1 = nn.BatchNorm2d(inplanes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv1 = conv3x3(inplanes, planes, stride)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.bn1(x)
46 | out = self.relu(out)
47 | out = self.conv1(out)
48 |
49 | out = self.bn2(out)
50 | out = self.relu(out)
51 | out = self.conv2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.bn1 = nn.BatchNorm2d(inplanes)
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn3 = nn.BatchNorm2d(planes)
72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.bn1(x)
81 | out = self.relu(out)
82 | out = self.conv1(out)
83 |
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 | out = self.conv2(out)
87 |
88 | out = self.bn3(out)
89 | out = self.relu(out)
90 | out = self.conv3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 |
97 | return out
98 |
99 |
100 | class PreResNet(nn.Module):
101 |
102 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
103 | super(PreResNet, self).__init__()
104 | # Model type specifies number of layers for CIFAR-10 model
105 | if block_name.lower() == 'basicblock':
106 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
107 | n = (depth - 2) // 6
108 | block = BasicBlock
109 | elif block_name.lower() == 'bottleneck':
110 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
111 | n = (depth - 2) // 9
112 | block = Bottleneck
113 | else:
114 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
115 |
116 | self.inplanes = 16
117 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
118 | bias=False)
119 | self.layer1 = self._make_layer(block, 16, n)
120 | self.layer2 = self._make_layer(block, 32, n, stride=2)
121 | self.layer3 = self._make_layer(block, 64, n, stride=2)
122 | self.bn = nn.BatchNorm2d(64 * block.expansion)
123 | self.relu = nn.ReLU(inplace=True)
124 | self.avgpool = nn.AvgPool2d(8)
125 | self.fc = nn.Linear(64 * block.expansion, num_classes)
126 |
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, nn.BatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 |
135 | def _make_layer(self, block, planes, blocks, stride=1):
136 | downsample = None
137 | if stride != 1 or self.inplanes != planes * block.expansion:
138 | downsample = nn.Sequential(
139 | nn.Conv2d(self.inplanes, planes * block.expansion,
140 | kernel_size=1, stride=stride, bias=False),
141 | )
142 |
143 | layers = []
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for i in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | x = self.conv1(x)
153 |
154 | x = self.layer1(x) # 32x32
155 | x = self.layer2(x) # 16x16
156 | x = self.layer3(x) # 8x8
157 | x = self.bn(x)
158 | x = self.relu(x)
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 |
167 | def preresnet(**kwargs):
168 | """
169 | Constructs a ResNet model.
170 | """
171 | return PreResNet(**kwargs)
172 |
--------------------------------------------------------------------------------
/classification/models/cifar/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 | __all__ = ['resnet']
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | """
18 | 3x3 convolution with padding
19 | :param in_planes:
20 | :param out_planes:
21 | :param stride:
22 | :return:
23 | """
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | residual = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | residual = self.downsample(x)
53 |
54 | out += residual
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None):
64 | super(Bottleneck, self).__init__()
65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
66 | self.bn1 = nn.BatchNorm2d(planes)
67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
68 | padding=1, bias=False)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
71 | self.bn3 = nn.BatchNorm2d(planes * 4)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | residual = self.downsample(x)
92 |
93 | out += residual
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 |
101 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
102 | super(ResNet, self).__init__()
103 | # Model type specifies number of layers for CIFAR-10 model
104 | if block_name.lower() == 'basicblock':
105 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
106 | n = (depth - 2) // 6
107 | block = BasicBlock
108 | elif block_name.lower() == 'bottleneck':
109 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
110 | n = (depth - 2) // 9
111 | block = Bottleneck
112 | else:
113 | raise ValueError('block_name shoule be either Basicblock or Bottleneck')
114 |
115 | self.inplanes = 16
116 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
117 | self.bn1 = nn.BatchNorm2d(16)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.layer1 = self._make_layer(block, 16, n)
120 | self.layer2 = self._make_layer(block, 32, n, stride=2)
121 | self.layer3 = self._make_layer(block, 64, n, stride=2)
122 | self.avgpool = nn.AvgPool2d(8)
123 | self.fc = nn.Linear(64 * block.expansion, num_classes)
124 |
125 | for m in self.modules():
126 | if isinstance(m, nn.Conv2d):
127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
128 | m.weight.data.normal_(0, math.sqrt(2. / n))
129 | elif isinstance(m, nn.BatchNorm2d):
130 | m.weight.data.fill_(1)
131 | m.bias.data.zero_()
132 |
133 | def _make_layer(self, block, planes, blocks, stride=1):
134 | downsample = None
135 | if stride != 1 or self.inplanes != planes * block.expansion:
136 | downsample = nn.Sequential(
137 | nn.Conv2d(self.inplanes, planes * block.expansion,
138 | kernel_size=1, stride=stride, bias=False),
139 | nn.BatchNorm2d(planes * block.expansion),
140 | )
141 |
142 | layers = list()
143 | layers.append(block(self.inplanes, planes, stride, downsample))
144 | self.inplanes = planes * block.expansion
145 | for i in range(1, blocks):
146 | layers.append(block(self.inplanes, planes))
147 |
148 | return nn.Sequential(*layers)
149 |
150 | def forward(self, x):
151 | x = self.conv1(x)
152 | x = self.bn1(x)
153 | x = self.relu(x) # 32x32
154 |
155 | f1 = self.layer1(x) # 32x32
156 | f2 = self.layer2(f1) # 16x16
157 | f3 = self.layer3(f2) # 8x8
158 |
159 | x = self.avgpool(f3)
160 | x = x.view(x.size(0), -1)
161 | x = self.fc(x)
162 |
163 | return x, {1: f1, 2: f2, 3: f3}
164 |
165 |
166 | def resnet(**kwargs):
167 | """
168 | Constructs a ResNet model.
169 | """
170 | return ResNet(**kwargs)
171 |
--------------------------------------------------------------------------------
/classification/models/cifar/resnet_yw.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['resnet']
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | """
19 | 3x3 convolution with padding
20 | :param in_planes:
21 | :param out_planes:
22 | :param stride:
23 | :return:
24 | """
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=1, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(planes * 4)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 |
100 | class ResNet(nn.Module):
101 |
102 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
103 | super(ResNet, self).__init__()
104 | # Model type specifies number of layers for CIFAR-10 model
105 | if block_name.lower() == 'basicblock':
106 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
107 | n = (depth - 2) // 6
108 | block = BasicBlock
109 | elif block_name.lower() == 'bottleneck':
110 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
111 | n = (depth - 2) // 9
112 | block = Bottleneck
113 | else:
114 | raise ValueError('block_name shoule be either Basicblock or Bottleneck')
115 |
116 | self.inplanes = 16
117 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
118 | self.bn1 = nn.BatchNorm2d(16)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.layer1 = self._make_layer(block, 16, n)
121 | self.layer2 = self._make_layer(block, 32, n, stride=2)
122 | self.layer3 = self._make_layer(block, 64, n, stride=2)
123 | self.avgpool = nn.AvgPool2d(8)
124 | self.fc = nn.Linear(64 * block.expansion, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def _make_layer(self, block, planes, blocks, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes, planes * block.expansion,
139 | kernel_size=1, stride=stride, bias=False),
140 | nn.BatchNorm2d(planes * block.expansion),
141 | )
142 |
143 | layers = list()
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for i in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | x = self.conv1(x)
153 | x = self.bn1(x)
154 | x = self.relu(x) # 32x32
155 |
156 | x = self.layer1(x) # 32x32
157 | x = self.layer2(x) # 16x16
158 | x = self.layer3(x) # 8x8
159 |
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 | x = self.fc(x)
163 |
164 | return x
165 |
166 |
167 | def resnet(**kwargs):
168 | """
169 | Constructs a ResNet model.
170 | """
171 | return ResNet(**kwargs)
172 |
--------------------------------------------------------------------------------
/classification/models/cifar/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
8 | """
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn import init
12 |
13 | __all__ = ['resnext']
14 |
15 |
16 | class ResNeXtBottleneck(nn.Module):
17 | """
18 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
19 | """
20 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor):
21 | """ Constructor
22 | Args:
23 | in_channels: input channel dimensionality
24 | out_channels: output channel dimensionality
25 | stride: conv stride. Replaces pooling layer.
26 | cardinality: num of convolution groups.
27 | widen_factor: factor to reduce the input dimensionality before convolution.
28 | """
29 | super(ResNeXtBottleneck, self).__init__()
30 | D = cardinality * out_channels // widen_factor
31 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
32 | self.bn_reduce = nn.BatchNorm2d(D)
33 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
34 | self.bn = nn.BatchNorm2d(D)
35 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
36 | self.bn_expand = nn.BatchNorm2d(out_channels)
37 |
38 | self.shortcut = nn.Sequential()
39 | if in_channels != out_channels:
40 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False))
41 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels))
42 |
43 | def forward(self, x):
44 | bottleneck = self.conv_reduce.forward(x)
45 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
46 | bottleneck = self.conv_conv.forward(bottleneck)
47 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
48 | bottleneck = self.conv_expand.forward(bottleneck)
49 | bottleneck = self.bn_expand.forward(bottleneck)
50 | residual = self.shortcut.forward(x)
51 | return F.relu(residual + bottleneck, inplace=True)
52 |
53 |
54 | class CifarResNeXt(nn.Module):
55 | """
56 | ResNext optimized for the Cifar dataset, as specified in
57 | https://arxiv.org/pdf/1611.05431.pdf
58 | """
59 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0):
60 | """ Constructor
61 | Args:
62 | cardinality: number of convolution groups.
63 | depth: number of layers.
64 | num_classes: number of classes
65 | widen_factor: factor to adjust the channel dimensionality
66 | """
67 | super(CifarResNeXt, self).__init__()
68 | self.cardinality = cardinality
69 | self.depth = depth
70 | self.block_depth = (self.depth - 2) // 9
71 | self.widen_factor = widen_factor
72 | self.num_classes = num_classes
73 | self.output_size = 64
74 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
75 |
76 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
77 | self.bn_1 = nn.BatchNorm2d(64)
78 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
79 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
80 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
81 | self.classifier = nn.Linear(1024, num_classes)
82 | init.kaiming_normal(self.classifier.weight)
83 |
84 | for key in self.state_dict():
85 | if key.split('.')[-1] == 'weight':
86 | if 'conv' in key:
87 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
88 | if 'bn' in key:
89 | self.state_dict()[key][...] = 1
90 | elif key.split('.')[-1] == 'bias':
91 | self.state_dict()[key][...] = 0
92 |
93 | def block(self, name, in_channels, out_channels, pool_stride=2):
94 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
95 | Args:
96 | name: string name of the current block.
97 | in_channels: number of input channels
98 | out_channels: number of output channels
99 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
100 | Returns: a Module consisting of n sequential bottlenecks.
101 | """
102 | block = nn.Sequential()
103 | for bottleneck in range(self.block_depth):
104 | name_ = '%s_bottleneck_%d' % (name, bottleneck)
105 | if bottleneck == 0:
106 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality,
107 | self.widen_factor))
108 | else:
109 | block.add_module(name_,
110 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor))
111 | return block
112 |
113 | def forward(self, x):
114 | x = self.conv_1_3x3.forward(x)
115 | x = F.relu(self.bn_1.forward(x), inplace=True)
116 | x = self.stage_1.forward(x)
117 | x = self.stage_2.forward(x)
118 | x = self.stage_3.forward(x)
119 | x = F.avg_pool2d(x, 8, 1)
120 | x = x.view(-1, 1024)
121 | return self.classifier(x)
122 |
123 | def resnext(**kwargs):
124 | """Constructs a ResNeXt.
125 | """
126 | model = CifarResNeXt(**kwargs)
127 | return model
--------------------------------------------------------------------------------
/classification/models/cifar/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | import math
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, features, num_classes=1000):
26 | super(VGG, self).__init__()
27 | self.features = features
28 | self.classifier = nn.Linear(512, num_classes)
29 | self._initialize_weights()
30 |
31 | def forward(self, x):
32 | x = self.features(x)
33 | x = x.view(x.size(0), -1)
34 | x = self.classifier(x)
35 | return x
36 |
37 | def _initialize_weights(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, math.sqrt(2. / n))
42 | if m.bias is not None:
43 | m.bias.data.zero_()
44 | elif isinstance(m, nn.BatchNorm2d):
45 | m.weight.data.fill_(1)
46 | m.bias.data.zero_()
47 | elif isinstance(m, nn.Linear):
48 | n = m.weight.size(1)
49 | m.weight.data.normal_(0, 0.01)
50 | m.bias.data.zero_()
51 |
52 |
53 | def make_layers(cfg, batch_norm=False):
54 | layers = []
55 | in_channels = 3
56 | for v in cfg:
57 | if v == 'M':
58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
59 | else:
60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
61 | if batch_norm:
62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
63 | else:
64 | layers += [conv2d, nn.ReLU(inplace=True)]
65 | in_channels = v
66 | return nn.Sequential(*layers)
67 |
68 |
69 | cfg = {
70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
74 | }
75 |
76 |
77 | def vgg11(**kwargs):
78 | """VGG 11-layer model (configuration "A")
79 |
80 | Args:
81 | pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | """
83 | model = VGG(make_layers(cfg['A']), **kwargs)
84 | return model
85 |
86 |
87 | def vgg11_bn(**kwargs):
88 | """VGG 11-layer model (configuration "A") with batch normalization"""
89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
90 | return model
91 |
92 |
93 | def vgg13(**kwargs):
94 | """VGG 13-layer model (configuration "B")
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | """
99 | model = VGG(make_layers(cfg['B']), **kwargs)
100 | return model
101 |
102 |
103 | def vgg13_bn(**kwargs):
104 | """VGG 13-layer model (configuration "B") with batch normalization"""
105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
106 | return model
107 |
108 |
109 | def vgg16(**kwargs):
110 | """VGG 16-layer model (configuration "D")
111 |
112 | Args:
113 | pretrained (bool): If True, returns a model pre-trained on ImageNet
114 | """
115 | model = VGG(make_layers(cfg['D']), **kwargs)
116 | return model
117 |
118 |
119 | def vgg16_bn(**kwargs):
120 | """VGG 16-layer model (configuration "D") with batch normalization"""
121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
122 | return model
123 |
124 |
125 | def vgg19(**kwargs):
126 | """VGG 19-layer model (configuration "E")
127 |
128 | Args:
129 | pretrained (bool): If True, returns a model pre-trained on ImageNet
130 | """
131 | model = VGG(make_layers(cfg['E']), **kwargs)
132 | return model
133 |
134 |
135 | def vgg19_bn(**kwargs):
136 | """VGG 19-layer model (configuration 'E') with batch normalization"""
137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
138 | return model
139 |
--------------------------------------------------------------------------------
/classification/models/cifar/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['wrn']
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
11 | super(BasicBlock, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.relu1 = nn.ReLU(inplace=True)
14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 | self.bn2 = nn.BatchNorm2d(out_planes)
17 | self.relu2 = nn.ReLU(inplace=True)
18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
19 | padding=1, bias=False)
20 | self.droprate = dropRate
21 | self.equalInOut = (in_planes == out_planes)
22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
23 | padding=0, bias=False) or None
24 |
25 | def forward(self, x):
26 | if not self.equalInOut:
27 | x = self.relu1(self.bn1(x))
28 | else:
29 | out = self.relu1(self.bn1(x))
30 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
31 | if self.droprate > 0:
32 | out = F.dropout(out, p=self.droprate, training=self.training)
33 | out = self.conv2(out)
34 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
35 |
36 |
37 | class NetworkBlock(nn.Module):
38 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
39 | super(NetworkBlock, self).__init__()
40 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
41 |
42 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
43 | layers = []
44 | for i in range(nb_layers):
45 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
46 | return nn.Sequential(*layers)
47 |
48 | def forward(self, x):
49 | return self.layer(x)
50 |
51 |
52 | class WideResNet(nn.Module):
53 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
54 | super(WideResNet, self).__init__()
55 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
56 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
57 | n = (depth - 4) // 6
58 | block = BasicBlock
59 | # 1st conv before any network block
60 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
61 | padding=1, bias=False)
62 | # 1st block
63 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
64 | # 2nd block
65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
66 | # 3rd block
67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
68 | # global average pooling and classifier
69 | self.bn1 = nn.BatchNorm2d(nChannels[3])
70 | self.relu = nn.ReLU(inplace=True)
71 | self.fc = nn.Linear(nChannels[3], num_classes)
72 | self.nChannels = nChannels[3]
73 |
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77 | m.weight.data.normal_(0, math.sqrt(2. / n))
78 | elif isinstance(m, nn.BatchNorm2d):
79 | m.weight.data.fill_(1)
80 | m.bias.data.zero_()
81 | elif isinstance(m, nn.Linear):
82 | m.bias.data.zero_()
83 |
84 | def forward(self, x):
85 | out = self.conv1(x)
86 | out1 = self.block1(out)
87 | out2 = self.block2(out1)
88 | out3 = self.block3(out2)
89 | out = self.relu(self.bn1(out3))
90 | out = F.avg_pool2d(out, 8)
91 | out = out.view(-1, self.nChannels)
92 | return self.fc(out), {1: out1, 2: out2, 3: out3}
93 |
94 |
95 | def wrn(**kwargs):
96 | """
97 | Constructs a Wide Residual Networks.
98 | """
99 | model = WideResNet(**kwargs)
100 | return model
101 |
--------------------------------------------------------------------------------
/classification/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .resnext import *
4 |
--------------------------------------------------------------------------------
/classification/models/imagenet/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
8 | """
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 | import torch
14 |
15 | __all__ = ['resnext50', 'resnext101', 'resnext152']
16 |
17 | class Bottleneck(nn.Module):
18 | """
19 | RexNeXt bottleneck type C
20 | """
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None):
24 | """ Constructor
25 | Args:
26 | inplanes: input channel dimensionality
27 | planes: output channel dimensionality
28 | baseWidth: base width.
29 | cardinality: num of convolution groups.
30 | stride: conv stride. Replaces pooling layer.
31 | """
32 | super(Bottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64)))
35 | C = cardinality
36 |
37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(D*C)
39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
40 | self.bn2 = nn.BatchNorm2d(D*C)
41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes * 4)
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | self.downsample = downsample
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 | """
72 | ResNext optimized for the ImageNet dataset, as specified in
73 | https://arxiv.org/pdf/1611.05431.pdf
74 | """
75 | def __init__(self, baseWidth, cardinality, layers, num_classes):
76 | """ Constructor
77 | Args:
78 | baseWidth: baseWidth for ResNeXt.
79 | cardinality: number of convolution groups.
80 | layers: config of layers, e.g., [3, 4, 6, 3]
81 | num_classes: number of classes
82 | """
83 | super(ResNeXt, self).__init__()
84 | block = Bottleneck
85 |
86 | self.cardinality = cardinality
87 | self.baseWidth = baseWidth
88 | self.num_classes = num_classes
89 | self.inplanes = 64
90 | self.output_size = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0])
97 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
98 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
99 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
100 | self.avgpool = nn.AvgPool2d(7)
101 | self.fc = nn.Linear(512 * block.expansion, num_classes)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
113 | Args:
114 | block: block type used to construct ResNext
115 | planes: number of output channels (need to multiply by block.expansion)
116 | blocks: number of blocks to be built
117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
118 | Returns: a Module consisting of n sequential bottlenecks.
119 | """
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool1(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | x = self.fc(x)
148 |
149 | return x
150 |
151 |
152 | def resnext50(baseWidth, cardinality):
153 | """
154 | Construct ResNeXt-50.
155 | """
156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000)
157 | return model
158 |
159 |
160 | def resnext101(baseWidth, cardinality):
161 | """
162 | Construct ResNeXt-101.
163 | """
164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000)
165 | return model
166 |
167 |
168 | def resnext152(baseWidth, cardinality):
169 | """
170 | Construct ResNeXt-152.
171 | """
172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000)
173 | return model
174 |
--------------------------------------------------------------------------------
/classification/scripts/mimic.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet
3 | depth=20
4 | dataset=cifar100
5 | epochs=200
6 | lr=1e-1
7 | supervision=feature
8 |
9 | exp=${model}${depth}_${dataset}_${epochs}e_${lr}_resnet50_6908_l21_ce1e-1
10 | path=${dataset}/${model}/${supervision}
11 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp
12 | mkdir -p ${EXP_DIR}/${path}/${exp}/model
13 | now=$(date +"%Y%m%d_%H%M%S")
14 |
15 | part=Pixel2
16 | numGPU=1
17 | nodeGPU=1
18 | pg-run -rn ${path}/${exp} -c "
19 | srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \
20 | python -u cifar.py \
21 | --dataset ${dataset} \
22 | --data_dir /mnt/lustre21/wangjiong/Data_t1/datasets/cifar/${dataset} \
23 | --workers 4 \
24 | --reset \
25 | --epochs ${epochs} \
26 | --start-epoch 0 \
27 | --train-batch 256 \
28 | --test-batch 200 \
29 | \
30 | --lr ${lr} \
31 | --drop 0 \
32 | --schedule 100 150 \
33 | --gamma 0.1 \
34 | --weight-decay 1e-4 \
35 | --checkpoint ${EXP_DIR}/${path}/${exp}/model \
36 | \
37 | --arch ${model} \
38 | --depth ${depth} \
39 | \
40 | --teacher \
41 | --teacher_arch resnet \
42 | --teacher_depth 50 \
43 | --teacher_block_name BasicBlock \
44 | --teacher_checkpoint ${EXP_DIR}/../../model_zoo/${dataset}/resnet/resnet50_69_08.pth.tar \
45 | \
46 | --mimic_mean ${supervision} \
47 | --mimic_function L2 \
48 | --normalize 1 \
49 | \
50 | --manualSeed 345 \
51 | --gpu-id 0 \
52 | 2>&1 | tee -a ${EXP_DIR}/${path}/${exp}/train-${now}.log \
53 | & \
54 | "
55 |
--------------------------------------------------------------------------------
/classification/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet
3 | depth=110
4 | dataset=cifar10
5 | epochs=220
6 | lr=1e-1
7 |
8 | exp=${model}${depth}_${dataset}_${epochs}e_${lr}
9 | path=${dataset}/${model}
10 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp
11 | mkdir -p ${EXP_DIR}/${path}/${exp}/model
12 | now=$(date +"%Y%m%d_%H%M%S")
13 |
14 | part=Pixel
15 | numGPU=1
16 | nodeGPU=1
17 | pg-run -rn ${path}/${exp} -c "
18 | srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \
19 | python -u cifar.py \
20 | --dataset ${dataset} \
21 | --data_dir /mnt/lustre21/wangjiong/Data_t1/datasets/cifar/${dataset} \
22 | --workers 4 \
23 | --reset \
24 | --epochs ${epochs} \
25 | --start-epoch 0 \
26 | --train-batch 256 \
27 | --test-batch 200 \
28 | \
29 | --lr ${lr} \
30 | --drop 0 \
31 | --schedule 100 150 \
32 | --gamma 0.1 \
33 | --weight-decay 1e-4 \
34 | --checkpoint ${EXP_DIR}/${path}/${exp}/model \
35 | \
36 | --arch ${model} \
37 | --depth ${depth} \
38 | \
39 | --gpu-id 0 \
40 | 2>&1 | tee -a ${EXP_DIR}/${path}/${exp}/train-${now}.log \
41 | & \
42 | "
43 |
--------------------------------------------------------------------------------
/classification/scripts/train_local.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | model=resnet_vision
3 | depth=18
4 | dataset=cifar10
5 | epochs=50
6 | lr=1e-1
7 |
8 | exp=${model}${depth}_${dataset}_${epochs}e_${lr}
9 | path=${dataset}/${model}
10 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp
11 | # mkdir -p ${EXP_DIR}/${path}/${exp}/model
12 | mkdir -p ./model
13 | now=$(date +"%Y%m%d_%H%M%S")
14 |
15 | part=Pixel
16 | numGPU=1
17 | nodeGPU=1
18 | # pg-run -rn ${exp} -c "
19 | # srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \
20 | python -u cifar.py \
21 | --dataset ${dataset} \
22 | --data_dir ./data \
23 | --workers 4 \
24 | --epochs 5 \
25 | --start-epoch 0 \
26 | --train-batch 256 \
27 | --test-batch 200 \
28 | \
29 | --lr ${lr} \
30 | --drop 0 \
31 | --schedule 10 20 \
32 | --weight-decay 1e-4 \
33 | --checkpoint ./model \
34 | \
35 | --arch ${model} \
36 | --depth ${depth} \
37 | --cardinality 32 \
38 | --widen-factor 4 \
39 | \
40 | --manualSeed 345 \
41 | --gpu-id 0 \
42 | 2>&1 | tee -a ./model/train-${now}.log \
43 | & \
44 | # "
45 |
--------------------------------------------------------------------------------
/classification/source.sh:
--------------------------------------------------------------------------------
1 | export LD_LIBRARY_PATH=/mnt/lustre/wangjiong/Data_t1/anaconda3/lib:$LD_LIBRARY_PATH
2 | export LD_LIBRARY_PATH=/mnt/lustre/share/cuda-9.0/lib64:$LD_LIBRARY_PATH
3 | export LD_LIBRARY_PATH=/mnt/lustre/share/nccl_2.1.15-1+cuda9.0_x86_64/lib:$LD_LIBRARY_PATH
4 | export LD_LIBRARY_PATH=/mnt/lustre/share/intel64/lib/:$LD_LIBRARY_PATH
5 | export PATH=/mnt/lustre/share/cuda-9.0/bin:$PATH
6 | export PATH=/mnt/lustre/wangjiong/Data_t1/anaconda3/bin:$PATH
7 | # export PATH=/usr/bin:$PATH
8 | source activate /mnt/lustre/share/fundamental-support/envs/pytorch-0.4.0
9 | # source activate pytorch0.4.0-py3
10 | # source activate pytorch
11 |
--------------------------------------------------------------------------------
/classification/tensorboardX/__init__.py:
--------------------------------------------------------------------------------
1 | """A module for visualization with tensorboard
2 | """
3 |
4 | from .writer import FileWriter, SummaryWriter
5 | from .record_writer import RecordWriter
6 |
--------------------------------------------------------------------------------
/classification/tensorboardX/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__init__.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/crc32c.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/crc32c.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/embedding.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/embedding.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/event_file_writer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/event_file_writer.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/graph.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/graph.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/graph_onnx.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/graph_onnx.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/record_writer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/record_writer.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/summary.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/summary.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/writer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/writer.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/__pycache__/x2num.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/x2num.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/crc32c.py:
--------------------------------------------------------------------------------
1 | import array
2 |
3 |
4 | CRC_TABLE = (
5 | 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4,
6 | 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb,
7 | 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
8 | 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24,
9 | 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b,
10 | 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
11 | 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54,
12 | 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b,
13 | 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
14 | 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35,
15 | 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5,
16 | 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
17 | 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45,
18 | 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a,
19 | 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
20 | 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595,
21 | 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48,
22 | 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
23 | 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687,
24 | 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198,
25 | 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
26 | 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38,
27 | 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8,
28 | 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
29 | 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096,
30 | 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789,
31 | 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
32 | 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46,
33 | 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9,
34 | 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
35 | 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36,
36 | 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829,
37 | 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
38 | 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93,
39 | 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043,
40 | 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
41 | 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3,
42 | 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc,
43 | 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
44 | 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033,
45 | 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652,
46 | 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
47 | 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d,
48 | 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982,
49 | 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
50 | 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622,
51 | 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2,
52 | 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
53 | 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530,
54 | 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f,
55 | 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
56 | 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0,
57 | 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f,
58 | 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
59 | 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90,
60 | 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f,
61 | 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
62 | 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1,
63 | 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321,
64 | 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
65 | 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81,
66 | 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e,
67 | 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
68 | 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351,
69 | )
70 |
71 |
72 | CRC_INIT = 0
73 |
74 | _MASK = 0xFFFFFFFF
75 |
76 |
77 | def crc_update(crc, data):
78 | """Update CRC-32C checksum with data.
79 |
80 | Args:
81 | crc: 32-bit checksum to update as long.
82 | data: byte array, string or iterable over bytes.
83 |
84 | Returns:
85 | 32-bit updated CRC-32C as long.
86 | """
87 |
88 | if type(data) != array.array or data.itemsize != 1:
89 | buf = array.array("B", data)
90 | else:
91 | buf = data
92 |
93 | crc ^= _MASK
94 | for b in buf:
95 | table_index = (crc ^ b) & 0xff
96 | crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK
97 | return crc ^ _MASK
98 |
99 |
100 | def crc_finalize(crc):
101 | """Finalize CRC-32C checksum.
102 |
103 | This function should be called as last step of crc calculation.
104 |
105 | Args:
106 | crc: 32-bit checksum as long.
107 |
108 | Returns:
109 | finalized 32-bit checksum as long
110 | """
111 | return crc & _MASK
112 |
113 |
114 | def crc32c(data):
115 | """Compute CRC-32C checksum of the data.
116 |
117 | Args:
118 | data: byte array, string or iterable over bytes.
119 |
120 | Returns:
121 | 32-bit CRC-32C checksum of data as long.
122 | """
123 | return crc_finalize(crc_update(CRC_INIT, data))
124 |
--------------------------------------------------------------------------------
/classification/tensorboardX/crc32c.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/crc32c.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/embedding.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def make_tsv(metadata, save_path):
5 | metadata = [str(x) for x in metadata]
6 | with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
7 | for x in metadata:
8 | f.write(x + '\n')
9 |
10 |
11 | # https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
12 | def make_sprite(label_img, save_path):
13 | import math
14 | import torch
15 | import torchvision
16 | from .x2num import makenp
17 | # this ensures the sprite image has correct dimension as described in
18 | # https://www.tensorflow.org/get_started/embedding_viz
19 | nrow = int(math.ceil((label_img.size(0)) ** 0.5))
20 |
21 | label_img = torch.from_numpy(makenp(label_img)) # for other framework
22 | # augment images so that #images equals nrow*nrow
23 | label_img = torch.cat((label_img, torch.randn(nrow ** 2 - label_img.size(0), *label_img.size()[1:]) * 255), 0)
24 |
25 | torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)
26 |
27 |
28 | def append_pbtxt(metadata, label_img, save_path, global_step, tag):
29 | with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
30 | # step = os.path.split(save_path)[-1]
31 | f.write('embeddings {\n')
32 | f.write('tensor_name: "{}:{}"\n'.format(tag, global_step))
33 | f.write('tensor_path: "{}"\n'.format(os.path.join(global_step, 'tensors.tsv')))
34 | if metadata is not None:
35 | f.write('metadata_path: "{}"\n'.format(os.path.join(global_step, 'metadata.tsv')))
36 | if label_img is not None:
37 | f.write('sprite {\n')
38 | f.write('image_path: "{}"\n'.format(os.path.join(global_step, 'sprite.png')))
39 | f.write('single_image_dim: {}\n'.format(label_img.size(3)))
40 | f.write('single_image_dim: {}\n'.format(label_img.size(2)))
41 | f.write('}\n')
42 | f.write('}\n')
43 |
44 |
45 | def make_mat(matlist, save_path):
46 | with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
47 | for x in matlist:
48 | x = [str(i) for i in x]
49 | f.write('\t'.join(x) + '\n')
50 |
--------------------------------------------------------------------------------
/classification/tensorboardX/embedding.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/embedding.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/event_file_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Writes events to disk in a logdir."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import logging
22 | import os.path
23 | import socket
24 | import threading
25 | import time
26 |
27 | import six
28 |
29 | from .src import event_pb2
30 | from .record_writer import RecordWriter
31 |
32 |
33 | def directory_check(path):
34 | '''Initialize the directory for log files.'''
35 | # If the direcotry does not exist, create it!
36 | if not os.path.exists(path):
37 | os.makedirs(path)
38 |
39 |
40 | class EventsWriter(object):
41 | '''Writes `Event` protocol buffers to an event file.'''
42 |
43 | def __init__(self, file_prefix):
44 | '''
45 | Events files have a name of the form
46 | '/some/file/path/events.out.tfevents.[timestamp].[hostname]'
47 | '''
48 | self._file_prefix = file_prefix + ".out.tfevents." + str(time.time())[:10] + "." + socket.gethostname()
49 |
50 | # Open(Create) the log file with the particular form of name.
51 | logging.basicConfig(filename=self._file_prefix)
52 |
53 | self._num_outstanding_events = 0
54 |
55 | self._py_recordio_writer = RecordWriter(self._file_prefix)
56 |
57 | # Initialize an event instance.
58 | self._event = event_pb2.Event()
59 |
60 | self._event.wall_time = time.time()
61 |
62 | self.write_event(self._event)
63 |
64 | def write_event(self, event):
65 | '''Append "event" to the file.'''
66 |
67 | # Check if event is of type event_pb2.Event proto.
68 | if not isinstance(event, event_pb2.Event):
69 | raise TypeError("Expected an event_pb2.Event proto, "
70 | " but got %s" % type(event))
71 | return self._write_serialized_event(event.SerializeToString())
72 |
73 | def _write_serialized_event(self, event_str):
74 | self._num_outstanding_events += 1
75 | self._py_recordio_writer.write(event_str)
76 |
77 | def flush(self):
78 | '''Flushes the event file to disk.'''
79 | self._num_outstanding_events = 0
80 | return True
81 |
82 | def close(self):
83 | '''Call self.flush().'''
84 | return_value = self.flush()
85 | return return_value
86 |
87 |
88 | class EventFileWriter(object):
89 | """Writes `Event` protocol buffers to an event file.
90 | The `EventFileWriter` class creates an event file in the specified directory,
91 | and asynchronously writes Event protocol buffers to the file. The Event file
92 | is encoded using the tfrecord format, which is similar to RecordIO.
93 | @@__init__
94 | @@add_event
95 | @@flush
96 | @@close
97 | """
98 |
99 | def __init__(self, logdir, max_queue=10, flush_secs=120):
100 | """Creates a `EventFileWriter` and an event file to write to.
101 | On construction the summary writer creates a new event file in `logdir`.
102 | This event file will contain `Event` protocol buffers, which are written to
103 | disk via the add_event method.
104 | The other arguments to the constructor control the asynchronous writes to
105 | the event file:
106 | * `flush_secs`: How often, in seconds, to flush the added summaries
107 | and events to disk.
108 | * `max_queue`: Maximum number of summaries or events pending to be
109 | written to disk before one of the 'add' calls block.
110 | Args:
111 | logdir: A string. Directory where event file will be written.
112 | max_queue: Integer. Size of the queue for pending events and summaries.
113 | flush_secs: Number. How often, in seconds, to flush the
114 | pending events and summaries to disk.
115 | """
116 | self._logdir = logdir
117 | directory_check(self._logdir)
118 | self._event_queue = six.moves.queue.Queue(max_queue)
119 | self._ev_writer = EventsWriter(os.path.join(self._logdir, "events"))
120 | self._closed = False
121 | self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
122 | flush_secs)
123 |
124 | self._worker.start()
125 |
126 | def get_logdir(self):
127 | """Returns the directory where event file will be written."""
128 | return self._logdir
129 |
130 | def reopen(self):
131 | """Reopens the EventFileWriter.
132 | Can be called after `close()` to add more events in the same directory.
133 | The events will go into a new events file.
134 | Does nothing if the EventFileWriter was not closed.
135 | """
136 | if self._closed:
137 | self._closed = False
138 |
139 | def add_event(self, event):
140 | """Adds an event to the event file.
141 | Args:
142 | event: An `Event` protocol buffer.
143 | """
144 | if not self._closed:
145 | self._event_queue.put(event)
146 |
147 | def flush(self):
148 | """Flushes the event file to disk.
149 | Call this method to make sure that all pending events have been written to
150 | disk.
151 | """
152 | self._event_queue.join()
153 | self._ev_writer.flush()
154 |
155 | def close(self):
156 | """Flushes the event file to disk and close the file.
157 | Call this method when you do not need the summary writer anymore.
158 | """
159 | self.flush()
160 | self._ev_writer.close()
161 | self._closed = True
162 |
163 |
164 | class _EventLoggerThread(threading.Thread):
165 | """Thread that logs events."""
166 |
167 | def __init__(self, queue, ev_writer, flush_secs):
168 | """Creates an _EventLoggerThread.
169 | Args:
170 | queue: A Queue from which to dequeue events.
171 | ev_writer: An event writer. Used to log brain events for
172 | the visualizer.
173 | flush_secs: How often, in seconds, to flush the
174 | pending file to disk.
175 | """
176 | threading.Thread.__init__(self)
177 | self.daemon = True
178 | self._queue = queue
179 | self._ev_writer = ev_writer
180 | self._flush_secs = flush_secs
181 | # The first event will be flushed immediately.
182 | self._next_event_flush_time = 0
183 |
184 | def run(self):
185 | while True:
186 | event = self._queue.get()
187 | try:
188 | self._ev_writer.write_event(event)
189 | # Flush the event writer every so often.
190 | now = time.time()
191 | if now > self._next_event_flush_time:
192 | self._ev_writer.flush()
193 | # Do it again in two minutes.
194 | self._next_event_flush_time = now + self._flush_secs
195 | finally:
196 | self._queue.task_done()
197 |
--------------------------------------------------------------------------------
/classification/tensorboardX/event_file_writer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/event_file_writer.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/graph.py:
--------------------------------------------------------------------------------
1 | from .src.graph_pb2 import GraphDef
2 | from .src.node_def_pb2 import NodeDef
3 | from .src.versions_pb2 import VersionDef
4 | from .src.attr_value_pb2 import AttrValue
5 | from .src.tensor_shape_pb2 import TensorShapeProto
6 |
7 |
8 | def replace(name, scope):
9 | return '/'.join([scope[name], name])
10 |
11 |
12 | def parse(graph):
13 | scope = {}
14 | for n in graph.nodes():
15 | inputs = [i.uniqueName() for i in n.inputs()]
16 | for i in range(1, len(inputs)):
17 | scope[inputs[i]] = n.scopeName()
18 |
19 | uname = next(n.outputs()).uniqueName()
20 | assert n.scopeName() != '', '{} has empty scope name'.format(n)
21 | scope[uname] = n.scopeName()
22 | scope['0'] = 'input'
23 |
24 | nodes = []
25 | for n in graph.nodes():
26 | attrs = {k: n[k] for k in n.attributeNames()}
27 | attrs = str(attrs).replace("'", ' ') # singlequote will be escaped by tensorboard
28 | inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
29 | uname = next(n.outputs()).uniqueName()
30 | nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs})
31 |
32 | for n in graph.inputs():
33 | uname = n.uniqueName()
34 | if uname not in scope.keys():
35 | scope[uname] = 'unused'
36 | nodes.append({'name': replace(uname, scope), 'op': 'Parameter', 'inputs': [], 'attr': str(n.type())})
37 |
38 | return nodes
39 |
40 |
41 | def graph(model, args, verbose=False):
42 | import torch
43 | with torch.onnx.set_training(model, False):
44 | trace, _ = torch.jit.trace(model, args)
45 | torch.onnx._optimize_trace(trace, False)
46 | graph = trace.graph()
47 | if verbose:
48 | print(graph)
49 | list_of_nodes = parse(graph)
50 | nodes = []
51 | for node in list_of_nodes:
52 | nodes.append(
53 | NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
54 | attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8'))}))
55 | return GraphDef(node=nodes, versions=VersionDef(producer=22))
56 |
--------------------------------------------------------------------------------
/classification/tensorboardX/graph.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/graph.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/graph_onnx.py:
--------------------------------------------------------------------------------
1 | from .src.graph_pb2 import GraphDef
2 | from .src.node_def_pb2 import NodeDef
3 | from .src.versions_pb2 import VersionDef
4 | from .src.attr_value_pb2 import AttrValue
5 | from .src.tensor_shape_pb2 import TensorShapeProto
6 | # from .src.onnx_pb2 import ModelProto
7 |
8 |
9 | def gg(fname):
10 | import onnx # 0.2.1
11 | m = onnx.load(fname)
12 | nodes_proto = []
13 | nodes = []
14 | g = m.graph
15 | import itertools
16 | for node in itertools.chain(g.input, g.output):
17 | nodes_proto.append(node)
18 |
19 | for node in nodes_proto:
20 | shapeproto = TensorShapeProto(
21 | dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim])
22 | nodes.append(NodeDef(
23 | name=node.name,
24 | op='Variable',
25 | input=[],
26 | attr={
27 | 'dtype': AttrValue(type=node.type.tensor_type.elem_type),
28 | 'shape': AttrValue(shape=shapeproto),
29 | })
30 | )
31 |
32 | for node in g.node:
33 | attr = []
34 | for s in node.attribute:
35 | attr.append(' = '.join([str(f[1]) for f in s.ListFields()]))
36 | attr = ', '.join(attr).encode(encoding='utf_8')
37 |
38 | nodes.append(NodeDef(
39 | name=node.output[0],
40 | op=node.op_type,
41 | input=node.input,
42 | attr={'parameters': AttrValue(s=attr)},
43 | ))
44 | # two pass token replacement, appends opname to object id
45 | mapping = {}
46 | for node in nodes:
47 | mapping[node.name] = node.op + '_' + node.name
48 |
49 | nodes, mapping = updatenodes(nodes, mapping)
50 | mapping = smartGrouping(nodes, mapping)
51 | nodes, mapping = updatenodes(nodes, mapping)
52 |
53 | return GraphDef(node=nodes, versions=VersionDef(producer=22))
54 |
55 |
56 | def updatenodes(nodes, mapping):
57 | for node in nodes:
58 | newname = mapping[node.name]
59 | node.name = newname
60 | newinput = []
61 | for inputnode in list(node.input):
62 | newinput.append(mapping[inputnode])
63 | node.input.remove(inputnode)
64 | node.input.extend(newinput)
65 | newmap = {}
66 | for k, v in mapping.items():
67 | newmap[v] = v
68 | return nodes, newmap
69 |
70 |
71 | def findnode(nodes, name):
72 | """ input: node name
73 | returns: node object
74 | """
75 | for n in nodes:
76 | if n.name == name:
77 | return n
78 |
79 |
80 | def parser(s, nodes, node):
81 | print(s)
82 | if len(s) == 0:
83 | return
84 | if len(s) > 0:
85 | if s[0] == node.op:
86 | print(s[0], node.name, s[1], node.input)
87 | for n in node.input:
88 | print(n, s[1])
89 | parser(s[1], nodes, findnode(nodes, n))
90 | else:
91 | return False
92 |
93 |
94 | # TODO: use recursive parse
95 |
96 | def smartGrouping(nodes, mapping):
97 | # a Fully Conv is: (TODO: check var1.size(0)==var2.size(0))
98 | # GEMM <-- Variable (c1)
99 | # ^-- Transpose (c2) <-- Variable (c3)
100 |
101 | # a Conv with bias is: (TODO: check var1.size(0)==var2.size(0))
102 | # Add <-- Conv (c2) <-- Variable (c3)
103 | # ^-- Variable (c1)
104 | #
105 | # gemm = ('Gemm', ('Variable', ('Transpose', ('Variable'))))
106 |
107 | FCcounter = 1
108 | Convcounter = 1
109 | for node in nodes:
110 | if node.op == 'Gemm':
111 | c1 = c2 = c3 = False
112 | for name_in in node.input:
113 | n = findnode(nodes, name_in)
114 | if n.op == 'Variable':
115 | c1 = True
116 | c1name = n.name
117 | if n.op == 'Transpose':
118 | c2 = True
119 | c2name = n.name
120 | if len(n.input) == 1:
121 | nn = findnode(nodes, n.input[0])
122 | if nn.op == 'Variable':
123 | c3 = True
124 | c3name = nn.name
125 | # print(n.op, n.name, c1, c2, c3)
126 | if c1 and c2 and c3:
127 | # print(c1name, c2name, c3name)
128 | mapping[c1name] = 'FC{}/{}'.format(FCcounter, c1name)
129 | mapping[c2name] = 'FC{}/{}'.format(FCcounter, c2name)
130 | mapping[c3name] = 'FC{}/{}'.format(FCcounter, c3name)
131 | mapping[node.name] = 'FC{}/{}'.format(FCcounter, node.name)
132 | FCcounter += 1
133 | continue
134 | if node.op == 'Add':
135 | c1 = c2 = c3 = False
136 | for name_in in node.input:
137 | n = findnode(nodes, name_in)
138 | if n.op == 'Variable':
139 | c1 = True
140 | c1name = n.name
141 | if n.op == 'Conv':
142 | c2 = True
143 | c2name = n.name
144 | if len(n.input) >= 1:
145 | for nn_name in n.input:
146 | nn = findnode(nodes, nn_name)
147 | if nn.op == 'Variable':
148 | c3 = True
149 | c3name = nn.name
150 |
151 | if c1 and c2 and c3:
152 | # print(c1name, c2name, c3name)
153 | mapping[c1name] = 'Conv{}/{}'.format(Convcounter, c1name)
154 | mapping[c2name] = 'Conv{}/{}'.format(Convcounter, c2name)
155 | mapping[c3name] = 'Conv{}/{}'.format(Convcounter, c3name)
156 | mapping[node.name] = 'Conv{}/{}'.format(Convcounter, node.name)
157 | Convcounter += 1
158 | return mapping
159 |
--------------------------------------------------------------------------------
/classification/tensorboardX/graph_onnx.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/graph_onnx.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/record_writer.py:
--------------------------------------------------------------------------------
1 | """
2 | To write tf_record into file. Here we use it for tensorboard's event writting.
3 | The code was borrow from https://github.com/TeamHG-Memex/tensorboard_logger
4 | """
5 |
6 | import re
7 | import struct
8 |
9 | from .crc32c import crc32c
10 |
11 | _VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]')
12 | _VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+')
13 |
14 |
15 | class RecordWriter(object):
16 | def __init__(self, path, flush_secs=2):
17 | self._name_to_tf_name = {}
18 | self._tf_names = set()
19 | self.path = path
20 | self.flush_secs = flush_secs # TODO. flush every flush_secs, not every time.
21 | self._writer = None
22 | self._writer = open(path, 'wb')
23 |
24 | def write(self, event_str):
25 | w = self._writer.write
26 | header = struct.pack('Q', len(event_str))
27 | w(header)
28 | w(struct.pack('I', masked_crc32c(header)))
29 | w(event_str)
30 | w(struct.pack('I', masked_crc32c(event_str)))
31 | self._writer.flush()
32 |
33 |
34 | def masked_crc32c(data):
35 | x = u32(crc32c(data))
36 | return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8)
37 |
38 |
39 | def u32(x):
40 | return x & 0xffffffff
41 |
42 |
43 | def make_valid_tf_name(name):
44 | if not _VALID_OP_NAME_START.match(name):
45 | # Must make it valid somehow, but don't want to remove stuff
46 | name = '.' + name
47 | return '_'.join(_VALID_OP_NAME_PART.findall(name))
48 |
--------------------------------------------------------------------------------
/classification/tensorboardX/record_writer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/record_writer.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__init__.py
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__init__.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/attr_value_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/attr_value_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/event_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/event_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/graph_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/graph_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/node_def_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/node_def_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/plugin_pr_curve_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/plugin_pr_curve_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/resource_handle_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/resource_handle_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/summary_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/summary_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/tensor_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/tensor_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/tensor_shape_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/tensor_shape_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/types_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/types_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/__pycache__/versions_pb2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/versions_pb2.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/attr_value.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "AttrValueProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | import "tensorboardX/src/tensor.proto";
10 | import "tensorboardX/src/tensor_shape.proto";
11 | import "tensorboardX/src/types.proto";
12 |
13 | // Protocol buffer representing the value for an attr used to configure an Op.
14 | // Comment indicates the corresponding attr type. Only the field matching the
15 | // attr type may be filled.
16 | message AttrValue {
17 | // LINT.IfChange
18 | message ListValue {
19 | repeated bytes s = 2; // "list(string)"
20 | repeated int64 i = 3 [packed = true]; // "list(int)"
21 | repeated float f = 4 [packed = true]; // "list(float)"
22 | repeated bool b = 5 [packed = true]; // "list(bool)"
23 | repeated DataType type = 6 [packed = true]; // "list(type)"
24 | repeated TensorShapeProto shape = 7; // "list(shape)"
25 | repeated TensorProto tensor = 8; // "list(tensor)"
26 | repeated NameAttrList func = 9; // "list(attr)"
27 | }
28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
29 |
30 | oneof value {
31 | bytes s = 2; // "string"
32 | int64 i = 3; // "int"
33 | float f = 4; // "float"
34 | bool b = 5; // "bool"
35 | DataType type = 6; // "type"
36 | TensorShapeProto shape = 7; // "shape"
37 | TensorProto tensor = 8; // "tensor"
38 | ListValue list = 1; // any "list(...)"
39 |
40 | // "func" represents a function. func.name is a function's name or
41 | // a primitive op's name. func.attr.first is the name of an attr
42 | // defined for that function. func.attr.second is the value for
43 | // that attr in the instantiation.
44 | NameAttrList func = 10;
45 |
46 | // This is a placeholder only used in nodes defined inside a
47 | // function. It indicates the attr value will be supplied when
48 | // the function is instantiated. For example, let us suppose a
49 | // node "N" in function "FN". "N" has an attr "A" with value
50 | // placeholder = "foo". When FN is instantiated with attr "foo"
51 | // set to "bar", the instantiated node N's attr A will have been
52 | // given the value "bar".
53 | string placeholder = 9;
54 | }
55 | }
56 |
57 | // A list of attr names and their values. The whole list is attached
58 | // with a string name. E.g., MatMul[T=float].
59 | message NameAttrList {
60 | string name = 1;
61 | map attr = 2;
62 | }
63 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/attr_value_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/attr_value_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/event.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "EventProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.util";
8 |
9 | import "tensorboardX/src/summary.proto";
10 |
11 | // Protocol buffer representing an event that happened during
12 | // the execution of a Brain model.
13 | message Event {
14 | // Timestamp of the event.
15 | double wall_time = 1;
16 |
17 | // Global step of the event.
18 | int64 step = 2;
19 |
20 | oneof what {
21 | // An event file was started, with the specified version.
22 | // This is use to identify the contents of the record IO files
23 | // easily. Current version is "brain.Event:2". All versions
24 | // start with "brain.Event:".
25 | string file_version = 3;
26 | // An encoded version of a GraphDef.
27 | bytes graph_def = 4;
28 | // A summary was generated.
29 | Summary summary = 5;
30 | // The user output a log message. Not all messages are logged, only ones
31 | // generated via the Python tensorboard_logging module.
32 | LogMessage log_message = 6;
33 | // The state of the session which can be used for restarting after crashes.
34 | SessionLog session_log = 7;
35 | // The metadata returned by running a session.run() call.
36 | TaggedRunMetadata tagged_run_metadata = 8;
37 | // An encoded version of a MetaGraphDef.
38 | bytes meta_graph_def = 9;
39 | }
40 | }
41 |
42 | // Protocol buffer used for logging messages to the events file.
43 | message LogMessage {
44 | enum Level {
45 | UNKNOWN = 0;
46 | DEBUG = 10;
47 | INFO = 20;
48 | WARN = 30;
49 | ERROR = 40;
50 | FATAL = 50;
51 | }
52 | Level level = 1;
53 | string message = 2;
54 | }
55 |
56 | // Protocol buffer used for logging session state.
57 | message SessionLog {
58 | enum SessionStatus {
59 | STATUS_UNSPECIFIED = 0;
60 | START = 1;
61 | STOP = 2;
62 | CHECKPOINT = 3;
63 | }
64 |
65 | SessionStatus status = 1;
66 | // This checkpoint_path contains both the path and filename.
67 | string checkpoint_path = 2;
68 | string msg = 3;
69 | }
70 |
71 | // For logging the metadata output for a single session.run() call.
72 | message TaggedRunMetadata {
73 | // Tag name associated with this metadata.
74 | string tag = 1;
75 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy
76 | // deserialization.
77 | bytes run_metadata = 2;
78 | }
79 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/event_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/event_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/graph.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "GraphProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | import "tensorboardX/src/node_def.proto";
10 | //import "tensorflow/core/framework/function.proto";
11 | import "tensorboardX/src/versions.proto";
12 |
13 | // Represents the graph of operations
14 | message GraphDef {
15 | repeated NodeDef node = 1;
16 |
17 | // Compatibility versions of the graph. See core/public/version.h for version
18 | // history. The GraphDef version is distinct from the TensorFlow version, and
19 | // each release of TensorFlow will support a range of GraphDef versions.
20 | VersionDef versions = 4;
21 |
22 | // Deprecated single version field; use versions above instead. Since all
23 | // GraphDef changes before "versions" was introduced were forward
24 | // compatible, this field is entirely ignored.
25 | int32 version = 3 [deprecated = true];
26 |
27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
28 | //
29 | // "library" provides user-defined functions.
30 | //
31 | // Naming:
32 | // * library.function.name are in a flat namespace.
33 | // NOTE: We may need to change it to be hierarchical to support
34 | // different orgs. E.g.,
35 | // { "/google/nn", { ... }},
36 | // { "/google/vision", { ... }}
37 | // { "/org_foo/module_bar", { ... }}
38 | // map named_lib;
39 | // * If node[i].op is the name of one function in "library",
40 | // node[i] is deemed as a function call. Otherwise, node[i].op
41 | // must be a primitive operation supported by the runtime.
42 | //
43 | //
44 | // Function call semantics:
45 | //
46 | // * The callee may start execution as soon as some of its inputs
47 | // are ready. The caller may want to use Tuple() mechanism to
48 | // ensure all inputs are ready in the same time.
49 | //
50 | // * The consumer of return values may start executing as soon as
51 | // the return values the consumer depends on are ready. The
52 | // consumer may want to use Tuple() mechanism to ensure the
53 | // consumer does not start until all return values of the callee
54 | // function are ready.
55 | //FunctionDefLibrary library = 2;
56 | };
57 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/graph_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/graph.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from tensorboardX.src import node_def_pb2 as tensorboardX_dot_src_dot_node__def__pb2
17 | from tensorboardX.src import versions_pb2 as tensorboardX_dot_src_dot_versions__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='tensorboardX/src/graph.proto',
22 | package='tensorboard',
23 | syntax='proto3',
24 | serialized_pb=_b('\n\x1ctensorboardX/src/graph.proto\x12\x0btensorboard\x1a\x1ftensorboardX/src/node_def.proto\x1a\x1ftensorboardX/src/versions.proto\"n\n\x08GraphDef\x12\"\n\x04node\x18\x01 \x03(\x0b\x32\x14.tensorboard.NodeDef\x12)\n\x08versions\x18\x04 \x01(\x0b\x32\x17.tensorboard.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x42,\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01\xf8\x01\x01\x62\x06proto3')
25 | ,
26 | dependencies=[tensorboardX_dot_src_dot_node__def__pb2.DESCRIPTOR,tensorboardX_dot_src_dot_versions__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _GRAPHDEF = _descriptor.Descriptor(
32 | name='GraphDef',
33 | full_name='tensorboard.GraphDef',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='node', full_name='tensorboard.GraphDef.node', index=0,
40 | number=1, type=11, cpp_type=10, label=3,
41 | has_default_value=False, default_value=[],
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None),
45 | _descriptor.FieldDescriptor(
46 | name='versions', full_name='tensorboard.GraphDef.versions', index=1,
47 | number=4, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None),
52 | _descriptor.FieldDescriptor(
53 | name='version', full_name='tensorboard.GraphDef.version', index=2,
54 | number=3, type=5, cpp_type=1, label=1,
55 | has_default_value=False, default_value=0,
56 | message_type=None, enum_type=None, containing_type=None,
57 | is_extension=False, extension_scope=None,
58 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))),
59 | ],
60 | extensions=[
61 | ],
62 | nested_types=[],
63 | enum_types=[
64 | ],
65 | options=None,
66 | is_extendable=False,
67 | syntax='proto3',
68 | extension_ranges=[],
69 | oneofs=[
70 | ],
71 | serialized_start=111,
72 | serialized_end=221,
73 | )
74 |
75 | _GRAPHDEF.fields_by_name['node'].message_type = tensorboardX_dot_src_dot_node__def__pb2._NODEDEF
76 | _GRAPHDEF.fields_by_name['versions'].message_type = tensorboardX_dot_src_dot_versions__pb2._VERSIONDEF
77 | DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF
78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
79 |
80 | GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict(
81 | DESCRIPTOR = _GRAPHDEF,
82 | __module__ = 'tensorboardX.src.graph_pb2'
83 | # @@protoc_insertion_point(class_scope:tensorboard.GraphDef)
84 | ))
85 | _sym_db.RegisterMessage(GraphDef)
86 |
87 |
88 | DESCRIPTOR.has_options = True
89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013GraphProtosP\001\370\001\001'))
90 | _GRAPHDEF.fields_by_name['version'].has_options = True
91 | _GRAPHDEF.fields_by_name['version']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))
92 | # @@protoc_insertion_point(module_scope)
93 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/graph_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/graph_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/node_def.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "NodeProto";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | import "tensorboardX/src/attr_value.proto";
10 |
11 | message NodeDef {
12 | // The name given to this operator. Used for naming inputs,
13 | // logging, visualization, etc. Unique within a single GraphDef.
14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
15 | string name = 1;
16 |
17 | // The operation name. There may be custom parameters in attrs.
18 | // Op names starting with an underscore are reserved for internal use.
19 | string op = 2;
20 |
21 | // Each input is "node:src_output" with "node" being a string name and
22 | // "src_output" indicating which output tensor to use from "node". If
23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
24 | // may optionally be followed by control inputs that have the format
25 | // "^node".
26 | repeated string input = 3;
27 |
28 | // A (possibly partial) specification for the device on which this
29 | // node should be placed.
30 | // The expected syntax for this string is as follows:
31 | //
32 | // DEVICE_SPEC ::= PARTIAL_SPEC
33 | //
34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) *
35 | // CONSTRAINT ::= ("job:" JOB_NAME)
36 | // | ("replica:" [1-9][0-9]*)
37 | // | ("task:" [1-9][0-9]*)
38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
39 | //
40 | // Valid values for this string include:
41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification)
42 | // * "/job:worker/gpu:3" (partial specification)
43 | // * "" (no specification)
44 | //
45 | // If the constraints do not resolve to a single device (or if this
46 | // field is empty or not present), the runtime will attempt to
47 | // choose a device automatically.
48 | string device = 4;
49 |
50 | // Operation-specific graph-construction-time configuration.
51 | // Note that this should include all attrs defined in the
52 | // corresponding OpDef, including those with a value matching
53 | // the default -- this allows the default to change and makes
54 | // NodeDefs easier to interpret on their own. However, if
55 | // an attr with a default is not specified in this list, the
56 | // default will be used.
57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
58 | // one of the names from the corresponding OpDef's attr field).
59 | // The values must have a type matching the corresponding OpDef
60 | // attr's type field.
61 | // TODO(josh11b): Add some examples here showing best practices.
62 | map attr = 5;
63 | };
64 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/node_def_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/node_def.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from tensorboardX.src import attr_value_pb2 as tensorboardX_dot_src_dot_attr__value__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='tensorboardX/src/node_def.proto',
21 | package='tensorboard',
22 | syntax='proto3',
23 | serialized_pb=_b('\n\x1ftensorboardX/src/node_def.proto\x12\x0btensorboard\x1a!tensorboardX/src/attr_value.proto\"\xb5\x01\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12,\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1e.tensorboard.NodeDef.AttrEntry\x1a\x43\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorboard.AttrValue:\x02\x38\x01\x42*\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01\xf8\x01\x01\x62\x06proto3')
24 | ,
25 | dependencies=[tensorboardX_dot_src_dot_attr__value__pb2.DESCRIPTOR,])
26 |
27 |
28 |
29 |
30 | _NODEDEF_ATTRENTRY = _descriptor.Descriptor(
31 | name='AttrEntry',
32 | full_name='tensorboard.NodeDef.AttrEntry',
33 | filename=None,
34 | file=DESCRIPTOR,
35 | containing_type=None,
36 | fields=[
37 | _descriptor.FieldDescriptor(
38 | name='key', full_name='tensorboard.NodeDef.AttrEntry.key', index=0,
39 | number=1, type=9, cpp_type=9, label=1,
40 | has_default_value=False, default_value=_b("").decode('utf-8'),
41 | message_type=None, enum_type=None, containing_type=None,
42 | is_extension=False, extension_scope=None,
43 | options=None),
44 | _descriptor.FieldDescriptor(
45 | name='value', full_name='tensorboard.NodeDef.AttrEntry.value', index=1,
46 | number=2, type=11, cpp_type=10, label=1,
47 | has_default_value=False, default_value=None,
48 | message_type=None, enum_type=None, containing_type=None,
49 | is_extension=False, extension_scope=None,
50 | options=None),
51 | ],
52 | extensions=[
53 | ],
54 | nested_types=[],
55 | enum_types=[
56 | ],
57 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')),
58 | is_extendable=False,
59 | syntax='proto3',
60 | extension_ranges=[],
61 | oneofs=[
62 | ],
63 | serialized_start=198,
64 | serialized_end=265,
65 | )
66 |
67 | _NODEDEF = _descriptor.Descriptor(
68 | name='NodeDef',
69 | full_name='tensorboard.NodeDef',
70 | filename=None,
71 | file=DESCRIPTOR,
72 | containing_type=None,
73 | fields=[
74 | _descriptor.FieldDescriptor(
75 | name='name', full_name='tensorboard.NodeDef.name', index=0,
76 | number=1, type=9, cpp_type=9, label=1,
77 | has_default_value=False, default_value=_b("").decode('utf-8'),
78 | message_type=None, enum_type=None, containing_type=None,
79 | is_extension=False, extension_scope=None,
80 | options=None),
81 | _descriptor.FieldDescriptor(
82 | name='op', full_name='tensorboard.NodeDef.op', index=1,
83 | number=2, type=9, cpp_type=9, label=1,
84 | has_default_value=False, default_value=_b("").decode('utf-8'),
85 | message_type=None, enum_type=None, containing_type=None,
86 | is_extension=False, extension_scope=None,
87 | options=None),
88 | _descriptor.FieldDescriptor(
89 | name='input', full_name='tensorboard.NodeDef.input', index=2,
90 | number=3, type=9, cpp_type=9, label=3,
91 | has_default_value=False, default_value=[],
92 | message_type=None, enum_type=None, containing_type=None,
93 | is_extension=False, extension_scope=None,
94 | options=None),
95 | _descriptor.FieldDescriptor(
96 | name='device', full_name='tensorboard.NodeDef.device', index=3,
97 | number=4, type=9, cpp_type=9, label=1,
98 | has_default_value=False, default_value=_b("").decode('utf-8'),
99 | message_type=None, enum_type=None, containing_type=None,
100 | is_extension=False, extension_scope=None,
101 | options=None),
102 | _descriptor.FieldDescriptor(
103 | name='attr', full_name='tensorboard.NodeDef.attr', index=4,
104 | number=5, type=11, cpp_type=10, label=3,
105 | has_default_value=False, default_value=[],
106 | message_type=None, enum_type=None, containing_type=None,
107 | is_extension=False, extension_scope=None,
108 | options=None),
109 | ],
110 | extensions=[
111 | ],
112 | nested_types=[_NODEDEF_ATTRENTRY, ],
113 | enum_types=[
114 | ],
115 | options=None,
116 | is_extendable=False,
117 | syntax='proto3',
118 | extension_ranges=[],
119 | oneofs=[
120 | ],
121 | serialized_start=84,
122 | serialized_end=265,
123 | )
124 |
125 | _NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = tensorboardX_dot_src_dot_attr__value__pb2._ATTRVALUE
126 | _NODEDEF_ATTRENTRY.containing_type = _NODEDEF
127 | _NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY
128 | DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF
129 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
130 |
131 | NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), dict(
132 |
133 | AttrEntry = _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), dict(
134 | DESCRIPTOR = _NODEDEF_ATTRENTRY,
135 | __module__ = 'tensorboardX.src.node_def_pb2'
136 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef.AttrEntry)
137 | ))
138 | ,
139 | DESCRIPTOR = _NODEDEF,
140 | __module__ = 'tensorboardX.src.node_def_pb2'
141 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef)
142 | ))
143 | _sym_db.RegisterMessage(NodeDef)
144 | _sym_db.RegisterMessage(NodeDef.AttrEntry)
145 |
146 |
147 | DESCRIPTOR.has_options = True
148 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\tNodeProtoP\001\370\001\001'))
149 | _NODEDEF_ATTRENTRY.has_options = True
150 | _NODEDEF_ATTRENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001'))
151 | # @@protoc_insertion_point(module_scope)
152 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/node_def_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/node_def_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/plugin_pr_curve.proto:
--------------------------------------------------------------------------------
1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | syntax = "proto3";
17 |
18 | package tensorboard;
19 |
20 | message PrCurvePluginData {
21 | // Version `0` is the only supported version.
22 | int32 version = 1;
23 |
24 | uint32 num_thresholds = 2;
25 | }
26 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/plugin_pr_curve_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/plugin_pr_curve.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='tensorboardX/src/plugin_pr_curve.proto',
20 | package='tensorboard',
21 | syntax='proto3',
22 | serialized_pb=_b('\n&tensorboardX/src/plugin_pr_curve.proto\x12\x0btensorboard\"<\n\x11PrCurvePluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x16\n\x0enum_thresholds\x18\x02 \x01(\rb\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _PRCURVEPLUGINDATA = _descriptor.Descriptor(
29 | name='PrCurvePluginData',
30 | full_name='tensorboard.PrCurvePluginData',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='version', full_name='tensorboard.PrCurvePluginData.version', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='num_thresholds', full_name='tensorboard.PrCurvePluginData.num_thresholds', index=1,
44 | number=2, type=13, cpp_type=3, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | ],
50 | extensions=[
51 | ],
52 | nested_types=[],
53 | enum_types=[
54 | ],
55 | options=None,
56 | is_extendable=False,
57 | syntax='proto3',
58 | extension_ranges=[],
59 | oneofs=[
60 | ],
61 | serialized_start=55,
62 | serialized_end=115,
63 | )
64 |
65 | DESCRIPTOR.message_types_by_name['PrCurvePluginData'] = _PRCURVEPLUGINDATA
66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
67 |
68 | PrCurvePluginData = _reflection.GeneratedProtocolMessageType('PrCurvePluginData', (_message.Message,), dict(
69 | DESCRIPTOR = _PRCURVEPLUGINDATA,
70 | __module__ = 'tensorboardX.src.plugin_pr_curve_pb2'
71 | # @@protoc_insertion_point(class_scope:tensorboard.PrCurvePluginData)
72 | ))
73 | _sym_db.RegisterMessage(PrCurvePluginData)
74 |
75 |
76 | # @@protoc_insertion_point(module_scope)
77 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/plugin_pr_curve_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/plugin_pr_curve_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/resource_handle.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "ResourceHandle";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are
10 | // not valid across executions, but can be serialized back and forth from within
11 | // a single run.
12 | message ResourceHandleProto {
13 | // Unique name for the device containing the resource.
14 | string device = 1;
15 |
16 | // Container in which this resource is placed.
17 | string container = 2;
18 |
19 | // Unique name of this resource.
20 | string name = 3;
21 |
22 | // Hash code for the type of the resource. Is only valid in the same device
23 | // and in the same execution.
24 | uint64 hash_code = 4;
25 |
26 | // For debug-only, the name of the type pointed to by this handle, if
27 | // available.
28 | string maybe_type_name = 5;
29 | };
30 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/resource_handle_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/resource_handle.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='tensorboardX/src/resource_handle.proto',
20 | package='tensorboard',
21 | syntax='proto3',
22 | serialized_pb=_b('\n&tensorboardX/src/resource_handle.proto\x12\x0btensorboard\"r\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\tB/\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01\xf8\x01\x01\x62\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _RESOURCEHANDLEPROTO = _descriptor.Descriptor(
29 | name='ResourceHandleProto',
30 | full_name='tensorboard.ResourceHandleProto',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='device', full_name='tensorboard.ResourceHandleProto.device', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='container', full_name='tensorboard.ResourceHandleProto.container', index=1,
44 | number=2, type=9, cpp_type=9, label=1,
45 | has_default_value=False, default_value=_b("").decode('utf-8'),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='name', full_name='tensorboard.ResourceHandleProto.name', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='hash_code', full_name='tensorboard.ResourceHandleProto.hash_code', index=3,
58 | number=4, type=4, cpp_type=4, label=1,
59 | has_default_value=False, default_value=0,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='maybe_type_name', full_name='tensorboard.ResourceHandleProto.maybe_type_name', index=4,
65 | number=5, type=9, cpp_type=9, label=1,
66 | has_default_value=False, default_value=_b("").decode('utf-8'),
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | ],
71 | extensions=[
72 | ],
73 | nested_types=[],
74 | enum_types=[
75 | ],
76 | options=None,
77 | is_extendable=False,
78 | syntax='proto3',
79 | extension_ranges=[],
80 | oneofs=[
81 | ],
82 | serialized_start=55,
83 | serialized_end=169,
84 | )
85 |
86 | DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO
87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
88 |
89 | ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), dict(
90 | DESCRIPTOR = _RESOURCEHANDLEPROTO,
91 | __module__ = 'tensorboardX.src.resource_handle_pb2'
92 | # @@protoc_insertion_point(class_scope:tensorboard.ResourceHandleProto)
93 | ))
94 | _sym_db.RegisterMessage(ResourceHandleProto)
95 |
96 |
97 | DESCRIPTOR.has_options = True
98 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016ResourceHandleP\001\370\001\001'))
99 | # @@protoc_insertion_point(module_scope)
100 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/resource_handle_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/resource_handle_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/summary.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "SummaryProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | import "tensorboardX/src/tensor.proto";
10 |
11 | // Metadata associated with a series of Summary data
12 | message SummaryDescription {
13 | // Hint on how plugins should process the data in this series.
14 | // Supported values include "scalar", "histogram", "image", "audio"
15 | string type_hint = 1;
16 | }
17 |
18 | // Serialization format for histogram module in
19 | // core/lib/histogram/histogram.h
20 | message HistogramProto {
21 | double min = 1;
22 | double max = 2;
23 | double num = 3;
24 | double sum = 4;
25 | double sum_squares = 5;
26 |
27 | // Parallel arrays encoding the bucket boundaries and the bucket values.
28 | // bucket(i) is the count for the bucket i. The range for
29 | // a bucket is:
30 | // i == 0: -DBL_MAX .. bucket_limit(0)
31 | // i != 0: bucket_limit(i-1) .. bucket_limit(i)
32 | repeated double bucket_limit = 6 [packed = true];
33 | repeated double bucket = 7 [packed = true];
34 | };
35 |
36 | // A SummaryMetadata encapsulates information on which plugins are able to make
37 | // use of a certain summary value.
38 | message SummaryMetadata {
39 | message PluginData {
40 | // The name of the plugin this data pertains to.
41 | string plugin_name = 1;
42 |
43 | // The content to store for the plugin. The best practice is for this JSON
44 | // string to be the canonical JSON serialization of a protocol buffer
45 | // defined by the plugin. Converting that protobuf to and from JSON is the
46 | // responsibility of the plugin code, and is not enforced by
47 | // TensorFlow/TensorBoard.
48 | string content = 2;
49 | }
50 |
51 | // A list of plugin data. A single summary value instance may be used by more
52 | // than 1 plugin.
53 | repeated PluginData plugin_data = 1;
54 | };
55 |
56 | // A Summary is a set of named values to be displayed by the
57 | // visualizer.
58 | //
59 | // Summaries are produced regularly during training, as controlled by
60 | // the "summary_interval_secs" attribute of the training operation.
61 | // Summaries are also produced at the end of an evaluation.
62 | message Summary {
63 | message Image {
64 | // Dimensions of the image.
65 | int32 height = 1;
66 | int32 width = 2;
67 | // Valid colorspace values are
68 | // 1 - grayscale
69 | // 2 - grayscale + alpha
70 | // 3 - RGB
71 | // 4 - RGBA
72 | // 5 - DIGITAL_YUV
73 | // 6 - BGRA
74 | int32 colorspace = 3;
75 | // Image data in encoded format. All image formats supported by
76 | // image_codec::CoderUtil can be stored here.
77 | bytes encoded_image_string = 4;
78 | }
79 |
80 | message Audio {
81 | // Sample rate of the audio in Hz.
82 | float sample_rate = 1;
83 | // Number of channels of audio.
84 | int64 num_channels = 2;
85 | // Length of the audio in frames (samples per channel).
86 | int64 length_frames = 3;
87 | // Encoded audio data and its associated RFC 2045 content type (e.g.
88 | // "audio/wav").
89 | bytes encoded_audio_string = 4;
90 | string content_type = 5;
91 | }
92 |
93 | message Value {
94 | // Name of the node that output this summary; in general, the name of a
95 | // TensorSummary node. If the node in question has multiple outputs, then
96 | // a ":\d+" suffix will be appended, like "some_op:13".
97 | // Might not be set for legacy summaries (i.e. those not using the tensor
98 | // value field)
99 | string node_name = 7;
100 |
101 | // Tag name for the data. Will only be used by legacy summaries
102 | // (ie. those not using the tensor value field)
103 | // For legacy summaries, will be used as the title of the graph
104 | // in the visualizer.
105 | //
106 | // Tag is usually "op_name:value_name", where "op_name" itself can have
107 | // structure to indicate grouping.
108 | string tag = 1;
109 | SummaryMetadata metadata = 9;
110 | // Value associated with the tag.
111 | oneof value {
112 | float simple_value = 2;
113 | bytes obsolete_old_style_histogram = 3;
114 | Image image = 4;
115 | HistogramProto histo = 5;
116 | Audio audio = 6;
117 | TensorProto tensor = 8;
118 | }
119 | }
120 |
121 | // Set of values for the summary.
122 | repeated Value value = 1;
123 | }
124 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/summary_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/summary_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/tensor.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "TensorProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | import "tensorboardX/src/resource_handle.proto";
10 | import "tensorboardX/src/tensor_shape.proto";
11 | import "tensorboardX/src/types.proto";
12 |
13 | // Protocol buffer representing a tensor.
14 | message TensorProto {
15 | DataType dtype = 1;
16 |
17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues.
18 | TensorShapeProto tensor_shape = 2;
19 |
20 | // Only one of the representations below is set, one of "tensor_contents" and
21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot
22 | // contain repeated fields it would require another extra set of messages.
23 |
24 | // Version number.
25 | //
26 | // In version 0, if the "repeated xxx" representations contain only one
27 | // element, that element is repeated to fill the shape. This makes it easy
28 | // to represent a constant Tensor with a single value.
29 | int32 version_number = 3;
30 |
31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or
32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
33 | // can be used for all tensor types. The purpose of this representation is to
34 | // reduce serialization overhead during RPC call by avoiding serialization of
35 | // many repeated small items.
36 | bytes tensor_content = 4;
37 |
38 | // Type specific representations that make it easy to create tensor protos in
39 | // all languages. Only the representation corresponding to "dtype" can
40 | // be set. The values hold the flattened representation of the tensor in
41 | // row major order.
42 |
43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some
44 | // pointless zero padding for each value here.
45 | repeated int32 half_val = 13 [packed = true];
46 |
47 | // DT_FLOAT.
48 | repeated float float_val = 5 [packed = true];
49 |
50 | // DT_DOUBLE.
51 | repeated double double_val = 6 [packed = true];
52 |
53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
54 | repeated int32 int_val = 7 [packed = true];
55 |
56 | // DT_STRING
57 | repeated bytes string_val = 8;
58 |
59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
60 | // and imaginary parts of i-th single precision complex.
61 | repeated float scomplex_val = 9 [packed = true];
62 |
63 | // DT_INT64
64 | repeated int64 int64_val = 10 [packed = true];
65 |
66 | // DT_BOOL
67 | repeated bool bool_val = 11 [packed = true];
68 |
69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
70 | // and imaginary parts of i-th double precision complex.
71 | repeated double dcomplex_val = 12 [packed = true];
72 |
73 | // DT_RESOURCE
74 | repeated ResourceHandleProto resource_handle_val = 14;
75 | };
76 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/tensor_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/tensor_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/tensor_shape.proto:
--------------------------------------------------------------------------------
1 | // Protocol buffer representing the shape of tensors.
2 |
3 | syntax = "proto3";
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "TensorShapeProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | package tensorboard;
10 |
11 | // Dimensions of a tensor.
12 | message TensorShapeProto {
13 | // One dimension of the tensor.
14 | message Dim {
15 | // Size of the tensor in that dimension.
16 | // This value must be >= -1, but values of -1 are reserved for "unknown"
17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers
18 | // that work with TensorShapeProto may fail at runtime when deserializing
19 | // a TensorShapeProto containing a dim value of -1.
20 | int64 size = 1;
21 |
22 | // Optional name of the tensor dimension.
23 | string name = 2;
24 | };
25 |
26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40}
27 | // for a 30 x 40 2D tensor. If an entry has size -1, this
28 | // corresponds to a dimension of unknown size. The names are
29 | // optional.
30 | //
31 | // The order of entries in "dim" matters: It indicates the layout of the
32 | // values in the tensor in-memory representation.
33 | //
34 | // The first entry in "dim" is the outermost dimension used to layout the
35 | // values, the last entry is the innermost dimension. This matches the
36 | // in-memory layout of RowMajor Eigen tensors.
37 | //
38 | // If "dim.size()" > 0, "unknown_rank" must be false.
39 | repeated Dim dim = 2;
40 |
41 | // If true, the number of dimensions in the shape is unknown.
42 | //
43 | // If true, "dim.size()" must be 0.
44 | bool unknown_rank = 3;
45 | };
46 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/tensor_shape_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/tensor_shape.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='tensorboardX/src/tensor_shape.proto',
20 | package='tensorboard',
21 | syntax='proto3',
22 | serialized_pb=_b('\n#tensorboardX/src/tensor_shape.proto\x12\x0btensorboard\"{\n\x10TensorShapeProto\x12.\n\x03\x64im\x18\x02 \x03(\x0b\x32!.tensorboard.TensorShapeProto.Dim\x12\x14\n\x0cunknown_rank\x18\x03 \x01(\x08\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tB2\n\x18org.tensorflow.frameworkB\x11TensorShapeProtosP\x01\xf8\x01\x01\x62\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
29 | name='Dim',
30 | full_name='tensorboard.TensorShapeProto.Dim',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='size', full_name='tensorboard.TensorShapeProto.Dim.size', index=0,
37 | number=1, type=3, cpp_type=2, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='name', full_name='tensorboard.TensorShapeProto.Dim.name', index=1,
44 | number=2, type=9, cpp_type=9, label=1,
45 | has_default_value=False, default_value=_b("").decode('utf-8'),
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | ],
50 | extensions=[
51 | ],
52 | nested_types=[],
53 | enum_types=[
54 | ],
55 | options=None,
56 | is_extendable=False,
57 | syntax='proto3',
58 | extension_ranges=[],
59 | oneofs=[
60 | ],
61 | serialized_start=142,
62 | serialized_end=175,
63 | )
64 |
65 | _TENSORSHAPEPROTO = _descriptor.Descriptor(
66 | name='TensorShapeProto',
67 | full_name='tensorboard.TensorShapeProto',
68 | filename=None,
69 | file=DESCRIPTOR,
70 | containing_type=None,
71 | fields=[
72 | _descriptor.FieldDescriptor(
73 | name='dim', full_name='tensorboard.TensorShapeProto.dim', index=0,
74 | number=2, type=11, cpp_type=10, label=3,
75 | has_default_value=False, default_value=[],
76 | message_type=None, enum_type=None, containing_type=None,
77 | is_extension=False, extension_scope=None,
78 | options=None),
79 | _descriptor.FieldDescriptor(
80 | name='unknown_rank', full_name='tensorboard.TensorShapeProto.unknown_rank', index=1,
81 | number=3, type=8, cpp_type=7, label=1,
82 | has_default_value=False, default_value=False,
83 | message_type=None, enum_type=None, containing_type=None,
84 | is_extension=False, extension_scope=None,
85 | options=None),
86 | ],
87 | extensions=[
88 | ],
89 | nested_types=[_TENSORSHAPEPROTO_DIM, ],
90 | enum_types=[
91 | ],
92 | options=None,
93 | is_extendable=False,
94 | syntax='proto3',
95 | extension_ranges=[],
96 | oneofs=[
97 | ],
98 | serialized_start=52,
99 | serialized_end=175,
100 | )
101 |
102 | _TENSORSHAPEPROTO_DIM.containing_type = _TENSORSHAPEPROTO
103 | _TENSORSHAPEPROTO.fields_by_name['dim'].message_type = _TENSORSHAPEPROTO_DIM
104 | DESCRIPTOR.message_types_by_name['TensorShapeProto'] = _TENSORSHAPEPROTO
105 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
106 |
107 | TensorShapeProto = _reflection.GeneratedProtocolMessageType('TensorShapeProto', (_message.Message,), dict(
108 |
109 | Dim = _reflection.GeneratedProtocolMessageType('Dim', (_message.Message,), dict(
110 | DESCRIPTOR = _TENSORSHAPEPROTO_DIM,
111 | __module__ = 'tensorboardX.src.tensor_shape_pb2'
112 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto.Dim)
113 | ))
114 | ,
115 | DESCRIPTOR = _TENSORSHAPEPROTO,
116 | __module__ = 'tensorboardX.src.tensor_shape_pb2'
117 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto)
118 | ))
119 | _sym_db.RegisterMessage(TensorShapeProto)
120 | _sym_db.RegisterMessage(TensorShapeProto.Dim)
121 |
122 |
123 | DESCRIPTOR.has_options = True
124 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\021TensorShapeProtosP\001\370\001\001'))
125 | # @@protoc_insertion_point(module_scope)
126 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/tensor_shape_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/tensor_shape_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/types.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "TypesProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | // LINT.IfChange
10 | enum DataType {
11 | // Not a legal value for DataType. Used to indicate a DataType field
12 | // has not been set.
13 | DT_INVALID = 0;
14 |
15 | // Data types that all computation devices are expected to be
16 | // capable to support.
17 | DT_FLOAT = 1;
18 | DT_DOUBLE = 2;
19 | DT_INT32 = 3;
20 | DT_UINT8 = 4;
21 | DT_INT16 = 5;
22 | DT_INT8 = 6;
23 | DT_STRING = 7;
24 | DT_COMPLEX64 = 8; // Single-precision complex
25 | DT_INT64 = 9;
26 | DT_BOOL = 10;
27 | DT_QINT8 = 11; // Quantized int8
28 | DT_QUINT8 = 12; // Quantized uint8
29 | DT_QINT32 = 13; // Quantized int32
30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
31 | DT_QINT16 = 15; // Quantized int16
32 | DT_QUINT16 = 16; // Quantized uint16
33 | DT_UINT16 = 17;
34 | DT_COMPLEX128 = 18; // Double-precision complex
35 | DT_HALF = 19;
36 | DT_RESOURCE = 20;
37 |
38 | // TODO(josh11b): DT_GENERIC_PROTO = ??;
39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32?
40 |
41 | // Do not use! These are only for parameters. Every enum above
42 | // should have a corresponding value below (verified by types_test).
43 | DT_FLOAT_REF = 101;
44 | DT_DOUBLE_REF = 102;
45 | DT_INT32_REF = 103;
46 | DT_UINT8_REF = 104;
47 | DT_INT16_REF = 105;
48 | DT_INT8_REF = 106;
49 | DT_STRING_REF = 107;
50 | DT_COMPLEX64_REF = 108;
51 | DT_INT64_REF = 109;
52 | DT_BOOL_REF = 110;
53 | DT_QINT8_REF = 111;
54 | DT_QUINT8_REF = 112;
55 | DT_QINT32_REF = 113;
56 | DT_BFLOAT16_REF = 114;
57 | DT_QINT16_REF = 115;
58 | DT_QUINT16_REF = 116;
59 | DT_UINT16_REF = 117;
60 | DT_COMPLEX128_REF = 118;
61 | DT_HALF_REF = 119;
62 | DT_RESOURCE_REF = 120;
63 | }
64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
65 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/types_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/types.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf.internal import enum_type_wrapper
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | from google.protobuf import descriptor_pb2
12 | # @@protoc_insertion_point(imports)
13 |
14 | _sym_db = _symbol_database.Default()
15 |
16 |
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='tensorboardX/src/types.proto',
21 | package='tensorboard',
22 | syntax='proto3',
23 | serialized_pb=_b('\n\x1ctensorboardX/src/types.proto\x12\x0btensorboard*\xc2\x05\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x10\n\x0c\x44T_COMPLEX64\x10\x08\x12\x0c\n\x08\x44T_INT64\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n\x12\x0c\n\x08\x44T_QINT8\x10\x0b\x12\r\n\tDT_QUINT8\x10\x0c\x12\r\n\tDT_QINT32\x10\r\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0e\x12\r\n\tDT_QINT16\x10\x0f\x12\x0e\n\nDT_QUINT16\x10\x10\x12\r\n\tDT_UINT16\x10\x11\x12\x11\n\rDT_COMPLEX128\x10\x12\x12\x0b\n\x07\x44T_HALF\x10\x13\x12\x0f\n\x0b\x44T_RESOURCE\x10\x14\x12\x10\n\x0c\x44T_FLOAT_REF\x10\x65\x12\x11\n\rDT_DOUBLE_REF\x10\x66\x12\x10\n\x0c\x44T_INT32_REF\x10g\x12\x10\n\x0c\x44T_UINT8_REF\x10h\x12\x10\n\x0c\x44T_INT16_REF\x10i\x12\x0f\n\x0b\x44T_INT8_REF\x10j\x12\x11\n\rDT_STRING_REF\x10k\x12\x14\n\x10\x44T_COMPLEX64_REF\x10l\x12\x10\n\x0c\x44T_INT64_REF\x10m\x12\x0f\n\x0b\x44T_BOOL_REF\x10n\x12\x10\n\x0c\x44T_QINT8_REF\x10o\x12\x11\n\rDT_QUINT8_REF\x10p\x12\x11\n\rDT_QINT32_REF\x10q\x12\x13\n\x0f\x44T_BFLOAT16_REF\x10r\x12\x11\n\rDT_QINT16_REF\x10s\x12\x12\n\x0e\x44T_QUINT16_REF\x10t\x12\x11\n\rDT_UINT16_REF\x10u\x12\x15\n\x11\x44T_COMPLEX128_REF\x10v\x12\x0f\n\x0b\x44T_HALF_REF\x10w\x12\x13\n\x0f\x44T_RESOURCE_REF\x10xB,\n\x18org.tensorflow.frameworkB\x0bTypesProtosP\x01\xf8\x01\x01\x62\x06proto3')
24 | )
25 |
26 | _DATATYPE = _descriptor.EnumDescriptor(
27 | name='DataType',
28 | full_name='tensorboard.DataType',
29 | filename=None,
30 | file=DESCRIPTOR,
31 | values=[
32 | _descriptor.EnumValueDescriptor(
33 | name='DT_INVALID', index=0, number=0,
34 | options=None,
35 | type=None),
36 | _descriptor.EnumValueDescriptor(
37 | name='DT_FLOAT', index=1, number=1,
38 | options=None,
39 | type=None),
40 | _descriptor.EnumValueDescriptor(
41 | name='DT_DOUBLE', index=2, number=2,
42 | options=None,
43 | type=None),
44 | _descriptor.EnumValueDescriptor(
45 | name='DT_INT32', index=3, number=3,
46 | options=None,
47 | type=None),
48 | _descriptor.EnumValueDescriptor(
49 | name='DT_UINT8', index=4, number=4,
50 | options=None,
51 | type=None),
52 | _descriptor.EnumValueDescriptor(
53 | name='DT_INT16', index=5, number=5,
54 | options=None,
55 | type=None),
56 | _descriptor.EnumValueDescriptor(
57 | name='DT_INT8', index=6, number=6,
58 | options=None,
59 | type=None),
60 | _descriptor.EnumValueDescriptor(
61 | name='DT_STRING', index=7, number=7,
62 | options=None,
63 | type=None),
64 | _descriptor.EnumValueDescriptor(
65 | name='DT_COMPLEX64', index=8, number=8,
66 | options=None,
67 | type=None),
68 | _descriptor.EnumValueDescriptor(
69 | name='DT_INT64', index=9, number=9,
70 | options=None,
71 | type=None),
72 | _descriptor.EnumValueDescriptor(
73 | name='DT_BOOL', index=10, number=10,
74 | options=None,
75 | type=None),
76 | _descriptor.EnumValueDescriptor(
77 | name='DT_QINT8', index=11, number=11,
78 | options=None,
79 | type=None),
80 | _descriptor.EnumValueDescriptor(
81 | name='DT_QUINT8', index=12, number=12,
82 | options=None,
83 | type=None),
84 | _descriptor.EnumValueDescriptor(
85 | name='DT_QINT32', index=13, number=13,
86 | options=None,
87 | type=None),
88 | _descriptor.EnumValueDescriptor(
89 | name='DT_BFLOAT16', index=14, number=14,
90 | options=None,
91 | type=None),
92 | _descriptor.EnumValueDescriptor(
93 | name='DT_QINT16', index=15, number=15,
94 | options=None,
95 | type=None),
96 | _descriptor.EnumValueDescriptor(
97 | name='DT_QUINT16', index=16, number=16,
98 | options=None,
99 | type=None),
100 | _descriptor.EnumValueDescriptor(
101 | name='DT_UINT16', index=17, number=17,
102 | options=None,
103 | type=None),
104 | _descriptor.EnumValueDescriptor(
105 | name='DT_COMPLEX128', index=18, number=18,
106 | options=None,
107 | type=None),
108 | _descriptor.EnumValueDescriptor(
109 | name='DT_HALF', index=19, number=19,
110 | options=None,
111 | type=None),
112 | _descriptor.EnumValueDescriptor(
113 | name='DT_RESOURCE', index=20, number=20,
114 | options=None,
115 | type=None),
116 | _descriptor.EnumValueDescriptor(
117 | name='DT_FLOAT_REF', index=21, number=101,
118 | options=None,
119 | type=None),
120 | _descriptor.EnumValueDescriptor(
121 | name='DT_DOUBLE_REF', index=22, number=102,
122 | options=None,
123 | type=None),
124 | _descriptor.EnumValueDescriptor(
125 | name='DT_INT32_REF', index=23, number=103,
126 | options=None,
127 | type=None),
128 | _descriptor.EnumValueDescriptor(
129 | name='DT_UINT8_REF', index=24, number=104,
130 | options=None,
131 | type=None),
132 | _descriptor.EnumValueDescriptor(
133 | name='DT_INT16_REF', index=25, number=105,
134 | options=None,
135 | type=None),
136 | _descriptor.EnumValueDescriptor(
137 | name='DT_INT8_REF', index=26, number=106,
138 | options=None,
139 | type=None),
140 | _descriptor.EnumValueDescriptor(
141 | name='DT_STRING_REF', index=27, number=107,
142 | options=None,
143 | type=None),
144 | _descriptor.EnumValueDescriptor(
145 | name='DT_COMPLEX64_REF', index=28, number=108,
146 | options=None,
147 | type=None),
148 | _descriptor.EnumValueDescriptor(
149 | name='DT_INT64_REF', index=29, number=109,
150 | options=None,
151 | type=None),
152 | _descriptor.EnumValueDescriptor(
153 | name='DT_BOOL_REF', index=30, number=110,
154 | options=None,
155 | type=None),
156 | _descriptor.EnumValueDescriptor(
157 | name='DT_QINT8_REF', index=31, number=111,
158 | options=None,
159 | type=None),
160 | _descriptor.EnumValueDescriptor(
161 | name='DT_QUINT8_REF', index=32, number=112,
162 | options=None,
163 | type=None),
164 | _descriptor.EnumValueDescriptor(
165 | name='DT_QINT32_REF', index=33, number=113,
166 | options=None,
167 | type=None),
168 | _descriptor.EnumValueDescriptor(
169 | name='DT_BFLOAT16_REF', index=34, number=114,
170 | options=None,
171 | type=None),
172 | _descriptor.EnumValueDescriptor(
173 | name='DT_QINT16_REF', index=35, number=115,
174 | options=None,
175 | type=None),
176 | _descriptor.EnumValueDescriptor(
177 | name='DT_QUINT16_REF', index=36, number=116,
178 | options=None,
179 | type=None),
180 | _descriptor.EnumValueDescriptor(
181 | name='DT_UINT16_REF', index=37, number=117,
182 | options=None,
183 | type=None),
184 | _descriptor.EnumValueDescriptor(
185 | name='DT_COMPLEX128_REF', index=38, number=118,
186 | options=None,
187 | type=None),
188 | _descriptor.EnumValueDescriptor(
189 | name='DT_HALF_REF', index=39, number=119,
190 | options=None,
191 | type=None),
192 | _descriptor.EnumValueDescriptor(
193 | name='DT_RESOURCE_REF', index=40, number=120,
194 | options=None,
195 | type=None),
196 | ],
197 | containing_type=None,
198 | options=None,
199 | serialized_start=46,
200 | serialized_end=752,
201 | )
202 | _sym_db.RegisterEnumDescriptor(_DATATYPE)
203 |
204 | DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
205 | DT_INVALID = 0
206 | DT_FLOAT = 1
207 | DT_DOUBLE = 2
208 | DT_INT32 = 3
209 | DT_UINT8 = 4
210 | DT_INT16 = 5
211 | DT_INT8 = 6
212 | DT_STRING = 7
213 | DT_COMPLEX64 = 8
214 | DT_INT64 = 9
215 | DT_BOOL = 10
216 | DT_QINT8 = 11
217 | DT_QUINT8 = 12
218 | DT_QINT32 = 13
219 | DT_BFLOAT16 = 14
220 | DT_QINT16 = 15
221 | DT_QUINT16 = 16
222 | DT_UINT16 = 17
223 | DT_COMPLEX128 = 18
224 | DT_HALF = 19
225 | DT_RESOURCE = 20
226 | DT_FLOAT_REF = 101
227 | DT_DOUBLE_REF = 102
228 | DT_INT32_REF = 103
229 | DT_UINT8_REF = 104
230 | DT_INT16_REF = 105
231 | DT_INT8_REF = 106
232 | DT_STRING_REF = 107
233 | DT_COMPLEX64_REF = 108
234 | DT_INT64_REF = 109
235 | DT_BOOL_REF = 110
236 | DT_QINT8_REF = 111
237 | DT_QUINT8_REF = 112
238 | DT_QINT32_REF = 113
239 | DT_BFLOAT16_REF = 114
240 | DT_QINT16_REF = 115
241 | DT_QUINT16_REF = 116
242 | DT_UINT16_REF = 117
243 | DT_COMPLEX128_REF = 118
244 | DT_HALF_REF = 119
245 | DT_RESOURCE_REF = 120
246 |
247 |
248 | DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
249 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
250 |
251 |
252 | DESCRIPTOR.has_options = True
253 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013TypesProtosP\001\370\001\001'))
254 | # @@protoc_insertion_point(module_scope)
255 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/types_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/types_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/src/versions.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package tensorboard;
4 | option cc_enable_arenas = true;
5 | option java_outer_classname = "VersionsProtos";
6 | option java_multiple_files = true;
7 | option java_package = "org.tensorflow.framework";
8 |
9 | // Version information for a piece of serialized data
10 | //
11 | // There are different types of versions for each type of data
12 | // (GraphDef, etc.), but they all have the same common shape
13 | // described here.
14 | //
15 | // Each consumer has "consumer" and "min_producer" versions (specified
16 | // elsewhere). A consumer is allowed to consume this data if
17 | //
18 | // producer >= min_producer
19 | // consumer >= min_consumer
20 | // consumer not in bad_consumers
21 | //
22 | message VersionDef {
23 | // The version of the code that produced this data.
24 | int32 producer = 1;
25 |
26 | // Any consumer below this version is not allowed to consume this data.
27 | int32 min_consumer = 2;
28 |
29 | // Specific consumer versions which are disallowed (e.g. due to bugs).
30 | repeated int32 bad_consumers = 3;
31 | };
32 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/versions_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: tensorboardX/src/versions.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='tensorboardX/src/versions.proto',
20 | package='tensorboard',
21 | syntax='proto3',
22 | serialized_pb=_b('\n\x1ftensorboardX/src/versions.proto\x12\x0btensorboard\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x42/\n\x18org.tensorflow.frameworkB\x0eVersionsProtosP\x01\xf8\x01\x01\x62\x06proto3')
23 | )
24 |
25 |
26 |
27 |
28 | _VERSIONDEF = _descriptor.Descriptor(
29 | name='VersionDef',
30 | full_name='tensorboard.VersionDef',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='producer', full_name='tensorboard.VersionDef.producer', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='min_consumer', full_name='tensorboard.VersionDef.min_consumer', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='bad_consumers', full_name='tensorboard.VersionDef.bad_consumers', index=2,
51 | number=3, type=5, cpp_type=1, label=3,
52 | has_default_value=False, default_value=[],
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | syntax='proto3',
65 | extension_ranges=[],
66 | oneofs=[
67 | ],
68 | serialized_start=48,
69 | serialized_end=123,
70 | )
71 |
72 | DESCRIPTOR.message_types_by_name['VersionDef'] = _VERSIONDEF
73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
74 |
75 | VersionDef = _reflection.GeneratedProtocolMessageType('VersionDef', (_message.Message,), dict(
76 | DESCRIPTOR = _VERSIONDEF,
77 | __module__ = 'tensorboardX.src.versions_pb2'
78 | # @@protoc_insertion_point(class_scope:tensorboard.VersionDef)
79 | ))
80 | _sym_db.RegisterMessage(VersionDef)
81 |
82 |
83 | DESCRIPTOR.has_options = True
84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016VersionsProtosP\001\370\001\001'))
85 | # @@protoc_insertion_point(module_scope)
86 |
--------------------------------------------------------------------------------
/classification/tensorboardX/src/versions_pb2.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/versions_pb2.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/summary.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/summary.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/writer.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/writer.pyc
--------------------------------------------------------------------------------
/classification/tensorboardX/x2num.py:
--------------------------------------------------------------------------------
1 | # DO NOT alter/distruct/free input object !
2 |
3 | import numpy as np
4 |
5 |
6 | def makenp(x, modality=None):
7 | # if already numpy, return
8 | if isinstance(x, np.ndarray):
9 | if modality == 'IMG' and x.dtype == np.uint8:
10 | return x.astype(np.float32) / 255.0
11 | return x
12 | if np.isscalar(x):
13 | return np.array([x])
14 | if 'torch' in str(type(x)):
15 | return pytorch_np(x, modality)
16 | if 'chainer' in str(type(x)):
17 | return chainer_np(x, modality)
18 | if 'mxnet' in str(type(x)):
19 | return mxnet_np(x, modality)
20 |
21 |
22 | def pytorch_np(x, modality):
23 | import torch
24 | try:
25 | if isinstance(x, torch.autograd.Variable):
26 | x = x.data
27 | except:
28 | if isinstance(x, torch.autograd.variable.Variable):
29 | x = x.data
30 | x = x.cpu().numpy()
31 | if modality == 'IMG':
32 | x = _prepare_image(x)
33 | return x
34 |
35 |
36 | def theano_np(x):
37 | import theano
38 | pass
39 |
40 |
41 | def caffe2_np(x):
42 | pass
43 |
44 |
45 | def mxnet_np(x, modality):
46 | x = x.asnumpy()
47 | if modality == 'IMG':
48 | x = _prepare_image(x)
49 | return x
50 |
51 |
52 | def chainer_np(x, modality):
53 | import chainer
54 | x = chainer.cuda.to_cpu(x.data)
55 | if modality == 'IMG':
56 | x = _prepare_image(x)
57 | return x
58 |
59 |
60 | def make_grid(I, ncols=8):
61 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here'
62 | assert I.ndim == 4 and I.shape[1] == 3
63 | nimg = I.shape[0]
64 | H = I.shape[2]
65 | W = I.shape[3]
66 | ncols = min(nimg, ncols)
67 | nrows = int(np.ceil(float(nimg) / ncols))
68 | canvas = np.zeros((3, H * nrows, W * ncols))
69 | i = 0
70 | for y in range(nrows):
71 | for x in range(ncols):
72 | if i >= nimg:
73 | break
74 | canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i]
75 | i = i + 1
76 | return canvas
77 |
78 |
79 | def _prepare_image(I):
80 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here'
81 | assert I.ndim == 2 or I.ndim == 3 or I.ndim == 4
82 | if I.ndim == 4: # NCHW
83 | if I.shape[1] == 1: # N1HW
84 | I = np.concatenate((I, I, I), 1) # N3HW
85 | assert I.shape[1] == 3
86 | I = make_grid(I) # 3xHxW
87 | if I.ndim == 3 and I.shape[0] == 1: # 1xHxW
88 | I = np.concatenate((I, I, I), 0) # 3xHxW
89 | if I.ndim == 2: # HxW
90 | I = np.expand_dims(I, 0) # 1xHxW
91 | I = np.concatenate((I, I, I), 0) # 3xHxW
92 | I = I.transpose(1, 2, 0)
93 |
94 | return I
95 |
--------------------------------------------------------------------------------
/classification/tensorboardX/x2num.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/x2num.pyc
--------------------------------------------------------------------------------
/classification/utility.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def attention(x):
7 | return F.normalize(x.pow(2).mean(1).view(x.shape[0], -1))
8 |
9 |
10 | def at_loss(x, y):
11 | return (attention(x) - attention(y)).pow(2).mean()
12 |
13 |
--------------------------------------------------------------------------------
/classification/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from utils.misc import *
4 | from utils.logger import *
5 | from utils.visualize import *
6 | from utils.eval import *
7 |
8 | # progress bar
9 | import os
10 | import sys
11 | from utils.progress.bar import Bar as Bar
12 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
13 |
--------------------------------------------------------------------------------
/classification/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/__pycache__/eval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/eval.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/__pycache__/visualize.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/visualize.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 |
6 | def accuracy(output, target, topk=(1,)):
7 | """
8 | Computes the precision@k for the specified values of k
9 | :param output:
10 | :param target:
11 | :param topk:
12 | :return:
13 | """
14 | maxk = max(topk)
15 | batch_size = target.size(0)
16 |
17 | _, pred = output.topk(maxk, 1, True, True)
18 | pred = pred.t()
19 | correct = pred.eq(target.view(1, -1).expand_as(pred))
20 |
21 | res = []
22 | for k in topk:
23 | correct_k = correct[:k].view(-1).float().sum(0)
24 | res.append(correct_k.mul_(100.0 / batch_size))
25 | return res
26 |
--------------------------------------------------------------------------------
/classification/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/images/cifar.png
--------------------------------------------------------------------------------
/classification/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/images/imagenet.png
--------------------------------------------------------------------------------
/classification/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import math
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 | import os
9 | import sys
10 | import numpy as np
11 |
12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
13 |
14 |
15 | def savefig(fname, dpi=None):
16 | """
17 | :param fname: file name
18 | :param dpi: dpi set to 150 by default
19 | :return:
20 | """
21 | dpi = 150 if dpi is None else dpi
22 | plt.savefig(fname, dpi=dpi)
23 |
24 |
25 | def plot_overlap(logger, names=None):
26 | names = logger.names if names is None else names
27 | numbers = logger.numbers
28 | for _, name in enumerate(names):
29 | x = np.arange(len(numbers[name]))
30 | plt.plot(x, np.asarray(numbers[name]))
31 | return [logger.title + '(' + name + ')' for name in names]
32 |
33 |
34 | class Logger(object):
35 | """
36 | Save training process to log file with simple plot function.
37 | """
38 | def __init__(self, fpath, title=None, resume=False):
39 | self.file = None
40 | self.resume = resume
41 | self.title = '' if title is None else title
42 | if fpath is not None:
43 | if resume:
44 | self.file = open(fpath, 'r')
45 | name = self.file.readline()
46 | self.names = name.rstrip().split('\t')
47 | self.numbers = {}
48 | for _, name in enumerate(self.names):
49 | self.numbers[name] = []
50 |
51 | for numbers in self.file:
52 | numbers = numbers.rstrip().split('\t')
53 | for i in range(0, len(numbers)):
54 | self.numbers[self.names[i]].append(numbers[i])
55 | self.file.close()
56 | self.file = open(fpath, 'a')
57 | else:
58 | self.file = open(fpath, 'w')
59 |
60 | def set_names(self, names):
61 | if self.resume:
62 | pass
63 | # initialize numbers as empty list
64 | self.numbers = {}
65 | self.names = names
66 | for _, name in enumerate(self.names):
67 | self.file.write(name)
68 | self.file.write('\t')
69 | self.numbers[name] = []
70 | self.file.write('\n')
71 | self.file.flush()
72 |
73 | def append(self, numbers):
74 | assert len(self.names) == len(numbers), 'Numbers do not match names'
75 | for index, num in enumerate(numbers):
76 | self.file.write("{0:.6f}".format(num))
77 | self.file.write('\t')
78 | self.numbers[self.names[index]].append(num)
79 | self.file.write('\n')
80 | self.file.flush()
81 |
82 | def plot(self, names=None):
83 | names = self.names if names is None else names
84 | numbers = self.numbers
85 | # fig = plt.figure()
86 | for idx, name in enumerate(names):
87 | x = np.arange(len(numbers[name]))
88 | plt.plot(x, np.asarray(numbers[name]))
89 | plt.legend([self.title + '(' + name + ')' for name in names])
90 | plt.grid(True)
91 |
92 | def close(self):
93 | if self.file is not None:
94 | self.file.close()
95 |
96 |
97 | class LoggerMonitor(object):
98 | """
99 | Load and visualize multiple logs.
100 | """
101 | def __init__(self, paths):
102 | """
103 | paths is a dictionary with {name:filepath} pair
104 | :param paths:
105 | """
106 | self.loggers = []
107 | for title, path in paths.items():
108 | logger = Logger(path, title=title, resume=True)
109 | self.loggers.append(logger)
110 |
111 | def plot(self, names=None):
112 | plt.figure()
113 | plt.subplot(121)
114 | legend_text = []
115 | for logger in self.loggers:
116 | legend_text += plot_overlap(logger, names)
117 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
118 | plt.grid(True)
119 |
120 |
121 | if __name__ == '__main__':
122 | # # Example
123 | # logger = Logger('test.txt')
124 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
125 |
126 | # length = 100
127 | # t = np.arange(length)
128 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
129 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
130 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
131 |
132 | # for i in range(0, length):
133 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
134 | # logger.plot()
135 |
136 | # Example: logger monitor
137 | paths = {
138 | 'resadvnet20': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
139 | 'resadvnet32': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
140 | 'resadvnet44': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
141 | }
142 |
143 | field = ['Valid Acc.']
144 |
145 | monitor = LoggerMonitor(paths)
146 | monitor.plot(names=field)
147 | savefig('test.pdf', dpi=300)
148 |
--------------------------------------------------------------------------------
/classification/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Some helper functions for PyTorch, including:
3 | - get_mean_and_std: calculate the mean and std value of dataset.
4 | - msr_init: net parameter initialization.
5 | - progress_bar: progress bar mimic xlua.progress.
6 | """
7 |
8 | import errno
9 | import os
10 | import sys
11 | import time
12 | import math
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.init as init
17 | from torch.autograd import Variable
18 |
19 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
20 |
21 |
22 | def get_mean_and_std(dataset):
23 | """
24 | Compute the mean and std value of dataset.
25 | :param dataset:
26 | :return:
27 | """
28 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
29 |
30 | mean = torch.zeros(3)
31 | std = torch.zeros(3)
32 | print('==> Computing mean and std..')
33 | for inputs, targets in dataloader:
34 | for i in range(3):
35 | mean[i] += inputs[:, i, :, :].mean()
36 | std[i] += inputs[:, i, :, :].std()
37 | mean.div_(len(dataset))
38 | std.div_(len(dataset))
39 | return mean, std
40 |
41 |
42 | def init_params(net):
43 | """
44 | Initialize layer parameters.
45 | :param net:
46 | :return:
47 | """
48 | for m in net.modules():
49 | if isinstance(m, nn.Conv2d):
50 | init.kaiming_normal(m.weight, mode='fan_out')
51 | if m.bias:
52 | init.constant(m.bias, 0)
53 | elif isinstance(m, nn.BatchNorm2d):
54 | init.constant(m.weight, 1)
55 | init.constant(m.bias, 0)
56 | elif isinstance(m, nn.Linear):
57 | init.normal(m.weight, std=1e-3)
58 | if m.bias:
59 | init.constant(m.bias, 0)
60 |
61 |
62 | def mkdir_p(path):
63 | """
64 | make dir if not exist
65 | :param path:
66 | :return:
67 | """
68 | try:
69 | os.makedirs(path)
70 | except OSError as exc: # Python >2.5
71 | if exc.errno == errno.EEXIST and os.path.isdir(path):
72 | pass
73 | else:
74 | raise
75 |
76 |
77 | class AverageMeter(object):
78 | """
79 | Computes and stores the average and current value
80 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
81 | """
82 | def __init__(self):
83 | self.reset()
84 |
85 | def reset(self):
86 | self.val = 0
87 | self.avg = 0
88 | self.sum = 0
89 | self.count = 0
90 |
91 | def update(self, val, n=1):
92 | self.val = val
93 | self.sum += val * n
94 | self.count += n
95 | self.avg = self.sum / self.count
96 |
--------------------------------------------------------------------------------
/classification/utils/progress/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import division
16 |
17 | from collections import deque
18 | from datetime import timedelta
19 | from math import ceil
20 | from sys import stderr
21 | from time import time
22 |
23 |
24 | __version__ = '1.3'
25 |
26 |
27 | class Infinite(object):
28 | file = stderr
29 | sma_window = 10 # Simple Moving Average window
30 |
31 | def __init__(self, *args, **kwargs):
32 | self.index = 0
33 | self.start_ts = time()
34 | self.avg = 0
35 | self._ts = self.start_ts
36 | self._xput = deque(maxlen=self.sma_window)
37 | for key, val in kwargs.items():
38 | setattr(self, key, val)
39 |
40 | def __getitem__(self, key):
41 | if key.startswith('_'):
42 | return None
43 | return getattr(self, key, None)
44 |
45 | @property
46 | def elapsed(self):
47 | return int(time() - self.start_ts)
48 |
49 | @property
50 | def elapsed_td(self):
51 | return timedelta(seconds=self.elapsed)
52 |
53 | def update_avg(self, n, dt):
54 | if n > 0:
55 | self._xput.append(dt / n)
56 | self.avg = sum(self._xput) / len(self._xput)
57 |
58 | def update(self):
59 | pass
60 |
61 | def start(self):
62 | pass
63 |
64 | def finish(self):
65 | pass
66 |
67 | def next(self, n=1):
68 | now = time()
69 | dt = now - self._ts
70 | self.update_avg(n, dt)
71 | self._ts = now
72 | self.index = self.index + n
73 | self.update()
74 |
75 | def iter(self, it):
76 | try:
77 | for x in it:
78 | yield x
79 | self.next()
80 | finally:
81 | self.finish()
82 |
83 |
84 | class Progress(Infinite):
85 | def __init__(self, *args, **kwargs):
86 | super(Progress, self).__init__(*args, **kwargs)
87 | self.max = kwargs.get('max', 100)
88 |
89 | @property
90 | def eta(self):
91 | return int(ceil(self.avg * self.remaining))
92 |
93 | @property
94 | def eta_td(self):
95 | return timedelta(seconds=self.eta)
96 |
97 | @property
98 | def percent(self):
99 | return self.progress * 100
100 |
101 | @property
102 | def progress(self):
103 | return min(1, self.index / self.max)
104 |
105 | @property
106 | def remaining(self):
107 | return max(self.max - self.index, 0)
108 |
109 | def start(self):
110 | self.update()
111 |
112 | def goto(self, index):
113 | incr = index - self.index
114 | self.next(incr)
115 |
116 | def iter(self, it):
117 | try:
118 | self.max = len(it)
119 | except TypeError:
120 | pass
121 |
122 | try:
123 | for x in it:
124 | yield x
125 | self.next()
126 | finally:
127 | self.finish()
128 |
--------------------------------------------------------------------------------
/classification/utils/progress/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/progress/__pycache__/bar.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/bar.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/progress/__pycache__/helpers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/helpers.cpython-36.pyc
--------------------------------------------------------------------------------
/classification/utils/progress/bar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from utils.progress import Progress
19 | from utils.progress.helpers import WritelnMixin
20 |
21 |
22 | class Bar(WritelnMixin, Progress):
23 | width = 32
24 | message = ''
25 | suffix = '%(index)d/%(max)d'
26 | bar_prefix = ' |'
27 | bar_suffix = '| '
28 | empty_fill = ' '
29 | fill = '#'
30 | hide_cursor = True
31 |
32 | def update(self):
33 | filled_length = int(self.width * self.progress)
34 | empty_length = self.width - filled_length
35 |
36 | message = self.message % self
37 | bar = self.fill * filled_length
38 | empty = self.empty_fill * empty_length
39 | suffix = self.suffix % self
40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
41 | suffix])
42 | self.writeln(line)
43 |
44 |
45 | class ChargingBar(Bar):
46 | suffix = '%(percent)d%%'
47 | bar_prefix = ' '
48 | bar_suffix = ' '
49 | empty_fill = '∙'
50 | fill = '█'
51 |
52 |
53 | class FillingSquaresBar(ChargingBar):
54 | empty_fill = '▢'
55 | fill = '▣'
56 |
57 |
58 | class FillingCirclesBar(ChargingBar):
59 | empty_fill = '◯'
60 | fill = '◉'
61 |
62 |
63 | class IncrementalBar(Bar):
64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')
65 |
66 | def update(self):
67 | nphases = len(self.phases)
68 | filled_len = self.width * self.progress
69 | nfull = int(filled_len) # Number of full chars
70 | phase = int((filled_len - nfull) * nphases) # Phase of last char
71 | nempty = self.width - nfull # Number of empty chars
72 |
73 | message = self.message % self
74 | bar = self.phases[-1] * nfull
75 | current = self.phases[phase] if phase > 0 else ''
76 | empty = self.empty_fill * max(0, nempty - len(current))
77 | suffix = self.suffix % self
78 | line = ''.join([message, self.bar_prefix, bar, current, empty,
79 | self.bar_suffix, suffix])
80 | self.writeln(line)
81 |
82 |
83 | class PixelBar(IncrementalBar):
84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')
85 |
86 |
87 | class ShadyBar(IncrementalBar):
88 | phases = (' ', '░', '▒', '▓', '█')
89 |
--------------------------------------------------------------------------------
/classification/utils/progress/counter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from utils.progress import Infinite, Progress
19 | from utils.progress.helpers import WriteMixin
20 |
21 |
22 | class Counter(WriteMixin, Infinite):
23 | message = ''
24 | hide_cursor = True
25 |
26 | def update(self):
27 | self.write(str(self.index))
28 |
29 |
30 | class Countdown(WriteMixin, Progress):
31 | hide_cursor = True
32 |
33 | def update(self):
34 | self.write(str(self.remaining))
35 |
36 |
37 | class Stack(WriteMixin, Progress):
38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')
39 | hide_cursor = True
40 |
41 | def update(self):
42 | nphases = len(self.phases)
43 | i = min(nphases - 1, int(self.progress * nphases))
44 | self.write(self.phases[i])
45 |
46 |
47 | class Pie(Stack):
48 | phases = ('○', '◔', '◑', '◕', '●')
49 |
--------------------------------------------------------------------------------
/classification/utils/progress/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import print_function
16 |
17 | from signal import signal, SIGINT
18 | from sys import exit
19 |
20 |
21 | HIDE_CURSOR = '\x1b[?25l'
22 | SHOW_CURSOR = '\x1b[?25h'
23 |
24 |
25 | class WriteMixin(object):
26 | hide_cursor = False
27 |
28 | def __init__(self, message=None, **kwargs):
29 | super(WriteMixin, self).__init__(**kwargs)
30 | self._width = 0
31 | if message:
32 | self.message = message
33 |
34 | if self.file.isatty():
35 | if self.hide_cursor:
36 | print(HIDE_CURSOR, end='', file=self.file)
37 | print(self.message, end='', file=self.file)
38 | self.file.flush()
39 |
40 | def write(self, s):
41 | if self.file.isatty():
42 | b = '\b' * self._width
43 | c = s.ljust(self._width)
44 | print(b + c, end='', file=self.file)
45 | self._width = max(self._width, len(s))
46 | self.file.flush()
47 |
48 | def finish(self):
49 | if self.file.isatty() and self.hide_cursor:
50 | print(SHOW_CURSOR, end='', file=self.file)
51 |
52 |
53 | class WritelnMixin(object):
54 | hide_cursor = False
55 |
56 | def __init__(self, message=None, **kwargs):
57 | super(WritelnMixin, self).__init__(**kwargs)
58 | if message:
59 | self.message = message
60 |
61 | if self.file.isatty() and self.hide_cursor:
62 | print(HIDE_CURSOR, end='', file=self.file)
63 |
64 | def clearln(self):
65 | if self.file.isatty():
66 | print('\r\x1b[K', end='', file=self.file)
67 |
68 | def writeln(self, line):
69 | if self.file.isatty():
70 | self.clearln()
71 | print(line, end='', file=self.file)
72 | self.file.flush()
73 |
74 | def finish(self):
75 | if self.file.isatty():
76 | print(file=self.file)
77 | if self.hide_cursor:
78 | print(SHOW_CURSOR, end='', file=self.file)
79 |
80 |
81 | class SigIntMixin(object):
82 | """
83 | Registers a signal handler that calls finish on SIGINT
84 | """
85 |
86 | def __init__(self, *args, **kwargs):
87 | super(SigIntMixin, self).__init__(*args, **kwargs)
88 | signal(SIGINT, self._sigint_handler)
89 |
90 | def _sigint_handler(self, signum, frame):
91 | self.finish()
92 | exit(0)
93 |
--------------------------------------------------------------------------------
/classification/utils/progress/spinner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from utils.progress import Infinite
19 | from utils.progress.helpers import WriteMixin
20 |
21 |
22 | class Spinner(WriteMixin, Infinite):
23 | message = ''
24 | phases = ('-', '\\', '|', '/')
25 | hide_cursor = True
26 |
27 | def update(self):
28 | i = self.index % len(self.phases)
29 | self.write(self.phases[i])
30 |
31 |
32 | class PieSpinner(Spinner):
33 | phases = ['◷', '◶', '◵', '◴']
34 |
35 |
36 | class MoonSpinner(Spinner):
37 | phases = ['◑', '◒', '◐', '◓']
38 |
39 |
40 | class LineSpinner(Spinner):
41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']
42 |
43 |
44 | class PixelSpinner(Spinner):
45 | phases = ['⣾', '⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']
46 |
--------------------------------------------------------------------------------
/classification/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import torchvision
5 |
6 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
7 |
8 |
9 | # functions to show an image
10 | def make_image(img, mean=(0, 0, 0), std=(1, 1, 1)):
11 | for i in range(0, 3):
12 | img[i] = img[i] * std[i] + mean[i] # de-normalize
13 | npimg = img.numpy()
14 | return np.transpose(npimg, (1, 2, 0))
15 |
16 |
17 | def gauss(x, a, b, c):
18 | return torch.exp(-torch.pow(torch.add(x, -b), 2).div(2*c*c)).mul(a)
19 |
20 |
21 | def colorize(x):
22 | """
23 | Converts a one-channel gray-scale image to a color heatmap image
24 | :param x: input gray-scale image
25 | :return:
26 | """
27 | if x.dim() == 2:
28 | torch.unsqueeze(x, 0, out=x)
29 | if x.dim() == 3:
30 | cl = torch.zeros([3, x.size(1), x.size(2)])
31 | cl[0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
32 | cl[1] = gauss(x, 1, .5, .3)
33 | cl[2] = gauss(x, 1, .2, .3)
34 | cl[cl.gt(1)] = 1
35 | elif x.dim() == 4:
36 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
37 | cl[:, 0, :, :] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
38 | cl[:, 1, :, :] = gauss(x, 1, .5, .3)
39 | cl[:, 2, :, :] = gauss(x, 1, .2, .3)
40 | return cl
41 |
42 |
43 | def show_batch(images, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)):
44 | images = make_image(torchvision.utils.make_grid(images), mean, std)
45 | plt.imshow(images)
46 | plt.show()
47 |
48 |
49 | def show_mask_single(images, mask, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)):
50 | im_size = images.size(2)
51 |
52 | # save for adding mask
53 | im_data = images.clone()
54 | for i in range(0, 3):
55 | im_data[:, i, :, :] = im_data[:, i, :, :] * std[i] + mean[i] # denormalize
56 |
57 | images = make_image(torchvision.utils.make_grid(images), mean, std)
58 | plt.subplot(2, 1, 1)
59 | plt.imshow(images)
60 | plt.axis('off')
61 |
62 | # for b in range(mask.size(0)):
63 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
64 | mask_size = mask.size(2)
65 | # print('Max %f Min %f' % (mask.max(), mask.min()))
66 | mask = (upsampling(mask, scale_factor=im_size / mask_size))
67 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
68 | # for c in range(3):
69 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
70 |
71 | # print(mask.size())
72 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
73 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
74 | plt.subplot(2, 1, 2)
75 | plt.imshow(mask)
76 | plt.axis('off')
77 |
78 |
79 | def show_mask(images, masklist, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)):
80 | im_size = images.size(2)
81 |
82 | # save for adding mask
83 | im_data = images.clone()
84 | for i in range(0, 3):
85 | im_data[:, i, :, :] = im_data[:, i, :, :] * std[i] + mean[i] # unnormalize
86 |
87 | images = make_image(torchvision.utils.make_grid(images), mean, std)
88 | plt.subplot(1+len(masklist), 1, 1)
89 | plt.imshow(images)
90 | plt.axis('off')
91 |
92 | for i in range(len(masklist)):
93 | mask = masklist[i].data.cpu()
94 | # for b in range(mask.size(0)):
95 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
96 | mask_size = mask.size(2)
97 | # print('Max %f Min %f' % (mask.max(), mask.min()))
98 | mask = (upsampling(mask, scale_factor=im_size / mask_size))
99 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
100 | # for c in range(3):
101 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
102 |
103 | # print(mask.size())
104 | mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data)))
105 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
106 | plt.subplot(1+len(masklist), 1, i + 2)
107 | plt.imshow(mask)
108 | plt.axis('off')
109 |
110 | # x = torch.zeros(1, 3, 3)
111 | # out = colorize(x)
112 | # out_im = make_image(out)
113 | # plt.imshow(out_im)
114 | # plt.show()
115 |
--------------------------------------------------------------------------------