├── .gitignore ├── .idea ├── Knowledge-Distillation-PyTorch.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md └── classification ├── LICENSE ├── README.md ├── TRAINING.md ├── cifar.py ├── imagenet.py ├── images ├── KD_total_loss.png ├── Softmax_output.png ├── act_attention.png ├── at_losses.png ├── at_scheme.png ├── attention.png ├── cross_entropy_loss.png ├── fitnet_scheme.png ├── fitnet_stage1.png ├── fitnet_stage2.png ├── fsp_loss.png ├── fsp_matrix.png ├── fsp_stage1.png ├── fsp_stage2.png ├── mmd_loss.png ├── nst_gram.png ├── nst_kernels.png ├── nst_loss.png └── nst_total_loss.png ├── loss ├── distillation.py └── losses.py ├── models ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── cifar │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── alexnet.cpython-36.pyc │ │ ├── densenet.cpython-36.pyc │ │ ├── preresnet.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet_vision.cpython-36.pyc │ │ ├── resnext.cpython-36.pyc │ │ ├── vgg.cpython-36.pyc │ │ └── wrn.cpython-36.pyc │ ├── alexnet.py │ ├── densenet.py │ ├── preresnet.py │ ├── resnet.py │ ├── resnet_vision.py │ ├── resnet_yw.py │ ├── resnext.py │ ├── vgg.py │ └── wrn.py └── imagenet │ ├── __init__.py │ └── resnext.py ├── scripts ├── mimic.sh ├── train.sh └── train_local.sh ├── source.sh ├── tensorboardX ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── crc32c.cpython-36.pyc │ ├── embedding.cpython-36.pyc │ ├── event_file_writer.cpython-36.pyc │ ├── graph.cpython-36.pyc │ ├── graph_onnx.cpython-36.pyc │ ├── record_writer.cpython-36.pyc │ ├── summary.cpython-36.pyc │ ├── writer.cpython-36.pyc │ └── x2num.cpython-36.pyc ├── crc32c.py ├── crc32c.pyc ├── embedding.py ├── embedding.pyc ├── event_file_writer.py ├── event_file_writer.pyc ├── graph.py ├── graph.pyc ├── graph_onnx.py ├── graph_onnx.pyc ├── record_writer.py ├── record_writer.pyc ├── src │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── attr_value_pb2.cpython-36.pyc │ │ ├── event_pb2.cpython-36.pyc │ │ ├── graph_pb2.cpython-36.pyc │ │ ├── node_def_pb2.cpython-36.pyc │ │ ├── plugin_pr_curve_pb2.cpython-36.pyc │ │ ├── resource_handle_pb2.cpython-36.pyc │ │ ├── summary_pb2.cpython-36.pyc │ │ ├── tensor_pb2.cpython-36.pyc │ │ ├── tensor_shape_pb2.cpython-36.pyc │ │ ├── types_pb2.cpython-36.pyc │ │ └── versions_pb2.cpython-36.pyc │ ├── attr_value.proto │ ├── attr_value_pb2.py │ ├── attr_value_pb2.pyc │ ├── event.proto │ ├── event_pb2.py │ ├── event_pb2.pyc │ ├── graph.proto │ ├── graph_pb2.py │ ├── graph_pb2.pyc │ ├── node_def.proto │ ├── node_def_pb2.py │ ├── node_def_pb2.pyc │ ├── plugin_pr_curve.proto │ ├── plugin_pr_curve_pb2.py │ ├── plugin_pr_curve_pb2.pyc │ ├── resource_handle.proto │ ├── resource_handle_pb2.py │ ├── resource_handle_pb2.pyc │ ├── summary.proto │ ├── summary_pb2.py │ ├── summary_pb2.pyc │ ├── tensor.proto │ ├── tensor_pb2.py │ ├── tensor_pb2.pyc │ ├── tensor_shape.proto │ ├── tensor_shape_pb2.py │ ├── tensor_shape_pb2.pyc │ ├── types.proto │ ├── types_pb2.py │ ├── types_pb2.pyc │ ├── versions.proto │ ├── versions_pb2.py │ └── versions_pb2.pyc ├── summary.py ├── summary.pyc ├── writer.py ├── writer.pyc ├── x2num.py └── x2num.pyc ├── utility.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── eval.cpython-36.pyc ├── logger.cpython-36.pyc ├── misc.cpython-36.pyc └── visualize.cpython-36.pyc ├── eval.py ├── images ├── cifar.png └── imagenet.png ├── logger.py ├── misc.py ├── progress ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── bar.cpython-36.pyc │ └── helpers.cpython-36.pyc ├── bar.py ├── counter.py ├── helpers.py └── spinner.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | .idea/ 4 | 5 | classification/.idea 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /.idea/Knowledge-Distillation-PyTorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 wangjiong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge-Distillation-PyTorch 2 | Knowledge Distillation Algorithms implemented with PyTorch \ 3 | Trying to complete various tasks... 4 | ## Directories 5 | - classification\ 6 | Classification on CIFAR-10/100 and ImageNet with PyTorch. \ 7 | Based on repository [bearpaw/pytorch-classification](https://github.com/bearpaw/pytorch-classification) 8 | -------------------------------------------------------------------------------- /classification/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Wei Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /classification/TRAINING.md: -------------------------------------------------------------------------------- 1 | 2 | ## CIFAR-10 3 | 4 | #### AlexNet 5 | ``` 6 | python cifar.py -a alexnet --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/alexnet 7 | ``` 8 | 9 | 10 | #### VGG19 (BN) 11 | ``` 12 | python cifar.py -a vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/vgg19_bn 13 | ``` 14 | 15 | #### ResNet-110 16 | ``` 17 | python cifar.py -a resnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-110 18 | ``` 19 | 20 | #### ResNet-1202 21 | ``` 22 | python cifar.py -a resnet --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-1202 23 | ``` 24 | 25 | #### PreResNet-110 26 | ``` 27 | python cifar.py -a preresnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/preresnet-110 28 | ``` 29 | 30 | #### ResNeXt-29, 8x64d 31 | ``` 32 | python cifar.py -a resnext --depth 29 --cardinality 8 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-8x64d 33 | ``` 34 | #### ResNeXt-29, 16x64d 35 | ``` 36 | python cifar.py -a resnext --depth 29 --cardinality 16 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-16x64d 37 | ``` 38 | 39 | #### WRN-28-10-drop 40 | ``` 41 | python cifar.py -a wrn --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar10/WRN-28-10-drop 42 | ``` 43 | 44 | #### DenseNet-BC (L=100, k=12) 45 | **Note**: 46 | * DenseNet use weight decay value `1e-4`. Larger weight decay (`5e-4`) if harmful for the accuracy (95.46 vs. 94.05) 47 | * Official batch size is 64. But there is no big difference using batchsize 64 or 128 (95.46 vs 95.11). 48 | 49 | ``` 50 | python cifar.py -a densenet --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-100-12 51 | ``` 52 | 53 | #### DenseNet-BC (L=190, k=40) 54 | ``` 55 | python cifar.py -a densenet --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-L190-k40 56 | ``` 57 | 58 | ## CIFAR-100 59 | 60 | #### AlexNet 61 | ``` 62 | python cifar.py -a alexnet --dataset cifar100 --checkpoint checkpoints/cifar100/alexnet --epochs 164 --schedule 81 122 --gamma 0.1 63 | ``` 64 | 65 | #### VGG19 (BN) 66 | ``` 67 | python cifar.py -a vgg19_bn --dataset cifar100 --checkpoint checkpoints/cifar100/vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 68 | ``` 69 | 70 | #### ResNet-110 71 | ``` 72 | python cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110 73 | ``` 74 | 75 | #### ResNet-1202 76 | ``` 77 | python cifar.py -a resnet --dataset cifar100 --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-1202 78 | ``` 79 | 80 | #### PreResNet-110 81 | ``` 82 | python cifar.py -a preresnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/preresnet-110 83 | ``` 84 | 85 | #### ResNeXt-29, 8x64d 86 | ``` 87 | python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 8 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-8x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 88 | ``` 89 | #### ResNeXt-29, 16x64d 90 | ``` 91 | python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 16 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-16x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 92 | ``` 93 | 94 | #### WRN-28-10-drop 95 | ``` 96 | python cifar.py -a wrn --dataset cifar100 --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar100/WRN-28-10-drop 97 | ``` 98 | 99 | #### DenseNet-BC (L=100, k=12) 100 | ``` 101 | python cifar.py -a densenet --dataset cifar100 --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-100-12 102 | ``` 103 | 104 | #### DenseNet-BC (L=190, k=40) 105 | ``` 106 | python cifar.py -a densenet --dataset cifar100 --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-L190-k40 107 | ``` 108 | 109 | ## ImageNet 110 | ### ResNet-18 111 | ``` 112 | python imagenet.py -a resnet18 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet18 113 | ``` 114 | 115 | ### ResNeXt-50 (32x4d) 116 | *(Originally trained on 8xGPUs)* 117 | ``` 118 | python imagenet.py -a resnext50 --base-width 4 --cardinality 32 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnext50-32x4d 119 | ``` -------------------------------------------------------------------------------- /classification/images/KD_total_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/KD_total_loss.png -------------------------------------------------------------------------------- /classification/images/Softmax_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/Softmax_output.png -------------------------------------------------------------------------------- /classification/images/act_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/act_attention.png -------------------------------------------------------------------------------- /classification/images/at_losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/at_losses.png -------------------------------------------------------------------------------- /classification/images/at_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/at_scheme.png -------------------------------------------------------------------------------- /classification/images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/attention.png -------------------------------------------------------------------------------- /classification/images/cross_entropy_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/cross_entropy_loss.png -------------------------------------------------------------------------------- /classification/images/fitnet_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_scheme.png -------------------------------------------------------------------------------- /classification/images/fitnet_stage1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_stage1.png -------------------------------------------------------------------------------- /classification/images/fitnet_stage2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fitnet_stage2.png -------------------------------------------------------------------------------- /classification/images/fsp_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_loss.png -------------------------------------------------------------------------------- /classification/images/fsp_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_matrix.png -------------------------------------------------------------------------------- /classification/images/fsp_stage1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_stage1.png -------------------------------------------------------------------------------- /classification/images/fsp_stage2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/fsp_stage2.png -------------------------------------------------------------------------------- /classification/images/mmd_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/mmd_loss.png -------------------------------------------------------------------------------- /classification/images/nst_gram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_gram.png -------------------------------------------------------------------------------- /classification/images/nst_kernels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_kernels.png -------------------------------------------------------------------------------- /classification/images/nst_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_loss.png -------------------------------------------------------------------------------- /classification/images/nst_total_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/images/nst_total_loss.png -------------------------------------------------------------------------------- /classification/loss/distillation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # from utility import vis_feat 4 | # from loss import adversarial 5 | # from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | # logger = logging.getLogger('SR') 12 | 13 | 14 | def input2feature(x, normalize=-1, pow=1): 15 | return x 16 | 17 | 18 | def input2attention(x, normalize=-1, pow=2): 19 | """ 20 | :param x: 21 | :param normalize: axis to normalize along with 22 | :param pow: order of power 23 | :return: 24 | """ 25 | # follows Attention Transfer 26 | if normalize < 0: 27 | return x.pow(pow).mean(1).view(x.size(0), -1) 28 | else: 29 | return F.normalize(x.pow(pow).mean(1).view(x.size(0), -1)) 30 | 31 | 32 | def gram(x, norm=-1): 33 | """ 34 | :param x: 35 | :param norm: normalization axis 36 | :return: 37 | """ 38 | n, c, h, w = x.size() 39 | x = x.view(n, c, -1) 40 | if norm > 0: 41 | x = F.normalize(x, dim=norm) 42 | return x.transpose(1, 2).bmm(x) 43 | 44 | 45 | def input2gram(x, normalize=2, power=1): 46 | return gram(x.pow(power), norm=normalize) 47 | 48 | 49 | def input2similarity(x, normalize=1, power=1): 50 | return gram(x.pow(power), norm=normalize) 51 | 52 | 53 | def L1Loss(x, y): 54 | return torch.abs(x - y).mean() 55 | 56 | 57 | def L2Loss(x, y): 58 | return (x - y).pow(2).mean() 59 | 60 | 61 | class Distillation(nn.Module): 62 | def __init__(self, supervision='attention', function='L2', normalize=False): 63 | super(Distillation, self).__init__() 64 | if supervision == 'feature': 65 | self.process = input2feature 66 | elif supervision == 'attention': 67 | self.process = input2attention 68 | elif supervision == 'gram': 69 | self.process = input2gram 70 | elif supervision =='similarity': 71 | self.process = input2similarity 72 | else: 73 | logger.info('Supervision type [{}] is not implemented'.format(args.distill_supervision)) 74 | 75 | self.norm = normalize 76 | 77 | if function == 'L1': 78 | self.function = nn.L1Loss() 79 | # self.function = L1Loss 80 | elif function == 'L2': 81 | # self.function = nn.MSELoss() 82 | self.function = L2Loss 83 | else: 84 | logger.info('Choose L1, L2 loss, rather than {}'.format(function)) 85 | 86 | def forward(self, student, teacher, assistant=None, writer=None, batch=None): 87 | """ 88 | :param student: dict of student feature to distillate 89 | :param teacher: dict of teacher feature to distillate 90 | :param assistant: dict of assistant feature to distillate 91 | :param label: label in tensorboard 92 | :param writer: tensorboard writer 93 | :return: 94 | """ 95 | losses = list() 96 | # k: feature position; fs: student feature 97 | for k, fs in student.items(): 98 | if fs is not None: 99 | ft = teacher[k] if teacher is not None else None 100 | fa = assistant[k] if assistant is not None else None 101 | if ft is None: 102 | # logger.info('teacher feature {} is None'.format(k)) 103 | continue 104 | # assistant provides feature residual 105 | fs = fs + fa if fa is not None else fs 106 | # logger.info('{} vs {}'.format(fs.shape, ft.shape)) 107 | # map input features to supervision 108 | fs = self.process(fs, normalize=self.norm) if fs is not None else None 109 | ft = self.process(ft, normalize=self.norm) if ft is not None else None 110 | # fa = self.process(fa) if fa else None 111 | loss = self.function(fs, ft) 112 | losses.append(loss) 113 | if writer and batch % 10 == 0: 114 | name = 'Mimic Loss feat{}'.format(k) 115 | writer.add_scalar('Distill_loss_batch/{}'.format(name), loss, batch + 1) 116 | else: 117 | # logger.info('{}th feature of student: {}'.format(k, fs)) 118 | continue 119 | return torch.sum(torch.stack(losses, dim=0)) if len(losses) > 0 else torch.tensor(0).float().cuda() 120 | -------------------------------------------------------------------------------- /classification/loss/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 7 | # Total loss = kd_loss * alpha + CELoss * (1 - alpha) 8 | def kd_loss(s, t, args=None): 9 | """ 10 | Knowledge Distillation Loss 11 | :param s: student outputs 12 | :param t: teacher outputs 13 | :param args: 14 | :return: 15 | """ 16 | T = args.temperature if args and hasattr(args, 'temperature') else 4 17 | kd_loss = nn.KLDivLoss()(F.log_softmax(s / T, dim=1), F.softmax(t / T, dim=1)) * (T * T) 18 | return kd_loss 19 | 20 | 21 | # attention transfer loss 22 | def attention(x): 23 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 24 | 25 | 26 | def at_loss(student, teacher, args=None): 27 | """ 28 | Attention Transfer Loss 29 | :param student: feature dictionary 30 | :param teacher: feature dictionary 31 | :param args: 32 | :return: 33 | """ 34 | losses = [] 35 | assert len(args.mimic_position) == len(args.mimic_lambda) 36 | for i, p in enumerate(args.mimic_position): 37 | s = student[p] 38 | t = teacher[p] 39 | n, c, h, w = s.shape 40 | _n, _c, _h, _w = t.shape 41 | assert h == _h and w == _w 42 | loss = (attention(s) - attention(t)).pow(2).mean() 43 | losses.append(args.mimic_lambda[i] * loss) 44 | return losses 45 | 46 | 47 | def fm_loss(student, teacher, args=None): 48 | """ 49 | Feature Mimic Loss 50 | :param student: feature dictionary 51 | :param teacher: feature dictionary 52 | :param args: 53 | :return: 54 | """ 55 | losses = [] 56 | assert len(args.mimic_position) == len(args.mimic_lambda) 57 | for i, p in enumerate(args.mimic_position): 58 | s = student[p] 59 | t = teacher[p] 60 | loss = (s - t).pow(2).mean() 61 | losses.append(args.mimic_lambda[i] * loss) 62 | return losses 63 | 64 | 65 | def nst_loss(student, teacher, args=None): 66 | """ 67 | Neural Selectivity Transfer Loss (paper) 68 | :param student: feature dictionary 69 | :param teacher: feature dictionary 70 | :param args: 71 | :return: 72 | """ 73 | losses = [] 74 | assert len(args.mimic_position) == len(args.mimic_lambda) 75 | for i, p in enumerate(args.mimic_position): 76 | s = student[p] 77 | t = teacher[p] 78 | s = F.normalize(s.view(s.shape[0], s.shape[1], -1), dim=1) # N, C, H * W 79 | gram_s = s.transpose(1, 2).bmm(s) # H * W, H * W 80 | assert gram_s.shape[1] == gram_s.shape[2], print("gram_student's shape: {}".format(gram_s.shape)) 81 | 82 | t = F.normalize(t.view(t.shape[0], t.shape[1], -1), dim=1) 83 | gram_t = t.transpose(1, 2).bmm(t) 84 | assert gram_t.shape[1] == gram_t.shape[2], print("gram_teacher's shape: {}".format(gram_t.shape)) 85 | loss = (gram_s - gram_t).pow(2).mean() 86 | losses.append(args.mimic_lambda[i] * loss) 87 | return losses 88 | 89 | 90 | def mmd_loss(student, teacher, args=None): 91 | """ 92 | Maximum Mean Discrepancy Loss (NST Project) 93 | :param student: feature dictionary 94 | :param teacher: feature dictionary 95 | :param args: 96 | :return: 97 | """ 98 | losses = [] 99 | assert len(args.mimic_position) == len(args.mimic_lambda) 100 | for i, p in enumerate(args.mimic_position): 101 | s = student[p] 102 | t = teacher[p] 103 | s = F.normalize(s.view(s.shape[0], s.shape[1], -1), dim=1) # N, C, H * W 104 | mmd_s = s.bmm(s.transpose(2, 1)) # N, C, C 105 | mmd_s_mean = mmd_s.pow(2).mean() 106 | t = F.normalize(t.view(t.shape[0], t.shape[1], -1), dim=1) 107 | mmd_st = s.bmm(t.transpose(2, 1)) 108 | mmd_st_mean = mmd_st.pow(2).mean() 109 | loss = mmd_s_mean - 2 * mmd_st_mean 110 | losses.append(args.mimic_lambda[i] * loss) 111 | return losses 112 | 113 | 114 | def similarity_loss(student, teacher, args=None): 115 | """ 116 | Similarity Transfer Loss 117 | :param student: feature dictionary 118 | :param teacher: feature dictionary 119 | :param args: 120 | :return: 121 | """ 122 | return 123 | 124 | 125 | def fsp_loss(student, teacher, args=None): 126 | """ 127 | Flow of Solving Problem Loss 128 | :param student: feature dictionary 129 | :param teacher: feature dictionary 130 | :param args: 131 | :return: 132 | """ 133 | losses = [] 134 | assert len(args.mimic_position) == 2 * len(args.mimic_lambda) 135 | for i in range(len(args.mimic_theta)): 136 | s1 = student[args.mimic_position[2 * i]] 137 | s2 = student[args.mimic_position[2 * i + 1]] 138 | t1 = teacher[args.mimic_position[2 * i]] 139 | t2 = teacher[args.mimic_position[2 * i + 1]] 140 | s1 = s1.view(s1.shape[0], s1.shape[1], -1) # N, C1, H * W 141 | s2 = s2.view(s2.shape[0], s2.shape[1], -1) # N, C2, H * W 142 | fsp_s = s1.bmm(s2.transpose(1, 2)) / s1.shape[2] # N, C1, C2 143 | 144 | t1 = t1.view(t1.shape[0], t1.shape[1], -1) 145 | t2 = t2.view(t2.shape[0], t2.shape[1], -1) 146 | fsp_t = t1.bmm(t2.transpose(1, 2)) / t1.shape[2] # N, C1, C2 147 | loss = (fsp_s - fsp_t).pow(2).mean() 148 | losses.append(args.mimic_lambda[i] * loss) 149 | return losses 150 | -------------------------------------------------------------------------------- /classification/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/__init__.py -------------------------------------------------------------------------------- /classification/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 4 | architectures: 5 | 6 | - `AlexNet`_ 7 | - `VGG`_ 8 | - `ResNet`_ 9 | - `SqueezeNet`_ 10 | - `DenseNet`_ 11 | 12 | You can construct a model with random weights by calling its constructor: 13 | 14 | .. code:: python 15 | 16 | import torchvision.models as models 17 | resnet18 = models.resnet18() 18 | alexnet = models.alexnet() 19 | squeezenet = models.squeezenet1_0() 20 | densenet = models.densenet_161() 21 | 22 | We provide pre-trained models for the ResNet variants and AlexNet, using the 23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing 24 | ``pretrained=True``: 25 | 26 | .. code:: python 27 | 28 | import torchvision.models as models 29 | resnet18 = models.resnet18(pretrained=True) 30 | alexnet = models.alexnet(pretrained=True) 31 | 32 | ImageNet 1-crop error rates (224x224) 33 | 34 | ======================== ============= ============= 35 | Network Top-1 error Top-5 error 36 | ======================== ============= ============= 37 | ResNet-18 30.24 10.92 38 | ResNet-34 26.70 8.58 39 | ResNet-50 23.85 7.13 40 | ResNet-101 22.63 6.44 41 | ResNet-152 21.69 5.94 42 | Inception v3 22.55 6.44 43 | AlexNet 43.45 20.91 44 | VGG-11 30.98 11.37 45 | VGG-13 30.07 10.75 46 | VGG-16 28.41 9.62 47 | VGG-19 27.62 9.12 48 | SqueezeNet 1.0 41.90 19.58 49 | SqueezeNet 1.1 41.81 19.38 50 | Densenet-121 25.35 7.83 51 | Densenet-169 24.00 7.00 52 | Densenet-201 22.80 6.43 53 | Densenet-161 22.35 6.20 54 | ======================== ============= ============= 55 | 56 | 57 | .. _AlexNet: https://arxiv.org/abs/1404.5997 58 | .. _VGG: https://arxiv.org/abs/1409.1556 59 | .. _ResNet: https://arxiv.org/abs/1512.03385 60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360 61 | .. _DenseNet: https://arxiv.org/abs/1608.06993 62 | """ 63 | 64 | from .alexnet import * 65 | from .vgg import * 66 | from .resnet import * 67 | from .resnet_vision import ResNet as resnet_vision 68 | from .resnext import * 69 | from .wrn import * 70 | from .preresnet import * 71 | from .densenet import * 72 | -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/densenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/densenet.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/preresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/preresnet.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/resnet_vision.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnet_vision.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/resnext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/resnext.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/__pycache__/wrn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/models/cifar/__pycache__/wrn.cpython-36.pyc -------------------------------------------------------------------------------- /classification/models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['alexnet'] 9 | 10 | 11 | class AlexNet(nn.Module): 12 | 13 | def __init__(self, num_classes=10): 14 | super(AlexNet, self).__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=2, stride=2), 19 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2, stride=2), 29 | ) 30 | self.classifier = nn.Linear(256, num_classes) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def alexnet(**kwargs): 40 | r"""AlexNet model architecture from the 41 | `"One weird trick..." `_ paper. 42 | """ 43 | model = AlexNet(**kwargs) 44 | return model 45 | -------------------------------------------------------------------------------- /classification/models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['densenet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Bottleneck(nn.Module): 14 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 15 | super(Bottleneck, self).__init__() 16 | planes = expansion * growthRate 17 | self.bn1 = nn.BatchNorm2d(inplanes) 18 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 21 | padding=1, bias=False) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.dropRate = dropRate 24 | 25 | def forward(self, x): 26 | out = self.bn1(x) 27 | out = self.relu(out) 28 | out = self.conv1(out) 29 | out = self.bn2(out) 30 | out = self.relu(out) 31 | out = self.conv2(out) 32 | if self.dropRate > 0: 33 | out = F.dropout(out, p=self.dropRate, training=self.training) 34 | 35 | out = torch.cat((x, out), 1) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 42 | super(BasicBlock, self).__init__() 43 | planes = expansion * growthRate 44 | self.bn1 = nn.BatchNorm2d(inplanes) 45 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 46 | padding=1, bias=False) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.dropRate = dropRate 49 | 50 | def forward(self, x): 51 | out = self.bn1(x) 52 | out = self.relu(out) 53 | out = self.conv1(out) 54 | if self.dropRate > 0: 55 | out = F.dropout(out, p=self.dropRate, training=self.training) 56 | 57 | out = torch.cat((x, out), 1) 58 | 59 | return out 60 | 61 | 62 | class Transition(nn.Module): 63 | def __init__(self, inplanes, outplanes): 64 | super(Transition, self).__init__() 65 | self.bn1 = nn.BatchNorm2d(inplanes) 66 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 67 | bias=False) 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | def forward(self, x): 71 | out = self.bn1(x) 72 | out = self.relu(out) 73 | out = self.conv1(out) 74 | out = F.avg_pool2d(out, 2) 75 | return out 76 | 77 | 78 | class DenseNet(nn.Module): 79 | 80 | def __init__(self, depth=22, block=Bottleneck, dropRate=0, 81 | num_classes=10, growthRate=12, compressionRate=2): 82 | super(DenseNet, self).__init__() 83 | 84 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 85 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 86 | 87 | self.growthRate = growthRate 88 | self.dropRate = dropRate 89 | 90 | # self.inplanes is a global variable used across multiple 91 | # helper functions 92 | self.inplanes = growthRate * 2 93 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 94 | bias=False) 95 | self.dense1 = self._make_denseblock(block, n) 96 | self.trans1 = self._make_transition(compressionRate) 97 | self.dense2 = self._make_denseblock(block, n) 98 | self.trans2 = self._make_transition(compressionRate) 99 | self.dense3 = self._make_denseblock(block, n) 100 | self.bn = nn.BatchNorm2d(self.inplanes) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.avgpool = nn.AvgPool2d(8) 103 | self.fc = nn.Linear(self.inplanes, num_classes) 104 | 105 | # Weight initialization 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | 114 | def _make_denseblock(self, block, blocks): 115 | layers = [] 116 | for i in range(blocks): 117 | # Currently we fix the expansion ratio as the default value 118 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 119 | self.inplanes += self.growthRate 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def _make_transition(self, compressionRate): 124 | inplanes = self.inplanes 125 | outplanes = int(math.floor(self.inplanes // compressionRate)) 126 | self.inplanes = outplanes 127 | return Transition(inplanes, outplanes) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | 132 | x = self.trans1(self.dense1(x)) 133 | x = self.trans2(self.dense2(x)) 134 | x = self.dense3(x) 135 | x = self.bn(x) 136 | x = self.relu(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def densenet(**kwargs): 146 | """ 147 | Constructs a ResNet model. 148 | """ 149 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /classification/models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['preresnet'] 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """ 19 | 3x3 convolution with padding 20 | :param in_planes: 21 | :param out_planes: 22 | :param stride: 23 | :return: 24 | """ 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.bn1 = nn.BatchNorm2d(inplanes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.bn1(x) 46 | out = self.relu(out) 47 | out = self.conv1(out) 48 | 49 | out = self.bn2(out) 50 | out = self.relu(out) 51 | out = self.conv2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.bn1 = nn.BatchNorm2d(inplanes) 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.bn1(x) 81 | out = self.relu(out) 82 | out = self.conv1(out) 83 | 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | out = self.conv2(out) 87 | 88 | out = self.bn3(out) 89 | out = self.relu(out) 90 | out = self.conv3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | 97 | return out 98 | 99 | 100 | class PreResNet(nn.Module): 101 | 102 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 103 | super(PreResNet, self).__init__() 104 | # Model type specifies number of layers for CIFAR-10 model 105 | if block_name.lower() == 'basicblock': 106 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 107 | n = (depth - 2) // 6 108 | block = BasicBlock 109 | elif block_name.lower() == 'bottleneck': 110 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 111 | n = (depth - 2) // 9 112 | block = Bottleneck 113 | else: 114 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 115 | 116 | self.inplanes = 16 117 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 118 | bias=False) 119 | self.layer1 = self._make_layer(block, 16, n) 120 | self.layer2 = self._make_layer(block, 32, n, stride=2) 121 | self.layer3 = self._make_layer(block, 64, n, stride=2) 122 | self.bn = nn.BatchNorm2d(64 * block.expansion) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.avgpool = nn.AvgPool2d(8) 125 | self.fc = nn.Linear(64 * block.expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.Conv2d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.conv1(x) 153 | 154 | x = self.layer1(x) # 32x32 155 | x = self.layer2(x) # 16x16 156 | x = self.layer3(x) # 8x8 157 | x = self.bn(x) 158 | x = self.relu(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | 167 | def preresnet(**kwargs): 168 | """ 169 | Constructs a ResNet model. 170 | """ 171 | return PreResNet(**kwargs) 172 | -------------------------------------------------------------------------------- /classification/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | __all__ = ['resnet'] 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | """ 18 | 3x3 convolution with padding 19 | :param in_planes: 20 | :param out_planes: 21 | :param stride: 22 | :return: 23 | """ 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 102 | super(ResNet, self).__init__() 103 | # Model type specifies number of layers for CIFAR-10 model 104 | if block_name.lower() == 'basicblock': 105 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 106 | n = (depth - 2) // 6 107 | block = BasicBlock 108 | elif block_name.lower() == 'bottleneck': 109 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 110 | n = (depth - 2) // 9 111 | block = Bottleneck 112 | else: 113 | raise ValueError('block_name shoule be either Basicblock or Bottleneck') 114 | 115 | self.inplanes = 16 116 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(16) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.layer1 = self._make_layer(block, 16, n) 120 | self.layer2 = self._make_layer(block, 32, n, stride=2) 121 | self.layer3 = self._make_layer(block, 64, n, stride=2) 122 | self.avgpool = nn.AvgPool2d(8) 123 | self.fc = nn.Linear(64 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d(self.inplanes, planes * block.expansion, 138 | kernel_size=1, stride=stride, bias=False), 139 | nn.BatchNorm2d(planes * block.expansion), 140 | ) 141 | 142 | layers = list() 143 | layers.append(block(self.inplanes, planes, stride, downsample)) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, blocks): 146 | layers.append(block(self.inplanes, planes)) 147 | 148 | return nn.Sequential(*layers) 149 | 150 | def forward(self, x): 151 | x = self.conv1(x) 152 | x = self.bn1(x) 153 | x = self.relu(x) # 32x32 154 | 155 | f1 = self.layer1(x) # 32x32 156 | f2 = self.layer2(f1) # 16x16 157 | f3 = self.layer3(f2) # 8x8 158 | 159 | x = self.avgpool(f3) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x, {1: f1, 2: f2, 3: f3} 164 | 165 | 166 | def resnet(**kwargs): 167 | """ 168 | Constructs a ResNet model. 169 | """ 170 | return ResNet(**kwargs) 171 | -------------------------------------------------------------------------------- /classification/models/cifar/resnet_yw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['resnet'] 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """ 19 | 3x3 convolution with padding 20 | :param in_planes: 21 | :param out_planes: 22 | :param stride: 23 | :return: 24 | """ 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 103 | super(ResNet, self).__init__() 104 | # Model type specifies number of layers for CIFAR-10 model 105 | if block_name.lower() == 'basicblock': 106 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 107 | n = (depth - 2) // 6 108 | block = BasicBlock 109 | elif block_name.lower() == 'bottleneck': 110 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 111 | n = (depth - 2) // 9 112 | block = Bottleneck 113 | else: 114 | raise ValueError('block_name shoule be either Basicblock or Bottleneck') 115 | 116 | self.inplanes = 16 117 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) 118 | self.bn1 = nn.BatchNorm2d(16) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.layer1 = self._make_layer(block, 16, n) 121 | self.layer2 = self._make_layer(block, 32, n, stride=2) 122 | self.layer3 = self._make_layer(block, 64, n, stride=2) 123 | self.avgpool = nn.AvgPool2d(8) 124 | self.fc = nn.Linear(64 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = list() 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) # 32x32 155 | 156 | x = self.layer1(x) # 32x32 157 | x = self.layer2(x) # 16x16 158 | x = self.layer3(x) # 8x8 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | 167 | def resnet(**kwargs): 168 | """ 169 | Constructs a ResNet model. 170 | """ 171 | return ResNet(**kwargs) 172 | -------------------------------------------------------------------------------- /classification/models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | 16 | class ResNeXtBottleneck(nn.Module): 17 | """ 18 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 19 | """ 20 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 21 | """ Constructor 22 | Args: 23 | in_channels: input channel dimensionality 24 | out_channels: output channel dimensionality 25 | stride: conv stride. Replaces pooling layer. 26 | cardinality: num of convolution groups. 27 | widen_factor: factor to reduce the input dimensionality before convolution. 28 | """ 29 | super(ResNeXtBottleneck, self).__init__() 30 | D = cardinality * out_channels // widen_factor 31 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 32 | self.bn_reduce = nn.BatchNorm2d(D) 33 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 34 | self.bn = nn.BatchNorm2d(D) 35 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 36 | self.bn_expand = nn.BatchNorm2d(out_channels) 37 | 38 | self.shortcut = nn.Sequential() 39 | if in_channels != out_channels: 40 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 41 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 42 | 43 | def forward(self, x): 44 | bottleneck = self.conv_reduce.forward(x) 45 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 46 | bottleneck = self.conv_conv.forward(bottleneck) 47 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 48 | bottleneck = self.conv_expand.forward(bottleneck) 49 | bottleneck = self.bn_expand.forward(bottleneck) 50 | residual = self.shortcut.forward(x) 51 | return F.relu(residual + bottleneck, inplace=True) 52 | 53 | 54 | class CifarResNeXt(nn.Module): 55 | """ 56 | ResNext optimized for the Cifar dataset, as specified in 57 | https://arxiv.org/pdf/1611.05431.pdf 58 | """ 59 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 60 | """ Constructor 61 | Args: 62 | cardinality: number of convolution groups. 63 | depth: number of layers. 64 | num_classes: number of classes 65 | widen_factor: factor to adjust the channel dimensionality 66 | """ 67 | super(CifarResNeXt, self).__init__() 68 | self.cardinality = cardinality 69 | self.depth = depth 70 | self.block_depth = (self.depth - 2) // 9 71 | self.widen_factor = widen_factor 72 | self.num_classes = num_classes 73 | self.output_size = 64 74 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 75 | 76 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 77 | self.bn_1 = nn.BatchNorm2d(64) 78 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 79 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 80 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 81 | self.classifier = nn.Linear(1024, num_classes) 82 | init.kaiming_normal(self.classifier.weight) 83 | 84 | for key in self.state_dict(): 85 | if key.split('.')[-1] == 'weight': 86 | if 'conv' in key: 87 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 88 | if 'bn' in key: 89 | self.state_dict()[key][...] = 1 90 | elif key.split('.')[-1] == 'bias': 91 | self.state_dict()[key][...] = 0 92 | 93 | def block(self, name, in_channels, out_channels, pool_stride=2): 94 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 95 | Args: 96 | name: string name of the current block. 97 | in_channels: number of input channels 98 | out_channels: number of output channels 99 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 100 | Returns: a Module consisting of n sequential bottlenecks. 101 | """ 102 | block = nn.Sequential() 103 | for bottleneck in range(self.block_depth): 104 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 105 | if bottleneck == 0: 106 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 107 | self.widen_factor)) 108 | else: 109 | block.add_module(name_, 110 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 111 | return block 112 | 113 | def forward(self, x): 114 | x = self.conv_1_3x3.forward(x) 115 | x = F.relu(self.bn_1.forward(x), inplace=True) 116 | x = self.stage_1.forward(x) 117 | x = self.stage_2.forward(x) 118 | x = self.stage_3.forward(x) 119 | x = F.avg_pool2d(x, 8, 1) 120 | x = x.view(-1, 1024) 121 | return self.classifier(x) 122 | 123 | def resnext(**kwargs): 124 | """Constructs a ResNeXt. 125 | """ 126 | model = CifarResNeXt(**kwargs) 127 | return model -------------------------------------------------------------------------------- /classification/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | n = m.weight.size(1) 49 | m.weight.data.normal_(0, 0.01) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, batch_norm=False): 54 | layers = [] 55 | in_channels = 3 56 | for v in cfg: 57 | if v == 'M': 58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 59 | else: 60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 61 | if batch_norm: 62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 63 | else: 64 | layers += [conv2d, nn.ReLU(inplace=True)] 65 | in_channels = v 66 | return nn.Sequential(*layers) 67 | 68 | 69 | cfg = { 70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 74 | } 75 | 76 | 77 | def vgg11(**kwargs): 78 | """VGG 11-layer model (configuration "A") 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = VGG(make_layers(cfg['A']), **kwargs) 84 | return model 85 | 86 | 87 | def vgg11_bn(**kwargs): 88 | """VGG 11-layer model (configuration "A") with batch normalization""" 89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 90 | return model 91 | 92 | 93 | def vgg13(**kwargs): 94 | """VGG 13-layer model (configuration "B") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B']), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(**kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(**kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['D']), **kwargs) 116 | return model 117 | 118 | 119 | def vgg16_bn(**kwargs): 120 | """VGG 16-layer model (configuration "D") with batch normalization""" 121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 122 | return model 123 | 124 | 125 | def vgg19(**kwargs): 126 | """VGG 19-layer model (configuration "E") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['E']), **kwargs) 132 | return model 133 | 134 | 135 | def vgg19_bn(**kwargs): 136 | """VGG 19-layer model (configuration 'E') with batch normalization""" 137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /classification/models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['wrn'] 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | 25 | def forward(self, x): 26 | if not self.equalInOut: 27 | x = self.relu1(self.bn1(x)) 28 | else: 29 | out = self.relu1(self.bn1(x)) 30 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 31 | if self.droprate > 0: 32 | out = F.dropout(out, p=self.droprate, training=self.training) 33 | out = self.conv2(out) 34 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 35 | 36 | 37 | class NetworkBlock(nn.Module): 38 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 39 | super(NetworkBlock, self).__init__() 40 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 41 | 42 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 43 | layers = [] 44 | for i in range(nb_layers): 45 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | return self.layer(x) 50 | 51 | 52 | class WideResNet(nn.Module): 53 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 54 | super(WideResNet, self).__init__() 55 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 56 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 57 | n = (depth - 4) // 6 58 | block = BasicBlock 59 | # 1st conv before any network block 60 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 61 | padding=1, bias=False) 62 | # 1st block 63 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out1 = self.block1(out) 87 | out2 = self.block2(out1) 88 | out3 = self.block3(out2) 89 | out = self.relu(self.bn1(out3)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out), {1: out1, 2: out2, 3: out3} 93 | 94 | 95 | def wrn(**kwargs): 96 | """ 97 | Constructs a Wide Residual Networks. 98 | """ 99 | model = WideResNet(**kwargs) 100 | return model 101 | -------------------------------------------------------------------------------- /classification/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnext import * 4 | -------------------------------------------------------------------------------- /classification/models/imagenet/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua 8 | """ 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | import torch 14 | 15 | __all__ = ['resnext50', 'resnext101', 'resnext152'] 16 | 17 | class Bottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C 20 | """ 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | baseWidth: base width. 29 | cardinality: num of convolution groups. 30 | stride: conv stride. Replaces pooling layer. 31 | """ 32 | super(Bottleneck, self).__init__() 33 | 34 | D = int(math.floor(planes * (baseWidth / 64))) 35 | C = cardinality 36 | 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(D*C) 39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 40 | self.bn2 = nn.BatchNorm2d(D*C) 41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * 4) 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | self.downsample = downsample 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | """ 72 | ResNext optimized for the ImageNet dataset, as specified in 73 | https://arxiv.org/pdf/1611.05431.pdf 74 | """ 75 | def __init__(self, baseWidth, cardinality, layers, num_classes): 76 | """ Constructor 77 | Args: 78 | baseWidth: baseWidth for ResNeXt. 79 | cardinality: number of convolution groups. 80 | layers: config of layers, e.g., [3, 4, 6, 3] 81 | num_classes: number of classes 82 | """ 83 | super(ResNeXt, self).__init__() 84 | block = Bottleneck 85 | 86 | self.cardinality = cardinality 87 | self.baseWidth = baseWidth 88 | self.num_classes = num_classes 89 | self.inplanes = 64 90 | self.output_size = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 113 | Args: 114 | block: block type used to construct ResNext 115 | planes: number of output channels (need to multiply by block.expansion) 116 | blocks: number of blocks to be built 117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 118 | Returns: a Module consisting of n sequential bottlenecks. 119 | """ 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool1(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | 149 | return x 150 | 151 | 152 | def resnext50(baseWidth, cardinality): 153 | """ 154 | Construct ResNeXt-50. 155 | """ 156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000) 157 | return model 158 | 159 | 160 | def resnext101(baseWidth, cardinality): 161 | """ 162 | Construct ResNeXt-101. 163 | """ 164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) 165 | return model 166 | 167 | 168 | def resnext152(baseWidth, cardinality): 169 | """ 170 | Construct ResNeXt-152. 171 | """ 172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) 173 | return model 174 | -------------------------------------------------------------------------------- /classification/scripts/mimic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet 3 | depth=20 4 | dataset=cifar100 5 | epochs=200 6 | lr=1e-1 7 | supervision=feature 8 | 9 | exp=${model}${depth}_${dataset}_${epochs}e_${lr}_resnet50_6908_l21_ce1e-1 10 | path=${dataset}/${model}/${supervision} 11 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp 12 | mkdir -p ${EXP_DIR}/${path}/${exp}/model 13 | now=$(date +"%Y%m%d_%H%M%S") 14 | 15 | part=Pixel2 16 | numGPU=1 17 | nodeGPU=1 18 | pg-run -rn ${path}/${exp} -c " 19 | srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \ 20 | python -u cifar.py \ 21 | --dataset ${dataset} \ 22 | --data_dir /mnt/lustre21/wangjiong/Data_t1/datasets/cifar/${dataset} \ 23 | --workers 4 \ 24 | --reset \ 25 | --epochs ${epochs} \ 26 | --start-epoch 0 \ 27 | --train-batch 256 \ 28 | --test-batch 200 \ 29 | \ 30 | --lr ${lr} \ 31 | --drop 0 \ 32 | --schedule 100 150 \ 33 | --gamma 0.1 \ 34 | --weight-decay 1e-4 \ 35 | --checkpoint ${EXP_DIR}/${path}/${exp}/model \ 36 | \ 37 | --arch ${model} \ 38 | --depth ${depth} \ 39 | \ 40 | --teacher \ 41 | --teacher_arch resnet \ 42 | --teacher_depth 50 \ 43 | --teacher_block_name BasicBlock \ 44 | --teacher_checkpoint ${EXP_DIR}/../../model_zoo/${dataset}/resnet/resnet50_69_08.pth.tar \ 45 | \ 46 | --mimic_mean ${supervision} \ 47 | --mimic_function L2 \ 48 | --normalize 1 \ 49 | \ 50 | --manualSeed 345 \ 51 | --gpu-id 0 \ 52 | 2>&1 | tee -a ${EXP_DIR}/${path}/${exp}/train-${now}.log \ 53 | & \ 54 | " 55 | -------------------------------------------------------------------------------- /classification/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet 3 | depth=110 4 | dataset=cifar10 5 | epochs=220 6 | lr=1e-1 7 | 8 | exp=${model}${depth}_${dataset}_${epochs}e_${lr} 9 | path=${dataset}/${model} 10 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp 11 | mkdir -p ${EXP_DIR}/${path}/${exp}/model 12 | now=$(date +"%Y%m%d_%H%M%S") 13 | 14 | part=Pixel 15 | numGPU=1 16 | nodeGPU=1 17 | pg-run -rn ${path}/${exp} -c " 18 | srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \ 19 | python -u cifar.py \ 20 | --dataset ${dataset} \ 21 | --data_dir /mnt/lustre21/wangjiong/Data_t1/datasets/cifar/${dataset} \ 22 | --workers 4 \ 23 | --reset \ 24 | --epochs ${epochs} \ 25 | --start-epoch 0 \ 26 | --train-batch 256 \ 27 | --test-batch 200 \ 28 | \ 29 | --lr ${lr} \ 30 | --drop 0 \ 31 | --schedule 100 150 \ 32 | --gamma 0.1 \ 33 | --weight-decay 1e-4 \ 34 | --checkpoint ${EXP_DIR}/${path}/${exp}/model \ 35 | \ 36 | --arch ${model} \ 37 | --depth ${depth} \ 38 | \ 39 | --gpu-id 0 \ 40 | 2>&1 | tee -a ${EXP_DIR}/${path}/${exp}/train-${now}.log \ 41 | & \ 42 | " 43 | -------------------------------------------------------------------------------- /classification/scripts/train_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model=resnet_vision 3 | depth=18 4 | dataset=cifar10 5 | epochs=50 6 | lr=1e-1 7 | 8 | exp=${model}${depth}_${dataset}_${epochs}e_${lr} 9 | path=${dataset}/${model} 10 | EXP_DIR=/mnt/lustre21/wangjiong/classification_school/playground/exp 11 | # mkdir -p ${EXP_DIR}/${path}/${exp}/model 12 | mkdir -p ./model 13 | now=$(date +"%Y%m%d_%H%M%S") 14 | 15 | part=Pixel 16 | numGPU=1 17 | nodeGPU=1 18 | # pg-run -rn ${exp} -c " 19 | # srun -p ${part} --job-name=${exp} --gres=gpu:${nodeGPU} -n ${numGPU} --ntasks-per-node=${nodeGPU} \ 20 | python -u cifar.py \ 21 | --dataset ${dataset} \ 22 | --data_dir ./data \ 23 | --workers 4 \ 24 | --epochs 5 \ 25 | --start-epoch 0 \ 26 | --train-batch 256 \ 27 | --test-batch 200 \ 28 | \ 29 | --lr ${lr} \ 30 | --drop 0 \ 31 | --schedule 10 20 \ 32 | --weight-decay 1e-4 \ 33 | --checkpoint ./model \ 34 | \ 35 | --arch ${model} \ 36 | --depth ${depth} \ 37 | --cardinality 32 \ 38 | --widen-factor 4 \ 39 | \ 40 | --manualSeed 345 \ 41 | --gpu-id 0 \ 42 | 2>&1 | tee -a ./model/train-${now}.log \ 43 | & \ 44 | # " 45 | -------------------------------------------------------------------------------- /classification/source.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH=/mnt/lustre/wangjiong/Data_t1/anaconda3/lib:$LD_LIBRARY_PATH 2 | export LD_LIBRARY_PATH=/mnt/lustre/share/cuda-9.0/lib64:$LD_LIBRARY_PATH 3 | export LD_LIBRARY_PATH=/mnt/lustre/share/nccl_2.1.15-1+cuda9.0_x86_64/lib:$LD_LIBRARY_PATH 4 | export LD_LIBRARY_PATH=/mnt/lustre/share/intel64/lib/:$LD_LIBRARY_PATH 5 | export PATH=/mnt/lustre/share/cuda-9.0/bin:$PATH 6 | export PATH=/mnt/lustre/wangjiong/Data_t1/anaconda3/bin:$PATH 7 | # export PATH=/usr/bin:$PATH 8 | source activate /mnt/lustre/share/fundamental-support/envs/pytorch-0.4.0 9 | # source activate pytorch0.4.0-py3 10 | # source activate pytorch 11 | -------------------------------------------------------------------------------- /classification/tensorboardX/__init__.py: -------------------------------------------------------------------------------- 1 | """A module for visualization with tensorboard 2 | """ 3 | 4 | from .writer import FileWriter, SummaryWriter 5 | from .record_writer import RecordWriter 6 | -------------------------------------------------------------------------------- /classification/tensorboardX/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__init__.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/crc32c.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/crc32c.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/embedding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/embedding.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/event_file_writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/event_file_writer.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/graph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/graph.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/graph_onnx.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/graph_onnx.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/record_writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/record_writer.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/summary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/summary.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/writer.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/__pycache__/x2num.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/__pycache__/x2num.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/crc32c.py: -------------------------------------------------------------------------------- 1 | import array 2 | 3 | 4 | CRC_TABLE = ( 5 | 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 6 | 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, 7 | 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 8 | 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 9 | 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 10 | 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, 11 | 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 12 | 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, 13 | 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 14 | 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 15 | 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 16 | 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, 17 | 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 18 | 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, 19 | 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 20 | 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 21 | 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 22 | 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, 23 | 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 24 | 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, 25 | 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 26 | 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 27 | 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 28 | 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, 29 | 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 30 | 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, 31 | 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 32 | 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 33 | 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 34 | 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, 35 | 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 36 | 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, 37 | 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 38 | 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 39 | 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 40 | 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, 41 | 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 42 | 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, 43 | 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 44 | 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 45 | 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 46 | 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, 47 | 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 48 | 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, 49 | 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 50 | 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 51 | 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 52 | 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, 53 | 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 54 | 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, 55 | 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 56 | 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 57 | 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 58 | 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, 59 | 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 60 | 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, 61 | 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 62 | 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 63 | 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 64 | 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, 65 | 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 66 | 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, 67 | 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 68 | 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, 69 | ) 70 | 71 | 72 | CRC_INIT = 0 73 | 74 | _MASK = 0xFFFFFFFF 75 | 76 | 77 | def crc_update(crc, data): 78 | """Update CRC-32C checksum with data. 79 | 80 | Args: 81 | crc: 32-bit checksum to update as long. 82 | data: byte array, string or iterable over bytes. 83 | 84 | Returns: 85 | 32-bit updated CRC-32C as long. 86 | """ 87 | 88 | if type(data) != array.array or data.itemsize != 1: 89 | buf = array.array("B", data) 90 | else: 91 | buf = data 92 | 93 | crc ^= _MASK 94 | for b in buf: 95 | table_index = (crc ^ b) & 0xff 96 | crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK 97 | return crc ^ _MASK 98 | 99 | 100 | def crc_finalize(crc): 101 | """Finalize CRC-32C checksum. 102 | 103 | This function should be called as last step of crc calculation. 104 | 105 | Args: 106 | crc: 32-bit checksum as long. 107 | 108 | Returns: 109 | finalized 32-bit checksum as long 110 | """ 111 | return crc & _MASK 112 | 113 | 114 | def crc32c(data): 115 | """Compute CRC-32C checksum of the data. 116 | 117 | Args: 118 | data: byte array, string or iterable over bytes. 119 | 120 | Returns: 121 | 32-bit CRC-32C checksum of data as long. 122 | """ 123 | return crc_finalize(crc_update(CRC_INIT, data)) 124 | -------------------------------------------------------------------------------- /classification/tensorboardX/crc32c.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/crc32c.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_tsv(metadata, save_path): 5 | metadata = [str(x) for x in metadata] 6 | with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f: 7 | for x in metadata: 8 | f.write(x + '\n') 9 | 10 | 11 | # https://github.com/tensorflow/tensorboard/issues/44 image label will be squared 12 | def make_sprite(label_img, save_path): 13 | import math 14 | import torch 15 | import torchvision 16 | from .x2num import makenp 17 | # this ensures the sprite image has correct dimension as described in 18 | # https://www.tensorflow.org/get_started/embedding_viz 19 | nrow = int(math.ceil((label_img.size(0)) ** 0.5)) 20 | 21 | label_img = torch.from_numpy(makenp(label_img)) # for other framework 22 | # augment images so that #images equals nrow*nrow 23 | label_img = torch.cat((label_img, torch.randn(nrow ** 2 - label_img.size(0), *label_img.size()[1:]) * 255), 0) 24 | 25 | torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0) 26 | 27 | 28 | def append_pbtxt(metadata, label_img, save_path, global_step, tag): 29 | with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f: 30 | # step = os.path.split(save_path)[-1] 31 | f.write('embeddings {\n') 32 | f.write('tensor_name: "{}:{}"\n'.format(tag, global_step)) 33 | f.write('tensor_path: "{}"\n'.format(os.path.join(global_step, 'tensors.tsv'))) 34 | if metadata is not None: 35 | f.write('metadata_path: "{}"\n'.format(os.path.join(global_step, 'metadata.tsv'))) 36 | if label_img is not None: 37 | f.write('sprite {\n') 38 | f.write('image_path: "{}"\n'.format(os.path.join(global_step, 'sprite.png'))) 39 | f.write('single_image_dim: {}\n'.format(label_img.size(3))) 40 | f.write('single_image_dim: {}\n'.format(label_img.size(2))) 41 | f.write('}\n') 42 | f.write('}\n') 43 | 44 | 45 | def make_mat(matlist, save_path): 46 | with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f: 47 | for x in matlist: 48 | x = [str(i) for i in x] 49 | f.write('\t'.join(x) + '\n') 50 | -------------------------------------------------------------------------------- /classification/tensorboardX/embedding.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/embedding.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/event_file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Writes events to disk in a logdir.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import logging 22 | import os.path 23 | import socket 24 | import threading 25 | import time 26 | 27 | import six 28 | 29 | from .src import event_pb2 30 | from .record_writer import RecordWriter 31 | 32 | 33 | def directory_check(path): 34 | '''Initialize the directory for log files.''' 35 | # If the direcotry does not exist, create it! 36 | if not os.path.exists(path): 37 | os.makedirs(path) 38 | 39 | 40 | class EventsWriter(object): 41 | '''Writes `Event` protocol buffers to an event file.''' 42 | 43 | def __init__(self, file_prefix): 44 | ''' 45 | Events files have a name of the form 46 | '/some/file/path/events.out.tfevents.[timestamp].[hostname]' 47 | ''' 48 | self._file_prefix = file_prefix + ".out.tfevents." + str(time.time())[:10] + "." + socket.gethostname() 49 | 50 | # Open(Create) the log file with the particular form of name. 51 | logging.basicConfig(filename=self._file_prefix) 52 | 53 | self._num_outstanding_events = 0 54 | 55 | self._py_recordio_writer = RecordWriter(self._file_prefix) 56 | 57 | # Initialize an event instance. 58 | self._event = event_pb2.Event() 59 | 60 | self._event.wall_time = time.time() 61 | 62 | self.write_event(self._event) 63 | 64 | def write_event(self, event): 65 | '''Append "event" to the file.''' 66 | 67 | # Check if event is of type event_pb2.Event proto. 68 | if not isinstance(event, event_pb2.Event): 69 | raise TypeError("Expected an event_pb2.Event proto, " 70 | " but got %s" % type(event)) 71 | return self._write_serialized_event(event.SerializeToString()) 72 | 73 | def _write_serialized_event(self, event_str): 74 | self._num_outstanding_events += 1 75 | self._py_recordio_writer.write(event_str) 76 | 77 | def flush(self): 78 | '''Flushes the event file to disk.''' 79 | self._num_outstanding_events = 0 80 | return True 81 | 82 | def close(self): 83 | '''Call self.flush().''' 84 | return_value = self.flush() 85 | return return_value 86 | 87 | 88 | class EventFileWriter(object): 89 | """Writes `Event` protocol buffers to an event file. 90 | The `EventFileWriter` class creates an event file in the specified directory, 91 | and asynchronously writes Event protocol buffers to the file. The Event file 92 | is encoded using the tfrecord format, which is similar to RecordIO. 93 | @@__init__ 94 | @@add_event 95 | @@flush 96 | @@close 97 | """ 98 | 99 | def __init__(self, logdir, max_queue=10, flush_secs=120): 100 | """Creates a `EventFileWriter` and an event file to write to. 101 | On construction the summary writer creates a new event file in `logdir`. 102 | This event file will contain `Event` protocol buffers, which are written to 103 | disk via the add_event method. 104 | The other arguments to the constructor control the asynchronous writes to 105 | the event file: 106 | * `flush_secs`: How often, in seconds, to flush the added summaries 107 | and events to disk. 108 | * `max_queue`: Maximum number of summaries or events pending to be 109 | written to disk before one of the 'add' calls block. 110 | Args: 111 | logdir: A string. Directory where event file will be written. 112 | max_queue: Integer. Size of the queue for pending events and summaries. 113 | flush_secs: Number. How often, in seconds, to flush the 114 | pending events and summaries to disk. 115 | """ 116 | self._logdir = logdir 117 | directory_check(self._logdir) 118 | self._event_queue = six.moves.queue.Queue(max_queue) 119 | self._ev_writer = EventsWriter(os.path.join(self._logdir, "events")) 120 | self._closed = False 121 | self._worker = _EventLoggerThread(self._event_queue, self._ev_writer, 122 | flush_secs) 123 | 124 | self._worker.start() 125 | 126 | def get_logdir(self): 127 | """Returns the directory where event file will be written.""" 128 | return self._logdir 129 | 130 | def reopen(self): 131 | """Reopens the EventFileWriter. 132 | Can be called after `close()` to add more events in the same directory. 133 | The events will go into a new events file. 134 | Does nothing if the EventFileWriter was not closed. 135 | """ 136 | if self._closed: 137 | self._closed = False 138 | 139 | def add_event(self, event): 140 | """Adds an event to the event file. 141 | Args: 142 | event: An `Event` protocol buffer. 143 | """ 144 | if not self._closed: 145 | self._event_queue.put(event) 146 | 147 | def flush(self): 148 | """Flushes the event file to disk. 149 | Call this method to make sure that all pending events have been written to 150 | disk. 151 | """ 152 | self._event_queue.join() 153 | self._ev_writer.flush() 154 | 155 | def close(self): 156 | """Flushes the event file to disk and close the file. 157 | Call this method when you do not need the summary writer anymore. 158 | """ 159 | self.flush() 160 | self._ev_writer.close() 161 | self._closed = True 162 | 163 | 164 | class _EventLoggerThread(threading.Thread): 165 | """Thread that logs events.""" 166 | 167 | def __init__(self, queue, ev_writer, flush_secs): 168 | """Creates an _EventLoggerThread. 169 | Args: 170 | queue: A Queue from which to dequeue events. 171 | ev_writer: An event writer. Used to log brain events for 172 | the visualizer. 173 | flush_secs: How often, in seconds, to flush the 174 | pending file to disk. 175 | """ 176 | threading.Thread.__init__(self) 177 | self.daemon = True 178 | self._queue = queue 179 | self._ev_writer = ev_writer 180 | self._flush_secs = flush_secs 181 | # The first event will be flushed immediately. 182 | self._next_event_flush_time = 0 183 | 184 | def run(self): 185 | while True: 186 | event = self._queue.get() 187 | try: 188 | self._ev_writer.write_event(event) 189 | # Flush the event writer every so often. 190 | now = time.time() 191 | if now > self._next_event_flush_time: 192 | self._ev_writer.flush() 193 | # Do it again in two minutes. 194 | self._next_event_flush_time = now + self._flush_secs 195 | finally: 196 | self._queue.task_done() 197 | -------------------------------------------------------------------------------- /classification/tensorboardX/event_file_writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/event_file_writer.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/graph.py: -------------------------------------------------------------------------------- 1 | from .src.graph_pb2 import GraphDef 2 | from .src.node_def_pb2 import NodeDef 3 | from .src.versions_pb2 import VersionDef 4 | from .src.attr_value_pb2 import AttrValue 5 | from .src.tensor_shape_pb2 import TensorShapeProto 6 | 7 | 8 | def replace(name, scope): 9 | return '/'.join([scope[name], name]) 10 | 11 | 12 | def parse(graph): 13 | scope = {} 14 | for n in graph.nodes(): 15 | inputs = [i.uniqueName() for i in n.inputs()] 16 | for i in range(1, len(inputs)): 17 | scope[inputs[i]] = n.scopeName() 18 | 19 | uname = next(n.outputs()).uniqueName() 20 | assert n.scopeName() != '', '{} has empty scope name'.format(n) 21 | scope[uname] = n.scopeName() 22 | scope['0'] = 'input' 23 | 24 | nodes = [] 25 | for n in graph.nodes(): 26 | attrs = {k: n[k] for k in n.attributeNames()} 27 | attrs = str(attrs).replace("'", ' ') # singlequote will be escaped by tensorboard 28 | inputs = [replace(i.uniqueName(), scope) for i in n.inputs()] 29 | uname = next(n.outputs()).uniqueName() 30 | nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs}) 31 | 32 | for n in graph.inputs(): 33 | uname = n.uniqueName() 34 | if uname not in scope.keys(): 35 | scope[uname] = 'unused' 36 | nodes.append({'name': replace(uname, scope), 'op': 'Parameter', 'inputs': [], 'attr': str(n.type())}) 37 | 38 | return nodes 39 | 40 | 41 | def graph(model, args, verbose=False): 42 | import torch 43 | with torch.onnx.set_training(model, False): 44 | trace, _ = torch.jit.trace(model, args) 45 | torch.onnx._optimize_trace(trace, False) 46 | graph = trace.graph() 47 | if verbose: 48 | print(graph) 49 | list_of_nodes = parse(graph) 50 | nodes = [] 51 | for node in list_of_nodes: 52 | nodes.append( 53 | NodeDef(name=node['name'], op=node['op'], input=node['inputs'], 54 | attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8'))})) 55 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 56 | -------------------------------------------------------------------------------- /classification/tensorboardX/graph.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/graph.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/graph_onnx.py: -------------------------------------------------------------------------------- 1 | from .src.graph_pb2 import GraphDef 2 | from .src.node_def_pb2 import NodeDef 3 | from .src.versions_pb2 import VersionDef 4 | from .src.attr_value_pb2 import AttrValue 5 | from .src.tensor_shape_pb2 import TensorShapeProto 6 | # from .src.onnx_pb2 import ModelProto 7 | 8 | 9 | def gg(fname): 10 | import onnx # 0.2.1 11 | m = onnx.load(fname) 12 | nodes_proto = [] 13 | nodes = [] 14 | g = m.graph 15 | import itertools 16 | for node in itertools.chain(g.input, g.output): 17 | nodes_proto.append(node) 18 | 19 | for node in nodes_proto: 20 | shapeproto = TensorShapeProto( 21 | dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim]) 22 | nodes.append(NodeDef( 23 | name=node.name, 24 | op='Variable', 25 | input=[], 26 | attr={ 27 | 'dtype': AttrValue(type=node.type.tensor_type.elem_type), 28 | 'shape': AttrValue(shape=shapeproto), 29 | }) 30 | ) 31 | 32 | for node in g.node: 33 | attr = [] 34 | for s in node.attribute: 35 | attr.append(' = '.join([str(f[1]) for f in s.ListFields()])) 36 | attr = ', '.join(attr).encode(encoding='utf_8') 37 | 38 | nodes.append(NodeDef( 39 | name=node.output[0], 40 | op=node.op_type, 41 | input=node.input, 42 | attr={'parameters': AttrValue(s=attr)}, 43 | )) 44 | # two pass token replacement, appends opname to object id 45 | mapping = {} 46 | for node in nodes: 47 | mapping[node.name] = node.op + '_' + node.name 48 | 49 | nodes, mapping = updatenodes(nodes, mapping) 50 | mapping = smartGrouping(nodes, mapping) 51 | nodes, mapping = updatenodes(nodes, mapping) 52 | 53 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 54 | 55 | 56 | def updatenodes(nodes, mapping): 57 | for node in nodes: 58 | newname = mapping[node.name] 59 | node.name = newname 60 | newinput = [] 61 | for inputnode in list(node.input): 62 | newinput.append(mapping[inputnode]) 63 | node.input.remove(inputnode) 64 | node.input.extend(newinput) 65 | newmap = {} 66 | for k, v in mapping.items(): 67 | newmap[v] = v 68 | return nodes, newmap 69 | 70 | 71 | def findnode(nodes, name): 72 | """ input: node name 73 | returns: node object 74 | """ 75 | for n in nodes: 76 | if n.name == name: 77 | return n 78 | 79 | 80 | def parser(s, nodes, node): 81 | print(s) 82 | if len(s) == 0: 83 | return 84 | if len(s) > 0: 85 | if s[0] == node.op: 86 | print(s[0], node.name, s[1], node.input) 87 | for n in node.input: 88 | print(n, s[1]) 89 | parser(s[1], nodes, findnode(nodes, n)) 90 | else: 91 | return False 92 | 93 | 94 | # TODO: use recursive parse 95 | 96 | def smartGrouping(nodes, mapping): 97 | # a Fully Conv is: (TODO: check var1.size(0)==var2.size(0)) 98 | # GEMM <-- Variable (c1) 99 | # ^-- Transpose (c2) <-- Variable (c3) 100 | 101 | # a Conv with bias is: (TODO: check var1.size(0)==var2.size(0)) 102 | # Add <-- Conv (c2) <-- Variable (c3) 103 | # ^-- Variable (c1) 104 | # 105 | # gemm = ('Gemm', ('Variable', ('Transpose', ('Variable')))) 106 | 107 | FCcounter = 1 108 | Convcounter = 1 109 | for node in nodes: 110 | if node.op == 'Gemm': 111 | c1 = c2 = c3 = False 112 | for name_in in node.input: 113 | n = findnode(nodes, name_in) 114 | if n.op == 'Variable': 115 | c1 = True 116 | c1name = n.name 117 | if n.op == 'Transpose': 118 | c2 = True 119 | c2name = n.name 120 | if len(n.input) == 1: 121 | nn = findnode(nodes, n.input[0]) 122 | if nn.op == 'Variable': 123 | c3 = True 124 | c3name = nn.name 125 | # print(n.op, n.name, c1, c2, c3) 126 | if c1 and c2 and c3: 127 | # print(c1name, c2name, c3name) 128 | mapping[c1name] = 'FC{}/{}'.format(FCcounter, c1name) 129 | mapping[c2name] = 'FC{}/{}'.format(FCcounter, c2name) 130 | mapping[c3name] = 'FC{}/{}'.format(FCcounter, c3name) 131 | mapping[node.name] = 'FC{}/{}'.format(FCcounter, node.name) 132 | FCcounter += 1 133 | continue 134 | if node.op == 'Add': 135 | c1 = c2 = c3 = False 136 | for name_in in node.input: 137 | n = findnode(nodes, name_in) 138 | if n.op == 'Variable': 139 | c1 = True 140 | c1name = n.name 141 | if n.op == 'Conv': 142 | c2 = True 143 | c2name = n.name 144 | if len(n.input) >= 1: 145 | for nn_name in n.input: 146 | nn = findnode(nodes, nn_name) 147 | if nn.op == 'Variable': 148 | c3 = True 149 | c3name = nn.name 150 | 151 | if c1 and c2 and c3: 152 | # print(c1name, c2name, c3name) 153 | mapping[c1name] = 'Conv{}/{}'.format(Convcounter, c1name) 154 | mapping[c2name] = 'Conv{}/{}'.format(Convcounter, c2name) 155 | mapping[c3name] = 'Conv{}/{}'.format(Convcounter, c3name) 156 | mapping[node.name] = 'Conv{}/{}'.format(Convcounter, node.name) 157 | Convcounter += 1 158 | return mapping 159 | -------------------------------------------------------------------------------- /classification/tensorboardX/graph_onnx.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/graph_onnx.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/record_writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | To write tf_record into file. Here we use it for tensorboard's event writting. 3 | The code was borrow from https://github.com/TeamHG-Memex/tensorboard_logger 4 | """ 5 | 6 | import re 7 | import struct 8 | 9 | from .crc32c import crc32c 10 | 11 | _VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]') 12 | _VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+') 13 | 14 | 15 | class RecordWriter(object): 16 | def __init__(self, path, flush_secs=2): 17 | self._name_to_tf_name = {} 18 | self._tf_names = set() 19 | self.path = path 20 | self.flush_secs = flush_secs # TODO. flush every flush_secs, not every time. 21 | self._writer = None 22 | self._writer = open(path, 'wb') 23 | 24 | def write(self, event_str): 25 | w = self._writer.write 26 | header = struct.pack('Q', len(event_str)) 27 | w(header) 28 | w(struct.pack('I', masked_crc32c(header))) 29 | w(event_str) 30 | w(struct.pack('I', masked_crc32c(event_str))) 31 | self._writer.flush() 32 | 33 | 34 | def masked_crc32c(data): 35 | x = u32(crc32c(data)) 36 | return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8) 37 | 38 | 39 | def u32(x): 40 | return x & 0xffffffff 41 | 42 | 43 | def make_valid_tf_name(name): 44 | if not _VALID_OP_NAME_START.match(name): 45 | # Must make it valid somehow, but don't want to remove stuff 46 | name = '.' + name 47 | return '_'.join(_VALID_OP_NAME_PART.findall(name)) 48 | -------------------------------------------------------------------------------- /classification/tensorboardX/record_writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/record_writer.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__init__.py -------------------------------------------------------------------------------- /classification/tensorboardX/src/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__init__.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/attr_value_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/attr_value_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/event_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/event_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/graph_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/graph_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/node_def_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/node_def_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/plugin_pr_curve_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/plugin_pr_curve_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/resource_handle_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/resource_handle_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/summary_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/summary_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/tensor_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/tensor_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/tensor_shape_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/tensor_shape_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/types_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/types_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/__pycache__/versions_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/__pycache__/versions_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "AttrValueProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/tensor.proto"; 10 | import "tensorboardX/src/tensor_shape.proto"; 11 | import "tensorboardX/src/types.proto"; 12 | 13 | // Protocol buffer representing the value for an attr used to configure an Op. 14 | // Comment indicates the corresponding attr type. Only the field matching the 15 | // attr type may be filled. 16 | message AttrValue { 17 | // LINT.IfChange 18 | message ListValue { 19 | repeated bytes s = 2; // "list(string)" 20 | repeated int64 i = 3 [packed = true]; // "list(int)" 21 | repeated float f = 4 [packed = true]; // "list(float)" 22 | repeated bool b = 5 [packed = true]; // "list(bool)" 23 | repeated DataType type = 6 [packed = true]; // "list(type)" 24 | repeated TensorShapeProto shape = 7; // "list(shape)" 25 | repeated TensorProto tensor = 8; // "list(tensor)" 26 | repeated NameAttrList func = 9; // "list(attr)" 27 | } 28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) 29 | 30 | oneof value { 31 | bytes s = 2; // "string" 32 | int64 i = 3; // "int" 33 | float f = 4; // "float" 34 | bool b = 5; // "bool" 35 | DataType type = 6; // "type" 36 | TensorShapeProto shape = 7; // "shape" 37 | TensorProto tensor = 8; // "tensor" 38 | ListValue list = 1; // any "list(...)" 39 | 40 | // "func" represents a function. func.name is a function's name or 41 | // a primitive op's name. func.attr.first is the name of an attr 42 | // defined for that function. func.attr.second is the value for 43 | // that attr in the instantiation. 44 | NameAttrList func = 10; 45 | 46 | // This is a placeholder only used in nodes defined inside a 47 | // function. It indicates the attr value will be supplied when 48 | // the function is instantiated. For example, let us suppose a 49 | // node "N" in function "FN". "N" has an attr "A" with value 50 | // placeholder = "foo". When FN is instantiated with attr "foo" 51 | // set to "bar", the instantiated node N's attr A will have been 52 | // given the value "bar". 53 | string placeholder = 9; 54 | } 55 | } 56 | 57 | // A list of attr names and their values. The whole list is attached 58 | // with a string name. E.g., MatMul[T=float]. 59 | message NameAttrList { 60 | string name = 1; 61 | map attr = 2; 62 | } 63 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/attr_value_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/attr_value_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "EventProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.util"; 8 | 9 | import "tensorboardX/src/summary.proto"; 10 | 11 | // Protocol buffer representing an event that happened during 12 | // the execution of a Brain model. 13 | message Event { 14 | // Timestamp of the event. 15 | double wall_time = 1; 16 | 17 | // Global step of the event. 18 | int64 step = 2; 19 | 20 | oneof what { 21 | // An event file was started, with the specified version. 22 | // This is use to identify the contents of the record IO files 23 | // easily. Current version is "brain.Event:2". All versions 24 | // start with "brain.Event:". 25 | string file_version = 3; 26 | // An encoded version of a GraphDef. 27 | bytes graph_def = 4; 28 | // A summary was generated. 29 | Summary summary = 5; 30 | // The user output a log message. Not all messages are logged, only ones 31 | // generated via the Python tensorboard_logging module. 32 | LogMessage log_message = 6; 33 | // The state of the session which can be used for restarting after crashes. 34 | SessionLog session_log = 7; 35 | // The metadata returned by running a session.run() call. 36 | TaggedRunMetadata tagged_run_metadata = 8; 37 | // An encoded version of a MetaGraphDef. 38 | bytes meta_graph_def = 9; 39 | } 40 | } 41 | 42 | // Protocol buffer used for logging messages to the events file. 43 | message LogMessage { 44 | enum Level { 45 | UNKNOWN = 0; 46 | DEBUG = 10; 47 | INFO = 20; 48 | WARN = 30; 49 | ERROR = 40; 50 | FATAL = 50; 51 | } 52 | Level level = 1; 53 | string message = 2; 54 | } 55 | 56 | // Protocol buffer used for logging session state. 57 | message SessionLog { 58 | enum SessionStatus { 59 | STATUS_UNSPECIFIED = 0; 60 | START = 1; 61 | STOP = 2; 62 | CHECKPOINT = 3; 63 | } 64 | 65 | SessionStatus status = 1; 66 | // This checkpoint_path contains both the path and filename. 67 | string checkpoint_path = 2; 68 | string msg = 3; 69 | } 70 | 71 | // For logging the metadata output for a single session.run() call. 72 | message TaggedRunMetadata { 73 | // Tag name associated with this metadata. 74 | string tag = 1; 75 | // Byte-encoded version of the `RunMetadata` proto in order to allow lazy 76 | // deserialization. 77 | bytes run_metadata = 2; 78 | } 79 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/event_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/event_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "GraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/node_def.proto"; 10 | //import "tensorflow/core/framework/function.proto"; 11 | import "tensorboardX/src/versions.proto"; 12 | 13 | // Represents the graph of operations 14 | message GraphDef { 15 | repeated NodeDef node = 1; 16 | 17 | // Compatibility versions of the graph. See core/public/version.h for version 18 | // history. The GraphDef version is distinct from the TensorFlow version, and 19 | // each release of TensorFlow will support a range of GraphDef versions. 20 | VersionDef versions = 4; 21 | 22 | // Deprecated single version field; use versions above instead. Since all 23 | // GraphDef changes before "versions" was introduced were forward 24 | // compatible, this field is entirely ignored. 25 | int32 version = 3 [deprecated = true]; 26 | 27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 28 | // 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | //FunctionDefLibrary library = 2; 56 | }; 57 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/graph_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/graph.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorboardX.src import node_def_pb2 as tensorboardX_dot_src_dot_node__def__pb2 17 | from tensorboardX.src import versions_pb2 as tensorboardX_dot_src_dot_versions__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='tensorboardX/src/graph.proto', 22 | package='tensorboard', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\x1ctensorboardX/src/graph.proto\x12\x0btensorboard\x1a\x1ftensorboardX/src/node_def.proto\x1a\x1ftensorboardX/src/versions.proto\"n\n\x08GraphDef\x12\"\n\x04node\x18\x01 \x03(\x0b\x32\x14.tensorboard.NodeDef\x12)\n\x08versions\x18\x04 \x01(\x0b\x32\x17.tensorboard.VersionDef\x12\x13\n\x07version\x18\x03 \x01(\x05\x42\x02\x18\x01\x42,\n\x18org.tensorflow.frameworkB\x0bGraphProtosP\x01\xf8\x01\x01\x62\x06proto3') 25 | , 26 | dependencies=[tensorboardX_dot_src_dot_node__def__pb2.DESCRIPTOR,tensorboardX_dot_src_dot_versions__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _GRAPHDEF = _descriptor.Descriptor( 32 | name='GraphDef', 33 | full_name='tensorboard.GraphDef', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='node', full_name='tensorboard.GraphDef.node', index=0, 40 | number=1, type=11, cpp_type=10, label=3, 41 | has_default_value=False, default_value=[], 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='versions', full_name='tensorboard.GraphDef.versions', index=1, 47 | number=4, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | _descriptor.FieldDescriptor( 53 | name='version', full_name='tensorboard.GraphDef.version', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=111, 72 | serialized_end=221, 73 | ) 74 | 75 | _GRAPHDEF.fields_by_name['node'].message_type = tensorboardX_dot_src_dot_node__def__pb2._NODEDEF 76 | _GRAPHDEF.fields_by_name['versions'].message_type = tensorboardX_dot_src_dot_versions__pb2._VERSIONDEF 77 | DESCRIPTOR.message_types_by_name['GraphDef'] = _GRAPHDEF 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | GraphDef = _reflection.GeneratedProtocolMessageType('GraphDef', (_message.Message,), dict( 81 | DESCRIPTOR = _GRAPHDEF, 82 | __module__ = 'tensorboardX.src.graph_pb2' 83 | # @@protoc_insertion_point(class_scope:tensorboard.GraphDef) 84 | )) 85 | _sym_db.RegisterMessage(GraphDef) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013GraphProtosP\001\370\001\001')) 90 | _GRAPHDEF.fields_by_name['version'].has_options = True 91 | _GRAPHDEF.fields_by_name['version']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/graph_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/graph_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "NodeProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/attr_value.proto"; 10 | 11 | message NodeDef { 12 | // The name given to this operator. Used for naming inputs, 13 | // logging, visualization, etc. Unique within a single GraphDef. 14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". 15 | string name = 1; 16 | 17 | // The operation name. There may be custom parameters in attrs. 18 | // Op names starting with an underscore are reserved for internal use. 19 | string op = 2; 20 | 21 | // Each input is "node:src_output" with "node" being a string name and 22 | // "src_output" indicating which output tensor to use from "node". If 23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 24 | // may optionally be followed by control inputs that have the format 25 | // "^node". 26 | repeated string input = 3; 27 | 28 | // A (possibly partial) specification for the device on which this 29 | // node should be placed. 30 | // The expected syntax for this string is as follows: 31 | // 32 | // DEVICE_SPEC ::= PARTIAL_SPEC 33 | // 34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 35 | // CONSTRAINT ::= ("job:" JOB_NAME) 36 | // | ("replica:" [1-9][0-9]*) 37 | // | ("task:" [1-9][0-9]*) 38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) 39 | // 40 | // Valid values for this string include: 41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification) 42 | // * "/job:worker/gpu:3" (partial specification) 43 | // * "" (no specification) 44 | // 45 | // If the constraints do not resolve to a single device (or if this 46 | // field is empty or not present), the runtime will attempt to 47 | // choose a device automatically. 48 | string device = 4; 49 | 50 | // Operation-specific graph-construction-time configuration. 51 | // Note that this should include all attrs defined in the 52 | // corresponding OpDef, including those with a value matching 53 | // the default -- this allows the default to change and makes 54 | // NodeDefs easier to interpret on their own. However, if 55 | // an attr with a default is not specified in this list, the 56 | // default will be used. 57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 58 | // one of the names from the corresponding OpDef's attr field). 59 | // The values must have a type matching the corresponding OpDef 60 | // attr's type field. 61 | // TODO(josh11b): Add some examples here showing best practices. 62 | map attr = 5; 63 | }; 64 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/node_def_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/node_def.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from tensorboardX.src import attr_value_pb2 as tensorboardX_dot_src_dot_attr__value__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='tensorboardX/src/node_def.proto', 21 | package='tensorboard', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x1ftensorboardX/src/node_def.proto\x12\x0btensorboard\x1a!tensorboardX/src/attr_value.proto\"\xb5\x01\n\x07NodeDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02op\x18\x02 \x01(\t\x12\r\n\x05input\x18\x03 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x04 \x01(\t\x12,\n\x04\x61ttr\x18\x05 \x03(\x0b\x32\x1e.tensorboard.NodeDef.AttrEntry\x1a\x43\n\tAttrEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorboard.AttrValue:\x02\x38\x01\x42*\n\x18org.tensorflow.frameworkB\tNodeProtoP\x01\xf8\x01\x01\x62\x06proto3') 24 | , 25 | dependencies=[tensorboardX_dot_src_dot_attr__value__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _NODEDEF_ATTRENTRY = _descriptor.Descriptor( 31 | name='AttrEntry', 32 | full_name='tensorboard.NodeDef.AttrEntry', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='key', full_name='tensorboard.NodeDef.AttrEntry.key', index=0, 39 | number=1, type=9, cpp_type=9, label=1, 40 | has_default_value=False, default_value=_b("").decode('utf-8'), 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None), 44 | _descriptor.FieldDescriptor( 45 | name='value', full_name='tensorboard.NodeDef.AttrEntry.value', index=1, 46 | number=2, type=11, cpp_type=10, label=1, 47 | has_default_value=False, default_value=None, 48 | message_type=None, enum_type=None, containing_type=None, 49 | is_extension=False, extension_scope=None, 50 | options=None), 51 | ], 52 | extensions=[ 53 | ], 54 | nested_types=[], 55 | enum_types=[ 56 | ], 57 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 58 | is_extendable=False, 59 | syntax='proto3', 60 | extension_ranges=[], 61 | oneofs=[ 62 | ], 63 | serialized_start=198, 64 | serialized_end=265, 65 | ) 66 | 67 | _NODEDEF = _descriptor.Descriptor( 68 | name='NodeDef', 69 | full_name='tensorboard.NodeDef', 70 | filename=None, 71 | file=DESCRIPTOR, 72 | containing_type=None, 73 | fields=[ 74 | _descriptor.FieldDescriptor( 75 | name='name', full_name='tensorboard.NodeDef.name', index=0, 76 | number=1, type=9, cpp_type=9, label=1, 77 | has_default_value=False, default_value=_b("").decode('utf-8'), 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None), 81 | _descriptor.FieldDescriptor( 82 | name='op', full_name='tensorboard.NodeDef.op', index=1, 83 | number=2, type=9, cpp_type=9, label=1, 84 | has_default_value=False, default_value=_b("").decode('utf-8'), 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None), 88 | _descriptor.FieldDescriptor( 89 | name='input', full_name='tensorboard.NodeDef.input', index=2, 90 | number=3, type=9, cpp_type=9, label=3, 91 | has_default_value=False, default_value=[], 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None), 95 | _descriptor.FieldDescriptor( 96 | name='device', full_name='tensorboard.NodeDef.device', index=3, 97 | number=4, type=9, cpp_type=9, label=1, 98 | has_default_value=False, default_value=_b("").decode('utf-8'), 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None), 102 | _descriptor.FieldDescriptor( 103 | name='attr', full_name='tensorboard.NodeDef.attr', index=4, 104 | number=5, type=11, cpp_type=10, label=3, 105 | has_default_value=False, default_value=[], 106 | message_type=None, enum_type=None, containing_type=None, 107 | is_extension=False, extension_scope=None, 108 | options=None), 109 | ], 110 | extensions=[ 111 | ], 112 | nested_types=[_NODEDEF_ATTRENTRY, ], 113 | enum_types=[ 114 | ], 115 | options=None, 116 | is_extendable=False, 117 | syntax='proto3', 118 | extension_ranges=[], 119 | oneofs=[ 120 | ], 121 | serialized_start=84, 122 | serialized_end=265, 123 | ) 124 | 125 | _NODEDEF_ATTRENTRY.fields_by_name['value'].message_type = tensorboardX_dot_src_dot_attr__value__pb2._ATTRVALUE 126 | _NODEDEF_ATTRENTRY.containing_type = _NODEDEF 127 | _NODEDEF.fields_by_name['attr'].message_type = _NODEDEF_ATTRENTRY 128 | DESCRIPTOR.message_types_by_name['NodeDef'] = _NODEDEF 129 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 130 | 131 | NodeDef = _reflection.GeneratedProtocolMessageType('NodeDef', (_message.Message,), dict( 132 | 133 | AttrEntry = _reflection.GeneratedProtocolMessageType('AttrEntry', (_message.Message,), dict( 134 | DESCRIPTOR = _NODEDEF_ATTRENTRY, 135 | __module__ = 'tensorboardX.src.node_def_pb2' 136 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef.AttrEntry) 137 | )) 138 | , 139 | DESCRIPTOR = _NODEDEF, 140 | __module__ = 'tensorboardX.src.node_def_pb2' 141 | # @@protoc_insertion_point(class_scope:tensorboard.NodeDef) 142 | )) 143 | _sym_db.RegisterMessage(NodeDef) 144 | _sym_db.RegisterMessage(NodeDef.AttrEntry) 145 | 146 | 147 | DESCRIPTOR.has_options = True 148 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\tNodeProtoP\001\370\001\001')) 149 | _NODEDEF_ATTRENTRY.has_options = True 150 | _NODEDEF_ATTRENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 151 | # @@protoc_insertion_point(module_scope) 152 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/node_def_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/node_def_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/plugin_pr_curve.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorboard; 19 | 20 | message PrCurvePluginData { 21 | // Version `0` is the only supported version. 22 | int32 version = 1; 23 | 24 | uint32 num_thresholds = 2; 25 | } 26 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/plugin_pr_curve_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/plugin_pr_curve.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/plugin_pr_curve.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n&tensorboardX/src/plugin_pr_curve.proto\x12\x0btensorboard\"<\n\x11PrCurvePluginData\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x16\n\x0enum_thresholds\x18\x02 \x01(\rb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _PRCURVEPLUGINDATA = _descriptor.Descriptor( 29 | name='PrCurvePluginData', 30 | full_name='tensorboard.PrCurvePluginData', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='version', full_name='tensorboard.PrCurvePluginData.version', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='num_thresholds', full_name='tensorboard.PrCurvePluginData.num_thresholds', index=1, 44 | number=2, type=13, cpp_type=3, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=55, 62 | serialized_end=115, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['PrCurvePluginData'] = _PRCURVEPLUGINDATA 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | PrCurvePluginData = _reflection.GeneratedProtocolMessageType('PrCurvePluginData', (_message.Message,), dict( 69 | DESCRIPTOR = _PRCURVEPLUGINDATA, 70 | __module__ = 'tensorboardX.src.plugin_pr_curve_pb2' 71 | # @@protoc_insertion_point(class_scope:tensorboard.PrCurvePluginData) 72 | )) 73 | _sym_db.RegisterMessage(PrCurvePluginData) 74 | 75 | 76 | # @@protoc_insertion_point(module_scope) 77 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/plugin_pr_curve_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/plugin_pr_curve_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandle"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandleProto { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/resource_handle_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/resource_handle.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/resource_handle.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n&tensorboardX/src/resource_handle.proto\x12\x0btensorboard\"r\n\x13ResourceHandleProto\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x11\n\tcontainer\x18\x02 \x01(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\thash_code\x18\x04 \x01(\x04\x12\x17\n\x0fmaybe_type_name\x18\x05 \x01(\tB/\n\x18org.tensorflow.frameworkB\x0eResourceHandleP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOURCEHANDLEPROTO = _descriptor.Descriptor( 29 | name='ResourceHandleProto', 30 | full_name='tensorboard.ResourceHandleProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='device', full_name='tensorboard.ResourceHandleProto.device', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='container', full_name='tensorboard.ResourceHandleProto.container', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='name', full_name='tensorboard.ResourceHandleProto.name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='hash_code', full_name='tensorboard.ResourceHandleProto.hash_code', index=3, 58 | number=4, type=4, cpp_type=4, label=1, 59 | has_default_value=False, default_value=0, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='maybe_type_name', full_name='tensorboard.ResourceHandleProto.maybe_type_name', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | syntax='proto3', 79 | extension_ranges=[], 80 | oneofs=[ 81 | ], 82 | serialized_start=55, 83 | serialized_end=169, 84 | ) 85 | 86 | DESCRIPTOR.message_types_by_name['ResourceHandleProto'] = _RESOURCEHANDLEPROTO 87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 88 | 89 | ResourceHandleProto = _reflection.GeneratedProtocolMessageType('ResourceHandleProto', (_message.Message,), dict( 90 | DESCRIPTOR = _RESOURCEHANDLEPROTO, 91 | __module__ = 'tensorboardX.src.resource_handle_pb2' 92 | # @@protoc_insertion_point(class_scope:tensorboard.ResourceHandleProto) 93 | )) 94 | _sym_db.RegisterMessage(ResourceHandleProto) 95 | 96 | 97 | DESCRIPTOR.has_options = True 98 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016ResourceHandleP\001\370\001\001')) 99 | # @@protoc_insertion_point(module_scope) 100 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/resource_handle_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/resource_handle_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/summary.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "SummaryProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/tensor.proto"; 10 | 11 | // Metadata associated with a series of Summary data 12 | message SummaryDescription { 13 | // Hint on how plugins should process the data in this series. 14 | // Supported values include "scalar", "histogram", "image", "audio" 15 | string type_hint = 1; 16 | } 17 | 18 | // Serialization format for histogram module in 19 | // core/lib/histogram/histogram.h 20 | message HistogramProto { 21 | double min = 1; 22 | double max = 2; 23 | double num = 3; 24 | double sum = 4; 25 | double sum_squares = 5; 26 | 27 | // Parallel arrays encoding the bucket boundaries and the bucket values. 28 | // bucket(i) is the count for the bucket i. The range for 29 | // a bucket is: 30 | // i == 0: -DBL_MAX .. bucket_limit(0) 31 | // i != 0: bucket_limit(i-1) .. bucket_limit(i) 32 | repeated double bucket_limit = 6 [packed = true]; 33 | repeated double bucket = 7 [packed = true]; 34 | }; 35 | 36 | // A SummaryMetadata encapsulates information on which plugins are able to make 37 | // use of a certain summary value. 38 | message SummaryMetadata { 39 | message PluginData { 40 | // The name of the plugin this data pertains to. 41 | string plugin_name = 1; 42 | 43 | // The content to store for the plugin. The best practice is for this JSON 44 | // string to be the canonical JSON serialization of a protocol buffer 45 | // defined by the plugin. Converting that protobuf to and from JSON is the 46 | // responsibility of the plugin code, and is not enforced by 47 | // TensorFlow/TensorBoard. 48 | string content = 2; 49 | } 50 | 51 | // A list of plugin data. A single summary value instance may be used by more 52 | // than 1 plugin. 53 | repeated PluginData plugin_data = 1; 54 | }; 55 | 56 | // A Summary is a set of named values to be displayed by the 57 | // visualizer. 58 | // 59 | // Summaries are produced regularly during training, as controlled by 60 | // the "summary_interval_secs" attribute of the training operation. 61 | // Summaries are also produced at the end of an evaluation. 62 | message Summary { 63 | message Image { 64 | // Dimensions of the image. 65 | int32 height = 1; 66 | int32 width = 2; 67 | // Valid colorspace values are 68 | // 1 - grayscale 69 | // 2 - grayscale + alpha 70 | // 3 - RGB 71 | // 4 - RGBA 72 | // 5 - DIGITAL_YUV 73 | // 6 - BGRA 74 | int32 colorspace = 3; 75 | // Image data in encoded format. All image formats supported by 76 | // image_codec::CoderUtil can be stored here. 77 | bytes encoded_image_string = 4; 78 | } 79 | 80 | message Audio { 81 | // Sample rate of the audio in Hz. 82 | float sample_rate = 1; 83 | // Number of channels of audio. 84 | int64 num_channels = 2; 85 | // Length of the audio in frames (samples per channel). 86 | int64 length_frames = 3; 87 | // Encoded audio data and its associated RFC 2045 content type (e.g. 88 | // "audio/wav"). 89 | bytes encoded_audio_string = 4; 90 | string content_type = 5; 91 | } 92 | 93 | message Value { 94 | // Name of the node that output this summary; in general, the name of a 95 | // TensorSummary node. If the node in question has multiple outputs, then 96 | // a ":\d+" suffix will be appended, like "some_op:13". 97 | // Might not be set for legacy summaries (i.e. those not using the tensor 98 | // value field) 99 | string node_name = 7; 100 | 101 | // Tag name for the data. Will only be used by legacy summaries 102 | // (ie. those not using the tensor value field) 103 | // For legacy summaries, will be used as the title of the graph 104 | // in the visualizer. 105 | // 106 | // Tag is usually "op_name:value_name", where "op_name" itself can have 107 | // structure to indicate grouping. 108 | string tag = 1; 109 | SummaryMetadata metadata = 9; 110 | // Value associated with the tag. 111 | oneof value { 112 | float simple_value = 2; 113 | bytes obsolete_old_style_histogram = 3; 114 | Image image = 4; 115 | HistogramProto histo = 5; 116 | Audio audio = 6; 117 | TensorProto tensor = 8; 118 | } 119 | } 120 | 121 | // Set of values for the summary. 122 | repeated Value value = 1; 123 | } 124 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/summary_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/summary_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorboardX/src/resource_handle.proto"; 10 | import "tensorboardX/src/tensor_shape.proto"; 11 | import "tensorboardX/src/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 33 | // can be used for all tensor types. The purpose of this representation is to 34 | // reduce serialization overhead during RPC call by avoiding serialization of 35 | // many repeated small items. 36 | bytes tensor_content = 4; 37 | 38 | // Type specific representations that make it easy to create tensor protos in 39 | // all languages. Only the representation corresponding to "dtype" can 40 | // be set. The values hold the flattened representation of the tensor in 41 | // row major order. 42 | 43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 44 | // pointless zero padding for each value here. 45 | repeated int32 half_val = 13 [packed = true]; 46 | 47 | // DT_FLOAT. 48 | repeated float float_val = 5 [packed = true]; 49 | 50 | // DT_DOUBLE. 51 | repeated double double_val = 6 [packed = true]; 52 | 53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 54 | repeated int32 int_val = 7 [packed = true]; 55 | 56 | // DT_STRING 57 | repeated bytes string_val = 8; 58 | 59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 60 | // and imaginary parts of i-th single precision complex. 61 | repeated float scomplex_val = 9 [packed = true]; 62 | 63 | // DT_INT64 64 | repeated int64 int64_val = 10 [packed = true]; 65 | 66 | // DT_BOOL 67 | repeated bool bool_val = 11 [packed = true]; 68 | 69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 70 | // and imaginary parts of i-th double precision complex. 71 | repeated double dcomplex_val = 12 [packed = true]; 72 | 73 | // DT_RESOURCE 74 | repeated ResourceHandleProto resource_handle_val = 14; 75 | }; 76 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/tensor_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/tensor_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorboard; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/tensor_shape_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/tensor_shape.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/tensor_shape.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n#tensorboardX/src/tensor_shape.proto\x12\x0btensorboard\"{\n\x10TensorShapeProto\x12.\n\x03\x64im\x18\x02 \x03(\x0b\x32!.tensorboard.TensorShapeProto.Dim\x12\x14\n\x0cunknown_rank\x18\x03 \x01(\x08\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tB2\n\x18org.tensorflow.frameworkB\x11TensorShapeProtosP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor( 29 | name='Dim', 30 | full_name='tensorboard.TensorShapeProto.Dim', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='size', full_name='tensorboard.TensorShapeProto.Dim.size', index=0, 37 | number=1, type=3, cpp_type=2, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='name', full_name='tensorboard.TensorShapeProto.Dim.name', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=142, 62 | serialized_end=175, 63 | ) 64 | 65 | _TENSORSHAPEPROTO = _descriptor.Descriptor( 66 | name='TensorShapeProto', 67 | full_name='tensorboard.TensorShapeProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='dim', full_name='tensorboard.TensorShapeProto.dim', index=0, 74 | number=2, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None), 79 | _descriptor.FieldDescriptor( 80 | name='unknown_rank', full_name='tensorboard.TensorShapeProto.unknown_rank', index=1, 81 | number=3, type=8, cpp_type=7, label=1, 82 | has_default_value=False, default_value=False, 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | ], 87 | extensions=[ 88 | ], 89 | nested_types=[_TENSORSHAPEPROTO_DIM, ], 90 | enum_types=[ 91 | ], 92 | options=None, 93 | is_extendable=False, 94 | syntax='proto3', 95 | extension_ranges=[], 96 | oneofs=[ 97 | ], 98 | serialized_start=52, 99 | serialized_end=175, 100 | ) 101 | 102 | _TENSORSHAPEPROTO_DIM.containing_type = _TENSORSHAPEPROTO 103 | _TENSORSHAPEPROTO.fields_by_name['dim'].message_type = _TENSORSHAPEPROTO_DIM 104 | DESCRIPTOR.message_types_by_name['TensorShapeProto'] = _TENSORSHAPEPROTO 105 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 106 | 107 | TensorShapeProto = _reflection.GeneratedProtocolMessageType('TensorShapeProto', (_message.Message,), dict( 108 | 109 | Dim = _reflection.GeneratedProtocolMessageType('Dim', (_message.Message,), dict( 110 | DESCRIPTOR = _TENSORSHAPEPROTO_DIM, 111 | __module__ = 'tensorboardX.src.tensor_shape_pb2' 112 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto.Dim) 113 | )) 114 | , 115 | DESCRIPTOR = _TENSORSHAPEPROTO, 116 | __module__ = 'tensorboardX.src.tensor_shape_pb2' 117 | # @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto) 118 | )) 119 | _sym_db.RegisterMessage(TensorShapeProto) 120 | _sym_db.RegisterMessage(TensorShapeProto.Dim) 121 | 122 | 123 | DESCRIPTOR.has_options = True 124 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\021TensorShapeProtosP\001\370\001\001')) 125 | # @@protoc_insertion_point(module_scope) 126 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/tensor_shape_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/tensor_shape_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) 65 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/types_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/types.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='tensorboardX/src/types.proto', 21 | package='tensorboard', 22 | syntax='proto3', 23 | serialized_pb=_b('\n\x1ctensorboardX/src/types.proto\x12\x0btensorboard*\xc2\x05\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x10\n\x0c\x44T_COMPLEX64\x10\x08\x12\x0c\n\x08\x44T_INT64\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n\x12\x0c\n\x08\x44T_QINT8\x10\x0b\x12\r\n\tDT_QUINT8\x10\x0c\x12\r\n\tDT_QINT32\x10\r\x12\x0f\n\x0b\x44T_BFLOAT16\x10\x0e\x12\r\n\tDT_QINT16\x10\x0f\x12\x0e\n\nDT_QUINT16\x10\x10\x12\r\n\tDT_UINT16\x10\x11\x12\x11\n\rDT_COMPLEX128\x10\x12\x12\x0b\n\x07\x44T_HALF\x10\x13\x12\x0f\n\x0b\x44T_RESOURCE\x10\x14\x12\x10\n\x0c\x44T_FLOAT_REF\x10\x65\x12\x11\n\rDT_DOUBLE_REF\x10\x66\x12\x10\n\x0c\x44T_INT32_REF\x10g\x12\x10\n\x0c\x44T_UINT8_REF\x10h\x12\x10\n\x0c\x44T_INT16_REF\x10i\x12\x0f\n\x0b\x44T_INT8_REF\x10j\x12\x11\n\rDT_STRING_REF\x10k\x12\x14\n\x10\x44T_COMPLEX64_REF\x10l\x12\x10\n\x0c\x44T_INT64_REF\x10m\x12\x0f\n\x0b\x44T_BOOL_REF\x10n\x12\x10\n\x0c\x44T_QINT8_REF\x10o\x12\x11\n\rDT_QUINT8_REF\x10p\x12\x11\n\rDT_QINT32_REF\x10q\x12\x13\n\x0f\x44T_BFLOAT16_REF\x10r\x12\x11\n\rDT_QINT16_REF\x10s\x12\x12\n\x0e\x44T_QUINT16_REF\x10t\x12\x11\n\rDT_UINT16_REF\x10u\x12\x15\n\x11\x44T_COMPLEX128_REF\x10v\x12\x0f\n\x0b\x44T_HALF_REF\x10w\x12\x13\n\x0f\x44T_RESOURCE_REF\x10xB,\n\x18org.tensorflow.frameworkB\x0bTypesProtosP\x01\xf8\x01\x01\x62\x06proto3') 24 | ) 25 | 26 | _DATATYPE = _descriptor.EnumDescriptor( 27 | name='DataType', 28 | full_name='tensorboard.DataType', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='DT_INVALID', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='DT_FLOAT', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='DT_DOUBLE', index=2, number=2, 42 | options=None, 43 | type=None), 44 | _descriptor.EnumValueDescriptor( 45 | name='DT_INT32', index=3, number=3, 46 | options=None, 47 | type=None), 48 | _descriptor.EnumValueDescriptor( 49 | name='DT_UINT8', index=4, number=4, 50 | options=None, 51 | type=None), 52 | _descriptor.EnumValueDescriptor( 53 | name='DT_INT16', index=5, number=5, 54 | options=None, 55 | type=None), 56 | _descriptor.EnumValueDescriptor( 57 | name='DT_INT8', index=6, number=6, 58 | options=None, 59 | type=None), 60 | _descriptor.EnumValueDescriptor( 61 | name='DT_STRING', index=7, number=7, 62 | options=None, 63 | type=None), 64 | _descriptor.EnumValueDescriptor( 65 | name='DT_COMPLEX64', index=8, number=8, 66 | options=None, 67 | type=None), 68 | _descriptor.EnumValueDescriptor( 69 | name='DT_INT64', index=9, number=9, 70 | options=None, 71 | type=None), 72 | _descriptor.EnumValueDescriptor( 73 | name='DT_BOOL', index=10, number=10, 74 | options=None, 75 | type=None), 76 | _descriptor.EnumValueDescriptor( 77 | name='DT_QINT8', index=11, number=11, 78 | options=None, 79 | type=None), 80 | _descriptor.EnumValueDescriptor( 81 | name='DT_QUINT8', index=12, number=12, 82 | options=None, 83 | type=None), 84 | _descriptor.EnumValueDescriptor( 85 | name='DT_QINT32', index=13, number=13, 86 | options=None, 87 | type=None), 88 | _descriptor.EnumValueDescriptor( 89 | name='DT_BFLOAT16', index=14, number=14, 90 | options=None, 91 | type=None), 92 | _descriptor.EnumValueDescriptor( 93 | name='DT_QINT16', index=15, number=15, 94 | options=None, 95 | type=None), 96 | _descriptor.EnumValueDescriptor( 97 | name='DT_QUINT16', index=16, number=16, 98 | options=None, 99 | type=None), 100 | _descriptor.EnumValueDescriptor( 101 | name='DT_UINT16', index=17, number=17, 102 | options=None, 103 | type=None), 104 | _descriptor.EnumValueDescriptor( 105 | name='DT_COMPLEX128', index=18, number=18, 106 | options=None, 107 | type=None), 108 | _descriptor.EnumValueDescriptor( 109 | name='DT_HALF', index=19, number=19, 110 | options=None, 111 | type=None), 112 | _descriptor.EnumValueDescriptor( 113 | name='DT_RESOURCE', index=20, number=20, 114 | options=None, 115 | type=None), 116 | _descriptor.EnumValueDescriptor( 117 | name='DT_FLOAT_REF', index=21, number=101, 118 | options=None, 119 | type=None), 120 | _descriptor.EnumValueDescriptor( 121 | name='DT_DOUBLE_REF', index=22, number=102, 122 | options=None, 123 | type=None), 124 | _descriptor.EnumValueDescriptor( 125 | name='DT_INT32_REF', index=23, number=103, 126 | options=None, 127 | type=None), 128 | _descriptor.EnumValueDescriptor( 129 | name='DT_UINT8_REF', index=24, number=104, 130 | options=None, 131 | type=None), 132 | _descriptor.EnumValueDescriptor( 133 | name='DT_INT16_REF', index=25, number=105, 134 | options=None, 135 | type=None), 136 | _descriptor.EnumValueDescriptor( 137 | name='DT_INT8_REF', index=26, number=106, 138 | options=None, 139 | type=None), 140 | _descriptor.EnumValueDescriptor( 141 | name='DT_STRING_REF', index=27, number=107, 142 | options=None, 143 | type=None), 144 | _descriptor.EnumValueDescriptor( 145 | name='DT_COMPLEX64_REF', index=28, number=108, 146 | options=None, 147 | type=None), 148 | _descriptor.EnumValueDescriptor( 149 | name='DT_INT64_REF', index=29, number=109, 150 | options=None, 151 | type=None), 152 | _descriptor.EnumValueDescriptor( 153 | name='DT_BOOL_REF', index=30, number=110, 154 | options=None, 155 | type=None), 156 | _descriptor.EnumValueDescriptor( 157 | name='DT_QINT8_REF', index=31, number=111, 158 | options=None, 159 | type=None), 160 | _descriptor.EnumValueDescriptor( 161 | name='DT_QUINT8_REF', index=32, number=112, 162 | options=None, 163 | type=None), 164 | _descriptor.EnumValueDescriptor( 165 | name='DT_QINT32_REF', index=33, number=113, 166 | options=None, 167 | type=None), 168 | _descriptor.EnumValueDescriptor( 169 | name='DT_BFLOAT16_REF', index=34, number=114, 170 | options=None, 171 | type=None), 172 | _descriptor.EnumValueDescriptor( 173 | name='DT_QINT16_REF', index=35, number=115, 174 | options=None, 175 | type=None), 176 | _descriptor.EnumValueDescriptor( 177 | name='DT_QUINT16_REF', index=36, number=116, 178 | options=None, 179 | type=None), 180 | _descriptor.EnumValueDescriptor( 181 | name='DT_UINT16_REF', index=37, number=117, 182 | options=None, 183 | type=None), 184 | _descriptor.EnumValueDescriptor( 185 | name='DT_COMPLEX128_REF', index=38, number=118, 186 | options=None, 187 | type=None), 188 | _descriptor.EnumValueDescriptor( 189 | name='DT_HALF_REF', index=39, number=119, 190 | options=None, 191 | type=None), 192 | _descriptor.EnumValueDescriptor( 193 | name='DT_RESOURCE_REF', index=40, number=120, 194 | options=None, 195 | type=None), 196 | ], 197 | containing_type=None, 198 | options=None, 199 | serialized_start=46, 200 | serialized_end=752, 201 | ) 202 | _sym_db.RegisterEnumDescriptor(_DATATYPE) 203 | 204 | DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) 205 | DT_INVALID = 0 206 | DT_FLOAT = 1 207 | DT_DOUBLE = 2 208 | DT_INT32 = 3 209 | DT_UINT8 = 4 210 | DT_INT16 = 5 211 | DT_INT8 = 6 212 | DT_STRING = 7 213 | DT_COMPLEX64 = 8 214 | DT_INT64 = 9 215 | DT_BOOL = 10 216 | DT_QINT8 = 11 217 | DT_QUINT8 = 12 218 | DT_QINT32 = 13 219 | DT_BFLOAT16 = 14 220 | DT_QINT16 = 15 221 | DT_QUINT16 = 16 222 | DT_UINT16 = 17 223 | DT_COMPLEX128 = 18 224 | DT_HALF = 19 225 | DT_RESOURCE = 20 226 | DT_FLOAT_REF = 101 227 | DT_DOUBLE_REF = 102 228 | DT_INT32_REF = 103 229 | DT_UINT8_REF = 104 230 | DT_INT16_REF = 105 231 | DT_INT8_REF = 106 232 | DT_STRING_REF = 107 233 | DT_COMPLEX64_REF = 108 234 | DT_INT64_REF = 109 235 | DT_BOOL_REF = 110 236 | DT_QINT8_REF = 111 237 | DT_QUINT8_REF = 112 238 | DT_QINT32_REF = 113 239 | DT_BFLOAT16_REF = 114 240 | DT_QINT16_REF = 115 241 | DT_QUINT16_REF = 116 242 | DT_UINT16_REF = 117 243 | DT_COMPLEX128_REF = 118 244 | DT_HALF_REF = 119 245 | DT_RESOURCE_REF = 120 246 | 247 | 248 | DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE 249 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 250 | 251 | 252 | DESCRIPTOR.has_options = True 253 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\013TypesProtosP\001\370\001\001')) 254 | # @@protoc_insertion_point(module_scope) 255 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/types_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/types_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/src/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorboard; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "VersionsProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Version information for a piece of serialized data 10 | // 11 | // There are different types of versions for each type of data 12 | // (GraphDef, etc.), but they all have the same common shape 13 | // described here. 14 | // 15 | // Each consumer has "consumer" and "min_producer" versions (specified 16 | // elsewhere). A consumer is allowed to consume this data if 17 | // 18 | // producer >= min_producer 19 | // consumer >= min_consumer 20 | // consumer not in bad_consumers 21 | // 22 | message VersionDef { 23 | // The version of the code that produced this data. 24 | int32 producer = 1; 25 | 26 | // Any consumer below this version is not allowed to consume this data. 27 | int32 min_consumer = 2; 28 | 29 | // Specific consumer versions which are disallowed (e.g. due to bugs). 30 | repeated int32 bad_consumers = 3; 31 | }; 32 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/versions_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: tensorboardX/src/versions.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensorboardX/src/versions.proto', 20 | package='tensorboard', 21 | syntax='proto3', 22 | serialized_pb=_b('\n\x1ftensorboardX/src/versions.proto\x12\x0btensorboard\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x42/\n\x18org.tensorflow.frameworkB\x0eVersionsProtosP\x01\xf8\x01\x01\x62\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _VERSIONDEF = _descriptor.Descriptor( 29 | name='VersionDef', 30 | full_name='tensorboard.VersionDef', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='producer', full_name='tensorboard.VersionDef.producer', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='min_consumer', full_name='tensorboard.VersionDef.min_consumer', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='bad_consumers', full_name='tensorboard.VersionDef.bad_consumers', index=2, 51 | number=3, type=5, cpp_type=1, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=48, 69 | serialized_end=123, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['VersionDef'] = _VERSIONDEF 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | VersionDef = _reflection.GeneratedProtocolMessageType('VersionDef', (_message.Message,), dict( 76 | DESCRIPTOR = _VERSIONDEF, 77 | __module__ = 'tensorboardX.src.versions_pb2' 78 | # @@protoc_insertion_point(class_scope:tensorboard.VersionDef) 79 | )) 80 | _sym_db.RegisterMessage(VersionDef) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\n\030org.tensorflow.frameworkB\016VersionsProtosP\001\370\001\001')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /classification/tensorboardX/src/versions_pb2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/src/versions_pb2.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/summary.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/summary.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/writer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/writer.pyc -------------------------------------------------------------------------------- /classification/tensorboardX/x2num.py: -------------------------------------------------------------------------------- 1 | # DO NOT alter/distruct/free input object ! 2 | 3 | import numpy as np 4 | 5 | 6 | def makenp(x, modality=None): 7 | # if already numpy, return 8 | if isinstance(x, np.ndarray): 9 | if modality == 'IMG' and x.dtype == np.uint8: 10 | return x.astype(np.float32) / 255.0 11 | return x 12 | if np.isscalar(x): 13 | return np.array([x]) 14 | if 'torch' in str(type(x)): 15 | return pytorch_np(x, modality) 16 | if 'chainer' in str(type(x)): 17 | return chainer_np(x, modality) 18 | if 'mxnet' in str(type(x)): 19 | return mxnet_np(x, modality) 20 | 21 | 22 | def pytorch_np(x, modality): 23 | import torch 24 | try: 25 | if isinstance(x, torch.autograd.Variable): 26 | x = x.data 27 | except: 28 | if isinstance(x, torch.autograd.variable.Variable): 29 | x = x.data 30 | x = x.cpu().numpy() 31 | if modality == 'IMG': 32 | x = _prepare_image(x) 33 | return x 34 | 35 | 36 | def theano_np(x): 37 | import theano 38 | pass 39 | 40 | 41 | def caffe2_np(x): 42 | pass 43 | 44 | 45 | def mxnet_np(x, modality): 46 | x = x.asnumpy() 47 | if modality == 'IMG': 48 | x = _prepare_image(x) 49 | return x 50 | 51 | 52 | def chainer_np(x, modality): 53 | import chainer 54 | x = chainer.cuda.to_cpu(x.data) 55 | if modality == 'IMG': 56 | x = _prepare_image(x) 57 | return x 58 | 59 | 60 | def make_grid(I, ncols=8): 61 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here' 62 | assert I.ndim == 4 and I.shape[1] == 3 63 | nimg = I.shape[0] 64 | H = I.shape[2] 65 | W = I.shape[3] 66 | ncols = min(nimg, ncols) 67 | nrows = int(np.ceil(float(nimg) / ncols)) 68 | canvas = np.zeros((3, H * nrows, W * ncols)) 69 | i = 0 70 | for y in range(nrows): 71 | for x in range(ncols): 72 | if i >= nimg: 73 | break 74 | canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i] 75 | i = i + 1 76 | return canvas 77 | 78 | 79 | def _prepare_image(I): 80 | assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here' 81 | assert I.ndim == 2 or I.ndim == 3 or I.ndim == 4 82 | if I.ndim == 4: # NCHW 83 | if I.shape[1] == 1: # N1HW 84 | I = np.concatenate((I, I, I), 1) # N3HW 85 | assert I.shape[1] == 3 86 | I = make_grid(I) # 3xHxW 87 | if I.ndim == 3 and I.shape[0] == 1: # 1xHxW 88 | I = np.concatenate((I, I, I), 0) # 3xHxW 89 | if I.ndim == 2: # HxW 90 | I = np.expand_dims(I, 0) # 1xHxW 91 | I = np.concatenate((I, I, I), 0) # 3xHxW 92 | I = I.transpose(1, 2, 0) 93 | 94 | return I 95 | -------------------------------------------------------------------------------- /classification/tensorboardX/x2num.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/tensorboardX/x2num.pyc -------------------------------------------------------------------------------- /classification/utility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def attention(x): 7 | return F.normalize(x.pow(2).mean(1).view(x.shape[0], -1)) 8 | 9 | 10 | def at_loss(x, y): 11 | return (attention(x) - attention(y)).pow(2).mean() 12 | 13 | -------------------------------------------------------------------------------- /classification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from utils.misc import * 4 | from utils.logger import * 5 | from utils.visualize import * 6 | from utils.eval import * 7 | 8 | # progress bar 9 | import os 10 | import sys 11 | from utils.progress.bar import Bar as Bar 12 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 13 | -------------------------------------------------------------------------------- /classification/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """ 8 | Computes the precision@k for the specified values of k 9 | :param output: 10 | :param target: 11 | :param topk: 12 | :return: 13 | """ 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | _, pred = output.topk(maxk, 1, True, True) 18 | pred = pred.t() 19 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 20 | 21 | res = [] 22 | for k in topk: 23 | correct_k = correct[:k].view(-1).float().sum(0) 24 | res.append(correct_k.mul_(100.0 / batch_size)) 25 | return res 26 | -------------------------------------------------------------------------------- /classification/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/images/cifar.png -------------------------------------------------------------------------------- /classification/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/images/imagenet.png -------------------------------------------------------------------------------- /classification/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import math 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | import numpy as np 11 | 12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 13 | 14 | 15 | def savefig(fname, dpi=None): 16 | """ 17 | :param fname: file name 18 | :param dpi: dpi set to 150 by default 19 | :return: 20 | """ 21 | dpi = 150 if dpi is None else dpi 22 | plt.savefig(fname, dpi=dpi) 23 | 24 | 25 | def plot_overlap(logger, names=None): 26 | names = logger.names if names is None else names 27 | numbers = logger.numbers 28 | for _, name in enumerate(names): 29 | x = np.arange(len(numbers[name])) 30 | plt.plot(x, np.asarray(numbers[name])) 31 | return [logger.title + '(' + name + ')' for name in names] 32 | 33 | 34 | class Logger(object): 35 | """ 36 | Save training process to log file with simple plot function. 37 | """ 38 | def __init__(self, fpath, title=None, resume=False): 39 | self.file = None 40 | self.resume = resume 41 | self.title = '' if title is None else title 42 | if fpath is not None: 43 | if resume: 44 | self.file = open(fpath, 'r') 45 | name = self.file.readline() 46 | self.names = name.rstrip().split('\t') 47 | self.numbers = {} 48 | for _, name in enumerate(self.names): 49 | self.numbers[name] = [] 50 | 51 | for numbers in self.file: 52 | numbers = numbers.rstrip().split('\t') 53 | for i in range(0, len(numbers)): 54 | self.numbers[self.names[i]].append(numbers[i]) 55 | self.file.close() 56 | self.file = open(fpath, 'a') 57 | else: 58 | self.file = open(fpath, 'w') 59 | 60 | def set_names(self, names): 61 | if self.resume: 62 | pass 63 | # initialize numbers as empty list 64 | self.numbers = {} 65 | self.names = names 66 | for _, name in enumerate(self.names): 67 | self.file.write(name) 68 | self.file.write('\t') 69 | self.numbers[name] = [] 70 | self.file.write('\n') 71 | self.file.flush() 72 | 73 | def append(self, numbers): 74 | assert len(self.names) == len(numbers), 'Numbers do not match names' 75 | for index, num in enumerate(numbers): 76 | self.file.write("{0:.6f}".format(num)) 77 | self.file.write('\t') 78 | self.numbers[self.names[index]].append(num) 79 | self.file.write('\n') 80 | self.file.flush() 81 | 82 | def plot(self, names=None): 83 | names = self.names if names is None else names 84 | numbers = self.numbers 85 | # fig = plt.figure() 86 | for idx, name in enumerate(names): 87 | x = np.arange(len(numbers[name])) 88 | plt.plot(x, np.asarray(numbers[name])) 89 | plt.legend([self.title + '(' + name + ')' for name in names]) 90 | plt.grid(True) 91 | 92 | def close(self): 93 | if self.file is not None: 94 | self.file.close() 95 | 96 | 97 | class LoggerMonitor(object): 98 | """ 99 | Load and visualize multiple logs. 100 | """ 101 | def __init__(self, paths): 102 | """ 103 | paths is a dictionary with {name:filepath} pair 104 | :param paths: 105 | """ 106 | self.loggers = [] 107 | for title, path in paths.items(): 108 | logger = Logger(path, title=title, resume=True) 109 | self.loggers.append(logger) 110 | 111 | def plot(self, names=None): 112 | plt.figure() 113 | plt.subplot(121) 114 | legend_text = [] 115 | for logger in self.loggers: 116 | legend_text += plot_overlap(logger, names) 117 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 118 | plt.grid(True) 119 | 120 | 121 | if __name__ == '__main__': 122 | # # Example 123 | # logger = Logger('test.txt') 124 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 125 | 126 | # length = 100 127 | # t = np.arange(length) 128 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 129 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 130 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 131 | 132 | # for i in range(0, length): 133 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 134 | # logger.plot() 135 | 136 | # Example: logger monitor 137 | paths = { 138 | 'resadvnet20': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 139 | 'resadvnet32': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 140 | 'resadvnet44': '/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 141 | } 142 | 143 | field = ['Valid Acc.'] 144 | 145 | monitor = LoggerMonitor(paths) 146 | monitor.plot(names=field) 147 | savefig('test.pdf', dpi=300) 148 | -------------------------------------------------------------------------------- /classification/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some helper functions for PyTorch, including: 3 | - get_mean_and_std: calculate the mean and std value of dataset. 4 | - msr_init: net parameter initialization. 5 | - progress_bar: progress bar mimic xlua.progress. 6 | """ 7 | 8 | import errno 9 | import os 10 | import sys 11 | import time 12 | import math 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.init as init 17 | from torch.autograd import Variable 18 | 19 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 20 | 21 | 22 | def get_mean_and_std(dataset): 23 | """ 24 | Compute the mean and std value of dataset. 25 | :param dataset: 26 | :return: 27 | """ 28 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 29 | 30 | mean = torch.zeros(3) 31 | std = torch.zeros(3) 32 | print('==> Computing mean and std..') 33 | for inputs, targets in dataloader: 34 | for i in range(3): 35 | mean[i] += inputs[:, i, :, :].mean() 36 | std[i] += inputs[:, i, :, :].std() 37 | mean.div_(len(dataset)) 38 | std.div_(len(dataset)) 39 | return mean, std 40 | 41 | 42 | def init_params(net): 43 | """ 44 | Initialize layer parameters. 45 | :param net: 46 | :return: 47 | """ 48 | for m in net.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | init.kaiming_normal(m.weight, mode='fan_out') 51 | if m.bias: 52 | init.constant(m.bias, 0) 53 | elif isinstance(m, nn.BatchNorm2d): 54 | init.constant(m.weight, 1) 55 | init.constant(m.bias, 0) 56 | elif isinstance(m, nn.Linear): 57 | init.normal(m.weight, std=1e-3) 58 | if m.bias: 59 | init.constant(m.bias, 0) 60 | 61 | 62 | def mkdir_p(path): 63 | """ 64 | make dir if not exist 65 | :param path: 66 | :return: 67 | """ 68 | try: 69 | os.makedirs(path) 70 | except OSError as exc: # Python >2.5 71 | if exc.errno == errno.EEXIST and os.path.isdir(path): 72 | pass 73 | else: 74 | raise 75 | 76 | 77 | class AverageMeter(object): 78 | """ 79 | Computes and stores the average and current value 80 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 81 | """ 82 | def __init__(self): 83 | self.reset() 84 | 85 | def reset(self): 86 | self.val = 0 87 | self.avg = 0 88 | self.sum = 0 89 | self.count = 0 90 | 91 | def update(self, val, n=1): 92 | self.val = val 93 | self.sum += val * n 94 | self.count += n 95 | self.avg = self.sum / self.count 96 | -------------------------------------------------------------------------------- /classification/utils/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /classification/utils/progress/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/progress/__pycache__/bar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/bar.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/progress/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangjiongw/Knowledge-Distillation-PyTorch/39e1d70b7f13ea3a59d2b657de35d2fd799fcc65/classification/utils/progress/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /classification/utils/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from utils.progress import Progress 19 | from utils.progress.helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /classification/utils/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from utils.progress import Infinite, Progress 19 | from utils.progress.helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /classification/utils/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | from signal import signal, SIGINT 18 | from sys import exit 19 | 20 | 21 | HIDE_CURSOR = '\x1b[?25l' 22 | SHOW_CURSOR = '\x1b[?25h' 23 | 24 | 25 | class WriteMixin(object): 26 | hide_cursor = False 27 | 28 | def __init__(self, message=None, **kwargs): 29 | super(WriteMixin, self).__init__(**kwargs) 30 | self._width = 0 31 | if message: 32 | self.message = message 33 | 34 | if self.file.isatty(): 35 | if self.hide_cursor: 36 | print(HIDE_CURSOR, end='', file=self.file) 37 | print(self.message, end='', file=self.file) 38 | self.file.flush() 39 | 40 | def write(self, s): 41 | if self.file.isatty(): 42 | b = '\b' * self._width 43 | c = s.ljust(self._width) 44 | print(b + c, end='', file=self.file) 45 | self._width = max(self._width, len(s)) 46 | self.file.flush() 47 | 48 | def finish(self): 49 | if self.file.isatty() and self.hide_cursor: 50 | print(SHOW_CURSOR, end='', file=self.file) 51 | 52 | 53 | class WritelnMixin(object): 54 | hide_cursor = False 55 | 56 | def __init__(self, message=None, **kwargs): 57 | super(WritelnMixin, self).__init__(**kwargs) 58 | if message: 59 | self.message = message 60 | 61 | if self.file.isatty() and self.hide_cursor: 62 | print(HIDE_CURSOR, end='', file=self.file) 63 | 64 | def clearln(self): 65 | if self.file.isatty(): 66 | print('\r\x1b[K', end='', file=self.file) 67 | 68 | def writeln(self, line): 69 | if self.file.isatty(): 70 | self.clearln() 71 | print(line, end='', file=self.file) 72 | self.file.flush() 73 | 74 | def finish(self): 75 | if self.file.isatty(): 76 | print(file=self.file) 77 | if self.hide_cursor: 78 | print(SHOW_CURSOR, end='', file=self.file) 79 | 80 | 81 | class SigIntMixin(object): 82 | """ 83 | Registers a signal handler that calls finish on SIGINT 84 | """ 85 | 86 | def __init__(self, *args, **kwargs): 87 | super(SigIntMixin, self).__init__(*args, **kwargs) 88 | signal(SIGINT, self._sigint_handler) 89 | 90 | def _sigint_handler(self, signum, frame): 91 | self.finish() 92 | exit(0) 93 | -------------------------------------------------------------------------------- /classification/utils/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from utils.progress import Infinite 19 | from utils.progress.helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | 44 | class PixelSpinner(Spinner): 45 | phases = ['⣾', '⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 46 | -------------------------------------------------------------------------------- /classification/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | 6 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 7 | 8 | 9 | # functions to show an image 10 | def make_image(img, mean=(0, 0, 0), std=(1, 1, 1)): 11 | for i in range(0, 3): 12 | img[i] = img[i] * std[i] + mean[i] # de-normalize 13 | npimg = img.numpy() 14 | return np.transpose(npimg, (1, 2, 0)) 15 | 16 | 17 | def gauss(x, a, b, c): 18 | return torch.exp(-torch.pow(torch.add(x, -b), 2).div(2*c*c)).mul(a) 19 | 20 | 21 | def colorize(x): 22 | """ 23 | Converts a one-channel gray-scale image to a color heatmap image 24 | :param x: input gray-scale image 25 | :return: 26 | """ 27 | if x.dim() == 2: 28 | torch.unsqueeze(x, 0, out=x) 29 | if x.dim() == 3: 30 | cl = torch.zeros([3, x.size(1), x.size(2)]) 31 | cl[0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 32 | cl[1] = gauss(x, 1, .5, .3) 33 | cl[2] = gauss(x, 1, .2, .3) 34 | cl[cl.gt(1)] = 1 35 | elif x.dim() == 4: 36 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 37 | cl[:, 0, :, :] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 38 | cl[:, 1, :, :] = gauss(x, 1, .5, .3) 39 | cl[:, 2, :, :] = gauss(x, 1, .2, .3) 40 | return cl 41 | 42 | 43 | def show_batch(images, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)): 44 | images = make_image(torchvision.utils.make_grid(images), mean, std) 45 | plt.imshow(images) 46 | plt.show() 47 | 48 | 49 | def show_mask_single(images, mask, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)): 50 | im_size = images.size(2) 51 | 52 | # save for adding mask 53 | im_data = images.clone() 54 | for i in range(0, 3): 55 | im_data[:, i, :, :] = im_data[:, i, :, :] * std[i] + mean[i] # denormalize 56 | 57 | images = make_image(torchvision.utils.make_grid(images), mean, std) 58 | plt.subplot(2, 1, 1) 59 | plt.imshow(images) 60 | plt.axis('off') 61 | 62 | # for b in range(mask.size(0)): 63 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 64 | mask_size = mask.size(2) 65 | # print('Max %f Min %f' % (mask.max(), mask.min())) 66 | mask = (upsampling(mask, scale_factor=im_size / mask_size)) 67 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 68 | # for c in range(3): 69 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 70 | 71 | # print(mask.size()) 72 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 73 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 74 | plt.subplot(2, 1, 2) 75 | plt.imshow(mask) 76 | plt.axis('off') 77 | 78 | 79 | def show_mask(images, masklist, mean=(2, 2, 2), std=(0.5, 0.5, 0.5)): 80 | im_size = images.size(2) 81 | 82 | # save for adding mask 83 | im_data = images.clone() 84 | for i in range(0, 3): 85 | im_data[:, i, :, :] = im_data[:, i, :, :] * std[i] + mean[i] # unnormalize 86 | 87 | images = make_image(torchvision.utils.make_grid(images), mean, std) 88 | plt.subplot(1+len(masklist), 1, 1) 89 | plt.imshow(images) 90 | plt.axis('off') 91 | 92 | for i in range(len(masklist)): 93 | mask = masklist[i].data.cpu() 94 | # for b in range(mask.size(0)): 95 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 96 | mask_size = mask.size(2) 97 | # print('Max %f Min %f' % (mask.max(), mask.min())) 98 | mask = (upsampling(mask, scale_factor=im_size / mask_size)) 99 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 100 | # for c in range(3): 101 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 102 | 103 | # print(mask.size()) 104 | mask = make_image(torchvision.utils.make_grid(0.3 * im_data + 0.7 * mask.expand_as(im_data))) 105 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 106 | plt.subplot(1+len(masklist), 1, i + 2) 107 | plt.imshow(mask) 108 | plt.axis('off') 109 | 110 | # x = torch.zeros(1, 3, 3) 111 | # out = colorize(x) 112 | # out_im = make_image(out) 113 | # plt.imshow(out_im) 114 | # plt.show() 115 | --------------------------------------------------------------------------------