├── .github └── workflows │ ├── python-publish.yml │ ├── test_torch_1110.yml │ ├── test_torch_181.yml │ ├── test_torch_1_12_1.yml │ └── test_torch_200.yml ├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── README_CN.md ├── assets ├── concat.png ├── conv-conv.png ├── conv-fc.png ├── densenet_dep.png ├── dep.png ├── dep1.png ├── dep2.png ├── dep3.png ├── group_sparsity.png ├── intro.jpg ├── intro.png ├── isomorphic_pruning.png ├── residual.png └── split.png ├── examples ├── LLMs │ ├── eval_ppl.py │ ├── prune_llm.py │ └── readme.md ├── README.md ├── latency │ └── measure_latency.py ├── notebook │ ├── 0 - QuickStart.ipynb │ ├── 1 - Customize Your Own Pruners.ipynb │ └── 2 - Exploring Dependency Groups.ipynb ├── timm_models │ ├── prune_timm_models.py │ └── readme.md ├── torchvision_models │ ├── readme.md │ ├── torchvision_global_pruning.py │ └── torchvision_pruning.py ├── transformers │ ├── draw_acc_curve.py │ ├── finetune.py │ ├── measure_latency.py │ ├── presets.py │ ├── prune_hf_bert.py │ ├── prune_hf_swin.py │ ├── prune_hf_vit.py │ ├── prune_timm_vit.py │ ├── readme.md │ ├── sampler.py │ ├── scripts │ │ ├── finetune_hf_vit_b_16_l1_uniform.sh │ │ ├── finetune_hf_vit_b_16_taylor_uniform.sh │ │ ├── finetune_timm_deit_b_16_taylor_uniform.sh │ │ ├── finetune_timm_vit_b_16_hessian_uniform.sh │ │ ├── finetune_timm_vit_b_16_l1_uniform.sh │ │ ├── finetune_timm_vit_b_16_l2_uniform.sh │ │ ├── finetune_timm_vit_b_16_taylor_bottleneck.sh │ │ ├── finetune_timm_vit_b_16_taylor_uniform.sh │ │ ├── prune_hf_vit_b_16_l1_uniform.sh │ │ ├── prune_hf_vit_b_16_taylor_uniform.sh │ │ ├── prune_timm_deit_b_16_taylor_uniform.sh │ │ ├── prune_timm_vit_b_16_hessian_uniform.sh │ │ ├── prune_timm_vit_b_16_l1_uniform.sh │ │ ├── prune_timm_vit_b_16_l2_uniform.sh │ │ ├── prune_timm_vit_b_16_taylor_bottleneck.sh │ │ ├── prune_timm_vit_b_16_taylor_uniform.sh │ │ ├── prune_timm_vit_b_16_taylor_uniform_global.sh │ │ ├── test_pretrained_hf_vit_b_16.sh │ │ ├── test_pretrained_timm_deit_b_16.sh │ │ └── test_pretrained_timm_vit_b_16.sh │ ├── transforms.py │ └── utils.py ├── yolov5 │ ├── detect_after_pruning.py │ └── readme.md ├── yolov7 │ ├── readme.md │ ├── yolov7_detect_pruned.py │ └── yolov7_train_pruned.py └── yolov8 │ ├── readme.md │ └── yolov8_pruning.py ├── reproduce ├── benchmark_importance_criteria.py ├── benchmark_latency.py ├── engine │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── cifar │ │ │ ├── __init__.py │ │ │ ├── densenet.py │ │ │ ├── googlenet.py │ │ │ ├── inceptionv3.py │ │ │ ├── inceptionv4.py │ │ │ ├── mobilenetv2.py │ │ │ ├── nasnet.py │ │ │ ├── preactresnet.py │ │ │ ├── resnet.py │ │ │ ├── resnet_tiny.py │ │ │ ├── resnext.py │ │ │ ├── senet.py │ │ │ ├── swin.py │ │ │ ├── vgg.py │ │ │ ├── vit.py │ │ │ └── xception.py │ │ ├── graph │ │ │ ├── __init__.py │ │ │ └── dgcnn.py │ │ └── imagenet │ │ │ ├── __init__.py │ │ │ └── vision_transformer.py │ └── utils │ │ ├── __init__.py │ │ ├── datasets │ │ ├── __init__.py │ │ └── modelnet40.py │ │ ├── evaluator.py │ │ ├── imagenet_utils │ │ ├── __init__.py │ │ ├── presets.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ └── utils.py │ │ ├── metrics.py │ │ └── utils.py ├── main.py ├── main_imagenet.py ├── readme.md ├── registry.py ├── requirements.txt ├── run │ ├── cifar10 │ │ └── prune │ │ │ ├── cifar10-global-group_sl-resnet56 │ │ │ └── cifar10-global-group_sl-resnet56.txt │ │ │ ├── cifar10-global-group_sl-resnet56_2.55x │ │ │ └── cifar10-global-group_sl-resnet56.txt │ │ │ ├── cifar10-global-growing_reg-resnet56 │ │ │ └── cifar10-global-growing_reg-resnet56.txt │ │ │ ├── cifar10-global-l1-resnet56 │ │ │ └── cifar10-global-l1-resnet56.txt │ │ │ └── cifar10-global-slim-resnet56 │ │ │ └── cifar10-global-slim-resnet56.txt │ └── cifar100 │ │ └── prune │ │ └── cifar100-global-group_sl-vgg19 │ │ └── cifar100-global-group_sl-vgg19.txt ├── scripts │ ├── pretrain │ │ └── cifar_pretrain.sh │ └── prune │ │ ├── ablation │ │ ├── global_group_norm.sh │ │ ├── global_group_sl.sh │ │ ├── global_group_sl_p2.sh │ │ ├── global_l1.sh │ │ ├── global_l1_group_conv.sh │ │ ├── global_lamp.sh │ │ ├── global_rand.sh │ │ ├── global_slim.sh │ │ ├── local_group_norm.sh │ │ ├── local_l1.sh │ │ ├── local_l1_group_conv.sh │ │ └── local_rand.sh │ │ ├── cifar │ │ ├── bn_pruner.sh │ │ ├── group_pruner.sh │ │ └── l1_norm_pruner.sh │ │ ├── imagenet │ │ ├── densenet_gsl.sh │ │ ├── mobilenetv2_group_norm.sh │ │ ├── mobilenetv2_group_sl.sh │ │ ├── next50_group_norm.sh │ │ ├── next50_group_sl.sh │ │ ├── regnet_group_norm.sh │ │ ├── regnet_group_sl.sh │ │ ├── resnet50_group_norm.sh │ │ ├── resnet50_group_sl.sh │ │ ├── vgg_group_norm copy.sh │ │ └── vit_group_norm.sh │ │ ├── modelnet40 │ │ └── global_group_norm.sh │ │ └── ppi │ │ └── global_group_norm.sh └── tools │ └── draw.py ├── requirements.txt ├── setup.py ├── tests ├── graph_drawing.py ├── test_backward.py ├── test_benchmark.py ├── test_concat.py ├── test_concat_split.py ├── test_customized_layer.py ├── test_dependency_graph.py ├── test_dependency_lenet.py ├── test_flops.py ├── test_fully_connected_layers.py ├── test_group_prune.py ├── test_hessian_importance.py ├── test_importance_reduction.py ├── test_interactive_pruner.py ├── test_isomorphic.py ├── test_load.py ├── test_multiple_inputs_and_outputs.py ├── test_non_feature_dim_cat.py ├── test_print_tool.py ├── test_pruner.py ├── test_pruning_fn.py ├── test_regularization.py ├── test_reshape.py ├── test_score_normalization.py ├── test_serialization.py ├── test_single_channel_output.py ├── test_soft_pruning.py ├── test_split.py ├── test_taylor_importance.py ├── test_unused_parameters.py └── test_unwrapped_parameters.py └── torch_pruning ├── __init__.py ├── _helpers.py ├── dependency.py ├── ops.py ├── pruner ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── base_pruner.py │ ├── batchnorm_scale_pruner.py │ ├── compatibility.py │ ├── group_norm_pruner.py │ ├── growing_reg_pruner.py │ └── scheduler.py ├── function.py └── importance.py ├── serialization.py └── utils ├── __init__.py ├── benchmark.py ├── compute_mat_grad.py ├── op_counter.py └── utils.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/test_torch_1110.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Test-Pytorch-1.11.0 5 | 6 | on: workflow_dispatch 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.9" 22 | cache: 'pip' # caching pip dependencies 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest torch==1.11.0 torchvision 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Test with pytest 29 | run: | 30 | pytest --ignore=tests/test_unwrapped_parameters.py --ignore=tests/test_backward.py --ignore=tests/test_benchmark.py 31 | pytest 32 | -------------------------------------------------------------------------------- /.github/workflows/test_torch_181.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Test-Pytorch-1.8.1 5 | 6 | on: workflow_dispatch 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.9" 22 | cache: 'pip' # caching pip dependencies 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest torch==1.8.1 torchvision==0.9.1 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Test with pytest 29 | run: | 30 | pytest --ignore=tests/test_unwrapped_parameters.py --ignore=tests/test_backward.py --ignore=tests/test_concat_split.py --ignore=tests/test_serialization.py --ignore=tests/test_non_feature_dim_cat.py --ignore=tests/test_benchmark.py 31 | -------------------------------------------------------------------------------- /.github/workflows/test_torch_1_12_1.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Test-Pytorch-1.12.1 5 | 6 | on: workflow_dispatch 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.9" 22 | cache: 'pip' # caching pip dependencies 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest torch==1.12.1 torchvision==0.13.1 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Test with pytest 29 | run: | 30 | pytest --ignore=tests/test_unwrapped_parameters.py --ignore=tests/test_backward.py --ignore=tests/test_benchmark.py 31 | -------------------------------------------------------------------------------- /.github/workflows/test_torch_200.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Test-Pytorch-2.0.0 5 | 6 | on: workflow_dispatch 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.10" 22 | cache: 'pip' # caching pip dependencies 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest torch torchvision 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Test with pytest 29 | run: | 30 | pytest --ignore=tests/test_concat_split.py --ignore=tests/test_benchmark.py 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | *.pth 4 | .vscode 5 | *.egg-info 6 | dist 7 | build 8 | checkpoints 9 | run 10 | text.txt 11 | run* 12 | baselines 13 | *.txt 14 | ours.sh 15 | benchmarks/draw 16 | tests/*.png 17 | tests/*.pdf 18 | modelnet40_ply_hdf5_2048.zip 19 | modelnet40_ply_hdf5_2048 20 | benchmarks/draw 21 | torch_pruning_v1.0 22 | torch_pruning_bak 23 | *.log 24 | .ipynb_checkpoints 25 | .idea/ 26 | .pytest_cache 27 | benchmarks/run 28 | output 29 | acc.png 30 | pretrained 31 | data 32 | 33 | cache/ -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Gongfan Fang, gongfan@u.nus.edu 2 | Xinyin Ma, maxinyin@u.nus.edu 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gongfan Fang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/concat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/concat.png -------------------------------------------------------------------------------- /assets/conv-conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/conv-conv.png -------------------------------------------------------------------------------- /assets/conv-fc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/conv-fc.png -------------------------------------------------------------------------------- /assets/densenet_dep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/densenet_dep.png -------------------------------------------------------------------------------- /assets/dep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/dep.png -------------------------------------------------------------------------------- /assets/dep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/dep1.png -------------------------------------------------------------------------------- /assets/dep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/dep2.png -------------------------------------------------------------------------------- /assets/dep3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/dep3.png -------------------------------------------------------------------------------- /assets/group_sparsity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/group_sparsity.png -------------------------------------------------------------------------------- /assets/intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/intro.jpg -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/intro.png -------------------------------------------------------------------------------- /assets/isomorphic_pruning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/isomorphic_pruning.png -------------------------------------------------------------------------------- /assets/residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/residual.png -------------------------------------------------------------------------------- /assets/split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VainF/Torch-Pruning/6ca5595d5cb4366d2d8f11ae6879fc4b61b9727d/assets/split.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | * [LLMs on Huggingface](LLMs): Llama-2/3, Phi-3, Qwen2/2.5 4 | * [timm](timm_models) 5 | * [SlimSAM for Segment Anything Models (SAM)](https://github.com/czg1225/SlimSAM) 6 | * [Diffusion Models](https://github.com/VainF/Diff-Pruning) 7 | * [Transformers](transformers): ViTs, Bert, Swin, etc. 8 | * [Isomorphic Pruning (ViTs & ConvNextx)](https://github.com/VainF/Isomorphic-Pruning) 9 | * Yolo: [yolov5](yolov5), [yolov7](yolov7), [yolov8](yolov8) 10 | * [torchvision models](torchvision_models): FasterRCNN, SSD, ResNe(X)t, DenseNet, RegNet, DeepLab, etc. 11 | -------------------------------------------------------------------------------- /examples/transformers/draw_acc_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | # Read such line "Epoch: [250] Total time: ... Test: Acc@1 77.984 Acc@5 94.114" from a file and record the accuracy. 4 | plt.style.use('ggplot') 5 | 6 | def parse_acc_from_file(file_path): 7 | with open(file_path, 'r') as f: 8 | lines = f.readlines() 9 | acc = [] 10 | for line in lines: 11 | if 'Test: Acc@1' in line: 12 | acc.append(float(line.split(' ')[-3])) 13 | return acc 14 | 15 | log_dict = { 16 | 'Hessian-uniform': 'output/vit_b_16_pruning_hessian_uniform/train.log', 17 | 'Taylor-uniform': 'output/vit_b_16_pruning_taylor_uniform/train.log', 18 | 'Taylor-bottleneck': 'output/vit_b_16_pruning_taylor_bottleneck/train.log', 19 | 'L1-uniform': 'output/vit_b_16_pruning_l1_uniform/train.log', 20 | 'L2-uniform': 'output/vit_b_16_pruning_l2_uniform/train.log', 21 | } 22 | 23 | plt.figure(figsize=(8, 4), dpi=200) 24 | for exp_name, log_path in log_dict.items(): 25 | acc = parse_acc_from_file(log_path) 26 | plt.plot(acc, label=exp_name) 27 | print(exp_name, "| Last Epoch:", acc[-1], "| Best Epoch:", max(acc)) 28 | #plt.plot(acc_random, label='Random-uniform')rplt.xlabel('Epoch') 29 | plt.ylabel('Accuracy') 30 | plt.legend() 31 | plt.legend(loc='lower right') 32 | 33 | # change style 34 | plt.savefig('acc.png') -------------------------------------------------------------------------------- /examples/transformers/measure_latency.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import time 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | import torch_pruning as tp 5 | import torch 6 | import timm 7 | import torch.nn.functional as F 8 | 9 | import argparse 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', type=str, default='vit_base_patch16_224', help='model name or path') 12 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 13 | args = parser.parse_args() 14 | 15 | def forward(self, x): 16 | B, N, C = x.shape 17 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 18 | q, k, v = qkv.unbind(0) 19 | q, k = self.q_norm(q), self.k_norm(k) 20 | 21 | if self.fused_attn: 22 | x = F.scaled_dot_product_attention( 23 | q, k, v, 24 | dropout_p=self.attn_drop.p, 25 | ) 26 | else: 27 | q = q * self.scale 28 | attn = q @ k.transpose(-2, -1) 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | x = attn @ v 32 | 33 | x = x.transpose(1, 2).reshape(B, N, -1) 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x 37 | 38 | def main(): 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | if os.path.isfile(args.model): 41 | loaded_pth = torch.load(args.model, map_location='cpu') 42 | if isinstance(loaded_pth, dict): 43 | model = loaded_pth['model'].to(device) 44 | else: 45 | model = loaded_pth.to(device) 46 | else: 47 | model = timm.create_model(args.model, pretrained=True).to(device) 48 | 49 | for m in model.modules(): 50 | if isinstance(m, timm.models.vision_transformer.Attention): 51 | m.forward = forward.__get__(m, timm.models.vision_transformer.Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module 52 | 53 | example_input = torch.rand(args.batch_size, 3, 224, 224).to(device) 54 | macs, params = tp.utils.count_ops_and_params(model, example_input) 55 | latency_mu, latency_std = estimate_latency(model, example_input) 56 | print(f"MACs: {macs/1e9:.2f} G, \tParams: {params/1e6:.2f} M, \tLatency: {latency_mu:.2f} ms +- {latency_std:.2f} ms") 57 | 58 | def estimate_latency(model, example_inputs, repetitions=300): 59 | import numpy as np 60 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 61 | timings=np.zeros((repetitions,1)) 62 | 63 | for _ in range(50): 64 | _ = model(example_inputs) 65 | 66 | with torch.no_grad(): 67 | for rep in range(repetitions): 68 | starter.record() 69 | _ = model(example_inputs) 70 | ender.record() 71 | # WAIT FOR GPU SYNC 72 | torch.cuda.synchronize() 73 | curr_time = starter.elapsed_time(ender) 74 | timings[rep] = curr_time 75 | 76 | mean_syn = np.sum(timings) / repetitions 77 | std_syn = np.std(timings) 78 | return mean_syn, std_syn 79 | 80 | if __name__=='__main__': 81 | main() -------------------------------------------------------------------------------- /examples/transformers/presets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms.functional import InterpolationMode 3 | 4 | 5 | def get_module(use_v2): 6 | # We need a protected import to avoid the V2 warning in case just V1 is used 7 | if use_v2: 8 | import torchvision.transforms.v2 9 | 10 | return torchvision.transforms.v2 11 | else: 12 | import torchvision.transforms 13 | 14 | return torchvision.transforms 15 | 16 | 17 | class ClassificationPresetTrain: 18 | # Note: this transform assumes that the input to forward() are always PIL 19 | # images, regardless of the backend parameter. We may change that in the 20 | # future though, if we change the output type from the dataset. 21 | def __init__( 22 | self, 23 | *, 24 | crop_size, 25 | mean=(0.485, 0.456, 0.406), 26 | std=(0.229, 0.224, 0.225), 27 | interpolation=InterpolationMode.BILINEAR, 28 | hflip_prob=0.5, 29 | auto_augment_policy=None, 30 | ra_magnitude=9, 31 | augmix_severity=3, 32 | random_erase_prob=0.0, 33 | backend="pil", 34 | use_v2=False, 35 | ): 36 | T = get_module(use_v2) 37 | 38 | transforms = [] 39 | backend = backend.lower() 40 | if backend == "tensor": 41 | transforms.append(T.PILToTensor()) 42 | elif backend != "pil": 43 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") 44 | 45 | transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) 46 | if hflip_prob > 0: 47 | transforms.append(T.RandomHorizontalFlip(hflip_prob)) 48 | if auto_augment_policy is not None: 49 | if auto_augment_policy == "ra": 50 | transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) 51 | elif auto_augment_policy == "ta_wide": 52 | transforms.append(T.TrivialAugmentWide(interpolation=interpolation)) 53 | elif auto_augment_policy == "augmix": 54 | transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity)) 55 | else: 56 | aa_policy = T.AutoAugmentPolicy(auto_augment_policy) 57 | transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation)) 58 | 59 | if backend == "pil": 60 | transforms.append(T.PILToTensor()) 61 | 62 | transforms.extend( 63 | [ 64 | T.ConvertImageDtype(torch.float), 65 | T.Normalize(mean=mean, std=std), 66 | ] 67 | ) 68 | if random_erase_prob > 0: 69 | transforms.append(T.RandomErasing(p=random_erase_prob)) 70 | 71 | self.transforms = T.Compose(transforms) 72 | 73 | def __call__(self, img): 74 | return self.transforms(img) 75 | 76 | 77 | class ClassificationPresetEval: 78 | def __init__( 79 | self, 80 | *, 81 | crop_size, 82 | resize_size=256, 83 | mean=(0.485, 0.456, 0.406), 84 | std=(0.229, 0.224, 0.225), 85 | interpolation=InterpolationMode.BILINEAR, 86 | backend="pil", 87 | use_v2=False, 88 | ): 89 | T = get_module(use_v2) 90 | transforms = [] 91 | backend = backend.lower() 92 | if backend == "tensor": 93 | transforms.append(T.PILToTensor()) 94 | elif backend != "pil": 95 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") 96 | 97 | transforms += [ 98 | T.Resize(resize_size, interpolation=interpolation, antialias=True), 99 | T.CenterCrop(crop_size), 100 | ] 101 | 102 | if backend == "pil": 103 | transforms.append(T.PILToTensor()) 104 | 105 | transforms += [ 106 | T.ConvertImageDtype(torch.float), 107 | T.Normalize(mean=mean, std=std), 108 | ] 109 | 110 | self.transforms = T.Compose(transforms) 111 | 112 | def __call__(self, img): 113 | return self.transforms(img) 114 | -------------------------------------------------------------------------------- /examples/transformers/prune_hf_bert.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, BertModel 2 | import torch 3 | from transformers.models.bert.modeling_bert import BertSelfAttention 4 | import torch_pruning as tp 5 | 6 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 7 | model = BertModel.from_pretrained("bert-base-uncased") 8 | #print(model) 9 | hf_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 10 | example_inputs = {'input_ids': hf_inputs['input_ids'], 'token_type_ids': hf_inputs['token_type_ids'], 'attention_mask': hf_inputs['attention_mask']} 11 | 12 | #outputs = model(**example_inputs) 13 | #last_hidden_states = outputs.last_hidden_state 14 | 15 | imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") 16 | base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) 17 | num_heads = {} 18 | 19 | # All heads should be pruned simultaneously, so we group channels by head. 20 | for m in model.modules(): 21 | if isinstance(m, BertSelfAttention): 22 | num_heads[m.query] = m.num_attention_heads 23 | num_heads[m.key] = m.num_attention_heads 24 | num_heads[m.value] = m.num_attention_heads 25 | 26 | pruner = tp.pruner.BasePruner( 27 | model, 28 | example_inputs, 29 | global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers. 30 | importance=imp, # importance criterion for parameter selection 31 | iterative_steps=1, # the number of iterations to achieve target pruning ratio 32 | pruning_ratio=0.5, 33 | num_heads=num_heads, 34 | prune_head_dims=False, 35 | prune_num_heads=True, 36 | head_pruning_ratio=0.5, 37 | output_transform=lambda out: out.pooler_output.sum(), 38 | ignored_layers=[model.pooler], 39 | ) 40 | 41 | for g in pruner.step(interactive=True): 42 | #print(g) 43 | g.prune() 44 | 45 | # Modify the attention head size and all head size after pruning 46 | for m in model.modules(): 47 | if isinstance(m, BertSelfAttention): 48 | print("Num heads: %d, head size: %d =>"%(m.num_attention_heads, m.attention_head_size)) 49 | m.num_attention_heads = pruner.num_heads[m.query] 50 | m.attention_head_size = m.query.out_features // m.num_attention_heads 51 | m.all_head_size = m.query.out_features 52 | print("Num heads: %d, head size: %d"%(m.num_attention_heads, m.attention_head_size)) 53 | print() 54 | 55 | print(model) 56 | test_output = model(**example_inputs) 57 | pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) 58 | print("Base MACs: %f M, Pruned MACs: %f M"%(base_macs/1e6, pruned_macs/1e6)) 59 | print("Base Params: %f M, Pruned Params: %f M"%(base_params/1e6, pruned_params/1e6)) 60 | -------------------------------------------------------------------------------- /examples/transformers/prune_hf_swin.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoImageProcessor, AutoModelForImageClassification 2 | from PIL import Image 3 | import requests 4 | import torch 5 | import torch.nn as nn 6 | from typing import Sequence 7 | import torch_pruning as tp 8 | from transformers.models.swin.modeling_swin import SwinSelfAttention, SwinPatchMerging 9 | 10 | 11 | class SwinPatchMergingPruner(tp.BasePruningFunc): 12 | 13 | def prune_out_channels(self, layer: nn.Module, idxs: list): 14 | tp.prune_linear_out_channels(layer.reduction, idxs) 15 | return layer 16 | 17 | def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module: 18 | dim = layer.dim 19 | idxs_repeated = idxs + \ 20 | [i+dim for i in idxs] + \ 21 | [i+2*dim for i in idxs] + \ 22 | [i+3*dim for i in idxs] 23 | tp.prune_linear_in_channels(layer.reduction, idxs_repeated) 24 | tp.prune_layernorm_out_channels(layer.norm, idxs_repeated) 25 | return layer 26 | 27 | def get_out_channels(self, layer): 28 | return layer.reduction.out_features 29 | 30 | def get_in_channels(self, layer): 31 | return layer.dim 32 | 33 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 34 | image = Image.open(requests.get(url, stream=True).raw) 35 | 36 | processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") 37 | model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") 38 | 39 | example_inputs = processor(images=image, return_tensors="pt")["pixel_values"] 40 | inputs = processor(images=image, return_tensors="pt") 41 | outputs = model(**inputs) 42 | logits = outputs.logits 43 | # model predicts one of the 1000 ImageNet classes 44 | predicted_class_idx = logits.argmax(-1).item() 45 | print("Predicted class:", model.config.id2label[predicted_class_idx]) 46 | 47 | 48 | print(model) 49 | imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") 50 | base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) 51 | num_heads = {} 52 | 53 | ignored_layers = [model.classifier] 54 | # All heads should be pruned simultaneously, so we group channels by head. 55 | for m in model.modules(): 56 | if isinstance(m, SwinSelfAttention): 57 | num_heads[m.query] = m.num_attention_heads 58 | num_heads[m.key] = m.num_attention_heads 59 | num_heads[m.value] = m.num_attention_heads 60 | 61 | pruner = tp.pruner.BasePruner( 62 | model, 63 | example_inputs, 64 | global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers. 65 | importance=imp, # importance criterion for parameter selection 66 | iterative_steps=1, # the number of iterations to achieve target pruning ratio 67 | pruning_ratio=0.5, 68 | num_heads=num_heads, 69 | output_transform=lambda out: out.logits.sum(), 70 | ignored_layers=ignored_layers, 71 | customized_pruners={SwinPatchMerging: SwinPatchMergingPruner()}, 72 | root_module_types=(nn.Linear, nn.LayerNorm, SwinPatchMerging), 73 | ) 74 | 75 | for g in pruner.step(interactive=True): 76 | #print(g) 77 | g.prune() 78 | 79 | print(model) 80 | 81 | # Modify the attention head size and all head size aftering pruning 82 | for m in model.modules(): 83 | if isinstance(m, SwinSelfAttention): 84 | m.attention_head_size = m.query.out_features // m.num_attention_heads 85 | m.all_head_size = m.query.out_features 86 | 87 | test_output = model(example_inputs) 88 | pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) 89 | print("Base MACs: %f G, Pruned MACs: %f G"%(base_macs/1e9, pruned_macs/1e9)) 90 | print("Base Params: %f M, Pruned Params: %f M"%(base_params/1e6, pruned_params/1e6)) 91 | 92 | 93 | -------------------------------------------------------------------------------- /examples/transformers/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class RASampler(torch.utils.data.Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset for distributed, 9 | with repeated augmentation. 10 | It ensures that different each augmented version of a sample will be visible to a 11 | different process (GPU). 12 | Heavily based on 'torch.utils.data.DistributedSampler'. 13 | 14 | This is borrowed from the DeiT Repo: 15 | https://github.com/facebookresearch/deit/blob/main/samplers.py 16 | """ 17 | 18 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): 19 | if num_replicas is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available!") 22 | num_replicas = dist.get_world_size() 23 | if rank is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available!") 26 | rank = dist.get_rank() 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.epoch = 0 31 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) 32 | self.total_size = self.num_samples * self.num_replicas 33 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 34 | self.shuffle = shuffle 35 | self.seed = seed 36 | self.repetitions = repetitions 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # Deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.seed + self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 44 | else: 45 | indices = list(range(len(self.dataset))) 46 | 47 | # Add extra samples to make it evenly divisible 48 | indices = [ele for ele in indices for i in range(self.repetitions)] 49 | indices += indices[: (self.total_size - len(indices))] 50 | assert len(indices) == self.total_size 51 | 52 | # Subsample 53 | indices = indices[self.rank : self.total_size : self.num_replicas] 54 | assert len(indices) == self.num_samples 55 | 56 | return iter(indices[: self.num_selected_samples]) 57 | 58 | def __len__(self): 59 | return self.num_selected_samples 60 | 61 | def set_epoch(self, epoch): 62 | self.epoch = epoch 63 | -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_hf_vit_b_16_l1_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/hf_vit_base_patch16_224_pruned_l1_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/hf_vit_b_16_pruning_l1_uniform \ 21 | --is_huggingface \ -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_hf_vit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/hf_vit_base_patch16_224_pruned_taylor_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/hf_vit_b_16_pruning_taylor_uniform \ 21 | --is_huggingface \ -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_deit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/deit_base_patch16_224_pruned_taylor_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --random-erase 0.25 \ 19 | --cutmix-alpha 1.0 \ 20 | --data-path "data/imagenet" \ 21 | --output-dir output/deit_b_16_pruning_taylor_uniform \ 22 | --use_imagenet_mean_std \ -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_vit_b_16_hessian_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/vit_b_16_pruning_hessian_uniform -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_vit_b_16_l1_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/vit_base_patch16_224_pruned_l1_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/vit_b_16_pruning_l1_uniform 21 | -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_vit_b_16_l2_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/vit_base_patch16_224_pruned_l2_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/vit_b_16_pruning_l2_uniform -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_vit_b_16_taylor_bottleneck.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/vit_base_patch16_224_pruned_taylor_bottleneck.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --output-dir output/vit_b_16_pruning_taylor_bottleneck -------------------------------------------------------------------------------- /examples/transformers/scripts/finetune_timm_vit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=8 finetune.py \ 2 | --model "output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth" \ 3 | --epochs 300 \ 4 | --batch-size 256 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --random-erase 0.25 \ 19 | --cutmix-alpha 1.0 \ 20 | --data-path "data/imagenet" \ 21 | --output-dir output/vit_b_16_pruning_taylor_uniform_v2 -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_hf_vit_b_16_l1_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_hf_vit.py \ 2 | --model_name google/vit-base-patch16-224 \ 3 | --pruning_type l1 \ 4 | --pruning_ratio 0.5 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --train_batch_size 64 \ 8 | --val_batch_size 64 \ 9 | --save_as output/pruned/hf_vit_base_patch16_224_pruned_taylor_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_hf_vit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_hf_vit.py \ 2 | --model_name google/vit-base-patch16-224 \ 3 | --pruning_type taylor \ 4 | --pruning_ratio 0.5 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --test_accuracy \ 8 | --train_batch_size 64 \ 9 | --val_batch_size 64 \ 10 | --save_as output/pruned/hf_vit_base_patch16_224_pruned_taylor_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_deit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name deit_base_distilled_patch16_224 \ 3 | --pruning_type taylor \ 4 | --pruning_ratio 0.54 \ 5 | --taylor_batchs 50 \ 6 | --data_path data/imagenet \ 7 | --train_batch_size 64 \ 8 | --val_batch_size 64 \ 9 | --save_as output/pruned/deit_base_patch16_224_pruned_taylor_uniform.pth \ 10 | --use_imagenet_mean_std \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_hessian_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type hessian \ 4 | --pruning_ratio 0.5 \ 5 | --taylor_batchs 10 \ 6 | --test_accuracy \ 7 | --data_path data/imagenet \ 8 | --train_batch_size 64 \ 9 | --val_batch_size 64 \ 10 | --save_as output/pruned/vit_base_patch16_224_pruned_hessian_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_l1_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type l1 \ 4 | --pruning_ratio 0.5 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --train_batch_size 64 \ 8 | --val_batch_size 64 \ 9 | --save_as output/pruned/vit_base_patch16_224_pruned_l1_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_l2_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type l2 \ 4 | --pruning_ratio 0.5 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --test_accuracy \ 8 | --train_batch_size 64 \ 9 | --val_batch_size 64 \ 10 | --save_as output/pruned/vit_base_patch16_224_pruned_l2_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_taylor_bottleneck.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type taylor \ 4 | --pruning_ratio 0.73 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --bottleneck \ 8 | --train_batch_size 64 \ 9 | --val_batch_size 64 \ 10 | --save_as output/pruned/vit_base_patch16_224_pruned_taylor_bottleneck.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type taylor \ 4 | --pruning_ratio 0.54 \ 5 | --taylor_batchs 50 \ 6 | --data_path data/imagenet \ 7 | --train_batch_size 64 \ 8 | --val_batch_size 64 \ 9 | --save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \ -------------------------------------------------------------------------------- /examples/transformers/scripts/prune_timm_vit_b_16_taylor_uniform_global.sh: -------------------------------------------------------------------------------- 1 | python prune_timm_vit.py \ 2 | --model_name vit_base_patch16_224 \ 3 | --pruning_type taylor \ 4 | --pruning_ratio 0.6 \ 5 | --taylor_batchs 10 \ 6 | --data_path data/imagenet \ 7 | --train_batch_size 64 \ 8 | --val_batch_size 64 \ 9 | --save_as output/pruned/vit_base_patch16_224_pruned_taylor_uniform.pth \ 10 | --global_pruning -------------------------------------------------------------------------------- /examples/transformers/scripts/test_pretrained_hf_vit_b_16.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model "google/vit-base-patch16-224" \ 3 | --epochs 300 \ 4 | --batch-size 32 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --test-only \ 21 | --interpolation bilinear \ 22 | --is_huggingface \ -------------------------------------------------------------------------------- /examples/transformers/scripts/test_pretrained_timm_deit_b_16.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model "deit_base_distilled_patch16_224" \ 3 | --epochs 300 \ 4 | --batch-size 32 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --test-only \ 21 | --use_imagenet_mean_std \ 22 | -------------------------------------------------------------------------------- /examples/transformers/scripts/test_pretrained_timm_vit_b_16.sh: -------------------------------------------------------------------------------- 1 | python finetune.py \ 2 | --model "vit_base_patch16_224" \ 3 | --epochs 300 \ 4 | --batch-size 32 \ 5 | --opt adamw \ 6 | --lr 0.00015 \ 7 | --wd 0.3 \ 8 | --lr-scheduler cosineannealinglr \ 9 | --lr-warmup-method linear \ 10 | --lr-warmup-epochs 0 \ 11 | --lr-warmup-decay 0.033 \ 12 | --amp \ 13 | --label-smoothing 0.11 \ 14 | --mixup-alpha 0.2 \ 15 | --auto-augment ra \ 16 | --clip-grad-norm 1 \ 17 | --ra-sampler \ 18 | --cutmix-alpha 1.0 \ 19 | --data-path "data/imagenet" \ 20 | --test-only \ 21 | -------------------------------------------------------------------------------- /examples/yolov7/readme.md: -------------------------------------------------------------------------------- 1 | # YOLOv7 Pruning 2 | 3 | ## 0. Requirements 4 | 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | Tested environment: 9 | ``` 10 | Pytorch==1.12.1 11 | Torchvision==0.13.1 12 | ``` 13 | 14 | ## 1. Pruning 15 | The following scripts (adapted from [yolov7/detect.py](https://github.com/WongKinYiu/yolov7/blob/main/detect.py) and [yolov7/train.py](https://github.com/WongKinYiu/yolov7/blob/main/train.py)) provide the basic examples of pruning YOLOv7. It is important to note that the training part has not been validated yet due to the time-consuming training process. 16 | 17 | Note: [yolov7_detect_pruned.py](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/prunability/yolov7_detect_pruned.py) does not include any code for fine-tuning. 18 | 19 | ```bash 20 | git clone https://github.com/WongKinYiu/yolov7.git 21 | cp yolov7_detect_pruned.py yolov7/ 22 | cp yolov7_train_pruned.py yolov7/ 23 | cd yolov7 24 | 25 | # Test only: We only prune and test the YOLOv7 model in this script. COCO dataset is not required. 26 | python yolov7_detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg 27 | 28 | # Training with pruned yolov7 (The training part is not validated) 29 | # Please download the pretrained yolov7_training.pt from https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt. 30 | python yolov7_train_pruned.py --workers 8 --device 0 --batch-size 1 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights 'yolov7.pt' --name yolov7 --hyp data/hyp.scratch.p5.yaml 31 | ``` 32 | 33 | #### Screenshot for yolov7_train_pruned.py: 34 | ![image](https://user-images.githubusercontent.com/18592211/232129303-18a61be1-b505-4950-b6a1-c60b4974291b.png) 35 | 36 | 37 | #### Outputs of yolov7_detect_pruned.py: 38 | ``` 39 | Model( 40 | (model): Sequential( 41 | (0): Conv( 42 | (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 43 | (act): SiLU(inplace=True) 44 | ) 45 | ... 46 | (104): RepConv( 47 | (act): SiLU(inplace=True) 48 | (rbr_reparam): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 49 | ) 50 | (105): Detect( 51 | (m): ModuleList( 52 | (0): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) 53 | (1): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) 54 | (2): Conv2d(1024, 255, kernel_size=(1, 1), stride=(1, 1)) 55 | ) 56 | ) 57 | ) 58 | ) 59 | 60 | 61 | Model( 62 | (model): Sequential( 63 | (0): Conv( 64 | (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 65 | (act): SiLU(inplace=True) 66 | ) 67 | ... 68 | (104): RepConv( 69 | (act): SiLU(inplace=True) 70 | (rbr_reparam): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 71 | ) 72 | (105): Detect( 73 | (m): ModuleList( 74 | (0): Conv2d(128, 255, kernel_size=(1, 1), stride=(1, 1)) 75 | (1): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1)) 76 | (2): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1)) 77 | ) 78 | ) 79 | ) 80 | ) 81 | Before Pruning: MACs=6.413721 G, #Params=0.036905 G 82 | After Pruning: MACs=1.639895 G, #Params=0.009347 G 83 | ``` 84 | 85 | -------------------------------------------------------------------------------- /reproduce/benchmark_latency.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet50 as model_entry 2 | import sys, os 3 | import time 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 5 | import torch_pruning as tp 6 | import torch 7 | 8 | 9 | def main(): 10 | 11 | model = model_entry(pretrained=True).to('cuda:0') 12 | example_input = torch.rand(32, 3, 224, 224).to('cuda:0') 13 | importance = tp.importance.MagnitudeImportance(p=2) 14 | iterative_steps = 20 15 | pruner = tp.pruner.MagnitudePruner( 16 | model = model, 17 | example_inputs=example_input, 18 | importance=importance, 19 | pruning_ratio=1, 20 | iterative_steps=iterative_steps, 21 | round_to=2, 22 | ) 23 | 24 | # Before Pruning 25 | macs, params = tp.utils.count_ops_and_params(model, example_input) 26 | latency_mu, latency_std = estimate_latency(model, example_input) 27 | # print all with .2f 28 | print(f"[Iter 0] \tPruning ratio: 0.00, \tMACs: {macs/1e9:.2f} G, \tParams: {params/1e6:.2f} M, \tLatency: {latency_mu:.2f} ms +- {latency_std:.2f} ms") 29 | 30 | for iter in range(iterative_steps): 31 | pruner.step() 32 | _macs, _params = tp.utils.count_ops_and_params(model, example_input) 33 | latency_mu, latency_std = estimate_latency(model, example_input) 34 | current_pruning_ratio = 1 / iterative_steps * (iter + 1) 35 | print(f"[Iter {iter+1}] \tPruning ratio: {current_pruning_ratio:.2f}, \tMACs: {_macs/1e9:.2f} G, \tParams: {_params/1e6:.2f} M, \tLatency: {latency_mu:.2f} ms +- {latency_std:.2f} ms") 36 | 37 | # uncomment the following lines to profile 38 | #with torch.autograd.profiler.profile(use_cuda=True) as prof: 39 | # with torch.no_grad(): 40 | # for _ in range(50): 41 | # _ = model(example_input) 42 | #print(prof) 43 | 44 | def estimate_latency(model, example_inputs, repetitions=50): 45 | import numpy as np 46 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 47 | timings=np.zeros((repetitions,1)) 48 | 49 | for _ in range(5): 50 | _ = model(example_inputs) 51 | 52 | with torch.no_grad(): 53 | for rep in range(repetitions): 54 | starter.record() 55 | _ = model(example_inputs) 56 | ender.record() 57 | # WAIT FOR GPU SYNC 58 | torch.cuda.synchronize() 59 | curr_time = starter.elapsed_time(ender) 60 | timings[rep] = curr_time 61 | 62 | mean_syn = np.sum(timings) / repetitions 63 | std_syn = np.std(timings) 64 | return mean_syn, std_syn 65 | 66 | if __name__=='__main__': 67 | main() -------------------------------------------------------------------------------- /reproduce/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from . import models, utils 2 | -------------------------------------------------------------------------------- /reproduce/engine/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cifar, imagenet, graph -------------------------------------------------------------------------------- /reproduce/engine/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | densenet, 3 | googlenet, 4 | inceptionv4, 5 | inceptionv3, 6 | mobilenetv2, 7 | preactresnet, 8 | resnet_tiny, 9 | resnet, 10 | resnext, 11 | vgg, 12 | vit, 13 | swin, 14 | nasnet, 15 | senet, 16 | xception 17 | ) -------------------------------------------------------------------------------- /reproduce/engine/models/cifar/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """mobilenetv2 in pytorch 2 | [1] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 3 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 4 | https://arxiv.org/abs/1801.04381 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class LinearBottleNeck(nn.Module): 13 | 14 | def __init__(self, in_channels, out_channels, stride, t=6, num_classes=100): 15 | super().__init__() 16 | 17 | self.residual = nn.Sequential( 18 | nn.Conv2d(in_channels, in_channels * t, 1), 19 | nn.BatchNorm2d(in_channels * t), 20 | nn.ReLU6(inplace=True), 21 | 22 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 23 | nn.BatchNorm2d(in_channels * t), 24 | nn.ReLU6(inplace=True), 25 | 26 | nn.Conv2d(in_channels * t, out_channels, 1), 27 | nn.BatchNorm2d(out_channels) 28 | ) 29 | 30 | self.stride = stride 31 | self.in_channels = in_channels 32 | self.out_channels = out_channels 33 | 34 | def forward(self, x): 35 | 36 | residual = self.residual(x) 37 | 38 | if self.stride == 1 and self.in_channels == self.out_channels: 39 | residual += x 40 | 41 | return residual 42 | 43 | class MobileNetV2(nn.Module): 44 | 45 | def __init__(self, num_classes=100): 46 | super().__init__() 47 | 48 | self.pre = nn.Sequential( 49 | nn.Conv2d(3, 32, 1, padding=1), 50 | nn.BatchNorm2d(32), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 55 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 56 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 57 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 58 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 59 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 60 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 61 | 62 | self.conv1 = nn.Sequential( 63 | nn.Conv2d(320, 1280, 1), 64 | nn.BatchNorm2d(1280), 65 | nn.ReLU6(inplace=True) 66 | ) 67 | 68 | self.conv2 = nn.Conv2d(1280, num_classes, 1) 69 | 70 | def forward(self, x): 71 | x = self.pre(x) 72 | x = self.stage1(x) 73 | x = self.stage2(x) 74 | x = self.stage3(x) 75 | x = self.stage4(x) 76 | x = self.stage5(x) 77 | x = self.stage6(x) 78 | x = self.stage7(x) 79 | x = self.conv1(x) 80 | x = F.adaptive_avg_pool2d(x, 1) 81 | x = self.conv2(x) 82 | x = x.view(x.size(0), -1) 83 | 84 | return x 85 | 86 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 87 | 88 | layers = [] 89 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 90 | 91 | while repeat - 1: 92 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 93 | repeat -= 1 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def mobilenetv2(num_classes=100): 98 | return MobileNetV2(num_classes=num_classes) -------------------------------------------------------------------------------- /reproduce/engine/models/cifar/preactresnet.py: -------------------------------------------------------------------------------- 1 | """preactresnet in pytorch 2 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 3 | Identity Mappings in Deep Residual Networks 4 | https://arxiv.org/abs/1603.05027 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class PreActBasic(nn.Module): 12 | 13 | expansion = 1 14 | def __init__(self, in_channels, out_channels, stride): 15 | super().__init__() 16 | self.residual = nn.Sequential( 17 | nn.BatchNorm2d(in_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(out_channels, out_channels * PreActBasic.expansion, kernel_size=3, padding=1) 23 | ) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_channels != out_channels * PreActBasic.expansion: 27 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBasic.expansion, 1, stride=stride) 28 | 29 | def forward(self, x): 30 | 31 | res = self.residual(x) 32 | shortcut = self.shortcut(x) 33 | 34 | return res + shortcut 35 | 36 | 37 | class PreActBottleNeck(nn.Module): 38 | 39 | expansion = 4 40 | def __init__(self, in_channels, out_channels, stride): 41 | super().__init__() 42 | 43 | self.residual = nn.Sequential( 44 | nn.BatchNorm2d(in_channels), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(in_channels, out_channels, 1, stride=stride), 47 | 48 | nn.BatchNorm2d(out_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 51 | 52 | nn.BatchNorm2d(out_channels), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(out_channels, out_channels * PreActBottleNeck.expansion, 1) 55 | ) 56 | 57 | self.shortcut = nn.Sequential() 58 | 59 | if stride != 1 or in_channels != out_channels * PreActBottleNeck.expansion: 60 | self.shortcut = nn.Conv2d(in_channels, out_channels * PreActBottleNeck.expansion, 1, stride=stride) 61 | 62 | def forward(self, x): 63 | 64 | res = self.residual(x) 65 | shortcut = self.shortcut(x) 66 | 67 | return res + shortcut 68 | 69 | class PreActResNet(nn.Module): 70 | 71 | def __init__(self, block, num_block, num_classes=100): 72 | super().__init__() 73 | self.input_channels = 64 74 | 75 | self.pre = nn.Sequential( 76 | nn.Conv2d(3, 64, 3, padding=1), 77 | nn.BatchNorm2d(64), 78 | nn.ReLU(inplace=True) 79 | ) 80 | 81 | self.stage1 = self._make_layers(block, num_block[0], 64, 1) 82 | self.stage2 = self._make_layers(block, num_block[1], 128, 2) 83 | self.stage3 = self._make_layers(block, num_block[2], 256, 2) 84 | self.stage4 = self._make_layers(block, num_block[3], 512, 2) 85 | 86 | self.linear = nn.Linear(self.input_channels, num_classes) 87 | 88 | def _make_layers(self, block, block_num, out_channels, stride): 89 | layers = [] 90 | 91 | layers.append(block(self.input_channels, out_channels, stride)) 92 | self.input_channels = out_channels * block.expansion 93 | 94 | while block_num - 1: 95 | layers.append(block(self.input_channels, out_channels, 1)) 96 | self.input_channels = out_channels * block.expansion 97 | block_num -= 1 98 | 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.pre(x) 103 | 104 | x = self.stage1(x) 105 | x = self.stage2(x) 106 | x = self.stage3(x) 107 | x = self.stage4(x) 108 | 109 | x = F.adaptive_avg_pool2d(x, 1) 110 | x = x.view(x.size(0), -1) 111 | x = self.linear(x) 112 | 113 | return x 114 | 115 | def preactresnet18(num_classes): 116 | return PreActResNet(PreActBasic, [2, 2, 2, 2], num_classes=num_classes) 117 | 118 | def preactresnet34(num_classes): 119 | return PreActResNet(PreActBasic, [3, 4, 6, 3], num_classes=num_classes) 120 | 121 | def preactresnet50(num_classes): 122 | return PreActResNet(PreActBottleNeck, [3, 4, 6, 3], num_classes=num_classes) 123 | 124 | def preactresnet101(num_classes): 125 | return PreActResNet(PreActBottleNeck, [3, 4, 23, 3], num_classes=num_classes) 126 | 127 | def preactresnet152(num_classes): 128 | return PreActResNet(PreActBottleNeck, [3, 8, 36, 3], num_classes=num_classes) -------------------------------------------------------------------------------- /reproduce/engine/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | # ResNet for CIFAR (32x32) 2 | # 2019.07.24-Changed output of forward function 3 | # Huawei Technologies Co., Ltd. 4 | # taken from https://github.com/huawei-noah/Data-Efficient-Model-Compression/blob/master/DAFL/resnet.py 5 | # for comparison with DAFL 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x, return_features=False): 87 | x = self.conv1(x) 88 | x = self.bn1(x) 89 | out = F.relu(x) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.adaptive_avg_pool2d(out, (1,1)) 95 | feature = out.view(out.size(0), -1) 96 | out = self.linear(feature) 97 | 98 | if return_features: 99 | return out, feature 100 | else: 101 | return out 102 | 103 | def resnet18(num_classes=10): 104 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 105 | 106 | def resnet34(num_classes=10): 107 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 108 | 109 | def resnet50(num_classes=10): 110 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 111 | 112 | def resnet101(num_classes=10): 113 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 114 | 115 | def resnet152(num_classes=10): 116 | return ResNet(Bottleneck, [3,8,36,3], num_classes) -------------------------------------------------------------------------------- /reproduce/engine/models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | """resnext in pytorch 2 | [1] Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. 3 | Aggregated Residual Transformations for Deep Neural Networks 4 | https://arxiv.org/abs/1611.05431 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | #only implements ResNext bottleneck c 13 | 14 | 15 | #"""This strategy exposes a new dimension, which we call “cardinality” 16 | #(the size of the set of transformations), as an essential factor 17 | #in addition to the dimensions of depth and width.""" 18 | CARDINALITY = 32 19 | DEPTH = 4 20 | BASEWIDTH = 64 21 | 22 | #"""The grouped convolutional layer in Fig. 3(c) performs 32 groups 23 | #of convolutions whose input and output channels are 4-dimensional. 24 | #The grouped convolutional layer concatenates them as the outputs 25 | #of the layer.""" 26 | 27 | class ResNextBottleNeckC(nn.Module): 28 | 29 | def __init__(self, in_channels, out_channels, stride): 30 | super().__init__() 31 | 32 | C = CARDINALITY #How many groups a feature map was splitted into 33 | 34 | #"""We note that the input/output width of the template is fixed as 35 | #256-d (Fig. 3), We note that the input/output width of the template 36 | #is fixed as 256-d (Fig. 3), and all widths are dou- bled each time 37 | #when the feature map is subsampled (see Table 1).""" 38 | D = int(DEPTH * out_channels / BASEWIDTH) #number of channels per group 39 | self.split_transforms = nn.Sequential( 40 | nn.Conv2d(in_channels, C * D, kernel_size=1, groups=C, bias=False), 41 | nn.BatchNorm2d(C * D), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(C * D, C * D, kernel_size=3, stride=stride, groups=C, padding=1, bias=False), 44 | nn.BatchNorm2d(C * D), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(C * D, out_channels * 4, kernel_size=1, bias=False), 47 | nn.BatchNorm2d(out_channels * 4), 48 | ) 49 | 50 | self.shortcut = nn.Sequential() 51 | 52 | if stride != 1 or in_channels != out_channels * 4: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_channels, out_channels * 4, stride=stride, kernel_size=1, bias=False), 55 | nn.BatchNorm2d(out_channels * 4) 56 | ) 57 | 58 | def forward(self, x): 59 | return F.relu(self.split_transforms(x) + self.shortcut(x)) 60 | 61 | class ResNext(nn.Module): 62 | 63 | def __init__(self, block, num_blocks, num_classes=100): 64 | super().__init__() 65 | self.in_channels = 64 66 | 67 | self.conv1 = nn.Sequential( 68 | nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False), 69 | nn.BatchNorm2d(64), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | self.conv2 = self._make_layer(block, num_blocks[0], 64, 1) 74 | self.conv3 = self._make_layer(block, num_blocks[1], 128, 2) 75 | self.conv4 = self._make_layer(block, num_blocks[2], 256, 2) 76 | self.conv5 = self._make_layer(block, num_blocks[3], 512, 2) 77 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 78 | self.fc = nn.Linear(512 * 4, num_classes) 79 | 80 | def forward(self, x): 81 | x = self.conv1(x) 82 | x = self.conv2(x) 83 | x = self.conv3(x) 84 | x = self.conv4(x) 85 | x = self.conv5(x) 86 | x = self.avg(x) 87 | x = x.view(x.size(0), -1) 88 | x = self.fc(x) 89 | return x 90 | 91 | def _make_layer(self, block, num_block, out_channels, stride): 92 | """Building resnext block 93 | Args: 94 | block: block type(default resnext bottleneck c) 95 | num_block: number of blocks per layer 96 | out_channels: output channels per block 97 | stride: block stride 98 | Returns: 99 | a resnext layer 100 | """ 101 | strides = [stride] + [1] * (num_block - 1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_channels, out_channels, stride)) 105 | self.in_channels = out_channels * 4 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def resnext50(num_classes): 110 | """ return a resnext50(c32x4d) network 111 | """ 112 | return ResNext(ResNextBottleNeckC, [3, 4, 6, 3], num_classes=num_classes) 113 | 114 | def resnext101(num_classes): 115 | """ return a resnext101(c32x4d) network 116 | """ 117 | return ResNext(ResNextBottleNeckC, [3, 4, 23, 3], num_classes=num_classes) 118 | 119 | def resnext152(num_classes): 120 | """ return a resnext101(c32x4d) network 121 | """ 122 | return ResNext(ResNextBottleNeckC, [3, 4, 36, 3], num_classes=num_classes) 123 | -------------------------------------------------------------------------------- /reproduce/engine/models/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .dgcnn import dgcnn, pointnet -------------------------------------------------------------------------------- /reproduce/engine/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import ( 2 | resnet50, 3 | densenet121, 4 | mobilenet_v2, 5 | googlenet, 6 | inception_v3, 7 | squeezenet1_1, 8 | vgg16_bn, 9 | vgg19_bn, 10 | mnasnet1_0, 11 | alexnet, 12 | ) 13 | 14 | try: 15 | from torchvision.models import regnet_x_1_6gf 16 | from torchvision.models import resnext50_32x4d 17 | from .vision_transformer import vit_b_16 18 | except: 19 | regnet_x_1_6gf = None 20 | resnext50_32x4d = None 21 | vit_b_16 = None -------------------------------------------------------------------------------- /reproduce/engine/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from . import evaluator, metrics, utils, datasets 3 | from .utils import get_logger, MagnitudeRecover 4 | -------------------------------------------------------------------------------- /reproduce/engine/utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.spawn import import_main_path 2 | from .modelnet40 import ModelNet40 -------------------------------------------------------------------------------- /reproduce/engine/utils/datasets/modelnet40.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | https://github.com/WangYueFt/dgcnn/blob/master/pytorch/data.py 5 | 6 | @Author: Yue Wang 7 | @Contact: yuewangx@mit.edu 8 | @File: data.py 9 | @Time: 2018/10/13 6:21 PM 10 | """ 11 | 12 | 13 | import os 14 | import sys 15 | import glob 16 | import h5py 17 | import numpy as np 18 | from torch.utils.data import Dataset 19 | 20 | 21 | def download(path): 22 | DATA_DIR = path 23 | if not os.path.exists(DATA_DIR): 24 | os.makedirs(DATA_DIR) 25 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 26 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 27 | zipfile = os.path.basename(www) 28 | os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) 29 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 30 | os.system('mv %s %s' % (zipfile, DATA_DIR)) 31 | 32 | 33 | def load_data(data_root, partition, download_path='data'): 34 | download(download_path) 35 | 36 | DATA_DIR = data_root 37 | all_data = [] 38 | all_label = [] 39 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): 40 | f = h5py.File(h5_name) 41 | data = f['data'][:].astype('float32') 42 | label = f['label'][:].astype('int64') 43 | f.close() 44 | all_data.append(data) 45 | all_label.append(label) 46 | all_data = np.concatenate(all_data, axis=0) 47 | all_label = np.concatenate(all_label, axis=0) 48 | return all_data, all_label 49 | 50 | 51 | def translate_pointcloud(pointcloud): 52 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 53 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 54 | 55 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 56 | return translated_pointcloud 57 | 58 | 59 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 60 | N, C = pointcloud.shape 61 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 62 | return pointcloud 63 | 64 | 65 | class ModelNet40(Dataset): 66 | def __init__(self, data_root, num_points, partition='train'): 67 | self.data, self.label = load_data(data_root, partition) 68 | self.num_points = num_points 69 | self.partition = partition 70 | 71 | def __getitem__(self, item): 72 | pointcloud = self.data[item][:self.num_points] 73 | label = self.label[item] 74 | if self.partition == 'train': 75 | pointcloud = translate_pointcloud(pointcloud) 76 | np.random.shuffle(pointcloud) 77 | pointcloud = np.transpose(pointcloud, (1, 0)) 78 | label = np.squeeze(label) 79 | return pointcloud, label 80 | 81 | def __len__(self): 82 | return self.data.shape[0] 83 | 84 | 85 | if __name__ == '__main__': 86 | train = ModelNet40(1024) 87 | test = ModelNet40(1024, 'test') 88 | for data, label in train: 89 | print(data.shape) 90 | print(label.shape) -------------------------------------------------------------------------------- /reproduce/engine/utils/evaluator.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | from . import metrics 5 | 6 | class Evaluator(object): 7 | def __init__(self, metric, dataloader): 8 | self.dataloader = dataloader 9 | self.metric = metric 10 | 11 | def eval(self, model, device=None, progress=False): 12 | self.metric.reset() 13 | with torch.no_grad(): 14 | for i, (inputs, targets) in enumerate( tqdm(self.dataloader, disable=not progress) ): 15 | inputs, targets = inputs.cuda(), targets.cuda() 16 | outputs = model( inputs ) 17 | self.metric.update(outputs, targets) 18 | return self.metric.get_results() 19 | 20 | def __call__(self, *args, **kwargs): 21 | return self.eval(*args, **kwargs) 22 | 23 | def classification_evaluator(dataloader): 24 | metric = metrics.MetricCompose({ 25 | 'Acc': metrics.TopkAccuracy(), 26 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum')) 27 | }) 28 | return Evaluator( metric, dataloader=dataloader) 29 | 30 | def segmentation_evaluator(dataloader, num_classes, ignore_idx=255): 31 | cm = metrics.ConfusionMatrix(num_classes, ignore_idx=ignore_idx) 32 | metric = metrics.MetricCompose({ 33 | 'mIoU': metrics.mIoU(cm), 34 | 'Acc': metrics.Accuracy(), 35 | 'Loss': metrics.RunningLoss(torch.nn.CrossEntropyLoss(reduction='sum')) 36 | }) 37 | return Evaluator( metric, dataloader=dataloader) -------------------------------------------------------------------------------- /reproduce/engine/utils/imagenet_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import presets, sampler, transforms, utils -------------------------------------------------------------------------------- /reproduce/engine/utils/imagenet_utils/presets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import autoaugment, transforms 3 | from torchvision.transforms.functional import InterpolationMode 4 | 5 | 6 | class ClassificationPresetTrain: 7 | def __init__( 8 | self, 9 | *, 10 | crop_size, 11 | mean=(0.485, 0.456, 0.406), 12 | std=(0.229, 0.224, 0.225), 13 | interpolation=InterpolationMode.BILINEAR, 14 | hflip_prob=0.5, 15 | auto_augment_policy=None, 16 | ra_magnitude=9, 17 | augmix_severity=3, 18 | random_erase_prob=0.0, 19 | ): 20 | trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 21 | if hflip_prob > 0: 22 | trans.append(transforms.RandomHorizontalFlip(hflip_prob)) 23 | if auto_augment_policy is not None: 24 | if auto_augment_policy == "ra": 25 | trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) 26 | elif auto_augment_policy == "ta_wide": 27 | trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) 28 | elif auto_augment_policy == "augmix": 29 | trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) 30 | else: 31 | aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) 32 | trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) 33 | trans.extend( 34 | [ 35 | transforms.PILToTensor(), 36 | transforms.ConvertImageDtype(torch.float), 37 | transforms.Normalize(mean=mean, std=std), 38 | ] 39 | ) 40 | if random_erase_prob > 0: 41 | trans.append(transforms.RandomErasing(p=random_erase_prob)) 42 | 43 | self.transforms = transforms.Compose(trans) 44 | 45 | def __call__(self, img): 46 | return self.transforms(img) 47 | 48 | 49 | class ClassificationPresetEval: 50 | def __init__( 51 | self, 52 | *, 53 | crop_size, 54 | resize_size=256, 55 | mean=(0.485, 0.456, 0.406), 56 | std=(0.229, 0.224, 0.225), 57 | interpolation=InterpolationMode.BILINEAR, 58 | ): 59 | 60 | self.transforms = transforms.Compose( 61 | [ 62 | transforms.Resize(resize_size, interpolation=interpolation), 63 | transforms.CenterCrop(crop_size), 64 | transforms.PILToTensor(), 65 | transforms.ConvertImageDtype(torch.float), 66 | transforms.Normalize(mean=mean, std=std), 67 | ] 68 | ) 69 | 70 | def __call__(self, img): 71 | return self.transforms(img) -------------------------------------------------------------------------------- /reproduce/engine/utils/imagenet_utils/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class RASampler(torch.utils.data.Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset for distributed, 9 | with repeated augmentation. 10 | It ensures that different each augmented version of a sample will be visible to a 11 | different process (GPU). 12 | Heavily based on 'torch.utils.data.DistributedSampler'. 13 | This is borrowed from the DeiT Repo: 14 | https://github.com/facebookresearch/deit/blob/main/samplers.py 15 | """ 16 | 17 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError("Requires distributed package to be available!") 21 | num_replicas = dist.get_world_size() 22 | if rank is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available!") 25 | rank = dist.get_rank() 26 | self.dataset = dataset 27 | self.num_replicas = num_replicas 28 | self.rank = rank 29 | self.epoch = 0 30 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) 31 | self.total_size = self.num_samples * self.num_replicas 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | self.seed = seed 35 | self.repetitions = repetitions 36 | 37 | def __iter__(self): 38 | if self.shuffle: 39 | # Deterministically shuffle based on epoch 40 | g = torch.Generator() 41 | g.manual_seed(self.seed + self.epoch) 42 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 43 | else: 44 | indices = list(range(len(self.dataset))) 45 | 46 | # Add extra samples to make it evenly divisible 47 | indices = [ele for ele in indices for i in range(self.repetitions)] 48 | indices += indices[: (self.total_size - len(indices))] 49 | assert len(indices) == self.total_size 50 | 51 | # Subsample 52 | indices = indices[self.rank : self.total_size : self.num_replicas] 53 | assert len(indices) == self.num_samples 54 | 55 | return iter(indices[: self.num_selected_samples]) 56 | 57 | def __len__(self): 58 | return self.num_selected_samples 59 | 60 | def set_epoch(self, epoch): 61 | self.epoch = epoch -------------------------------------------------------------------------------- /reproduce/engine/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Callable 4 | 5 | __all__=['Accuracy', 'TopkAccuracy'] 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Callable, Union, Any, Mapping, Sequence 9 | import numbers 10 | import numpy as np 11 | 12 | class Metric(ABC): 13 | @abstractmethod 14 | def update(self, pred, target): 15 | """ Overridden by subclasses """ 16 | raise NotImplementedError() 17 | 18 | @abstractmethod 19 | def get_results(self): 20 | """ Overridden by subclasses """ 21 | raise NotImplementedError() 22 | 23 | @abstractmethod 24 | def reset(self): 25 | """ Overridden by subclasses """ 26 | raise NotImplementedError() 27 | 28 | class MetricCompose(dict): 29 | def __init__(self, metric_dict: Mapping): 30 | self._metric_dict = metric_dict 31 | 32 | @property 33 | def metrics(self): 34 | return self._metric_dict 35 | 36 | @torch.no_grad() 37 | def update(self, outputs, targets): 38 | for key, metric in self._metric_dict.items(): 39 | if isinstance(metric, Metric): 40 | metric.update(outputs, targets) 41 | 42 | def get_results(self): 43 | results = {} 44 | for key, metric in self._metric_dict.items(): 45 | if isinstance(metric, Metric): 46 | results[key] = metric.get_results() 47 | return results 48 | 49 | def reset(self): 50 | for key, metric in self._metric_dict.items(): 51 | if isinstance(metric, Metric): 52 | metric.reset() 53 | 54 | def __getitem__(self, name): 55 | return self._metric_dict[name] 56 | 57 | class Accuracy(Metric): 58 | def __init__(self): 59 | self.reset() 60 | 61 | @torch.no_grad() 62 | def update(self, outputs, targets): 63 | outputs = outputs.max(1)[1] 64 | self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() 65 | self._cnt += torch.numel( targets ) 66 | 67 | def get_results(self): 68 | return (self._correct / self._cnt * 100.).detach().cpu() 69 | 70 | def reset(self): 71 | self._correct = self._cnt = 0.0 72 | 73 | 74 | class TopkAccuracy(Metric): 75 | def __init__(self, topk=(1, 5)): 76 | self._topk = topk 77 | self.reset() 78 | 79 | @torch.no_grad() 80 | def update(self, outputs, targets): 81 | for k in self._topk: 82 | _, topk_outputs = outputs.topk(k, dim=1, largest=True, sorted=True) 83 | correct = topk_outputs.eq( targets.view(-1, 1).expand_as(topk_outputs) ) 84 | self._correct[k] += correct[:, :k].view(-1).float().sum(0).item() 85 | self._cnt += len(targets) 86 | 87 | def get_results(self): 88 | return tuple( self._correct[k] / self._cnt * 100. for k in self._topk ) 89 | 90 | def reset(self): 91 | self._correct = {k: 0 for k in self._topk} 92 | self._cnt = 0.0 93 | 94 | class RunningLoss(Metric): 95 | def __init__(self, loss_fn, is_batch_average=False): 96 | self.reset() 97 | self.loss_fn = loss_fn 98 | self.is_batch_average = is_batch_average 99 | 100 | @torch.no_grad() 101 | def update(self, outputs, targets): 102 | self._accum_loss += self.loss_fn(outputs, targets) 103 | if self.is_batch_average: 104 | self._cnt += 1 105 | else: 106 | self._cnt += len(outputs) 107 | 108 | def get_results(self): 109 | return (self._accum_loss / self._cnt).detach().cpu() 110 | 111 | def reset(self): 112 | self._accum_loss = self._cnt = 0.0 -------------------------------------------------------------------------------- /reproduce/engine/utils/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import logging 3 | import os, sys 4 | from termcolor import colored 5 | import copy 6 | import numpy as np 7 | import torch 8 | 9 | class MagnitudeRecover(): 10 | def __init__(self, model, reg=1e-3): 11 | self.rec = {} 12 | self.reg = reg 13 | self.cnt = 0 14 | with torch.no_grad(): 15 | for name, p in model.named_parameters(): 16 | norm = p.pow(2).mean() 17 | self.rec[name] = norm 18 | 19 | def regularize(self, model): 20 | with torch.no_grad(): 21 | for name, p in model.named_parameters(): 22 | if name in self.rec: 23 | target_norm = self.rec[name] 24 | if p.data.pow(2).mean() > target_norm: 25 | self.rec.pop(name) 26 | continue 27 | p.grad.data+= -self.reg * p.data 28 | if self.cnt%1000==0: 29 | print(name, p.pow(2).mean(), target_norm) 30 | self.cnt+=1 31 | 32 | def flatten_dict(dic): 33 | flattned = dict() 34 | def _flatten(prefix, d): 35 | for k, v in d.items(): 36 | if isinstance(v, dict): 37 | if prefix is None: 38 | _flatten( k, v ) 39 | else: 40 | _flatten( prefix+'/%s'%k, v ) 41 | else: 42 | if prefix is None: 43 | flattned[k] = v 44 | else: 45 | flattned[ prefix+'/%s'%k ] = v 46 | 47 | _flatten(None, dic) 48 | return flattned 49 | 50 | class _ColorfulFormatter(logging.Formatter): 51 | def __init__(self, *args, **kwargs): 52 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 53 | 54 | def formatMessage(self, record): 55 | log = super(_ColorfulFormatter, self).formatMessage(record) 56 | 57 | if record.levelno == logging.WARNING: 58 | prefix = colored("WARNING", "yellow", attrs=["blink"]) 59 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 60 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 61 | else: 62 | return log 63 | 64 | return prefix + " " + log 65 | 66 | def get_logger(name='train', output=None, color=True): 67 | logger = logging.getLogger(name) 68 | logger.setLevel(logging.DEBUG) 69 | logger.propagate = False 70 | 71 | # STDOUT 72 | stdout_handler = logging.StreamHandler( stream=sys.stdout ) 73 | stdout_handler.setLevel( logging.DEBUG ) 74 | 75 | plain_formatter = logging.Formatter( 76 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) 77 | if color: 78 | formatter = _ColorfulFormatter( 79 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 80 | datefmt="%m/%d %H:%M:%S") 81 | else: 82 | formatter = plain_formatter 83 | stdout_handler.setFormatter(formatter) 84 | 85 | logger.addHandler(stdout_handler) 86 | 87 | # FILE 88 | if output is not None: 89 | if output.endswith('.txt') or output.endswith('.log'): 90 | os.makedirs(os.path.dirname(output), exist_ok=True) 91 | filename = output 92 | else: 93 | os.makedirs(output, exist_ok=True) 94 | filename = os.path.join(output, "log.txt") 95 | file_handler = logging.FileHandler(filename) 96 | file_handler.setFormatter(plain_formatter) 97 | file_handler.setLevel(logging.DEBUG) 98 | logger.addHandler(file_handler) 99 | return logger -------------------------------------------------------------------------------- /reproduce/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.1 2 | termcolor 3 | einops 4 | tqdm 5 | torchvision>=0.9.1 6 | tensorboard 7 | 8 | # for main_gat.py 9 | h5py 10 | networkx 11 | igraph 12 | -------------------------------------------------------------------------------- /reproduce/scripts/pretrain/cifar_pretrain.sh: -------------------------------------------------------------------------------- 1 | # CIFAR-10 2 | python main.py --mode pretrain --dataset cifar10 --model resnet56 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 3 | python main.py --mode pretrain --dataset cifar10 --model vgg19 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 4 | 5 | # CIFAR-100 6 | python main.py --mode pretrain --dataset cifar100 --model resnet56 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 7 | python main.py --mode pretrain --dataset cifar100 --model vgg19 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 8 | python main.py --mode pretrain --dataset cifar100 --model densenet121 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 9 | python main.py --mode pretrain --dataset cifar100 --model googlenet --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 10 | python main.py --mode pretrain --dataset cifar100 --model mobilenetv2 --lr 0.05 --total-epochs 200 --lr-decay-milestones 120,150,180 11 | python main.py --mode pretrain --dataset cifar100 --model resnext50 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_group_norm.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.25 --global_pruning 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.25 --global_pruning 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.25 --global_pruning 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.25 --global_pruning 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.25 --global_pruning 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.25 --global_pruning 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.25 --global_pruning 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.25 --global_pruning 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.25 --global_pruning 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.25 --global_pruning 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.25 --global_pruning 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.25 --global_pruning 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.25 --global_pruning 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.25 --global_pruning 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.25 --global_pruning 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.25 --global_pruning 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.25 --global_pruning 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.25 --global_pruning 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.25 --global_pruning 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.25 --global_pruning 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_group_sl.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 15 | 16 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 17 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 18 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 19 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 20 | # 21 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 22 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 23 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 24 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_group_sl_p2.sh: -------------------------------------------------------------------------------- 1 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 1e-3 --global_pruning 2 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 3 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 4 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore# 5 | 6 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 1e-3 --global_pruning 7 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 8 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 9 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore # 10 | 11 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 1e-3 --global_pruning 12 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 13 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 14 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 2e-4 --global_pruning 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 2e-4 --global_pruning --sl_restore 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 1.5 --soft_rank 0.25 --reg 1e-3 --global_pruning 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 3.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 6.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_sl --speed_up 12.0 --soft_rank 0.25 --reg 1e-3 --global_pruning --sl_restore 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_l1.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning --soft_rank 0.25 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning --soft_rank 0.25 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning --soft_rank 0.25 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning --soft_rank 0.25 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning --soft_rank 0.25 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning --soft_rank 0.25 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning --soft_rank 0.25 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning --soft_rank 0.25 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning --soft_rank 0.25 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning --soft_rank 0.25 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning --soft_rank 0.25 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning --soft_rank 0.25 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning --soft_rank 0.25 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning --soft_rank 0.25 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning --soft_rank 0.25 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning --soft_rank 0.25 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning --soft_rank 0.25 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning --soft_rank 0.25 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning --soft_rank 0.25 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning --soft_rank 0.25 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_l1_group_conv.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 --global_pruning --soft_rank 0.25 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 --global_pruning --soft_rank 0.25 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 --global_pruning --soft_rank 0.25 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 --global_pruning --soft_rank 0.25 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 --global_pruning --soft_rank 0.25 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 --global_pruning --soft_rank 0.25 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 --global_pruning --soft_rank 0.25 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 --global_pruning --soft_rank 0.25 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 --global_pruning --soft_rank 0.25 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 --global_pruning --soft_rank 0.25 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 --global_pruning --soft_rank 0.25 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 --global_pruning --soft_rank 0.25 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 --global_pruning --soft_rank 0.25 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 --global_pruning --soft_rank 0.25 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 --global_pruning --soft_rank 0.25 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 --global_pruning --soft_rank 0.25 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 --global_pruning --soft_rank 0.25 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 --global_pruning --soft_rank 0.25 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 --global_pruning --soft_rank 0.25 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 --global_pruning --soft_rank 0.25 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_lamp.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method lamp --speed_up 1.5 --global 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method lamp --speed_up 3.0 --global 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method lamp --speed_up 6.0 --global 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method lamp --speed_up 12.0 --global 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 1.5 --global 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 3.0 --global 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 6.0 --global 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 12.0 --global 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 1.5 --global 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 3.0 --global 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 6.0 --global 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 12.0 --global 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 1.5 --global 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 3.0 --global 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 6.0 --global 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 12.0 --global 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 1.5 --global 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 3.0 --global 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 6.0 --global 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 12.0 --global 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_rand.sh: -------------------------------------------------------------------------------- 1 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 1.5 --global 2 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 3.0 --global 3 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 6.0 --global 4 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 12.0 --global 5 | # 6 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 1.5 --global 7 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 3.0 --global 8 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 6.0 --global 9 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 12.0 --global 10 | # 11 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 1.5 --global 12 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 3.0 --global 13 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 6.0 --global 14 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 12.0 --global 15 | 16 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 1.5 --global 17 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 3.0 --global 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 6.0 --global 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 12.0 --global 20 | 21 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 1.5 --global 22 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 3.0 --global 23 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 6.0 --global 24 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 12.0 --global 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/global_slim.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method slim --speed_up 1.5 --reg 1e-4 --global 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method slim --speed_up 3.0 --reg 1e-4 --global 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method slim --speed_up 6.0 --reg 1e-4 --global 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method slim --speed_up 12.0 --reg 1e-4 --global 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 1.5 --reg 1e-4 --global 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 3.0 --reg 1e-4 --global 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 6.0 --reg 1e-4 --global 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method lamp --speed_up 12.0 --reg 1e-4 --global 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 1.5 --reg 1e-4 --global 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 3.0 --reg 1e-4 --global 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 6.0 --reg 1e-4 --global 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method lamp --speed_up 12.0 --reg 1e-4 --global 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 1.5 --reg 1e-4 --global 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 3.0 --reg 1e-4 --global 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 6.0 --reg 1e-4 --global 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method lamp --speed_up 12.0 --reg 1e-4 --global 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 1.5 --reg 1e-4 --global 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 3.0 --reg 1e-4 --global 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 6.0 --reg 1e-4 --global 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method lamp --speed_up 12.0 --reg 1e-4 --global 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/local_group_norm.sh: -------------------------------------------------------------------------------- 1 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.5 2 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.5 3 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.5 4 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.5 5 | # 6 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.5 7 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.5 8 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.5 9 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.5 10 | # 11 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.0 12 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.0 13 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.0 14 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.0 15 | 16 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.5 17 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.5 18 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.5 19 | #python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.5 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 1.5 --soft_rank 0.5 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 3.0 --soft_rank 0.5 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 6.0 --soft_rank 0.5 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method group_norm --speed_up 12.0 --soft_rank 0.5 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/local_l1.sh: -------------------------------------------------------------------------------- 1 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning 2 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning 3 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning 4 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning 5 | # 6 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning 7 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning 8 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning 9 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning 10 | # 11 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning 12 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning 13 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning 14 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 1.5 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 3.0 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 6.0 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1 --speed_up 12.0 20 | 21 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 1.5 --global_pruning 22 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 3.0 --global_pruning 23 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 6.0 --global_pruning 24 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1 --speed_up 12.0 --global_pruning 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/local_l1_group_conv.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 2 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 3 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 4 | python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 5 | 6 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 7 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 8 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 9 | python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 10 | 11 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 12 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 13 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 14 | python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 20 | 21 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 1.5 22 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 3.0 23 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 6.0 24 | python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method l1_group_conv --speed_up 12.0 25 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ablation/local_rand.sh: -------------------------------------------------------------------------------- 1 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 1.5 --global 2 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 3.0 --global 3 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 6.0 --global 4 | #python main.py --mode prune --model resnet56 --batch_size 128 --restore run/cifar100/pretrain/cifar100_resnet56.pth --dataset cifar100 --method random --speed_up 12.0 --global 5 | # 6 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 1.5 --global 7 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 3.0 --global 8 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 6.0 --global 9 | #python main.py --mode prune --model vgg19 --batch_size 128 --restore run/cifar100/pretrain/cifar100_vgg19.pth --dataset cifar100 --method random --speed_up 12.0 --global 10 | # 11 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 1.5 --global 12 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 3.0 --global 13 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 6.0 --global 14 | #python main.py --mode prune --model densenet121 --batch_size 128 --restore run/cifar100/pretrain/cifar100_densenet121.pth --dataset cifar100 --method random --speed_up 12.0 --global 15 | 16 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 1.5 17 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 3.0 18 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 6.0 19 | python main.py --mode prune --model mobilenetv2 --batch_size 128 --restore run/cifar100/pretrain/cifar100_mobilenetv2.pth --dataset cifar100 --method random --speed_up 12.0 20 | 21 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 1.5 --global 22 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 3.0 --global 23 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 6.0 --global 24 | #python main.py --mode prune --model googlenet --batch_size 128 --restore run/cifar100/pretrain/cifar100_googlenet.pth --dataset cifar100 --method random --speed_up 12.0 --global -------------------------------------------------------------------------------- /reproduce/scripts/prune/cifar/bn_pruner.sh: -------------------------------------------------------------------------------- 1 | # python main.py --mode pretrain --dataset cifar10 --model resnet56 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 2 | 3 | python main.py --mode prune --model resnet56 --batch-size 128 --restore run/cifar10/pretrain/cifar10_resnet56.pth --dataset cifar10 --method slim --speed-up 2.11 --global-pruning --reg 1e-5 -------------------------------------------------------------------------------- /reproduce/scripts/prune/cifar/group_pruner.sh: -------------------------------------------------------------------------------- 1 | # python main.py --mode pretrain --dataset cifar10 --model resnet56 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 2 | 3 | python main.py --mode prune --model resnet56 --batch-size 128 --restore run/cifar10/pretrain/cifar10_resnet56.pth --dataset cifar10 --method group_sl --speed-up 2.11 --global-pruning --reg 5e-4 4 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/cifar/l1_norm_pruner.sh: -------------------------------------------------------------------------------- 1 | # python main.py --mode pretrain --dataset cifar10 --model resnet56 --lr 0.1 --total-epochs 200 --lr-decay-milestones 120,150,180 2 | 3 | python main.py --mode prune --model resnet56 --method l1 --batch-size 128 --restore run/cifar10/pretrain/cifar10_resnet56.pth --dataset cifar10 --speed-up 2.11 --global-pruning --reg 1e-5 -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/densenet_gsl.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=4 --master_port 18119 --use_env main_imagenet.py --model densenet121 --epochs 90 --batch-size 512 --lr-step-size 30 --lr 0.08 --prune --method group_sl --global-pruning --soft-keeping-ratio 0.25 --pretrained --output-dir run/imagenet/densenet121_sl --target-flops 1.38 --sl-epochs 30 --sl-lr 0.08 --sl-lr-step-size 10 --cache-dataset --reg 1e-4 --print-freq 100 --workers 16 --amp -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/mobilenetv2_group_norm.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18101 --use_env main_imagenet.py --model mobilenet_v2 --pretrained --epochs 300 --batch-size 256 --lr 0.045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98 --prune --cache-dataset --method group_norm --global_pruning --soft_keeping_ratio 0.5 --pretrained --target_flops 0.15 --output-dir run/imagenet/mobilenetv2_gnorm --amp -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/mobilenetv2_group_sl.sh: -------------------------------------------------------------------------------- 1 | # 8 GPUs 2 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18122 --use_env main_imagenet.py --model mobilenet_v2 --pretrained --epochs 150 --batch-size 32 --lr 0.0045 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98 --prune --cache-dataset --method group_sl --global-pruning --pretrained --target-flops 0.15 --output-dir run/imagenet/mobilenet_gsl --reg 1e-4 --sl-epochs 150 --sl-lr 0.0045 --sl-lr-step-size 1 --print-freq 100 --amp --max-pruning-ratio 0.7 3 | 4 | # 4 GPUs, 2048 bz, linear lr scaling 5 | CUDA_VISIBLE_DEVICES=4,5,6,7 OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=4 --master_port 18122 --use_env main_imagenet.py --model mobilenet_v2 --pretrained --epochs 150 --batch-size 512 --lr 0.036 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98 --prune --cache-dataset --method group_sl --global-pruning --pretrained --target-flops 0.15 --output-dir run/imagenet/mobilenet_gsl --reg 1e-4 --sl-epochs 150 --sl-lr 0.036 --sl-lr-step-size 1 --print-freq 100 --amp --max-pruning-ratio 0.7 -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/next50_group_norm.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18113 --use_env main_imagenet.py --model resnext50_32x4d --epochs 100 --batch-size 256 --lr 0.08 --prune --cache-dataset --method group_norm --soft-keeping-ratio 0.5 --pretrained --output-dir run/imagenet/next50_gnorm --target-flops 2.11 --global-pruning --print-freq 100 --workers 8 -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/next50_group_sl.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=4 --master_port 18113 --use_env main_imagenet.py --model resnext50_32x4d --epochs 100 --batch-size 512 --lr 0.08 --prune --cache-dataset --method group_sl --soft-keeping-ratio 0.25 --pretrained --output-dir run/imagenet/next50_gsl --target-flops 2.11 --global-pruning --print-freq 100 --workers 8 -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/regnet_group_norm.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port 18208 --use_env main_imagenet.py --model regnet_x_1_6gf --epochs 100 --batch-size 256 --wd 0.00005 --lr 0.08 --lr-scheduler=cosineannealinglr --lr-warmup-method=linear --lr-warmup-epochs=5 --lr-warmup-decay=0.1 --prune --cache-dataset --method group_norm --soft-keeping-ratio 0.6 --pretrained --output-dir run/imagenet/regnet_x_1_6gf_gnorm --target-flops 0.8 --global-pruning --workers 16 --print-freq 100 --amp -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/regnet_group_sl.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=4 --master_port 18211 --use_env main_imagenet.py --model regnet_x_1_6gf --epochs 60 --batch-size 256 --wd 0.00005 --lr 0.04 --lr-scheduler=cosineannealinglr --lr-warmup-method=linear --prune --cache-dataset --method group_sl --global-pruning --pretrained --output-dir run/imagenet/regnet_x_1_6gf_gsl --target-flops 0.8 --max-pruning-ratio 0.7 --sl-epochs 60 --sl-lr 0.04 --cache-dataset --reg 1e-4 --amp --data-path ~/Datasets/imagenet &> run/imagenet/regnet_gsl.log -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/resnet50_group_norm.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18113 --use_env main_imagenet.py --model resnet50 --epochs 90 --batch-size 256 --lr 0.08 --prune --cache-dataset --method group_norm --soft-keeping-ratio 0.5 --pretrained --output-dir run/imagenet/resnet50_gnorm --target-flops 2.04 --global-pruning --print-freq 100 --workers 8 -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/resnet50_group_sl.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18119 --use_env main_imagenet.py --model resnet50 --epochs 90 --batch-size 256 --lr-step-size 30 --lr 0.08 --prune --method group_sl --global-pruning --soft-keeping-ratio 0.5 --pretrained --output-dir run/imagenet/resnet50_sl --target-flops 2.04 --sl-epochs 30 --sl-lr 0.08 --sl-lr-step-size 10 --cache-dataset --reg 1e-4 --print-freq 100 --workers 16 --amp -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/vgg_group_norm copy.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=2 --master_port 18101 --use_env main_imagenet.py --model vgg19_nn --pretrained --epochs 90 --batch-size 64 --lr 0.01 --wd 0.00004 --lr-step-size 1 --lr-gamma 0.98 --prune --cache-dataset --method group_norm --global_pruning --soft_keeping_ratio 0.5 --pretrained --target_flops 2.0 --output-dir run/imagenet_log/mob_ckpt -------------------------------------------------------------------------------- /reproduce/scripts/prune/imagenet/vit_group_norm.sh: -------------------------------------------------------------------------------- 1 | OMP_NUM_THREADS=4 python -m torch.distributed.launch --nproc_per_node=8 --master_port 18101 --use_env main_imagenet.py --model vit_b_16 --epochs 300 --batch-size 512 --opt adamw --lr 0.003 --wd 0.3 --lr-scheduler cosineannealinglr --lr-warmup-method linear --lr-warmup-epochs 30 --lr-warmup-decay 0.033 --amp --label-smoothing 0.11 --mixup-alpha 0.2 --auto-augment ra --clip-grad-norm 1 --ra-sampler --cutmix-alpha 1.0 --model-ema --prune --cache-dataset --method group_norm --global_pruning --soft_keeping_ratio 0.5 --pretrained --target_flops 10 --output-dir run/imagenet/vit_norm --amp &> run/imagenet/vit_group_norm.log -------------------------------------------------------------------------------- /reproduce/scripts/prune/modelnet40/global_group_norm.sh: -------------------------------------------------------------------------------- 1 | python main.py --mode pretrain --dataset modelnet40 --model dgcnn --lr 0.1 --total-epochs 250 --lr-decay-milestones 100,150,200 --batch-size 32 --output-dir run/modelnet40 2 | 3 | python main.py --mode prune --model dgcnn --restore run/modelnet40/pretrain/modelnet40_dgcnn.pth --dataset modelnet40 --method group_norm --speed-up 4.0 --soft-rank 0.5 --global --lr 0.01 --total-epochs 100 --lr-decay-milestones 50,80 --batch-size 32 --output-dir run/modelnet40 4 | -------------------------------------------------------------------------------- /reproduce/scripts/prune/ppi/global_group_norm.sh: -------------------------------------------------------------------------------- 1 | python main_gat.py --prune --restore run/ppi/pretrain/best_gat.pth --method group_norm --speed-up 8.0 --soft-rank 0.0 --global 2 | 3 | python main_gat.py --prune --restore run/ppi/pretrain/best_gat.pth --method group_sl --speed-up 8.0 --soft-rank 0.5 --global-pruning --reg 1e-4 -------------------------------------------------------------------------------- /reproduce/tools/draw.py: -------------------------------------------------------------------------------- 1 | #load all 2 | import torch 3 | import matplotlib.pyplot as plt 4 | params_record, loss_record, acc_record = torch.load('record.pth') 5 | 6 | # change the plt style 7 | plt.style.use('bmh') 8 | 9 | color_dict = { 10 | 'Group Hessian': "C0", 11 | 'Single-layer Hessian': "C0", 12 | 13 | 'Random': "C1", 14 | 15 | 'Group L1': "C2", 16 | 'Single-layer L1': "C2", 17 | 18 | 'Group Slimming': "C3", 19 | 'Single-layer Slimming': "C3", 20 | 21 | 'Group Taylor': "C4", 22 | 'Single-layer Taylor': "C4" 23 | } 24 | 25 | plt.figure() 26 | for imp_name in params_record.keys(): 27 | # use dash if 'single-layer' is in the name, use the same color as the group version 28 | plt.plot(params_record[imp_name], acc_record[imp_name], label=imp_name, linestyle='--' if 'Single-layer' in imp_name else '-', color=color_dict[imp_name]) 29 | plt.xlabel('#Params') 30 | plt.ylabel('Accuracy') 31 | plt.legend() 32 | # remove white space 33 | plt.tight_layout() 34 | plt.savefig(f'params_acc_final.png') 35 | 36 | plt.figure() 37 | for imp_name in params_record.keys(): 38 | plt.plot(params_record[imp_name], loss_record[imp_name], label=imp_name, linestyle='--' if 'Single-layer' in imp_name else '-', color=color_dict[imp_name]) 39 | plt.xlabel('#Params') 40 | plt.ylabel('Loss') 41 | plt.legend() 42 | # remove white space 43 | plt.tight_layout() 44 | plt.savefig(f'params_loss_final.png') 45 | 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.12 2 | numpy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torch-pruning", 8 | version="v1.5.2", 9 | author="Gongfan Fang", 10 | author_email="gongfan@u.nus.edu", 11 | description="Towards Any Structural Pruning", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/VainF/Torch-Pruning", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=['torch', 'numpy'], 22 | python_requires='>=3.6', 23 | ) 24 | -------------------------------------------------------------------------------- /tests/graph_drawing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys, os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | import torch_pruning as tp 5 | from torchvision.models import densenet121, resnet18, googlenet, vgg16_bn 6 | import torch.nn as nn 7 | from torchvision.models.vision_transformer import VisionTransformer, vit_b_16 8 | import matplotlib.pyplot as plt 9 | 10 | class Net(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) 14 | self.conv2 = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1) 15 | self.conv3 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1) 16 | 17 | def forward(self, x): 18 | x = self.conv1(x) 19 | skip = x 20 | x = self.conv2(x) 21 | x = self.conv3(x) 22 | x += skip 23 | return x 24 | 25 | model = densenet121() #densenet121() #resnet18() #densenet121() # Net() 26 | 27 | unwrapped_parameters = None 28 | round_to = None 29 | if isinstance( 30 | model, VisionTransformer 31 | ): # Torchvision uses a static hidden_dim for reshape 32 | round_to = model.encoder.layers[0].num_heads 33 | unwrapped_parameters = [model.class_token, model.encoder.pos_embedding] 34 | 35 | DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1, 3, 224, 224), unwrapped_parameters=unwrapped_parameters) 36 | tp.utils.draw_dependency_graph(DG, save_as='draw_dep_graph.png', title=None) 37 | tp.utils.draw_groups(DG, save_as='draw_groups.png', title=None) 38 | tp.utils.draw_computational_graph(DG, save_as='draw_comp_graph.png', title=None) 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | from torchvision.models import resnet18 5 | import torch 6 | import torch_pruning as tp 7 | 8 | device = torch.device('cuda:0') 9 | model = resnet18().eval().to(device) 10 | example_inputs = torch.randn(32, 3, 224, 224).to(device) 11 | 12 | # Test forward in eval mode 13 | print("====== Forward (Inferece with torch.no_grad) ======") 14 | with torch.no_grad(): 15 | laterncy_mu, latency_std= tp.utils.benchmark.measure_latency(model, example_inputs, repeat=300) 16 | print('laterncy: {:.4f} +/- {:.4f} ms'.format(laterncy_mu, latency_std)) 17 | 18 | memory = tp.utils.benchmark.measure_memory(model, example_inputs, device=device) 19 | print('memory: {:.4f} MB'.format(memory/ (1024)**2)) 20 | 21 | example_inputs_bs1 = torch.randn(1, 3, 224, 224).to(device) 22 | fps = tp.utils.benchmark.measure_fps(model, example_inputs_bs1, repeat=300) 23 | print('fps: {:.4f}'.format(fps)) 24 | 25 | example_inputs = torch.randn(256, 3, 224, 224).to(device) 26 | throughput = tp.utils.benchmark.measure_throughput(model, example_inputs, repeat=300) 27 | print('throughput (bz=256): {:.4f} images/s'.format(throughput)) 28 | 29 | print("====== Forward & Backward ======") 30 | # Test forward & backward 31 | def run_fn(model, example_inputs): 32 | output = model(example_inputs) 33 | loss = output.sum() 34 | loss.backward() 35 | return loss 36 | 37 | laterncy_mu, latency_std= tp.utils.benchmark.measure_latency(model, example_inputs, repeat=300, run_fn=run_fn) 38 | print('laterncy: {:.4f} +/- {:.4f} ms'.format(laterncy_mu, latency_std)) 39 | 40 | memory = tp.utils.benchmark.measure_memory(model, example_inputs, device=device, run_fn=run_fn) 41 | print('memory: {:.4f} MB'.format(memory/ (1024)**2)) 42 | 43 | example_inputs_bs1 = torch.randn(1, 3, 224, 224).to(device) 44 | fps = tp.utils.benchmark.measure_fps(model, example_inputs_bs1, repeat=300, run_fn=run_fn) 45 | print('fps: {:.4f}'.format(fps)) 46 | 47 | example_inputs = torch.randn(256, 3, 224, 224).to(device) 48 | throughput = tp.utils.benchmark.measure_throughput(model, example_inputs, repeat=300, run_fn=run_fn) 49 | print('throughput (bz=256): {:.4f} images/s'.format(throughput)) -------------------------------------------------------------------------------- /tests/test_concat.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch_pruning as tp 7 | import torch.nn as nn 8 | 9 | class Net(nn.Module): 10 | def __init__(self, in_dim): 11 | super().__init__() 12 | self.block1 = nn.Sequential( 13 | nn.Conv2d(in_dim, in_dim, 1), 14 | nn.BatchNorm2d(in_dim), 15 | nn.GELU(), 16 | nn.Conv2d(in_dim, in_dim, 1), 17 | nn.BatchNorm2d(in_dim) 18 | ) 19 | self.parallel_path = nn.Sequential( 20 | nn.Conv2d(in_dim, in_dim, 1), 21 | nn.BatchNorm2d(in_dim), 22 | nn.GELU(), 23 | nn.Conv2d(in_dim, in_dim//2, 1), 24 | nn.BatchNorm2d(in_dim//2) 25 | ) 26 | self.block2 = nn.Sequential( 27 | nn.Conv2d(in_dim * 2 + in_dim//2, in_dim, 1), 28 | nn.BatchNorm2d(in_dim) 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.block1(x) 33 | x2 = self.parallel_path(x) 34 | x = torch.cat([x, x, x2], dim=1) 35 | x = self.block2(x) 36 | return x 37 | 38 | def test_pruner(): 39 | model = Net(512) 40 | print(model) 41 | # Global metrics 42 | example_inputs = torch.randn(1, 512, 7, 7) 43 | imp = tp.importance.MagnitudeImportance(p=2) 44 | ignored_layers = [] 45 | 46 | # DO NOT prune the final classifier! 47 | for m in model.modules(): 48 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 49 | ignored_layers.append(m) 50 | 51 | iterative_steps = 5 52 | pruner = tp.pruner.MagnitudePruner( 53 | model, 54 | example_inputs, 55 | importance=imp, 56 | iterative_steps=iterative_steps, 57 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 58 | ignored_layers=ignored_layers, 59 | ) 60 | 61 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 62 | for i in range(iterative_steps): 63 | pruner.step() 64 | print(model) 65 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 66 | 67 | print(model(example_inputs).shape) 68 | print( 69 | " Iter %d/%d, Params: %.2f => %.2f" 70 | % (i+1, iterative_steps, base_nparams, nparams) 71 | ) 72 | print( 73 | " Iter %d/%d, MACs: %.2f => %.2f" 74 | % (i+1, iterative_steps, base_macs, macs) 75 | ) 76 | # finetune your model here 77 | # finetune(model) 78 | # ... 79 | 80 | if __name__=='__main__': 81 | test_pruner() -------------------------------------------------------------------------------- /tests/test_concat_split.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch_pruning as tp 7 | import torch.nn as nn 8 | 9 | class Net(nn.Module): 10 | def __init__(self, in_dim): 11 | super().__init__() 12 | self.block1 = nn.Sequential( 13 | nn.Conv2d(in_dim, in_dim, 1), 14 | nn.BatchNorm2d(in_dim), 15 | nn.GELU(), 16 | nn.Conv2d(in_dim, 2*in_dim, 1), 17 | nn.BatchNorm2d(2*in_dim) 18 | ) 19 | self.parallel_path = nn.Sequential( 20 | nn.Conv2d(in_dim, in_dim, 1), 21 | nn.BatchNorm2d(in_dim), 22 | nn.GELU(), 23 | nn.Conv2d(in_dim, in_dim, 1), 24 | nn.BatchNorm2d(in_dim) 25 | ) 26 | self.block2 = nn.Sequential( 27 | nn.Conv2d(in_dim * 2, in_dim, 1), 28 | nn.BatchNorm2d(in_dim) 29 | ) 30 | 31 | def forward(self, x): 32 | x3 = self.parallel_path(x) 33 | x = self.block1(x) 34 | x1_ch = self.block2[0].in_channels - self.parallel_path[-1].num_features 35 | x2_ch = self.block1[-1].num_features - x1_ch 36 | x1, x2 = torch.split(x, (x1_ch, x2_ch), 1) 37 | x = torch.cat([x1, x3], dim=1) 38 | x = self.block2(x) 39 | return x + x2 40 | 41 | def test_pruner(): 42 | model = Net(10) 43 | print(model) 44 | # Global metrics 45 | example_inputs = torch.randn(1, 10, 7, 7) 46 | imp = tp.importance.RandomImportance() 47 | ignored_layers = [] 48 | 49 | # DO NOT prune the final classifier! 50 | for m in model.modules(): 51 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 52 | ignored_layers.append(m) 53 | 54 | iterative_steps = 1 55 | pruner = tp.pruner.MagnitudePruner( 56 | model, 57 | example_inputs, 58 | importance=imp, 59 | iterative_steps=iterative_steps, 60 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 61 | ignored_layers=ignored_layers, 62 | ) 63 | 64 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 65 | for i in range(iterative_steps): 66 | for g in pruner.step(interactive=True): 67 | #print(g.details()) 68 | g.prune() 69 | print(model) 70 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 71 | print(model(example_inputs).shape) 72 | print( 73 | " Iter %d/%d, Params: %.2f => %.2f" 74 | % (i+1, iterative_steps, base_nparams, nparams ) 75 | ) 76 | print( 77 | " Iter %d/%d, MACs: %.2f => %.2f " 78 | % (i+1, iterative_steps, base_macs, macs) 79 | ) 80 | # finetune your model here 81 | # finetune(model) 82 | # ... 83 | 84 | if __name__=='__main__': 85 | test_pruner() -------------------------------------------------------------------------------- /tests/test_customized_layer.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torch_pruning as tp 10 | from typing import Sequence 11 | 12 | ############ 13 | # Customize your layer 14 | # 15 | class CustomizedLayer(nn.Module): 16 | def __init__(self, in_dim): 17 | super().__init__() 18 | self.in_dim = in_dim 19 | self.scale = nn.Parameter(torch.Tensor(self.in_dim)) 20 | self.bias = nn.Parameter(torch.Tensor(self.in_dim)) 21 | self.fc = nn.Linear(self.in_dim, self.in_dim) 22 | 23 | def forward(self, x): 24 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() 25 | x = torch.div(x, norm) 26 | return self.fc(x * self.scale + self.bias) 27 | 28 | def __repr__(self): 29 | return "CustomizedLayer(in_dim=%d)"%(self.in_dim) 30 | 31 | class FullyConnectedNet(nn.Module): 32 | """https://github.com/VainF/Torch-Pruning/issues/21""" 33 | def __init__(self, input_size, num_classes, HIDDEN_UNITS): 34 | super().__init__() 35 | self.fc1 = nn.Linear(input_size, HIDDEN_UNITS) 36 | self.customized_layer = CustomizedLayer(HIDDEN_UNITS) 37 | self.fc2 = nn.Linear(HIDDEN_UNITS, num_classes) 38 | 39 | def forward(self, x): 40 | x = F.relu(self.fc1(x)) 41 | x = self.customized_layer(x) 42 | y_hat = self.fc2(x) 43 | return y_hat 44 | 45 | ############################ 46 | # Implement your pruning function for the customized layer 47 | # You should implement the following class fucntions: 48 | # 1. prune_out_channels 49 | # 2. prune_in_channels 50 | # 3. get_out_channels 51 | # 4. get_in_channels 52 | 53 | class MyPruner(tp.pruner.BasePruningFunc): 54 | 55 | def prune_out_channels(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn.Module: 56 | keep_idxs = list(set(range(layer.in_dim)) - set(idxs)) 57 | keep_idxs.sort() 58 | layer.in_dim = layer.in_dim-len(idxs) 59 | layer.scale = self._prune_parameter_and_grad(layer.scale, keep_idxs, pruning_dim=0) 60 | layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim=0) 61 | tp.prune_linear_in_channels(layer.fc, idxs) 62 | tp.prune_linear_out_channels(layer.fc, idxs) 63 | return layer 64 | 65 | def get_out_channels(self, layer): 66 | return self.in_dim 67 | 68 | # identical functions 69 | prune_in_channels = prune_out_channels 70 | get_in_channels = get_out_channels 71 | 72 | class MyLinearPruner(tp.function.LinearPruner): 73 | def prune_out_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: 74 | print("MyLinearPruner applied to layer: ", layer) 75 | return super().prune_out_channels(layer, idxs) 76 | 77 | def prune_in_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: 78 | print("MyLinearPruner applied to layer: ", layer) 79 | return super().prune_in_channels(layer, idxs) 80 | 81 | def test_customization(): 82 | model = FullyConnectedNet(128, 10, 256) 83 | 84 | DG = tp.DependencyGraph() 85 | 86 | # 1. Register your customized layer 87 | my_pruner = MyPruner() 88 | DG.register_customized_layer( 89 | CustomizedLayer, 90 | my_pruner) 91 | 92 | my_linear_pruner = MyLinearPruner() 93 | DG.register_customized_layer( 94 | nn.Linear, my_linear_pruner 95 | ) 96 | 97 | # 2. Build dependency graph 98 | DG.build_dependency(model, example_inputs=torch.randn(1,128)) 99 | 100 | # 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters. 101 | pruning_group = DG.get_pruning_group( model.fc1, my_linear_pruner.prune_out_channels, idxs=[0, 1, 6] ) 102 | print(pruning_group) 103 | 104 | # 4. execute this group (prune the model) 105 | pruning_group.prune() 106 | print("The pruned model:\n", model) 107 | print("Output: ", model(torch.randn(1,128)).shape) 108 | 109 | assert model.fc1.out_features==253 and model.customized_layer.in_dim==253 and model.fc2.in_features==253 110 | 111 | if __name__=='__main__': 112 | test_customization() -------------------------------------------------------------------------------- /tests/test_dependency_graph.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | from torchvision.models import resnet18 6 | import torch_pruning as tp 7 | 8 | def test_depgraph(): 9 | model = resnet18().eval() 10 | 11 | # 1. build dependency graph for resnet18 12 | DG = tp.DependencyGraph() 13 | DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) 14 | 15 | # 2. Select channels for pruning, here we prune the channels indexed by [2, 6, 9]. 16 | pruning_idxs = pruning_idxs=[2, 6, 9] 17 | pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs ) 18 | 19 | print("Pruning Group:") 20 | print(pruning_group.details()) # or print(pruning_group) 21 | 22 | # 3. prune all grouped layer that is coupled with model.conv1 23 | if DG.check_pruning_group(pruning_group): 24 | pruning_group.prune() 25 | 26 | print("After pruning:") 27 | print(model) 28 | 29 | for group in DG.get_all_groups(): 30 | # handle groups in sequential order 31 | idxs = [2,4,6] # my pruning indices 32 | group.prune(idxs=idxs) 33 | print(model) 34 | 35 | #groups = list(DG.get_all_groups()) 36 | #print("Num groups: %d"%(len(groups))) 37 | 38 | #for g in groups: 39 | # print(g) 40 | 41 | if __name__=='__main__': 42 | test_depgraph() -------------------------------------------------------------------------------- /tests/test_dependency_lenet.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch_pruning as tp 7 | 8 | class Model(nn.Module): 9 | def __init__(self): 10 | super(Model, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 6, 5) 12 | self.relu1 = nn.ReLU() 13 | self.pool1 = nn.MaxPool2d(2) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.relu2 = nn.ReLU() 16 | self.pool2 = nn.MaxPool2d(2) 17 | self.fc1 = nn.Linear(256, 120) 18 | self.relu3 = nn.ReLU() 19 | self.fc2 = nn.Linear(120, 84) 20 | self.relu4 = nn.ReLU() 21 | self.fc3 = nn.Linear(84, 10) 22 | self.relu5 = nn.ReLU() 23 | 24 | def forward(self, x): 25 | y = self.conv1(x) 26 | y = self.relu1(y) 27 | y = self.pool1(y) 28 | y = self.conv2(y) 29 | y = self.relu2(y) 30 | y = self.pool2(y) 31 | y = y.view(y.shape[0], -1) 32 | y = self.fc1(y) 33 | y = self.relu3(y) 34 | y = self.fc2(y) 35 | y = self.relu4(y) 36 | y = self.fc3(y) 37 | y = self.relu5(y) 38 | return y 39 | 40 | def test_lenet(): 41 | model = Model() 42 | 43 | # build layer dependency for resnet18 44 | DG = tp.DependencyGraph() 45 | DG.build_dependency(model, example_inputs=torch.randn(1,1,28,28)) 46 | # get a pruning group according to the dependency graph. idxs is the indices of pruned filters. 47 | pruning_idxs = [0, 2, 6] 48 | pruning_group = DG.get_pruning_group( model.conv2, tp.prune_conv_out_channels, idxs=[0, 2, 6] ) 49 | print(pruning_group) 50 | # execute this group (prune the model) 51 | if DG.check_pruning_group(pruning_group): 52 | pruning_group.prune() 53 | 54 | print("The pruned model: \n", model) 55 | print("Output:", model(torch.randn(1,1,28,28)).shape) 56 | 57 | if __name__=='__main__': 58 | test_lenet() -------------------------------------------------------------------------------- /tests/test_flops.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import resnet50 as entry 7 | import torch_pruning as tp 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | def test_pruner(): 12 | model = entry() 13 | print(model) 14 | # Global metrics 15 | example_inputs = torch.randn(1, 3, 224, 224) 16 | imp = tp.importance.MagnitudeImportance(p=2) 17 | ignored_layers = [] 18 | 19 | # DO NOT prune the final classifier! 20 | for m in model.modules(): 21 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 22 | ignored_layers.append(m) 23 | 24 | iterative_steps = 1 25 | pruner = tp.pruner.MagnitudePruner( 26 | model, 27 | example_inputs, 28 | importance=imp, 29 | global_pruning=True, 30 | iterative_steps=iterative_steps, 31 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 32 | ignored_layers=ignored_layers, 33 | ) 34 | 35 | base_macs, base_nparams, base_layer_macs, base_layer_params = tp.utils.count_ops_and_params(model, example_inputs, layer_wise=True) 36 | for i in range(iterative_steps): 37 | pruner.step() 38 | macs, nparams, layer_macs, layer_params = tp.utils.count_ops_and_params(model, example_inputs, layer_wise=True) 39 | print(model) 40 | print(model(example_inputs).shape) 41 | print( 42 | " Iter %d/%d, Params: %.2f M => %.2f M" 43 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 44 | ) 45 | print( 46 | " Iter %d/%d, MACs: %.2f G => %.2f G" 47 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 48 | ) 49 | 50 | for name, module in model.named_modules(): 51 | if name=='': 52 | name = 'ALL' 53 | print(name, layer_macs[module]/1e9, "G ", layer_params[module]/1e6, "M") 54 | 55 | # finetune your model here 56 | # finetune(model) 57 | # ... 58 | 59 | if __name__=='__main__': 60 | test_pruner() -------------------------------------------------------------------------------- /tests/test_fully_connected_layers.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import torch_pruning as tp 9 | 10 | class FullyConnectedNet(nn.Module): 11 | """https://github.com/VainF/Torch-Pruning/issues/21""" 12 | def __init__(self, input_size, num_classes, HIDDEN_UNITS): 13 | super().__init__() 14 | self.fc1 = nn.Linear(input_size, HIDDEN_UNITS) 15 | self.fc2 = nn.Linear(HIDDEN_UNITS, HIDDEN_UNITS) 16 | self.fc3 = nn.Linear(HIDDEN_UNITS, num_classes) 17 | 18 | def forward(self, x): 19 | x = F.relu(self.fc1(x)) 20 | skip=x 21 | x = F.relu(self.fc2(x)) 22 | x = x+skip 23 | x = self.fc3(x) 24 | return x 25 | 26 | def test_fc(): 27 | model = FullyConnectedNet(128, 10, 256) 28 | 29 | # Build dependency graph 30 | DG = tp.DependencyGraph() 31 | DG.build_dependency(model, example_inputs=torch.randn(1,128)) 32 | 33 | # get a pruning group according to the dependency graph. 34 | pruning_group = DG.get_pruning_group( model.fc1, tp.prune_linear_out_channels, idxs=[0, 4, 6] ) 35 | print(pruning_group) 36 | 37 | # execute the group (prune the model) 38 | pruning_group.prune() 39 | print(model) 40 | 41 | print("The pruned model: \n", model) 42 | print("Output:", model(torch.randn(1,128)).shape) 43 | 44 | if __name__=='__main__': 45 | test_fc() -------------------------------------------------------------------------------- /tests/test_group_prune.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | from torchvision.models import resnet18 6 | import torch_pruning as tp 7 | 8 | def test_depgraph(): 9 | model = resnet18(pretrained=True).eval() 10 | # 1. build dependency graph for resnet18 11 | DG = tp.DependencyGraph() 12 | DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) 13 | # 2. Select channels for pruning, here we prune the channels indexed by [2, 6, 9]. 14 | pruning_idxs = pruning_idxs=[2, 6, 9] 15 | pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs ) 16 | pruning_group.prune() 17 | affected_weights1 = [] 18 | for dep, _ in pruning_group: 19 | module = dep.target.module 20 | if hasattr(module, 'weight'): 21 | affected_weights1.append(module.weight.detach()) 22 | if hasattr(module, 'bias') and module.bias is not None: 23 | affected_weights1.append(module.bias.detach()) 24 | 25 | model = resnet18(pretrained=True).eval() 26 | # 1. build dependency graph for resnet18 27 | DG = tp.DependencyGraph() 28 | DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) 29 | # 2. Select channels for pruning 30 | pruning_idxs = pruning_idxs=[1, 2, 3, 4] # we will replace it with [2,6,9] 31 | pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs ) 32 | pruning_group.prune([2,6,9]) 33 | affected_weights2 = [] 34 | for dep, _ in pruning_group: 35 | module = dep.target.module 36 | if hasattr(module, 'weight'): 37 | affected_weights2.append(module.weight.detach()) 38 | if hasattr(module, 'bias') and module.bias is not None: 39 | affected_weights2.append(module.bias.detach()) 40 | 41 | for w1, w2 in zip(affected_weights1, affected_weights2): 42 | assert torch.allclose(w1, w2) 43 | 44 | if __name__=='__main__': 45 | test_depgraph() -------------------------------------------------------------------------------- /tests/test_hessian_importance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | import torch_pruning as tp 4 | 5 | def test_hessian(): 6 | model = resnet18(pretrained=True) 7 | 8 | # Importance criteria 9 | example_inputs = torch.randn(1, 3, 224, 224) 10 | imp = tp.importance.HessianImportance() 11 | 12 | ignored_layers = [] 13 | for m in model.modules(): 14 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 15 | ignored_layers.append(m) # DO NOT prune the final classifier! 16 | 17 | iterative_steps = 1 # progressive pruning 18 | pruner = tp.pruner.BasePruner( 19 | model, 20 | example_inputs, 21 | importance=imp, 22 | iterative_steps=iterative_steps, 23 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 24 | ignored_layers=ignored_layers, 25 | ) 26 | 27 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 28 | for i in range(iterative_steps): 29 | if isinstance(imp, tp.importance.HessianImportance): 30 | # loss = F.cross_entropy(model(images), targets) 31 | dummy_inputs = torch.randn(10, 3, 224, 224) 32 | output = model(dummy_inputs) 33 | # compute loss for each sample 34 | loss = torch.nn.functional.cross_entropy(output, torch.randint(0, 1000, (len(dummy_inputs),)), reduction='none').to(output.device) 35 | imp.zero_grad() # clear accumulated gradients 36 | for l in loss: 37 | model.zero_grad() # clear gradients 38 | l.backward(retain_graph=True) # simgle-sample gradient 39 | imp.accumulate_grad(model) # accumulate g^2 40 | 41 | #for g in pruner.DG.get_all_groups(ignored_layers=pruner.ignored_layers, root_module_types=pruner.root_module_types): 42 | # print(len(imp(g)) == len(imp2(g))) 43 | 44 | for g in pruner.step(interactive=True): 45 | g.prune() 46 | 47 | print(model) 48 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 49 | print(f"MACs: {macs/base_macs:.2f}, #Params: {nparams/base_nparams:.2f}") 50 | # finetune your model here 51 | # finetune(model) 52 | # ... 53 | 54 | if __name__=="__main__": 55 | test_hessian() -------------------------------------------------------------------------------- /tests/test_importance_reduction.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | from torchvision.models import resnet18 6 | import torch_pruning as tp 7 | model = resnet18() 8 | 9 | # Global metrics 10 | def test_imp(): 11 | DG = tp.DependencyGraph() 12 | example_inputs = torch.randn(1,3,224,224) 13 | DG.build_dependency(model, example_inputs=example_inputs) 14 | pruning_idxs = list( range( DG.get_out_channels(model.conv1) )) 15 | pruning_group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs) 16 | 17 | random_importance = tp.importance.RandomImportance() 18 | rand_imp = random_importance(pruning_group) 19 | print("Random: ", rand_imp) 20 | 21 | magnitude_importance = tp.importance.MagnitudeImportance(p=1, group_reduction=None, normalizer=None, bias=True) 22 | mag_imp_raw = magnitude_importance(pruning_group) 23 | print("L-1 Norm, No Reduction: ", mag_imp_raw) 24 | 25 | magnitude_importance = tp.importance.MagnitudeImportance(p=1, normalizer=None, bias=True) 26 | mag_imp = magnitude_importance(pruning_group) 27 | print("L-1 Norm, Group Mean: ", mag_imp) 28 | assert torch.allclose(mag_imp, mag_imp_raw.mean(0)) 29 | 30 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction=None, normalizer=None, bias=True) 31 | mag_imp_raw = magnitude_importance(pruning_group) 32 | print("L-2 Norm, No Reduction: ", mag_imp_raw) 33 | 34 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, normalizer=None, bias=True) 35 | mag_imp = magnitude_importance(pruning_group) 36 | print("L-2 Norm, Group Mean: ", mag_imp) 37 | assert torch.allclose(mag_imp, mag_imp_raw.mean(0)) 38 | 39 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='sum', normalizer=None, bias=True) 40 | mag_imp = magnitude_importance(pruning_group) 41 | print("L-2 Norm, Group Sum: ", mag_imp) 42 | assert torch.allclose(mag_imp, mag_imp_raw.sum(0)) 43 | 44 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='max', normalizer=None, bias=True) 45 | mag_imp = magnitude_importance(pruning_group) 46 | print("L-2 Norm, Group Max: ", mag_imp) 47 | assert torch.allclose(mag_imp, mag_imp_raw.max(0)[0]) 48 | 49 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='gate', normalizer=None, bias=True) 50 | mag_imp = magnitude_importance(pruning_group) 51 | print("L-2 Norm, Group Gate: ", mag_imp) 52 | assert torch.allclose(mag_imp, mag_imp_raw[-1]) 53 | 54 | magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='prod', normalizer=None, bias=True) 55 | mag_imp = magnitude_importance(pruning_group) 56 | print("L-2 Norm, Group Prod: ", mag_imp) 57 | print(mag_imp, torch.prod(mag_imp_raw, dim=0)) 58 | assert torch.allclose(mag_imp, torch.prod(mag_imp_raw, dim=0)) 59 | 60 | bn_scale_importance = tp.importance.BNScaleImportance(normalizer=None) 61 | bn_imp = bn_scale_importance(pruning_group) 62 | print("BN Scaling, Group mean: ", bn_imp) 63 | 64 | lamp_importance = tp.importance.LAMPImportance(bias=True) 65 | lamp_imp = lamp_importance(pruning_group) 66 | print("LAMP: ", lamp_imp) 67 | 68 | model(example_inputs).sum().backward() 69 | taylor_importance = tp.importance.TaylorImportance(normalizer='mean', bias=True) 70 | taylor_imp = taylor_importance(pruning_group) 71 | print("Taylor Importance", taylor_imp) 72 | 73 | if __name__=='__main__': 74 | test_imp() -------------------------------------------------------------------------------- /tests/test_interactive_pruner.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import resnet18 as entry 7 | import torch_pruning as tp 8 | 9 | def test_interactive_pruner(): 10 | model = entry() 11 | print(model) 12 | # Global metrics 13 | example_inputs = torch.randn(1, 3, 224, 224) 14 | imp = tp.importance.MagnitudeImportance(p=2) 15 | ignored_layers = [] 16 | 17 | # DO NOT prune the final classifier! 18 | for m in model.modules(): 19 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 20 | ignored_layers.append(m) 21 | 22 | iterative_steps = 5 23 | pruner = tp.pruner.MagnitudePruner( 24 | model, 25 | example_inputs, 26 | importance=imp, 27 | iterative_steps=iterative_steps, 28 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 29 | ignored_layers=ignored_layers, 30 | ) 31 | 32 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 33 | for i in range(iterative_steps): 34 | for group in pruner.step(interactive=True): 35 | print(group) 36 | group.prune() 37 | 38 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 39 | print(model) 40 | print(model(example_inputs).shape) 41 | print( 42 | " Iter %d/%d, Params: %.2f M => %.2f M" 43 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 44 | ) 45 | print( 46 | " Iter %d/%d, MACs: %.2f G => %.2f G" 47 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 48 | ) 49 | # finetune your model here 50 | # finetune(model) 51 | # ... 52 | 53 | if __name__=='__main__': 54 | test_interactive_pruner() 55 | -------------------------------------------------------------------------------- /tests/test_load.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import resnet18 as entry 7 | import torch_pruning as tp 8 | 9 | def test_pruner(): 10 | model = entry() 11 | print(model) 12 | # Global metrics 13 | example_inputs = torch.randn(1, 3, 224, 224) 14 | imp = tp.importance.MagnitudeImportance(p=2) 15 | ignored_layers = [] 16 | 17 | # DO NOT prune the final classifier! 18 | for m in model.modules(): 19 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 20 | ignored_layers.append(m) 21 | 22 | iterative_steps = 5 23 | pruner = tp.pruner.MagnitudePruner( 24 | model, 25 | example_inputs, 26 | importance=imp, 27 | iterative_steps=iterative_steps, 28 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 29 | ignored_layers=ignored_layers, 30 | ) 31 | 32 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 33 | for i in range(iterative_steps): 34 | pruner.step() 35 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 36 | print(model) 37 | print(model(example_inputs).shape) 38 | print( 39 | " Iter %d/%d, Params: %.2f M => %.2f M" 40 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 41 | ) 42 | print( 43 | " Iter %d/%d, MACs: %.2f G => %.2f G" 44 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 45 | ) 46 | 47 | state_dict = { 48 | 'model': model.state_dict(), 49 | 'pruning': pruner.pruning_history(), 50 | } 51 | torch.save(state_dict, 'pruned_model.pth') 52 | # Create a new model and pruner 53 | model = entry() 54 | DG = tp.DependencyGraph().build_dependency(model, example_inputs) 55 | state_dict = torch.load('pruned_model.pth') 56 | DG.load_pruning_history(state_dict['pruning']) 57 | model.load_state_dict(state_dict['model']) 58 | print(model) 59 | 60 | if __name__=='__main__': 61 | test_pruner() -------------------------------------------------------------------------------- /tests/test_multiple_inputs_and_outputs.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torch_pruning as tp 10 | 11 | class FullyConnectedNet(nn.Module): 12 | """https://github.com/VainF/Torch-Pruning/issues/21""" 13 | 14 | def __init__(self, input_sizes, output_sizes): 15 | super().__init__() 16 | 17 | self.fc1 = nn.Linear(input_sizes[0], output_sizes[0]) 18 | self.fc2 = nn.Linear(input_sizes[1], output_sizes[1]) 19 | self.fc3 = nn.Linear(sum(output_sizes), 1000) 20 | 21 | def forward(self, x1, x2): 22 | x1 = F.relu(self.fc1(x1)) 23 | x2 = F.relu(self.fc2(x2)) 24 | x3 = F.relu(self.fc3(torch.cat([x1, x2], dim=1))) 25 | return x1, x2, x3 26 | 27 | def test_multi_io(): 28 | model = FullyConnectedNet([128, 64], [32, 32]) 29 | 30 | # Build dependency graph 31 | DG = tp.DependencyGraph() 32 | DG.build_dependency(model, example_inputs={'x1': torch.randn(1, 128), 'x2': torch.randn(1, 64)}) 33 | 34 | # get a pruning group according to the dependency graph. idxs is the indices of pruned filters. 35 | pruning_group = DG.get_pruning_group( 36 | model.fc1, tp.prune_linear_out_channels, idxs=[0, 2, 4] 37 | ) 38 | print(pruning_group) 39 | 40 | # execute this group (prune the model) 41 | pruning_group.prune() 42 | 43 | print("The pruned model: \n", model) 44 | print("Output:") 45 | for o in model(torch.randn(1, 128), torch.randn(1, 64)): 46 | print('\t', o.shape) 47 | 48 | if __name__=='__main__': 49 | test_multi_io() -------------------------------------------------------------------------------- /tests/test_non_feature_dim_cat.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch_pruning as tp 7 | import torch.nn as nn 8 | 9 | class Net(nn.Module): 10 | def __init__(self, in_dim): 11 | super().__init__() 12 | self.block1 = nn.Sequential( 13 | nn.Conv2d(in_dim, in_dim, 1), 14 | nn.BatchNorm2d(in_dim), 15 | nn.GELU(), 16 | nn.Conv2d(in_dim, in_dim, 1), 17 | nn.BatchNorm2d(in_dim) 18 | ) 19 | self.parallel_path = nn.Sequential( 20 | nn.Conv2d(in_dim, in_dim, 1), 21 | nn.BatchNorm2d(in_dim), 22 | nn.GELU(), 23 | nn.Conv2d(in_dim, in_dim, 1), 24 | nn.BatchNorm2d(in_dim) 25 | ) 26 | 27 | self.conv1 = nn.Conv2d(in_dim, in_dim, 1) 28 | self.conv2 = nn.Conv2d(in_dim, in_dim, 1) 29 | 30 | def forward(self, x): 31 | x1 = self.block1(x) 32 | x2 = self.parallel_path(x) 33 | x = torch.cat([x1, x2], dim=2) 34 | x = self.conv1(x) 35 | x1, x2 = torch.split(x, [x1.shape[2], x2.shape[2]], dim=2) 36 | x = self.conv2(x1) 37 | return x 38 | 39 | def test_pruner(): 40 | model = Net(512) 41 | print(model) 42 | # Global metrics 43 | example_inputs = torch.randn(1, 512, 7, 7) 44 | imp = tp.importance.MagnitudeImportance(p=2) 45 | ignored_layers = [] 46 | 47 | # DO NOT prune the final classifier! 48 | for m in model.modules(): 49 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 50 | ignored_layers.append(m) 51 | 52 | iterative_steps = 1 53 | pruner = tp.pruner.MagnitudePruner( 54 | model, 55 | example_inputs, 56 | importance=imp, 57 | iterative_steps=iterative_steps, 58 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 59 | ignored_layers=ignored_layers, 60 | ) 61 | 62 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 63 | for i in range(iterative_steps): 64 | pruner.step() 65 | print(model) 66 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 67 | 68 | print(model(example_inputs).shape) 69 | print( 70 | " Iter %d/%d, Params: %.2f => %.2f" 71 | % (i+1, iterative_steps, base_nparams, nparams) 72 | ) 73 | print( 74 | " Iter %d/%d, MACs: %.2f => %.2f" 75 | % (i+1, iterative_steps, base_macs, macs) 76 | ) 77 | # finetune your model here 78 | # finetune(model) 79 | # ... 80 | 81 | if __name__=='__main__': 82 | test_pruner() -------------------------------------------------------------------------------- /tests/test_print_tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | import torch_pruning as tp 4 | 5 | model = resnet18(pretrained=True) 6 | example_inputs = torch.randn(1, 3, 224, 224) 7 | 8 | # 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score 9 | imp = tp.importance.GroupMagnitudeImportance(p=2) 10 | 11 | # 2. Initialize a pruner with the model and the importance criterion 12 | ignored_layers = [] 13 | for m in model.modules(): 14 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 15 | ignored_layers.append(m) # DO NOT prune the final classifier! 16 | 17 | pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required. 18 | model, 19 | example_inputs, 20 | importance=imp, 21 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 22 | # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks 23 | ignored_layers=ignored_layers, 24 | round_to=8, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html 25 | ) 26 | 27 | # 3. Prune the model 28 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 29 | tp.utils.print_tool.before_pruning(model) # or print(model) 30 | pruner.step() 31 | tp.utils.print_tool.after_pruning(model) # or print(model), this tool will show the difference 32 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 33 | print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M") -------------------------------------------------------------------------------- /tests/test_pruning_fn.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import alexnet 2 | import sys, os 3 | 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 5 | import torch_pruning as tp 6 | 7 | def test_pruning_fn(): 8 | model = alexnet() 9 | print("Before pruning: ") 10 | print(model.features[:4]) 11 | print(model.features[0].weight.shape) 12 | print(model.features[3].weight.shape) 13 | 14 | tp.prune_conv_out_channels(model.features[0], idxs=[0, 1, 3, 4]) 15 | tp.prune_conv_in_channels(model.features[3], idxs=[0, 1, 3, 4]) 16 | 17 | print("\nAfter pruning: ") 18 | print(model.features[:4]) 19 | print(model.features[0].weight.shape) 20 | print(model.features[3].weight.shape) 21 | 22 | if __name__=="__main__": 23 | test_pruning_fn() -------------------------------------------------------------------------------- /tests/test_regularization.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import densenet121 as entry 7 | import torch_pruning as tp 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | def test_pruner(): 12 | model = entry(pretrained=True) 13 | print(model) 14 | # Global metrics 15 | example_inputs = torch.randn(1, 3, 224, 224) 16 | 17 | for imp_cls, pruner_cls in [ 18 | [tp.importance.GroupMagnitudeImportance, tp.pruner.GroupNormPruner], 19 | [tp.importance.BNScaleImportance, tp.pruner.BNScalePruner], 20 | [tp.importance.GroupMagnitudeImportance, tp.pruner.GrowingRegPruner], 21 | ]: 22 | if imp_cls == tp.importance.OBDCImportance: 23 | imp = imp_cls(num_classes=1000) 24 | else: 25 | imp = imp_cls() 26 | ignored_layers = [] 27 | # DO NOT prune the final classifier! 28 | for m in model.modules(): 29 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 30 | ignored_layers.append(m) 31 | iterative_steps = 5 32 | pruner = pruner_cls( 33 | model, 34 | example_inputs, 35 | importance=imp, 36 | global_pruning=True, 37 | iterative_steps=iterative_steps, 38 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 39 | ignored_layers=ignored_layers, 40 | ) 41 | 42 | for i in range(iterative_steps): 43 | if isinstance(imp, tp.importance.OBDCImportance): 44 | imp._prepare_model(model, pruner) 45 | model(example_inputs).sum().backward() 46 | imp.step() 47 | else: 48 | model(example_inputs).sum().backward() 49 | grad_dict = {} 50 | for p in model.parameters(): 51 | if p.grad is not None: 52 | grad_dict[p] = p.grad.clone() 53 | else: 54 | grad_dict[p] = None 55 | pruner.update_regularizer() 56 | pruner.regularize(model) 57 | for name, p in model.named_parameters(): 58 | if p.grad is not None and grad_dict[p] is not None: 59 | print(name, (grad_dict[p] - p.grad).abs().sum()) 60 | else: 61 | print(name, "has no grad") 62 | for g in pruner.step(interactive=True): 63 | g.prune() 64 | if isinstance(imp, tp.importance.OBDCImportance): 65 | imp._rm_hooks(model) 66 | imp._clear_buffer() 67 | 68 | 69 | if __name__ == "__main__": 70 | test_pruner() 71 | -------------------------------------------------------------------------------- /tests/test_reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys, os 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 7 | 8 | import torch_pruning as tp 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.Linear = nn.Linear(in_features=512, out_features=4096) 13 | self.conv1T1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1) 14 | self.conv1T2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 15 | self.final = nn.Linear(16384, 10) 16 | def forward(self, x): 17 | x = F.relu(self.Linear(x)) 18 | x = x.view(-1,self.conv1T1.in_channels, 4, 4) 19 | x = F.relu(self.conv1T1(x)) 20 | x = F.relu(self.conv1T2(x)) 21 | x = torch.flatten(x, 1) 22 | x = self.final(x) 23 | return x 24 | 25 | def test_reshape(): 26 | model = Net() 27 | example_inputs = torch.randn(1, 512) 28 | imp = tp.importance.MagnitudeImportance() 29 | ignored_layers = [model.final] 30 | 31 | iterative_steps = 5 32 | pruner = tp.pruner.MagnitudePruner( 33 | model, 34 | example_inputs, 35 | importance=imp, 36 | iterative_steps=iterative_steps, 37 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 38 | ignored_layers=ignored_layers, 39 | root_module_types=[nn.ConvTranspose2d, nn.Linear], 40 | ) 41 | 42 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 43 | for i in range(iterative_steps): 44 | pruner.step() 45 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 46 | print(model) 47 | print( 48 | " Iter %d/%d, Params: %.2f M => %.2f M" 49 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 50 | ) 51 | print( 52 | " Iter %d/%d, MACs: %.2f G => %.2f G" 53 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 54 | ) 55 | # finetune your model here 56 | # finetune(model) 57 | # ... 58 | 59 | if __name__=='__main__': 60 | test_reshape() 61 | 62 | -------------------------------------------------------------------------------- /tests/test_score_normalization.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import resnet18 as entry 7 | import torch_pruning as tp 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | def test_pruner(): 12 | 13 | # Global metrics 14 | example_inputs = torch.randn(1, 3, 224, 224) 15 | for normalization in ['mean', 'max', 'gaussian', 'standarization', 'sum']: 16 | model = entry() 17 | print(model) 18 | imp = tp.importance.MagnitudeImportance(p=2, normalizer=normalization) 19 | ignored_layers = [] 20 | 21 | # DO NOT prune the final classifier! 22 | for m in model.modules(): 23 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 24 | ignored_layers.append(m) 25 | 26 | iterative_steps = 5 27 | pruner = tp.pruner.MagnitudePruner( 28 | model, 29 | example_inputs, 30 | importance=imp, 31 | global_pruning=True, 32 | iterative_steps=iterative_steps, 33 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 34 | ignored_layers=ignored_layers, 35 | ) 36 | 37 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 38 | for i in range(iterative_steps): 39 | pruner.step() 40 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 41 | print(model) 42 | print(model(example_inputs).shape) 43 | print( 44 | " Iter %d/%d, Params: %.2f M => %.2f M" 45 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 46 | ) 47 | print( 48 | " Iter %d/%d, MACs: %.2f G => %.2f G" 49 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 50 | ) 51 | # finetune your model here 52 | # finetune(model) 53 | # ... 54 | 55 | if __name__ == "__main__": 56 | test_pruner() 57 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 3 | 4 | import torch 5 | from torchvision.models import vit_b_16 as entry 6 | import torch_pruning as tp 7 | from torchvision.models.vision_transformer import VisionTransformer 8 | 9 | def test_serialization(): 10 | model = entry().eval() 11 | 12 | customized_value = 8 13 | model.customized_value = customized_value 14 | importance = tp.importance.MagnitudeImportance(p=1) 15 | round_to = None 16 | if isinstance( model, VisionTransformer): round_to = model.encoder.layers[0].num_heads 17 | pruner = tp.pruner.MagnitudePruner( 18 | model, 19 | example_inputs=torch.randn(1, 3, 224, 224), 20 | importance=importance, 21 | iterative_steps=1, 22 | pruning_ratio=0.5, 23 | round_to=round_to, 24 | ) 25 | pruner.step() 26 | if isinstance( 27 | model, VisionTransformer 28 | ): # Torchvision relies on the hidden_dim variable for forwarding, so we have to modify this varaible after pruning 29 | model.hidden_dim = model.conv_proj.out_channels 30 | true_hidden_dim = model.hidden_dim 31 | print(model.class_token.shape, model.encoder.pos_embedding.shape) 32 | 33 | state_dict = tp.state_dict(model) 34 | torch.save(state_dict, 'test.pth') 35 | 36 | # create a new model 37 | model = entry().eval() 38 | print(model) 39 | 40 | # load the pruned state_dict 41 | loaded_state_dict = torch.load('test.pth', map_location='cpu', weights_only=False) 42 | tp.load_state_dict(model, state_dict=loaded_state_dict) 43 | print(model) 44 | 45 | # test 46 | assert model.customized_value == customized_value 47 | assert model.hidden_dim == true_hidden_dim 48 | print(model.customized_value) # check the user attributes 49 | print(model.hidden_dim) 50 | 51 | out = model(torch.randn(1,3,224,224)) 52 | print(out.shape) 53 | loss = out.sum() 54 | loss.backward() 55 | 56 | if __name__=='__main__': 57 | test_serialization() -------------------------------------------------------------------------------- /tests/test_single_channel_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torch_pruning as tp 5 | 6 | class Model(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1) 10 | self.bn1 = nn.BatchNorm2d(16) 11 | self.conv2 = nn.Conv2d(16, 64, 2, stride=2) 12 | self.bn2 = nn.BatchNorm2d(64) 13 | self.conv3 = nn.Conv2d(64, 1, 3, 1, 1) 14 | self.bn3 = nn.BatchNorm2d(1) 15 | 16 | def forward(self, x): 17 | x = F.relu(self.bn1(self.conv1(x))) 18 | x = F.relu(self.bn2(self.conv2(x))) 19 | x = F.relu(self.bn3(self.conv3(x))) 20 | return x 21 | 22 | def test_single_channel_output(): 23 | model = Model() 24 | example_inputs = torch.randn(1, 3, 224, 224) 25 | DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs) 26 | 27 | all_groups = list(DG.get_all_groups()) 28 | print(all_groups[0]) 29 | assert len(all_groups[0])==3 30 | 31 | if __name__ == "__main__": 32 | test_single_channel_output() -------------------------------------------------------------------------------- /tests/test_soft_pruning.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import resnet50 as entry 7 | import torch_pruning as tp 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | def test_soft_pruning(): 12 | model = entry() 13 | print(model) 14 | # Global metrics 15 | example_inputs = torch.randn(1, 3, 224, 224) 16 | imp = tp.importance.MagnitudeImportance(p=2) 17 | ignored_layers = [] 18 | 19 | # DO NOT prune the final classifier! 20 | for m in model.modules(): 21 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 22 | ignored_layers.append(m) 23 | 24 | iterative_steps = 1 25 | pruner = tp.pruner.MagnitudePruner( 26 | model, 27 | example_inputs, 28 | importance=imp, 29 | global_pruning=True, 30 | iterative_steps=iterative_steps, 31 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 32 | ignored_layers=ignored_layers, 33 | ) 34 | 35 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 36 | for i in range(iterative_steps): 37 | 38 | # Soft Pruning 39 | for group in pruner.step(interactive=True): 40 | for dep, idxs in group: 41 | target_layer = dep.target.module 42 | pruning_fn = dep.handler 43 | if pruning_fn in [tp.prune_conv_in_channels, tp.prune_linear_in_channels]: 44 | target_layer.weight.data[:, idxs] *= 0 45 | elif pruning_fn in [tp.prune_conv_out_channels, tp.prune_linear_out_channels]: 46 | target_layer.weight.data[idxs] *= 0 47 | if target_layer.bias is not None: 48 | target_layer.bias.data[idxs] *= 0 49 | elif pruning_fn in [tp.prune_batchnorm_out_channels]: 50 | target_layer.weight.data[idxs] *= 0 51 | target_layer.bias.data[idxs] *= 0 52 | # group.prune() # <= disable hard pruning 53 | print(model.conv1.weight) 54 | 55 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 56 | print(model) 57 | print(model(example_inputs).shape) 58 | print( 59 | " Iter %d/%d, Params: %.2f M => %.2f M" 60 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 61 | ) 62 | print( 63 | " Iter %d/%d, MACs: %.2f G => %.2f G" 64 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 65 | ) 66 | # finetune your model here 67 | # finetune(model) 68 | # ... 69 | 70 | if __name__ == "__main__": 71 | test_soft_pruning() -------------------------------------------------------------------------------- /tests/test_split.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch_pruning as tp 7 | import torch.nn as nn 8 | 9 | class Net(nn.Module): 10 | def __init__(self, in_dim): 11 | super().__init__() 12 | 13 | self.block1 = nn.Sequential( 14 | nn.Conv2d(in_dim, in_dim, 1), 15 | nn.BatchNorm2d(in_dim), 16 | nn.GELU(), 17 | nn.Conv2d(in_dim, in_dim*4, 1), 18 | nn.BatchNorm2d(in_dim*4) 19 | ) 20 | 21 | self.block2_1 = nn.Sequential( 22 | nn.Conv2d(in_dim, in_dim, 1), 23 | nn.BatchNorm2d(in_dim) 24 | ) 25 | 26 | self.block2_2 = nn.Sequential( 27 | nn.Conv2d(2*in_dim, in_dim, 1), 28 | nn.BatchNorm2d(in_dim) 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.block1(x) 33 | num_ch = x.shape[1] 34 | 35 | c1, c2 = self.block2_1[0].in_channels, self.block2_2[0].in_channels 36 | x1, x2, x3 = torch.split(x, [c1, c1, c2], dim=1) 37 | x1 = self.block2_1(x1) 38 | x2 = self.block2_1(x2) 39 | x3 = self.block2_2(x3) 40 | return x1, x2, x3 41 | 42 | def test_pruner(): 43 | dim = 128 44 | model = Net(dim) 45 | print(model) 46 | # Global metrics 47 | example_inputs = torch.randn(1, dim, 7, 7) 48 | imp = tp.importance.RandomImportance() 49 | ignored_layers = [] 50 | 51 | # DO NOT prune the final classifier! 52 | for m in model.modules(): 53 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 54 | ignored_layers.append(m) 55 | 56 | iterative_steps = 1 57 | pruner = tp.pruner.MagnitudePruner( 58 | model, 59 | example_inputs, 60 | importance=imp, 61 | iterative_steps=iterative_steps, 62 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 63 | ignored_layers=ignored_layers, 64 | ) 65 | for g in pruner.DG.get_all_groups(): 66 | pass 67 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 68 | for i in range(iterative_steps): 69 | for g in pruner.step(interactive=True): 70 | #print(g.details()) 71 | g.prune() 72 | print(model) 73 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 74 | 75 | print([o.shape for o in model(example_inputs)]) 76 | print( 77 | " Iter %d/%d, Params: %.2f M => %.2f M" 78 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 79 | ) 80 | print( 81 | " Iter %d/%d, MACs: %.2f G => %.2f G" 82 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 83 | ) 84 | # finetune your model here 85 | # finetune(model) 86 | # ... 87 | 88 | if __name__=='__main__': 89 | test_pruner() -------------------------------------------------------------------------------- /tests/test_taylor_importance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | import torch_pruning as tp 4 | 5 | def test_taylor(): 6 | model = resnet18(pretrained=True) 7 | 8 | # Importance criteria 9 | example_inputs = torch.randn(1, 3, 224, 224) 10 | imp = tp.importance.TaylorImportance() 11 | 12 | ignored_layers = [] 13 | for m in model.modules(): 14 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 15 | ignored_layers.append(m) # DO NOT prune the final classifier! 16 | 17 | iterative_steps = 1 # progressive pruning 18 | pruner = tp.pruner.MagnitudePruner( 19 | model, 20 | example_inputs, 21 | importance=imp, 22 | iterative_steps=iterative_steps, 23 | global_pruning=True, 24 | pruning_ratio=0.1, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 25 | pruning_ratio_dict={model.layer1: 0.5, (model.layer2, model.layer3): 0.5}, 26 | ignored_layers=ignored_layers, 27 | ) 28 | print(model) 29 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 30 | for i in range(iterative_steps): 31 | if isinstance(imp, tp.importance.TaylorImportance): 32 | # loss = F.cross_entropy(model(images), targets) 33 | loss = model(example_inputs).sum() # a dummy loss for TaylorImportance 34 | loss.backward() 35 | pruner.step() 36 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 37 | # finetune your model here 38 | # finetune(model) 39 | # ... 40 | print(model) 41 | 42 | 43 | if __name__=="__main__": 44 | test_taylor() -------------------------------------------------------------------------------- /tests/test_unused_parameters.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torch_pruning as tp 10 | from typing import Sequence 11 | 12 | ############ 13 | # Customize your layer 14 | # 15 | class CustomizedLayer(nn.Module): 16 | def __init__(self, in_dim): 17 | super().__init__() 18 | self.in_dim = in_dim 19 | self.scale = nn.Parameter(torch.Tensor(self.in_dim)) 20 | self.bias = nn.Parameter(torch.Tensor(self.in_dim)) 21 | self.fc = nn.Linear(self.in_dim, self.in_dim) 22 | 23 | def forward(self, x): 24 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() 25 | x = torch.div(x, norm) 26 | return self.fc(x * self.scale + self.bias) 27 | 28 | def __repr__(self): 29 | return "CustomizedLayer(in_dim=%d)"%(self.in_dim) 30 | 31 | class FullyConnectedNet(nn.Module): 32 | """https://github.com/VainF/Torch-Pruning/issues/21""" 33 | def __init__(self, input_size, num_classes, HIDDEN_UNITS): 34 | super().__init__() 35 | self.fc1 = nn.Linear(input_size, HIDDEN_UNITS) 36 | self.customized_layer = CustomizedLayer(HIDDEN_UNITS) 37 | self.fc2 = nn.Linear(HIDDEN_UNITS, num_classes) 38 | self.unused = nn.Parameter(torch.Tensor(10)) 39 | 40 | def forward(self, x): 41 | x = F.relu(self.fc1(x)) 42 | x = self.customized_layer(x) 43 | y_hat = self.fc2(x) 44 | return y_hat 45 | 46 | ############################ 47 | # Implement your pruning function for the customized layer 48 | # You should implement the following class fucntions: 49 | # 1. prune_out_channels 50 | # 2. prune_in_channels 51 | # 3. get_out_channels 52 | # 4. get_in_channels 53 | 54 | class MyPruner(tp.pruner.BasePruningFunc): 55 | 56 | def prune_out_channels(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn.Module: 57 | keep_idxs = list(set(range(layer.in_dim)) - set(idxs)) 58 | keep_idxs.sort() 59 | layer.in_dim = layer.in_dim-len(idxs) 60 | layer.scale = self._prune_parameter_and_grad(layer.scale, keep_idxs, pruning_dim=0) 61 | layer.bias = self._prune_parameter_and_grad(layer.bias, keep_idxs, pruning_dim=0) 62 | tp.prune_linear_in_channels(layer.fc, idxs) 63 | tp.prune_linear_out_channels(layer.fc, idxs) 64 | return layer 65 | 66 | def get_out_channels(self, layer): 67 | return self.in_dim 68 | 69 | # identical functions 70 | prune_in_channels = prune_out_channels 71 | get_in_channels = get_out_channels 72 | 73 | class MyLinearPruner(tp.function.LinearPruner): 74 | def prune_out_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: 75 | print("MyLinearPruner applied to layer: ", layer) 76 | return super().prune_out_channels(layer, idxs) 77 | 78 | def prune_in_channels(self, layer: nn.Linear, idxs: Sequence[int]) -> nn.Linear: 79 | print("MyLinearPruner applied to layer: ", layer) 80 | return super().prune_in_channels(layer, idxs) 81 | 82 | def test_customization(): 83 | model = FullyConnectedNet(128, 10, 256) 84 | 85 | DG = tp.DependencyGraph() 86 | 87 | # 1. Register your customized layer 88 | my_pruner = MyPruner() 89 | DG.register_customized_layer( 90 | CustomizedLayer, 91 | my_pruner) 92 | 93 | my_linear_pruner = MyLinearPruner() 94 | DG.register_customized_layer( 95 | nn.Linear, my_linear_pruner 96 | ) 97 | 98 | # 2. Build dependency graph 99 | DG.build_dependency(model, example_inputs=torch.randn(1,128)) 100 | 101 | # 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters. 102 | pruning_group = DG.get_pruning_group( model.fc1, my_linear_pruner.prune_out_channels, idxs=[0, 1, 6] ) 103 | print(pruning_group) 104 | 105 | # 4. execute this group (prune the model) 106 | pruning_group.prune() 107 | print("The pruned model:\n", model) 108 | print("Output: ", model(torch.randn(1,128)).shape) 109 | 110 | assert model.fc1.out_features==253 and model.customized_layer.in_dim==253 and model.fc2.in_features==253 111 | 112 | if __name__=='__main__': 113 | test_customization() -------------------------------------------------------------------------------- /tests/test_unwrapped_parameters.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 4 | 5 | import torch 6 | from torchvision.models import convnext_base as entry 7 | import torch_pruning as tp 8 | 9 | model = entry() 10 | print(model) 11 | # Global metrics 12 | example_inputs = torch.randn(1, 3, 224, 224) 13 | imp = tp.importance.MagnitudeImportance(p=2) 14 | ignored_layers = [] 15 | 16 | # DO NOT prune the final classifier! 17 | for m in model.modules(): 18 | if isinstance(m, torch.nn.Linear) and m.out_features == 1000: 19 | ignored_layers.append(m) 20 | 21 | from torchvision.models.convnext import CNBlock, ConvNeXt 22 | unwrapped_parameters = [] 23 | for m in model.modules(): 24 | if isinstance(m, CNBlock): 25 | unwrapped_parameters.append( (m.layer_scale, 0) ) 26 | 27 | iterative_steps = 5 28 | pruner = tp.pruner.MagnitudePruner( 29 | model, 30 | example_inputs, 31 | importance=imp, 32 | iterative_steps=iterative_steps, 33 | pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} 34 | ignored_layers=ignored_layers, 35 | unwrapped_parameters=unwrapped_parameters 36 | ) 37 | 38 | base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) 39 | for i in range(iterative_steps): 40 | pruner.step() 41 | macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) 42 | print(model) 43 | print(model(example_inputs).shape) 44 | print( 45 | " Iter %d/%d, Params: %.2f M => %.2f M" 46 | % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) 47 | ) 48 | print( 49 | " Iter %d/%d, MACs: %.2f G => %.2f G" 50 | % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) 51 | ) 52 | # finetune your model here 53 | # finetune(model) 54 | # ... 55 | 56 | -------------------------------------------------------------------------------- /torch_pruning/__init__.py: -------------------------------------------------------------------------------- 1 | from .pruner import importance 2 | from .dependency import * 3 | from .pruner import * 4 | from . import _helpers, utils 5 | 6 | from .serialization import save, load, state_dict, load_state_dict -------------------------------------------------------------------------------- /torch_pruning/pruner/__init__.py: -------------------------------------------------------------------------------- 1 | from .function import * 2 | from .algorithms import * 3 | from . import importance -------------------------------------------------------------------------------- /torch_pruning/pruner/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_pruner import BasePruner 2 | 3 | # Regularization-based pruner 4 | from .batchnorm_scale_pruner import BNScalePruner 5 | from .group_norm_pruner import GroupNormPruner 6 | from .growing_reg_pruner import GrowingRegPruner 7 | 8 | # deprecated 9 | from .compatibility import MetaPruner, MagnitudePruner -------------------------------------------------------------------------------- /torch_pruning/pruner/algorithms/compatibility.py: -------------------------------------------------------------------------------- 1 | from .base_pruner import BasePruner 2 | 3 | MetaPruner=BasePruner # deprecated, for compatibility 4 | MagnitudePruner=BasePruner # deprecated, for compatibility -------------------------------------------------------------------------------- /torch_pruning/pruner/algorithms/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | def linear_scheduler(pruning_ratio: float, steps: int) -> List[float]: 4 | return [((i) / float(steps)) * pruning_ratio for i in range(steps + 1)] 5 | -------------------------------------------------------------------------------- /torch_pruning/serialization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.serialization import DEFAULT_PROTOCOL 3 | import pickle 4 | 5 | load = torch.load 6 | save = torch.save 7 | 8 | def state_dict(model: torch.nn.Module): 9 | """ Returns a dictionary containing the state, attributions of a module. 10 | """ 11 | full_state_dict = {} 12 | attributions = {} 13 | for name, module in model.named_modules(): 14 | # state dicts 15 | full_state_dict[name] = module.__dict__.copy() 16 | module_attr = {} 17 | 18 | # attributes 19 | for attr_name in dir(module): 20 | attr_value = getattr(module, attr_name) 21 | if attr_name=='T_destination': 22 | continue 23 | if not callable(attr_value) and (not attr_name.startswith('__')) and (not attr_name.startswith('_')): 24 | if not isinstance(attr_value, torch.nn.Parameter) and not isinstance(attr_value, torch.Tensor): 25 | module_attr[attr_name] = attr_value 26 | attributions[name] = module_attr 27 | return {'full_state_dict': full_state_dict, 'attributions': attributions} 28 | 29 | def load_state_dict(model: torch.nn.Module, state_dict: dict): 30 | """ Load a model given a state_dict. 31 | """ 32 | 33 | full_state_dict = state_dict['full_state_dict'] 34 | attributions = state_dict['attributions'] 35 | for name, module in model.named_modules(): 36 | # load state dicts 37 | if name in full_state_dict: 38 | module.__dict__.update(full_state_dict[name]) 39 | # load attributes 40 | if name in attributions: 41 | for attr_name, attr_value in attributions[name].items(): 42 | setattr(module, attr_name, attr_value) 43 | return model 44 | -------------------------------------------------------------------------------- /torch_pruning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .op_counter import count_ops_and_params 3 | from . import benchmark -------------------------------------------------------------------------------- /torch_pruning/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Latency 4 | def measure_latency(model, example_inputs, repeat=300, warmup=50, run_fn=None): 5 | model.eval() 6 | latency = [] 7 | for _ in range(warmup): 8 | if run_fn is not None: 9 | _ = run_fn(model, example_inputs) 10 | else: 11 | _ = model(example_inputs) 12 | 13 | for i in range(repeat): 14 | start = torch.cuda.Event(enable_timing=True) 15 | end = torch.cuda.Event(enable_timing=True) 16 | start.record() 17 | if run_fn is not None: 18 | _ = run_fn(model, example_inputs) 19 | else: 20 | _ = model(example_inputs) 21 | end.record() 22 | torch.cuda.synchronize() 23 | latency.append(start.elapsed_time(end)) 24 | 25 | latency = torch.tensor(latency) 26 | return latency.mean().item(), latency.std().item() 27 | 28 | # Memory Consumption 29 | def measure_memory(model, example_inputs, device=None, run_fn=None): 30 | """ https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management 31 | """ 32 | torch.cuda.reset_peak_memory_stats() 33 | model.eval() 34 | if run_fn is not None: 35 | _ = run_fn(model, example_inputs) 36 | else: 37 | _ = model(example_inputs) 38 | return torch.cuda.max_memory_allocated(device=device) 39 | 40 | # Frame (Batch) per Second 41 | def measure_fps(model, example_inputs, repeat=300, warmup=50, run_fn=None): 42 | latency_mu, latency_std = measure_latency(model, example_inputs, repeat=repeat, warmup=warmup, run_fn=run_fn) 43 | fps = 1000.0 / latency_mu # 1000 ms = 1 s 44 | return fps 45 | 46 | # Throughput 47 | def measure_throughput(model, example_inputs, repeat=300, warmup=50, run_fn=None): 48 | latency_mu, latency_std = measure_latency(model, example_inputs, repeat=repeat, warmup=warmup, run_fn=run_fn) 49 | throughput = example_inputs.shape[0] / (latency_mu/1000) 50 | return throughput -------------------------------------------------------------------------------- /torch_pruning/utils/compute_mat_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def _extract_patches(x, kernel_size, stride, padding): 6 | """ 7 | :param x: The input feature maps. (batch_size, in_c, h, w) 8 | :param kernel_size: the kernel size of the conv filter (tuple of two elements) 9 | :param stride: the stride of conv operation (tuple of two elements) 10 | :param padding: number of paddings. be a tuple of two elements 11 | :return: (batch_size, out_h, out_w, in_c*kh*kw) 12 | """ 13 | if padding[0] + padding[1] > 0: 14 | x = F.pad(x, (padding[1], padding[1], padding[0], 15 | padding[0])).data # Actually check dims 16 | x = x.unfold(2, kernel_size[0], stride[0]) 17 | x = x.unfold(3, kernel_size[1], stride[1]) 18 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 19 | x = x.view( 20 | x.size(0), x.size(1), x.size(2), 21 | x.size(3) * x.size(4) * x.size(5)) 22 | return x 23 | 24 | 25 | 26 | def try_contiguous(x): 27 | if not x.is_contiguous(): 28 | x = x.contiguous() 29 | 30 | return x 31 | 32 | 33 | class ComputeMatGrad: 34 | 35 | @classmethod 36 | def __call__(cls, input, grad_output, layer): 37 | if isinstance(layer, nn.Linear): 38 | grad = cls.linear(input, grad_output, layer) 39 | elif isinstance(layer, nn.Conv2d): 40 | grad = cls.conv2d(input, grad_output, layer) 41 | else: 42 | raise NotImplementedError 43 | return grad 44 | 45 | @staticmethod 46 | def linear(input, grad_output, layer): 47 | """ 48 | :param input: batch_size * input_dim 49 | :param grad_output: batch_size * output_dim 50 | :param layer: [nn.module] output_dim * input_dim 51 | :return: batch_size * output_dim * (input_dim + [1 if with bias]) 52 | """ 53 | with torch.no_grad(): 54 | if layer.bias is not None: 55 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 56 | input = input.unsqueeze(1) 57 | grad_output = grad_output.unsqueeze(2) 58 | grad = torch.bmm(grad_output, input) 59 | return grad 60 | 61 | @staticmethod 62 | def conv2d(input, grad_output, layer): 63 | """ 64 | :param input: batch_size * in_c * in_h * in_w 65 | :param grad_output: batch_size * out_c * h * w 66 | :param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias]) 67 | :return: 68 | """ 69 | with torch.no_grad(): 70 | input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding) 71 | input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw 72 | grad_output = grad_output.transpose(1, 2).transpose(2, 3) 73 | grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1)) 74 | # b * hw * out_c 75 | if layer.bias is not None: 76 | input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1) 77 | input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw 78 | grad = torch.einsum('abm,abn->amn', (grad_output, input)) 79 | return grad --------------------------------------------------------------------------------