├── .gitignore ├── LICENSE ├── README.md ├── config ├── llama2-7b-wikitext.yaml ├── llama3.1-8b-gsm8k.yaml ├── llama3.1-8b-mmlu.yaml ├── llama3.1-8b-wikitext.yaml ├── llama3.2-3b-gsm8k.yaml ├── llama3.2-3b-mmlu.yaml ├── llama3.2-3b-wikitext.yaml ├── resnet18-imagenet1k.yaml ├── resnet50-imagenet1k.yaml ├── swin-imagenet1k.yaml └── vit-imagenet1k.yaml ├── csrc ├── README.md ├── build_cutlass.sh ├── environment.sh ├── kernel │ ├── bcmm.cu │ ├── bindings.cpp │ ├── bmm.cu │ ├── bmw.cu │ ├── include │ │ ├── bcmm.h │ │ ├── bmm.h │ │ ├── bmw.h │ │ ├── common.h │ │ └── qbmw.h │ ├── qbmw.cu │ └── scaler.cu ├── pyproject.toml ├── setup.py └── t2c_gemm.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── deprecated ├── execute │ ├── bert │ │ ├── mrpc.py │ │ ├── mrpc_t2c.py │ │ ├── sst2.py │ │ └── sst2_t2c.py │ └── imagenet │ │ ├── main.py │ │ ├── ptq.py │ │ ├── reload.py │ │ ├── t2c.py │ │ ├── vit.py │ │ └── vit_prune.py └── scripts │ ├── bert │ ├── mrpc.sh │ ├── mrpc_t2c.sh │ ├── sst2.sh │ └── sst2_t2c.sh │ └── imagenet │ ├── pretrain.sh │ ├── ptq.sh │ ├── resnet50_ptq_lsq_adaround.sh │ ├── resnet50_ptq_lsq_minmax_channel.sh │ ├── resnet50_ptq_minmax_minmax_channel.sh │ ├── resnet50_ptq_qdrop_adaround.sh │ ├── resnet50_ptq_qdrop_minmax_channel.sh │ ├── resnet50_t2c_lsq_adaround.sh │ ├── resnet50_t2c_lsq_minmax_channel.sh │ ├── resnet50_t2c_qdrop_adaround.sh │ ├── resnet50_t2c_qdrop_minmax_channel.sh │ ├── resnet50_t2c_reload_lsq_adaround.sh │ ├── resnet50_t2c_reload_lsq_minmax_channel.sh │ ├── resnet50_t2c_reload_qdrop_adaround.sh │ ├── resnet50_t2c_reload_qdrop_minmax_channel.sh │ ├── swin-t2c-smoothquant.sh │ ├── swin-vit-ptq-minmax.sh │ ├── swin-vit-ptq-smoothquant.sh │ ├── swin-vit-t2c-minmax.sh │ ├── swin-vit-t2c-reload-minmax.sh │ ├── t2c.sh │ ├── t2c_cnn.sh │ ├── vit-ptq-adaround-lsq.sh │ ├── vit-ptq-adaround-qdrop.sh │ ├── vit-ptq-minmax-lsq.sh │ ├── vit-ptq-minmax.sh │ ├── vit-ptq-smoothquant.sh │ ├── vit-t2c-adaround-lsq.sh │ ├── vit-t2c-adaround-qdrop.sh │ ├── vit-t2c-minmax-lsq.sh │ ├── vit-t2c-minmax-qdrop.sh │ ├── vit-t2c-minmax.sh │ ├── vit-t2c-reload-adaround-lsq.sh │ ├── vit-t2c-reload-adaround-qdrop.sh │ ├── vit-t2c-reload-minmax-lsq.sh │ ├── vit-t2c-reload-minmax-qdrop.sh │ ├── vit-t2c-reload-minmax.sh │ ├── vit-t2c-reload-smoothquant.sh │ └── vit-t2c-smoothquant.sh ├── figs ├── DualPathDesign.png ├── Figure1.png ├── icon.png └── torch2chip_workflow.png ├── llm ├── gsm8k.py ├── mmlu.py └── wikitext.py ├── lm-dataset └── gsm8k │ ├── gsm8k_test.jsonl │ └── train.jsonl ├── prune └── resnet.py ├── requirements.txt ├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── base.py │ ├── data_utils.py │ ├── llm │ │ ├── __init__.py │ │ ├── hf.py │ │ └── math │ │ │ ├── __init__.py │ │ │ └── grader.py │ └── vision │ │ └── imagenet.py ├── hardware │ ├── systolic_array.py │ └── writer.py ├── models │ ├── __init__.py │ ├── auto_map.py │ ├── cifar │ │ ├── mobilenetv1.py │ │ ├── resnet.py │ │ └── vit.py │ ├── imagenet │ │ └── mobilenetv1.py │ └── lm │ │ ├── __init__.py │ │ ├── configuration_retnet.py │ │ └── retnet.py ├── module │ ├── __init__.py │ ├── attention.py │ ├── base.py │ ├── fuse.py │ ├── mlp.py │ └── ops.py ├── profiler │ └── profiler.py ├── pruner │ ├── __init__.py │ ├── base.py │ ├── element.py │ └── nm.py ├── quantization │ ├── __init__.py │ ├── adaround.py │ ├── lsq.py │ ├── minmax.py │ ├── mxint.py │ ├── observer.py │ ├── qdrop.py │ └── smoothquant.py ├── stage │ ├── base.py │ ├── calib.py │ └── hf.py ├── t2c │ ├── __init__.py │ ├── convert.py │ ├── fusers │ │ ├── bert.py │ │ ├── fusers.py │ │ ├── lm.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ ├── vgg.py │ │ └── vit.py │ └── t2c.py ├── trainer │ ├── __init__.py │ ├── base.py │ ├── ddp.py │ ├── llm │ │ ├── __init__.py │ │ ├── evaluator.py │ │ ├── metrics.py │ │ ├── ptq.py │ │ └── utils.py │ ├── loss.py │ ├── pruning.py │ ├── scheduler.py │ └── vision │ │ ├── __init__.py │ │ ├── ptq.py │ │ └── smoothquant.py └── utils │ ├── __init__.py │ ├── ddp.py │ ├── get_data.py │ └── utils.py └── vision ├── resnet.py ├── swin.py └── vit.py /.gitignore: -------------------------------------------------------------------------------- 1 | save/ 2 | __pycache__/ 3 | *.csv 4 | *.pt 5 | *.zip 6 | *.npy 7 | *.sub 8 | wandb/ 9 | tmp/ 10 | *.err 11 | *.out 12 | *.tar.gz 13 | .vscode/ 14 | build/ 15 | submodules/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jian Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/llama2-7b-wikitext.yaml: -------------------------------------------------------------------------------- 1 | seed: 5000 2 | 3 | model: 4 | model_type: "meta-llama/Llama-2-7b-hf" 5 | 6 | dataset: 7 | name: "wikitext" 8 | path: "wikitext-2-raw-v1" 9 | split: "test" 10 | 11 | save: 12 | run_dir: "save/Llama-2-7b-hf-wikitext/" 13 | logger: "inference.log" 14 | 15 | eval: 16 | chunk_size: 2048 17 | n_samples: 16 18 | 19 | quantization: 20 | wbit: 8 21 | abit: 8 22 | num_samples: 2048 23 | wqtype: smooth_quant 24 | xqtype: smooth_quant_token 25 | rescale_out: True 26 | 27 | train: 28 | batch_size: 2048 29 | 30 | smooth: 31 | alpha: 0.85 32 | 33 | t2c: 34 | swl: 32 35 | sfl: 26 36 | 37 | export: 38 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.1-8b-gsm8k.yaml: -------------------------------------------------------------------------------- 1 | seed: 5000 2 | 3 | model: 4 | model_type: "meta-llama/Llama-3.1-8B-Instruct" 5 | 6 | dataset: 7 | name: "gsm8k" 8 | train: "/scratch/dataset/gsm8k/train.jsonl" 9 | test: "/scratch/dataset/gsm8k/gsm8k_test.jsonl" 10 | split: "test" 11 | 12 | save: 13 | run_dir: "save/llama3.1-8B-gsm8k/" 14 | logger: "inference.log" 15 | 16 | eval: 17 | cot: True 18 | max_gen_toks: 512 19 | 20 | quantization: 21 | wbit: 8 22 | abit: 8 23 | num_samples: 2048 24 | wqtype: smooth_quant 25 | xqtype: smooth_quant_token 26 | rescale_out: True 27 | 28 | train: 29 | batch_size: 512 30 | 31 | smooth: 32 | alpha: 0.85 33 | 34 | t2c: 35 | swl: 32 36 | sfl: 26 37 | 38 | export: 39 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.1-8b-mmlu.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "meta-llama/Llama-3.1-8B-Instruct" 3 | 4 | dataset: 5 | name: "mmlu" 6 | train: "lm-dataset/mmlu/data/" 7 | test: "lm-dataset/mmlu/data/test/" 8 | cot: "lm-dataset/mmlu/data/dev/" 9 | nshot: 5 10 | split: "test" 11 | 12 | save: 13 | run_dir: "save/llama3.1-8B_mmlu/" 14 | logger: "inference.log" 15 | 16 | eval: 17 | cot: True 18 | max_gen_toks: 256 19 | 20 | quantization: 21 | wbit: 8 22 | abit: 8 23 | num_samples: 512 24 | wqtype: smooth_quant 25 | xqtype: smooth_quant_token 26 | rescale_out: True 27 | 28 | train: 29 | batch_size: 512 30 | 31 | smooth: 32 | alpha: 0.85 33 | 34 | t2c: 35 | swl: 32 36 | sfl: 26 37 | 38 | export: 39 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.1-8b-wikitext.yaml: -------------------------------------------------------------------------------- 1 | seed: 5000 2 | 3 | model: 4 | model_type: "meta-llama/Llama-3.1-8B-Instruct" 5 | 6 | dataset: 7 | name: "wikitext" 8 | path: "wikitext-2-raw-v1" 9 | split: "test" 10 | 11 | save: 12 | run_dir: "save/Llama-3.1-8b-wikitext/" 13 | logger: "inference.log" 14 | 15 | eval: 16 | chunk_size: 2048 17 | n_samples: 16 18 | 19 | quantization: 20 | wbit: 8 21 | abit: 8 22 | num_samples: 2048 23 | wqtype: smooth_quant 24 | xqtype: smooth_quant_token 25 | rescale_out: True 26 | 27 | train: 28 | batch_size: 2048 29 | 30 | smooth: 31 | alpha: 0.85 32 | 33 | t2c: 34 | swl: 32 35 | sfl: 26 36 | 37 | export: 38 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.2-3b-gsm8k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "meta-llama/Llama-3.2-3B-Instruct" 3 | 4 | dataset: 5 | name: "gsm8k" 6 | train: "lm-dataset/gsm8k/train.jsonl" 7 | test: "lm-dataset/gsm8k/gsm8k_test.jsonl" 8 | split: "test" 9 | 10 | save: 11 | run_dir: "save/llama3.2-3B_gsm8k/" 12 | logger: "inference.log" 13 | 14 | eval: 15 | cot: True 16 | max_gen_toks: 512 17 | 18 | quantization: 19 | wbit: 8 20 | abit: 8 21 | num_samples: 512 22 | wqtype: smooth_quant 23 | xqtype: smooth_quant_token 24 | rescale_out: True 25 | 26 | train: 27 | batch_size: 512 28 | 29 | smooth: 30 | alpha: 0.85 31 | 32 | t2c: 33 | swl: 32 34 | sfl: 26 35 | 36 | export: 37 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.2-3b-mmlu.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "meta-llama/Llama-3.2-3B-Instruct" 3 | 4 | dataset: 5 | name: "mmlu" 6 | train: "lm-dataset/mmlu/data/" 7 | test: "lm-dataset/mmlu/data/test/" 8 | cot: "lm-dataset/mmlu/data/dev/" 9 | nshot: 5 10 | split: "test" 11 | 12 | save: 13 | run_dir: "save/llama3.2-3B_mmlu/" 14 | logger: "inference.log" 15 | 16 | eval: 17 | cot: True 18 | max_gen_toks: 256 19 | 20 | quantization: 21 | wbit: 8 22 | abit: 8 23 | num_samples: 512 24 | wqtype: smooth_quant 25 | xqtype: smooth_quant_token 26 | rescale_out: True 27 | 28 | train: 29 | batch_size: 512 30 | 31 | smooth: 32 | alpha: 0.85 33 | 34 | t2c: 35 | swl: 32 36 | sfl: 26 37 | 38 | export: 39 | export_samples: 0 -------------------------------------------------------------------------------- /config/llama3.2-3b-wikitext.yaml: -------------------------------------------------------------------------------- 1 | seed: 5000 2 | 3 | model: 4 | model_type: "meta-llama/Llama-3.2-3B-Instruct" 5 | 6 | dataset: 7 | name: "wikitext" 8 | path: "wikitext-2-raw-v1" 9 | split: "test" 10 | 11 | save: 12 | run_dir: "save/Llama-3.2-3b-wikitext/" 13 | logger: "inference.log" 14 | 15 | eval: 16 | chunk_size: 2048 17 | n_samples: 16 18 | 19 | quantization: 20 | wbit: 8 21 | abit: 8 22 | num_samples: 2048 23 | wqtype: smooth_quant 24 | xqtype: smooth_quant_token 25 | rescale_out: True 26 | 27 | train: 28 | batch_size: 2048 29 | 30 | smooth: 31 | alpha: 0.85 32 | 33 | t2c: 34 | swl: 32 35 | sfl: 26 36 | 37 | export: 38 | export_samples: 0 -------------------------------------------------------------------------------- /config/resnet18-imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "resnet18" 3 | 4 | dataset: 5 | name: "ImageNet-1K" 6 | path: "/share/seo/imagenet/" 7 | split: "train" 8 | train_dir: "/share/seo/imagenet/train/" 9 | test_dir: "/share/seo/imagenet/val/" 10 | samples: 512 11 | num_workers: 16 12 | num_classes: 1000 13 | 14 | save: 15 | run_dir: "save/imagenet1K/resnet18/w8a8/" 16 | logger: "training.log" 17 | 18 | quantization: 19 | wbit: 8 20 | abit: 8 21 | wqtype: adaround 22 | xqtype: lsq 23 | requires_grad: True 24 | 25 | train: 26 | lr: 0.01 27 | momentum: 0.9 28 | epochs: 1 29 | weight_decay: 0.0001 30 | batch_size: 128 31 | loss_type: "mse" 32 | optim_type: sgd 33 | lr_sch: "step" 34 | mix_prec: True 35 | smoothing: 0.1 36 | schedule: [30, 60, 90] 37 | 38 | prune: 39 | drate: 0.5 40 | prune_ratio: 0.85 41 | type: "element" 42 | warmup: 1 43 | final_epoch: 80 44 | prune_freq: 1000 45 | 46 | t2c: 47 | swl: 32 48 | sfl: 26 -------------------------------------------------------------------------------- /config/resnet50-imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "resnet50" 3 | 4 | dataset: 5 | name: "ImageNet-1K" 6 | path: "/share/seo/imagenet/" 7 | split: "train" 8 | train_dir: "/share/seo/imagenet/train/" 9 | test_dir: "/share/seo/imagenet/val/" 10 | samples: 512 11 | num_workers: 16 12 | num_classes: 1000 13 | 14 | save: 15 | run_dir: "save/imagenet1K/resnet50/sparse/w8a8/" 16 | logger: "training.log" 17 | 18 | quantization: 19 | wbit: 8 20 | abit: 8 21 | wqtype: adaround 22 | xqtype: lsq 23 | requires_grad: True 24 | 25 | train: 26 | lr: 0.01 27 | momentum: 0.9 28 | epochs: 1 29 | weight_decay: 0.0001 30 | batch_size: 128 31 | loss_type: "mse" 32 | optim_type: sgd 33 | lr_sch: "step" 34 | mix_prec: True 35 | smoothing: 0.1 36 | schedule: [30, 60, 90] 37 | 38 | prune: 39 | drate: 0.5 40 | prune_ratio: 0.85 41 | type: "element" 42 | warmup: 1 43 | final_epoch: 80 44 | prune_freq: 1000 45 | 46 | t2c: 47 | swl: 32 48 | sfl: 26 -------------------------------------------------------------------------------- /config/swin-imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "swin_tiny_patch4_window7_224" 3 | 4 | dataset: 5 | name: "ImageNet-1K" 6 | path: "/scratch/dataset/imagenet-1k/" 7 | split: "train" 8 | train_dir: "/scratch/dataset/imagenet-1k/train/" 9 | test_dir: "/scratch/dataset/imagenet-1k/val/" 10 | samples: 500 11 | num_workers: 16 12 | mean: [0.5, 0.5, 0.5] 13 | std: [0.5, 0.5, 0.5] 14 | 15 | save: 16 | run_dir: "save/imagenet1K/swin_tiny_patch4_window7_224/lsq_adaround/w8a8/" 17 | logger: "inference.log" 18 | 19 | quantization: 20 | wbit: 8 21 | abit: 8 22 | wqtype: adaround 23 | xqtype: lsq_token 24 | requires_grad: True 25 | 26 | train: 27 | lr: 0.0001 28 | epochs: 1 29 | weight_decay: 1e-4 30 | batch_size: 128 31 | loss_type: "mse" 32 | optim_type: adam 33 | 34 | t2c: 35 | swl: 32 36 | sfl: 26 -------------------------------------------------------------------------------- /config/vit-imagenet1k.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_type: "vit_small_patch16_224" 3 | 4 | dataset: 5 | name: "ImageNet-1K" 6 | path: "/scratch/dataset/imagenet-1k/" 7 | split: "train" 8 | train_dir: "/scratch/dataset/imagenet-1k/train/" 9 | test_dir: "/scratch/dataset/imagenet-1k/val/" 10 | samples: 500 11 | num_workers: 16 12 | mean: [0.5, 0.5, 0.5] 13 | std: [0.5, 0.5, 0.5] 14 | 15 | save: 16 | run_dir: "save/imagenet1K/vit_small/lsq_adaround/w8a8/" 17 | logger: "inference.log" 18 | 19 | quantization: 20 | wbit: 8 21 | abit: 8 22 | wqtype: adaround 23 | xqtype: lsq 24 | requires_grad: True 25 | 26 | train: 27 | lr: 0.0001 28 | epochs: 1 29 | weight_decay: 1e-4 30 | batch_size: 128 31 | loss_type: "mse" 32 | optim_type: adam 33 | 34 | t2c: 35 | swl: 32 36 | sfl: 26 -------------------------------------------------------------------------------- /csrc/README.md: -------------------------------------------------------------------------------- 1 | # Customized MatMul Kernel of Torch2Chip 2 | 3 | The customized cuda kernel is built based on CUTLASS, with the support of multi-dimensional matrix multiplication. 4 | 5 | ### Installation 6 | 7 | Make sure you have NVCC 12.1 or 12.4 installed. Start from the home of `torch2chip`. 8 | 9 | ``` 10 | cd csrc 11 | source environment.sh 12 | bash build_cutlass.sh 13 | pip install . 14 | ``` 15 | The dedicated kernels will be installed as the `t2c_gemm` package to your anaconda environment. 16 | 17 | ### Usage 18 | 19 | > Version = 0.1.0 20 | 21 | Support INT8 matrix multiplication between 3-D and 2-D tensors (`_QBaseLinear`), 4-D x 4-D tensors (`BatchHeadIntMatMul`). 22 | 23 | 24 | **[Warning]**: Current MatMul kernel doesn't support the MXINT quantizer due to the fine-grained group-wise shifting. We will update the dedicated CUDA kernel implementation soon, stay tunned. -------------------------------------------------------------------------------- /csrc/build_cutlass.sh: -------------------------------------------------------------------------------- 1 | export CUDACXX=/usr/local/cuda-12.4/bin/nvcc 2 | export CC=/usr/bin/gcc 3 | export CXX=/usr/bin/g++ 4 | cd submodules 5 | git clone https://github.com/NVIDIA/cutlass.git 6 | 7 | cd cutlass 8 | rm -rf build 9 | mkdir -p build && cd build 10 | cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON 11 | make -j 16 -------------------------------------------------------------------------------- /csrc/environment.sh: -------------------------------------------------------------------------------- 1 | export T2C_CUDA_ROOT=$PWD 2 | export CUTLASS_PATH="$T2C_CUDA_ROOT/submodules/cutlass" 3 | export CUDA_PATH="/usr/local/cuda-12.4/" 4 | export PATH="$CUDA_PATH/bin:$PATH" 5 | 6 | # CUDA 7 | export CPATH="$CUDA_PATH/include:$CPATH" 8 | export C_INCLUDE_PATH="$CUDA_PATH/include:$C_INCLUDE_PATH" 9 | export CPLUS_INCLUDE_PATH="$CUDA_PATH/include:$CPLUS_INCLUDE_PATH" 10 | export LD_LIBRARY_PATH="$CUDA_PATH/lib64:$LD_LIBRARY_PATH" 11 | 12 | # CUTLASS 13 | export CPATH=$CUTLASS_PATH/tools/util/include:$CUTLASS_PATH/include:$CPATH 14 | export C_INCLUDE_PATH=$CUTLASS_PATH/tools/util/include:$CUTLASS_PATH/include:$C_INCLUDE_PATH 15 | export CPLUS_INCLUDE_PATH=$CUTLASS_PATH/tools/util/include:$CUTLASS_PATH/include:$CPLUS_INCLUDE_PATH -------------------------------------------------------------------------------- /csrc/kernel/bcmm.cu: -------------------------------------------------------------------------------- 1 | #include "include/bcmm.h" 2 | #include "include/common.h" 3 | #include "cutlass/core_io.h" 4 | #include "cutlass/gemm/device/gemm.h" 5 | #include "cutlass/gemm/device/gemm_batched.h" 6 | #include "cutlass/numeric_types.h" 7 | #include "cutlass/util/host_tensor.h" 8 | 9 | // the returned data type = torch::Tensor 10 | torch::Tensor bcmm_int8(torch::Tensor A, torch::Tensor B, float alpha) { 11 | int batch_size = A.size(0); 12 | int H = A.size(1); 13 | int M = A.size(2); 14 | int N = B.size(2); 15 | int K = A.size(3); 16 | 17 | // pad the height and width of the input matrices 18 | int int_pad_A = (M + 15) / 16 * 16; 19 | int int_pad_B = (N + 15) / 16 * 16; 20 | int int_pad_K = (K + 15) / 16 * 16; 21 | 22 | torch::Tensor A_pad = A; 23 | torch::Tensor B_pad = B; 24 | 25 | if (int_pad_A > M) { 26 | A_pad = torch::cat({A, torch::zeros({batch_size, H, int_pad_A - M, K}, A.options())}, 2); 27 | } 28 | 29 | if (int_pad_B > N) { 30 | B_pad = torch::cat({B, torch::zeros({batch_size, H, int_pad_B - N, K}, B.options())}, 2); 31 | } 32 | 33 | if (int_pad_K > K) { 34 | A_pad = torch::cat({A_pad, torch::zeros({batch_size, H, int_pad_A, int_pad_K - K}, A.options())}, 3); 35 | B_pad = torch::cat({B_pad, torch::zeros({batch_size, H, int_pad_B, int_pad_K - K}, B.options())}, 3); 36 | } 37 | 38 | // automatically define the data type of output tensor C 39 | auto C = torch::empty({batch_size, H, int_pad_A, int_pad_B}, torch::dtype(torch::kFloat32).device(A_pad.device())); 40 | 41 | int lda = A_pad.size(3); 42 | int ldb = B_pad.size(3); 43 | int ldc = C.size(3); 44 | 45 | // define the layout of memory storage for each matrix 46 | using LayoutA = cutlass::layout::RowMajor; 47 | using LayoutB = cutlass::layout::ColumnMajor; 48 | using LayoutC = cutlass::layout::RowMajor; 49 | 50 | // define data type and accumulation precision 51 | using FormatInputA = int8_t; 52 | using FormatInputB = int8_t; 53 | using FormatOutputC = float; 54 | using FormatAccumulator = int32_t; 55 | using ElementComputeEpilogue = float; 56 | 57 | // deivce-dependent definition 58 | #if CUDA_ARCH >= 800 59 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 60 | 61 | using Gemm = cutlass::gemm::device::GemmBatched< 62 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 63 | cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, EpilogueOp>; 64 | 65 | #elif CUDA_ARCH >= 750 66 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 67 | 68 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 69 | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 70 | FormatInputA, FormatInputB, FormatOutputC, FormatAccumulator>; 71 | 72 | using Gemm = cutlass::gemm::device::GemmBatched< 73 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 74 | cutlass::arch::Sm75, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 75 | 76 | #elif CUDA_ARCH >= 700 77 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination; 78 | 79 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 80 | cutlass::arch::OpClassSimt, cutlass::arch::Sm70 81 | FormatInputA, FormatInputB, FormatOutputC, FormatAccumulator>; 82 | 83 | using Gemm = cutlass::gemm::device::GemmBatched< 84 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassSimt, 85 | cutlass::arch::Sm70, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 86 | #else 87 | #error "Unsupported GPU type" 88 | #endif 89 | 90 | // stride between two matrices within each batch 91 | long long int batch_stride_A = int_pad_A * int_pad_K; 92 | long long int batch_stride_B = int_pad_B * int_pad_K; 93 | long long int batch_stride_C = int_pad_A * int_pad_B; 94 | int batch_count = H * batch_size; 95 | 96 | // Define the operation of Gemm 97 | Gemm gemm_op; 98 | // Argument of Gemm Op 99 | typename Gemm::Arguments arguments{ 100 | {int_pad_A, int_pad_B, int_pad_K}, {A_pad.data_ptr(), lda}, 101 | batch_stride_A, {B_pad.data_ptr(), ldb}, 102 | batch_stride_B, {C.data_ptr(), ldc}, 103 | batch_stride_C, {C.data_ptr(), ldc}, 104 | batch_stride_C, {alpha, 0}, 105 | batch_count}; 106 | 107 | // request extra space for GEMM operation 108 | size_t workspace_size = Gemm::get_workspace_size(arguments); 109 | 110 | // allocate workspace memory 111 | cutlass::device_memory::allocation workspace(workspace_size); 112 | 113 | // Check the problem size is supported or not 114 | cutlass::Status status = gemm_op.can_implement(arguments); 115 | if (status != cutlass::Status::kSuccess) { 116 | throw std::runtime_error("cutlass cannot implement"); 117 | } 118 | 119 | // Initialize CUTLASS kernel with arguments and workspace pointer 120 | status = gemm_op.initialize(arguments, workspace.get()); 121 | if (status != cutlass::Status::kSuccess) { 122 | throw std::runtime_error("cutlass cannot initialize"); 123 | } 124 | 125 | status = gemm_op(); 126 | if (status != cutlass::Status::kSuccess) { 127 | throw std::runtime_error("cutlass cannot run"); 128 | } 129 | 130 | if (int_pad_A > M){ 131 | return C.slice(2, 0, M).slice(3, 0, N); 132 | } else { 133 | return C; 134 | } 135 | 136 | } -------------------------------------------------------------------------------- /csrc/kernel/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "include/bmm.h" 2 | #include "include/bmw.h" 3 | #include "include/qbmw.h" 4 | #include "include/bcmm.h" 5 | #include 6 | #include 7 | #include 8 | 9 | // Binding the function to Python 10 | PYBIND11_MODULE(t2c_gemm, m) { 11 | // Function wrapper 12 | m.def("bmm_int8", &bmm_int8, 13 | pybind11::arg("A"), pybind11::arg("B"), pybind11::arg("alpha"), 14 | R"pbdoc( 15 | Batched matrix multiplication with int8 inputs. 16 | 17 | Args: 18 | A (torch.Tensor): Tensor of type int8 (3-D). 19 | B (torch.Tensor): Tensor of type int8 (3-D). 20 | alpha (float): Scalar value for scaling the result. 21 | 22 | Returns: 23 | torch.Tensor: Result of the batched matrix multiplication. 24 | )pbdoc"); 25 | 26 | m.def("bmw_int8", &bmw_int8, 27 | pybind11::arg("A"), pybind11::arg("W"), pybind11::arg("alpha"), 28 | R"pbdoc( 29 | Batched matrix multiplication with int8 inputs. 30 | 31 | Args: 32 | A (torch.Tensor): Input Activation Tensor of type int8 (3-D). 33 | W (torch.Tensor): Weight Tensor of type int8 (2-D). 34 | alpha (float): Scalar value for scaling the result. 35 | 36 | Returns: 37 | torch.Tensor: Result of the batched matrix multiplication. 38 | )pbdoc"); 39 | 40 | m.def("qbmw", &qbmw, 41 | pybind11::arg("A"), pybind11::arg("W"), 42 | R"pbdoc( 43 | Batched matrix multiplication with int8 inputs. 44 | 45 | Args: 46 | A (torch.Tensor): Input Activation Tensor of type int8 (3-D). 47 | W (torch.Tensor): Weight Tensor of type int8 (2-D). 48 | 49 | Returns: 50 | torch.Tensor: Result of the batched matrix multiplication. 51 | )pbdoc"); 52 | 53 | m.def("qbmw_scaled", &qbmw_scaled, 54 | pybind11::arg("A"), pybind11::arg("W"), pybind11::arg("scale"), 55 | R"pbdoc( 56 | Batched matrix multiplication with int8 inputs. 57 | 58 | Args: 59 | A (torch.Tensor): Input Activation Tensor of type int8 (3-D). 60 | W (torch.Tensor): Weight Tensor of type int8 (2-D). 61 | scale (torch.Tensor): Scaling factor. 62 | 63 | Returns: 64 | torch.Tensor: Result of the batched matrix multiplication. 65 | )pbdoc"); 66 | 67 | m.def("qbmw_quantized", &qbmw_quantized, 68 | pybind11::arg("A"), pybind11::arg("W"), pybind11::arg("scale"), 69 | R"pbdoc( 70 | Batched matrix multiplication with int8 inputs. 71 | 72 | Args: 73 | A (torch.Tensor): Input Activation Tensor of type int8 (3-D). 74 | W (torch.Tensor): Weight Tensor of type int8 (2-D). 75 | scale (torch.Tensor): Scaling factor. 76 | 77 | Returns: 78 | torch.Tensor: Result of the batched matrix multiplication. 79 | )pbdoc"); 80 | 81 | m.def("bcmm_int8", &bcmm_int8, 82 | pybind11::arg("A"), pybind11::arg("B"), pybind11::arg("alpha"), 83 | R"pbdoc( 84 | Batched matrix multiplication with int8 inputs. 85 | 86 | Args: 87 | A (torch.Tensor): Tensor of type int8 (4-D). 88 | B (torch.Tensor): Tensor of type int8 (4-D). 89 | alpha (float): Scalar value for scaling the result. 90 | 91 | Returns: 92 | torch.Tensor: Result of the batched matrix multiplication. 93 | )pbdoc"); 94 | } -------------------------------------------------------------------------------- /csrc/kernel/bmm.cu: -------------------------------------------------------------------------------- 1 | #include "include/bmm.h" 2 | #include "include/common.h" 3 | #include "cutlass/core_io.h" 4 | #include "cutlass/gemm/device/gemm.h" 5 | #include "cutlass/gemm/device/gemm_batched.h" 6 | #include "cutlass/numeric_types.h" 7 | #include "cutlass/util/host_tensor.h" 8 | 9 | // the return type of the function is torch Tensor 10 | torch::Tensor bmm_int8(torch::Tensor A, torch::Tensor B, float alpha) { 11 | int batch_size = A.size(0); 12 | int M = A.size(1); 13 | int N = B.size(1); 14 | int K = A.size(2); 15 | 16 | // pad the height and width of the input matrices 17 | int int_pad_A = (M + 7) / 8 * 8; 18 | int int_pad_B = (N + 7) / 8 * 8; 19 | 20 | torch::Tensor A_pad; 21 | torch::Tensor B_pad; 22 | 23 | if (int_pad_A > M) { 24 | A_pad = torch::cat({A, torch::zeros({batch_size, int_pad_A - M, K}, A.options())}, 1); 25 | } else { 26 | A_pad = A; 27 | } 28 | 29 | if (int_pad_B > N){ 30 | B_pad = torch::cat({B, torch::zeros({batch_size, int_pad_B - N, K}, B.options())}, 1); 31 | } else { 32 | B_pad = B; 33 | } 34 | 35 | // automatically define the data type of output tensor C 36 | auto C = torch::empty({batch_size, int_pad_A, int_pad_B}, torch::dtype(torch::kFloat32).device(A_pad.device())); 37 | 38 | // define the leading dimension of each matrix 39 | int lda = A_pad.size(2); 40 | int ldb = B_pad.size(2); 41 | int ldc = C.size(2); 42 | 43 | // define the layout of memory storage for each matrix 44 | using LayoutA = cutlass::layout::RowMajor; 45 | using LayoutB = cutlass::layout::ColumnMajor; 46 | using LayoutC = cutlass::layout::RowMajor; 47 | 48 | // define data type and accumulation precision 49 | using FormatInputA = int8_t; 50 | using FormatInputB = int8_t; 51 | using FormatOutputC = float; 52 | using FormatAccumulator = int32_t; 53 | using ElementComputeEpilogue = float; 54 | 55 | // deivce-dependent definition 56 | #if CUDA_ARCH >= 800 57 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 58 | 59 | using Gemm = cutlass::gemm::device::GemmBatched< 60 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 61 | cutlass::arch::Sm80, cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, EpilogueOp>; 62 | 63 | #elif CUDA_ARCH >= 750 64 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 65 | 66 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 67 | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 68 | FormatInputA, FormatInputB, FormatOutputC, FormatAccumulator>; 69 | 70 | using Gemm = cutlass::gemm::device::GemmBatched< 71 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 72 | cutlass::arch::Sm75, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 73 | 74 | #elif CUDA_ARCH >= 700 75 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination; 76 | 77 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 78 | cutlass::arch::OpClassSimt, cutlass::arch::Sm70 79 | FormatInputA, FormatInputB, FormatOutputC, FormatAccumulator>; 80 | 81 | using Gemm = cutlass::gemm::device::GemmBatched< 82 | FormatInputA, LayoutA, FormatInputB, LayoutB, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassSimt, 83 | cutlass::arch::Sm70, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 84 | #else 85 | #error "Unsupported GPU type" 86 | #endif 87 | 88 | // stride between two matrices within each batch 89 | long long int batch_stride_A = int_pad_A * K; 90 | long long int batch_stride_B = int_pad_B * K; 91 | long long int batch_stride_C = int_pad_A * int_pad_B; 92 | 93 | // Define the operation of Gemm 94 | Gemm gemm_op; 95 | // Argument of Gemm Op 96 | typename Gemm::Arguments arguments{ 97 | {int_pad_A, int_pad_B, K}, {A_pad.data_ptr(), lda}, 98 | batch_stride_A, {B_pad.data_ptr(), ldb}, 99 | batch_stride_B, {C.data_ptr(), ldc}, 100 | batch_stride_C, {C.data_ptr(), ldc}, 101 | batch_stride_C, {alpha, 0}, 102 | batch_size}; 103 | 104 | // request extra space for GEMM operation 105 | size_t workspace_size = Gemm::get_workspace_size(arguments); 106 | 107 | // allocate workspace memory 108 | cutlass::device_memory::allocation workspace(workspace_size); 109 | 110 | // Check the problem size is supported or not 111 | cutlass::Status status = gemm_op.can_implement(arguments); 112 | if (status != cutlass::Status::kSuccess) { 113 | throw std::runtime_error("cutlass cannot implement"); 114 | } 115 | 116 | // Initialize CUTLASS kernel with arguments and workspace pointer 117 | status = gemm_op.initialize(arguments, workspace.get()); 118 | if (status != cutlass::Status::kSuccess) { 119 | throw std::runtime_error("cutlass cannot initialize"); 120 | } 121 | 122 | status = gemm_op(); 123 | if (status != cutlass::Status::kSuccess) { 124 | throw std::runtime_error("cutlass cannot run"); 125 | } 126 | 127 | if (int_pad_A > M){ 128 | return C.slice(1, 0, M).slice(2, 0, N); 129 | } else { 130 | return C; 131 | } 132 | } 133 | 134 | -------------------------------------------------------------------------------- /csrc/kernel/bmw.cu: -------------------------------------------------------------------------------- 1 | #include "include/bmw.h" 2 | #include "include/common.h" 3 | #include "cutlass/core_io.h" 4 | #include "cutlass/gemm/device/gemm.h" 5 | #include "cutlass/gemm/device/gemm_batched.h" 6 | #include "cutlass/numeric_types.h" 7 | #include "cutlass/util/host_tensor.h" 8 | 9 | torch::Tensor bmw_int8(torch::Tensor A, torch::Tensor W, torch::Tensor alpha) { 10 | int batch_size = A.size(0); 11 | int M = A.size(1); 12 | int N = W.size(0); 13 | int K = A.size(2); 14 | 15 | // automatically define the data type of output tensor C 16 | auto C = torch::empty({batch_size, M, N}, torch::dtype(torch::kFloat32).device(A.device())); 17 | 18 | // define the leading dimension of each matrix 19 | int lda = A.size(2); 20 | int ldb = W.size(1); 21 | int ldc = C.size(2); 22 | 23 | // define the memory layout 24 | using LayoutA = cutlass::layout::RowMajor; 25 | using LayoutW = cutlass::layout::ColumnMajor; 26 | using LayoutC = cutlass::layout::RowMajor; 27 | 28 | // define data type and accumulation precision 29 | using FormatInputA = int8_t; 30 | using FormatInputW = int8_t; 31 | using FormatOutputC = float; 32 | using FormatAccumulator = int32_t; 33 | using ElementComputeEpilogue = float; 34 | 35 | // device-dependent architecture 36 | #if CUDA_ARCH >= 800 37 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 38 | 39 | using Gemm = cutlass::gemm::device::GemmBatched< 40 | FormatInputA, LayoutA, FormatInputW, LayoutW, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 41 | cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, EpilogueOp>; 42 | 43 | #elif CUDA_ARCH >= 750 44 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination::value, FormatAccumulator, ElementComputeEpilogue>; 45 | 46 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 47 | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 48 | FormatInputA, FormatInputW, FormatOutputC, FormatAccumulator>; 49 | 50 | using Gemm = cutlass::gemm::device::GemmBatched< 51 | FormatInputA, LayoutA, FormatInputW, LayoutW, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassTensorOp, 52 | cutlass::arch::Sm75, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 53 | 54 | #elif CUDA_ARCH >= 700 55 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination; 56 | 57 | using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< 58 | cutlass::arch::OpClassSimt, cutlass::arch::Sm70 59 | FormatInputA, FormatInputW, FormatOutputC, FormatAccumulator>; 60 | 61 | using Gemm = cutlass::gemm::device::GemmBatched< 62 | FormatInputA, LayoutA, FormatInputW, LayoutW, FormatOutputC, LayoutC, FormatAccumulator, cutlass::arch::OpClassSimt, 63 | cutlass::arch::Sm70, DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, DefaultGemmCfg::InstructionShape, EpilogueOp>; 64 | 65 | #else 66 | #error "Unsupported GPU type" 67 | #endif 68 | 69 | cutlass::gemm::GemmCoord problem_size(M, N, K); 70 | 71 | // stride between two matrices within each batch 72 | long long int batch_stride_A = M * K; 73 | long long int batch_stride_W = 0; 74 | long long int batch_stride_C = M * N; 75 | 76 | // Define the operation of Gemm 77 | Gemm gemm_op; 78 | 79 | // Argument of Gemm Op 80 | typename Gemm::Arguments arguments{ 81 | {M, N, K}, 82 | {A.data_ptr(), lda}, batch_stride_A, 83 | {W.data_ptr(), ldb}, batch_stride_W, 84 | {C.data_ptr(), ldc}, batch_stride_C, 85 | {C.data_ptr(), ldc}, batch_stride_C, 86 | {1.0f, 0}, 87 | batch_size 88 | }; 89 | 90 | // request extra space for GEMM operation 91 | size_t workspace_size = Gemm::get_workspace_size(arguments); 92 | 93 | // allocate workspace memory 94 | cutlass::device_memory::allocation workspace(workspace_size); 95 | 96 | // Check the problem size is supported or not 97 | cutlass::Status status = gemm_op.can_implement(arguments); 98 | if (status != cutlass::Status::kSuccess) { 99 | throw std::runtime_error("cutlass cannot implement"); 100 | } 101 | 102 | // Initialize CUTLASS kernel with arguments and workspace pointer 103 | status = gemm_op.initialize(arguments, workspace.get()); 104 | if (status != cutlass::Status::kSuccess) { 105 | throw std::runtime_error("cutlass cannot initialize"); 106 | } 107 | 108 | status = gemm_op(); 109 | if (status != cutlass::Status::kSuccess) { 110 | throw std::runtime_error("cutlass cannot run"); 111 | } 112 | 113 | torch::Tensor Y = (C * alpha); 114 | return Y; 115 | 116 | } 117 | -------------------------------------------------------------------------------- /csrc/kernel/include/bcmm.h: -------------------------------------------------------------------------------- 1 | #ifndef BCMM_H 2 | #define BCMM_H 3 | #include 4 | #include 5 | 6 | torch::Tensor bcmm_int8(torch::Tensor A, torch::Tensor B, float alpha); 7 | 8 | #endif // BCMM_H -------------------------------------------------------------------------------- /csrc/kernel/include/bmm.h: -------------------------------------------------------------------------------- 1 | #ifndef BMM_H 2 | #define BMM_H 3 | #include 4 | #include 5 | 6 | torch::Tensor bmm_int8(torch::Tensor A, torch::Tensor B, float alpha); 7 | 8 | #endif // BMM_H -------------------------------------------------------------------------------- /csrc/kernel/include/bmw.h: -------------------------------------------------------------------------------- 1 | #ifndef BMW_H 2 | #define BMW_H 3 | #include 4 | #include 5 | 6 | torch::Tensor bmw_int8(torch::Tensor A, torch::Tensor W, torch::Tensor alpha); 7 | 8 | #endif // BMW_H -------------------------------------------------------------------------------- /csrc/kernel/include/common.h: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_H 2 | #define COMMON_H 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #endif // !COMMON -------------------------------------------------------------------------------- /csrc/kernel/include/qbmw.h: -------------------------------------------------------------------------------- 1 | #ifndef QBMW_H 2 | #define QBMW_H 3 | #include 4 | #include 5 | 6 | torch::Tensor qbmw(torch::Tensor A, torch::Tensor W); 7 | 8 | torch::Tensor qbmw_scaled(torch::Tensor A, torch::Tensor W, torch::Tensor scale); 9 | 10 | torch::Tensor qbmw_quantized(torch::Tensor A, torch::Tensor W, torch::Tensor scale); 11 | 12 | #endif // BMW_H -------------------------------------------------------------------------------- /csrc/kernel/qbmw.cu: -------------------------------------------------------------------------------- 1 | #include "include/qbmw.h" 2 | #include "include/common.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | 18 | __global__ void quantize_kernel( 19 | const float* __restrict__ input, 20 | const float* __restrict__ scale, 21 | int8_t* __restrict__ output, 22 | int M, 23 | int K 24 | ) { 25 | int row = blockIdx.y * blockDim.y + threadIdx.y; 26 | int col = blockIdx.x * blockDim.x + threadIdx.x; 27 | 28 | if (row < M && col < K) { 29 | int index = row * K + col; 30 | float scaled_value = input[index] * scale[index]; 31 | 32 | int8_t quantized_value = static_cast(roundf(scaled_value)); 33 | 34 | // clamp the range 35 | output[index] = max(-128, min(127, quantized_value)); 36 | } 37 | } 38 | 39 | torch::Tensor quantize_tensor(const torch::Tensor& input, const torch::Tensor& scale) { 40 | 41 | // shape of input 42 | int M = input.size(0); 43 | int K = input.size(1); 44 | 45 | auto output = torch::empty({M, K}, torch::dtype(torch::kInt8).device(input.device())); 46 | 47 | // get pointers 48 | const float* input_ptr = input.data_ptr(); 49 | const float* scale_ptr = scale.data_ptr(); 50 | int8_t* output_ptr = output.data_ptr(); 51 | 52 | // define CUDA block and grid sizes 53 | dim3 blockDim(16, 16); 54 | dim3 gridDim((K + blockDim.x - 1) / blockDim.x, (M + blockDim.y - 1) / blockDim.y); 55 | 56 | // launch the kernel 57 | quantize_kernel<<>>( 58 | input_ptr, scale_ptr, output_ptr, M, K 59 | ); 60 | 61 | return output; 62 | } 63 | 64 | torch::Tensor qbmw(torch::Tensor A, torch::Tensor W) { 65 | auto M = A.size(0); 66 | auto N = W.size(0); 67 | auto K = A.size(1); 68 | 69 | // automatically define the data type of output tensor C 70 | auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device())); 71 | 72 | using ElementOutput = int32_t; 73 | using ElementAccumulator = int32_t; 74 | using ElementComputeEpilogue = int32_t; 75 | using ElementInputA = int8_t; 76 | using ElementInputW = int8_t; 77 | 78 | using LayoutInputA = cutlass::layout::RowMajor; 79 | using LayoutInputW = cutlass::layout::ColumnMajor; 80 | using LayoutOutput = cutlass::layout::RowMajor; 81 | 82 | using EpilogueOp = cutlass::epilogue::thread::LinearCombination< 83 | ElementOutput, 128 / cutlass::sizeof_bits::value, 84 | ElementAccumulator, ElementComputeEpilogue, 85 | cutlass::epilogue::thread::ScaleType::NoBetaScaling>; 86 | 87 | using Gemm = cutlass::gemm::device::Gemm< 88 | ElementInputA, LayoutInputA, ElementInputW, LayoutInputW, 89 | ElementOutput, LayoutOutput, ElementAccumulator, 90 | cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, 91 | cutlass::gemm::GemmShape<128, 128, 64>, 92 | cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, 93 | EpilogueOp, 94 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; 95 | 96 | auto input_size = cutlass::MatrixCoord(M, K); 97 | auto weight_size = cutlass::MatrixCoord(K, N); 98 | auto output_size = cutlass::MatrixCoord(M, N); 99 | 100 | auto device = A.device(); 101 | cutlass::gemm::GemmCoord problem_size(M, N, K); 102 | 103 | cutlass::TensorRef input_ref( 104 | A.data_ptr(), LayoutInputA::packed(input_size)); 105 | 106 | cutlass::TensorRef weight_ref( 107 | W.data_ptr(), LayoutInputW::packed(weight_size)); 108 | 109 | cutlass::TensorRef out_ref( 110 | C.data_ptr(), LayoutOutput::packed(output_size)); 111 | 112 | typename Gemm::Arguments arguments{ 113 | problem_size, 114 | input_ref, 115 | weight_ref, 116 | out_ref, 117 | out_ref, 118 | {1, 0}, 1}; 119 | 120 | Gemm gemm_op; 121 | 122 | size_t workspace_size = Gemm::get_workspace_size(arguments); 123 | cutlass::device_memory::allocation workspace(workspace_size); 124 | 125 | // Check the problem size is supported or not 126 | cutlass::Status status = gemm_op.can_implement(arguments); 127 | if (status != cutlass::Status::kSuccess) { 128 | throw std::runtime_error("cutlass cannot implement"); 129 | } 130 | 131 | status = gemm_op.initialize(arguments, workspace.get()); 132 | if (status != cutlass::Status::kSuccess) { 133 | throw std::runtime_error("cutlass cannot initialize"); 134 | } 135 | 136 | status = gemm_op(); 137 | if (status != cutlass::Status::kSuccess) { 138 | throw std::runtime_error("cutlass cannot run"); 139 | } 140 | 141 | return C; 142 | 143 | } 144 | 145 | torch::Tensor qbmw_scaled(torch::Tensor A, torch::Tensor W, torch::Tensor scale) { 146 | 147 | torch::Tensor out = qbmw(A, W); 148 | torch::Tensor Y = (out * scale); 149 | 150 | Y = Y.to(torch::kFloat16); 151 | 152 | return Y; 153 | } 154 | 155 | torch::Tensor qbmw_quantized(torch::Tensor A, torch::Tensor W, torch::Tensor scale) { 156 | torch::Tensor out = qbmw(A, W); 157 | out = out.to(torch::kFloat32); 158 | 159 | torch::Tensor Y = quantize_tensor(out, scale); 160 | return Y; 161 | } -------------------------------------------------------------------------------- /csrc/kernel/scaler.cu: -------------------------------------------------------------------------------- 1 | #include "include/qbmw.h" 2 | #include "include/common.h" 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | -------------------------------------------------------------------------------- /csrc/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "torch", "wheel", "numpy"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "t2c_gemm" 7 | version = "0.1.0" 8 | description = "t2c_gemm: CUDA Extension of Torch2Chip" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = { text = "MIT" } 12 | 13 | [tool.setuptools] 14 | py-modules = [] -------------------------------------------------------------------------------- /csrc/setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CppExtension 4 | 5 | compute_capability = torch.cuda.get_device_capability() 6 | cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 7 | print(cuda_arch) 8 | 9 | setup( 10 | name="t2c_gemm", 11 | ext_modules=[ 12 | CppExtension( 13 | name="t2c_gemm", 14 | sources=[ 15 | "kernel/bmm.cu", 16 | "kernel/bmw.cu", 17 | "kernel/qbmw.cu", 18 | "kernel/bcmm.cu", 19 | "kernel/bindings.cpp", 20 | ], 21 | include_dirs=["t2c_gemm/kernel/include"], 22 | extra_compile_args={ 23 | "cxx": ["-std=c++17", "-O3"], 24 | "nvcc": [ 25 | "-O3", 26 | "-std=c++17", 27 | "-U__CUDA_NO_HALF_OPERATORS__", 28 | "-U__CUDA_NO_HALF_CONVERSIONS__", 29 | "-U__CUDA_NO_HALF2_OPERATORS__", 30 | f"-DCUDA_ARCH={cuda_arch}" 31 | ], 32 | }, 33 | ), 34 | ], 35 | cmdclass={"build_ext": BuildExtension}, 36 | ) -------------------------------------------------------------------------------- /csrc/t2c_gemm.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.2 2 | Name: t2c_gemm 3 | Version: 0.1.0 4 | Summary: t2c_gemm: CUDA Extension of Torch2Chip 5 | License: MIT 6 | Requires-Python: >=3.8 7 | Description-Content-Type: text/markdown 8 | -------------------------------------------------------------------------------- /csrc/t2c_gemm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | pyproject.toml 3 | setup.py 4 | kernel/bcmm.cu 5 | kernel/bindings.cpp 6 | kernel/bmm.cu 7 | kernel/bmw.cu 8 | kernel/qbmw.cu 9 | t2c_gemm.egg-info/PKG-INFO 10 | t2c_gemm.egg-info/SOURCES.txt 11 | t2c_gemm.egg-info/dependency_links.txt 12 | t2c_gemm.egg-info/top_level.txt -------------------------------------------------------------------------------- /csrc/t2c_gemm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /csrc/t2c_gemm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | t2c_gemm 2 | -------------------------------------------------------------------------------- /deprecated/scripts/bert/mrpc.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | epochs=1 8 | batch_size=32 9 | lr=1e-4 10 | loss=mse 11 | weight_decay=1e-4 12 | dataset="mrpc" 13 | log_file="training.log" 14 | 15 | wbit=8 16 | abit=8 17 | xqtype="minmax_token" 18 | wqtype="minmax_channel" 19 | num_samples=512 20 | 21 | save_path="./save/${dataset}/BERT-BASE/${xqtype}_${wqtype}/BERT-BASE_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 22 | 23 | python3 -W ignore ./bert/mrpc.py \ 24 | --save_path ${save_path} \ 25 | --epochs ${epochs} \ 26 | --log_file ${log_file} \ 27 | --lr ${lr} \ 28 | --weight-decay ${weight_decay} \ 29 | --batch_size ${batch_size} \ 30 | --loss_type ${loss} \ 31 | --mixed_prec True \ 32 | --optimizer adam \ 33 | --wqtype ${wqtype} \ 34 | --xqtype ${xqtype} \ 35 | --wbit ${wbit} \ 36 | --abit ${abit} \ 37 | --num_samples ${num_samples} \ -------------------------------------------------------------------------------- /deprecated/scripts/bert/mrpc_t2c.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | epochs=1 8 | batch_size=32 9 | lr=1e-4 10 | loss=mse 11 | weight_decay=1e-4 12 | dataset="mrpc" 13 | log_file="training.log" 14 | 15 | wbit=8 16 | abit=8 17 | xqtype="minmax_token" 18 | wqtype="minmax_channel" 19 | num_samples=512 20 | 21 | pre_trained="./save/mrpc/BERT-BASE/minmax_token_minmax_channel/BERT-BASE_w8_a8_lr1e-4_batch32_mseloss_all/model_best.pth.tar" 22 | save_path="./save/mrpc/BERT-BASE/minmax_token_minmax_channel/BERT-BASE_w8_a8_lr1e-4_batch32_mseloss_all/t2c/" 23 | 24 | python3 -W ignore ./bert/mrpc_t2c.py \ 25 | --model "bert" \ 26 | --save_path ${save_path} \ 27 | --epochs ${epochs} \ 28 | --log_file ${log_file} \ 29 | --lr ${lr} \ 30 | --weight-decay ${weight_decay} \ 31 | --batch_size ${batch_size} \ 32 | --loss_type ${loss} \ 33 | --mixed_prec True \ 34 | --optimizer adam \ 35 | --wqtype ${wqtype} \ 36 | --xqtype ${xqtype} \ 37 | --wbit ${wbit} \ 38 | --abit ${abit} \ 39 | --num_samples ${num_samples} \ 40 | --resume ${pre_trained} \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/bert/sst2.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | epochs=1 8 | batch_size=32 9 | lr=1e-4 10 | loss=mse 11 | weight_decay=1e-4 12 | dataset="sst2" 13 | log_file="training.log" 14 | 15 | wbit=8 16 | abit=8 17 | xqtype="smooth_token" 18 | wqtype="smooth_channel" 19 | num_samples=512 20 | 21 | 22 | save_path="./save/${dataset}/BERT-BASE/${xqtype}_${wqtype}/BERT-BASE_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 23 | 24 | python3 -W ignore ./bert/sst2.py \ 25 | --save_path ${save_path} \ 26 | --epochs ${epochs} \ 27 | --log_file ${log_file} \ 28 | --lr ${lr} \ 29 | --weight-decay ${weight_decay} \ 30 | --batch_size ${batch_size} \ 31 | --loss_type ${loss} \ 32 | --mixed_prec True \ 33 | --optimizer adam \ 34 | --wqtype ${wqtype} \ 35 | --xqtype ${xqtype} \ 36 | --wbit ${wbit} \ 37 | --abit ${abit} \ 38 | --num_samples ${num_samples} \ -------------------------------------------------------------------------------- /deprecated/scripts/bert/sst2_t2c.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | epochs=1 8 | batch_size=32 9 | lr=1e-4 10 | loss=mse 11 | weight_decay=1e-4 12 | dataset="sst2" 13 | log_file="training.log" 14 | 15 | wbit=8 16 | abit=8 17 | xqtype="lsq" 18 | wqtype="minmax_channel" 19 | num_samples=512 20 | 21 | pre_trained="./save/sst2/BERT-BASE/lsq_minmax_channel/BERT-BASE_w8_a8_lr1e-4_batch32_mseloss_all/model_best.pth.tar" 22 | save_path="./save/sst2/BERT-BASE/lsq_minmax_channel/BERT-BASE_w8_a8_lr1e-4_batch32_mseloss_all/t2c/" 23 | 24 | python3 -W ignore ./bert/sst2_t2c.py \ 25 | --model "bert" \ 26 | --save_path ${save_path} \ 27 | --epochs ${epochs} \ 28 | --log_file ${log_file} \ 29 | --lr ${lr} \ 30 | --weight-decay ${weight_decay} \ 31 | --batch_size ${batch_size} \ 32 | --loss_type ${loss} \ 33 | --mixed_prec True \ 34 | --optimizer adam \ 35 | --wqtype ${wqtype} \ 36 | --xqtype ${xqtype} \ 37 | --wbit ${wbit} \ 38 | --abit ${abit} \ 39 | --num_samples ${num_samples} \ 40 | --resume ${pre_trained} \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/pretrain.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model=mobilenetv1 4 | epochs=150 5 | batch_size=256 6 | lr=0.1 7 | weight_decay=1e-4 8 | dataset="imagenet" 9 | log_file="training.log" 10 | loss_type="soft_ce" 11 | 12 | save_path="/home/jm2787/MLSys24/T2C/save/${dataset}/${model}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss/" 13 | 14 | torchrun --nproc_per_node=1 --master_port 48002 ./imagenet/main.py \ 15 | --save_path ${save_path} \ 16 | --model ${model} \ 17 | --epochs ${epochs} \ 18 | --log_file ${log_file} \ 19 | --lr ${lr} \ 20 | --weight-decay ${weight_decay} \ 21 | --batch_size ${batch_size} \ 22 | --loss_type ${loss_type} \ 23 | --dataset ${dataset} \ 24 | --optimizer sgd \ 25 | --train_dir "/share/seo/imagenet/train/" \ 26 | --val_dir "/share/seo/imagenet/val/" \ 27 | --mixup 0.8 \ 28 | --cutmix 1.0 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/ptq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=mobilenetv1 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="qdrop" 19 | wqtype="minmax_channel" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=True 23 | 24 | pre_trained="/home/jm2787/MLSys24/T2C/save/imagenet/mobilenetv1/mobilenetv1_w_a_lr0.1_batch256_loss/checkpoint.pth.tar" 25 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 26 | 27 | python3 -W ignore ./imagenet/ptq.py \ 28 | --save_path ${save_path} \ 29 | --model ${model} \ 30 | --epochs ${epochs} \ 31 | --log_file ${log_file} \ 32 | --lr ${lr} \ 33 | --weight-decay ${weight_decay} \ 34 | --batch_size ${batch_size} \ 35 | --loss_type ${loss} \ 36 | --dataset ${dataset} \ 37 | --mixed_prec True \ 38 | --optimizer adam \ 39 | --trainer ${ttype} \ 40 | --wqtype ${wqtype} \ 41 | --xqtype ${xqtype} \ 42 | --wbit ${wbit} \ 43 | --abit ${abit} \ 44 | --num_samples ${num_samples} \ 45 | --fine_tune \ 46 | --resume ${pre_trained} \ 47 | --train_dir "/share/seo/imagenet/train/" \ 48 | --val_dir "/share/seo/imagenet/val/" \ 49 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_ptq_lsq_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="lsq" 19 | wqtype="adaround" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 25 | 26 | python3 -W ignore ./imagenet/ptq.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_ptq_lsq_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="lsq" 19 | wqtype="minmax_channel" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 25 | 26 | python3 -W ignore ./imagenet/ptq.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_ptq_minmax_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="minmax" 19 | wqtype="minmax_channel" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=False 23 | 24 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 25 | 26 | python3 -W ignore ./imagenet/ptq.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_ptq_qdrop_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="qdrop" 19 | wqtype="adaround" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 25 | 26 | python3 -W ignore ./imagenet/ptq.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_ptq_qdrop_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=1 9 | batch_size=64 10 | lr=1e-3 11 | loss=mse 12 | weight_decay=4e-5 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="qdrop" 19 | wqtype="minmax_channel" 20 | num_samples=1024 21 | ttype=ptq 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${ttype}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_layer_train${layer_train}/" 25 | 26 | python3 -W ignore ./imagenet/ptq.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_lsq_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="training.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq" 18 | wqtype="adaround" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/resnet50/ptq/lsq_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 22 | pre_trained="./save/imagenet/resnet50/ptq/lsq_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --resume ${pre_trained} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer ptq \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_lsq_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="training.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/resnet50/ptq/lsq_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 22 | pre_trained="./save/imagenet/resnet50/ptq/lsq_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --resume ${pre_trained} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer ptq \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_qdrop_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="training.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="adaround" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/resnet50/ptq/qdrop_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 22 | pre_trained="./save/imagenet/resnet50/ptq/qdrop_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --resume ${pre_trained} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer ptq \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_qdrop_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="training.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/resnet50/ptq/qdrop_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 22 | pre_trained="./save/imagenet/resnet50/ptq/qdrop_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --resume ${pre_trained} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer ptq \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_reload_lsq_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="reload.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq" 18 | wqtype="adaround" 19 | 20 | save_path="./save/imagenet/resnet50/ptq/lsq_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 21 | pre_trained="./save/imagenet/resnet50/ptq/lsq_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/t2c_model.pth.tar" 22 | 23 | python3 -W ignore ./imagenet/reload.py \ 24 | --save_path ${save_path} \ 25 | --model ${model} \ 26 | --batch_size ${batch_size} \ 27 | --resume ${pre_trained} \ 28 | --log_file ${log_file} \ 29 | --fine_tune \ 30 | --wqtype ${wqtype} \ 31 | --xqtype ${xqtype} \ 32 | --wbit ${wbit} \ 33 | --abit ${abit} \ 34 | --dataset ${dataset} \ 35 | --train_dir "/share/seo/imagenet/train/" \ 36 | --val_dir "/share/seo/imagenet/val/" \ 37 | --evaluate \ 38 | --trainer qattn \ 39 | --swl 32 \ 40 | --sfl 26 \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_reload_lsq_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="reload.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq" 18 | wqtype="minmax_channel" 19 | 20 | save_path="./save/imagenet/resnet50/ptq/lsq_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 21 | pre_trained="./save/imagenet/resnet50/ptq/lsq_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/t2c_model.pth.tar" 22 | 23 | python3 -W ignore ./imagenet/reload.py \ 24 | --save_path ${save_path} \ 25 | --model ${model} \ 26 | --batch_size ${batch_size} \ 27 | --resume ${pre_trained} \ 28 | --log_file ${log_file} \ 29 | --fine_tune \ 30 | --wqtype ${wqtype} \ 31 | --xqtype ${xqtype} \ 32 | --wbit ${wbit} \ 33 | --abit ${abit} \ 34 | --dataset ${dataset} \ 35 | --train_dir "/share/seo/imagenet/train/" \ 36 | --val_dir "/share/seo/imagenet/val/" \ 37 | --evaluate \ 38 | --trainer qattn \ 39 | --swl 32 \ 40 | --sfl 26 \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_reload_qdrop_adaround.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="reload.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="adaround" 19 | 20 | save_path="./save/imagenet/resnet50/ptq/qdrop_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 21 | pre_trained="./save/imagenet/resnet50/ptq/qdrop_adaround/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/t2c_model.pth.tar" 22 | 23 | python3 -W ignore ./imagenet/reload.py \ 24 | --save_path ${save_path} \ 25 | --model ${model} \ 26 | --batch_size ${batch_size} \ 27 | --resume ${pre_trained} \ 28 | --log_file ${log_file} \ 29 | --fine_tune \ 30 | --wqtype ${wqtype} \ 31 | --xqtype ${xqtype} \ 32 | --wbit ${wbit} \ 33 | --abit ${abit} \ 34 | --dataset ${dataset} \ 35 | --train_dir "/share/seo/imagenet/train/" \ 36 | --val_dir "/share/seo/imagenet/val/" \ 37 | --evaluate \ 38 | --trainer qattn \ 39 | --swl 32 \ 40 | --sfl 26 \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/resnet50_t2c_reload_qdrop_minmax_channel.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=resnet50 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="reload.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="minmax_channel" 19 | 20 | save_path="./save/imagenet/resnet50/ptq/qdrop_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 21 | pre_trained="./save/imagenet/resnet50/ptq/qdrop_minmax_channel/resnet50_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/t2c_model.pth.tar" 22 | 23 | python3 -W ignore ./imagenet/reload.py \ 24 | --save_path ${save_path} \ 25 | --model ${model} \ 26 | --batch_size ${batch_size} \ 27 | --resume ${pre_trained} \ 28 | --log_file ${log_file} \ 29 | --fine_tune \ 30 | --wqtype ${wqtype} \ 31 | --xqtype ${xqtype} \ 32 | --wbit ${wbit} \ 33 | --abit ${abit} \ 34 | --dataset ${dataset} \ 35 | --train_dir "/share/seo/imagenet/train/" \ 36 | --val_dir "/share/seo/imagenet/val/" \ 37 | --evaluate \ 38 | --trainer qattn \ 39 | --swl 32 \ 40 | --sfl 26 \ 41 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/swin-t2c-smoothquant.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=swin_tiny_patch4_window7_224 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="smooth_token" 18 | wqtype="smooth_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/swin_tiny_patch4_window7_224/smooth_token_smooth_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/swin_tiny_patch4_window7_224/smooth_token_smooth_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ 43 | 44 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/swin-vit-ptq-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=swin_tiny_patch4_window7_224 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="minmax_token" 19 | wqtype="minmax_channel" 20 | num_samples=500 21 | ttype=qattn 22 | layer_train=False 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/swin-vit-ptq-smoothquant.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=swin_tiny_patch4_window7_224 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="smooth_token" 19 | wqtype="smooth_channel" 20 | num_samples=500 21 | ttype=smooth_quant 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/swin-vit-t2c-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=swin_tiny_patch4_window7_224 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="t2c.log" 13 | wbit=8 14 | abit=8 15 | xqtype="minmax_token" 16 | wqtype="minmax_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/swin_tiny_patch4_window7_224/minmax_token_minmax_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/swin_tiny_patch4_window7_224/minmax_token_minmax_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/t2c.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/swin-vit-t2c-reload-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=swin_tiny_patch4_window7_224 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="minmax_token" 16 | wqtype="minmax_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/swin_tiny_patch4_window7_224/minmax_token_minmax_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/swin_tiny_patch4_window7_224/minmax_token_minmax_channel/swin_tiny_patch4_window7_224_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | 39 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/t2c.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_tiny 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="smooth" 18 | wqtype="smooth" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/vit_tiny/smooth_smooth/vit_tiny_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/vit_tiny/smooth_smooth/vit_tiny_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ 43 | 44 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/t2c_cnn.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=mobilenetv1 8 | epochs=200 9 | batch_size=64 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="training.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/mobilenetv1/ptq/qdrop_minmax_channel/mobilenetv1_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/t2c/" 22 | pre_trained="./save/imagenet/mobilenetv1/ptq/qdrop_minmax_channel/mobilenetv1_w8_a8_lr1e-3_batch64_mseloss_layer_trainTrue/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --resume ${pre_trained} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer ptq \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 8 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-ptq-adaround-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="qdrop" 19 | wqtype="adaround" 20 | num_samples=500 21 | ttype=qattn 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-ptq-adaround-qdrop.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="qdrop" 19 | wqtype="mx_channel" 20 | num_samples=500 21 | ttype=qattn 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-ptq-minmax-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_tiny 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="lsq_token" 19 | wqtype="minmax_channel" 20 | num_samples=500 21 | ttype=qattn 22 | layer_train=True 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-ptq-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_base 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="minmax_token" 19 | wqtype="minmax_channel" 20 | num_samples=500 21 | ttype=qattn 22 | layer_train=False 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-ptq-smoothquant.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=1 9 | batch_size=100 10 | lr=1e-4 11 | loss=cross_entropy 12 | weight_decay=1e-4 13 | dataset="imagenet" 14 | log_file="training.log" 15 | 16 | wbit=8 17 | abit=8 18 | xqtype="smooth" 19 | wqtype="smooth_channel" 20 | num_samples=500 21 | ttype=smooth_quant 22 | layer_train=False 23 | 24 | save_path="./save/${dataset}/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr${lr}_batch${batch_size}_${loss}loss_all/" 25 | 26 | python3 -W ignore ./imagenet/vit.py \ 27 | --save_path ${save_path} \ 28 | --model ${model} \ 29 | --epochs ${epochs} \ 30 | --log_file ${log_file} \ 31 | --lr ${lr} \ 32 | --weight-decay ${weight_decay} \ 33 | --batch_size ${batch_size} \ 34 | --loss_type ${loss} \ 35 | --dataset ${dataset} \ 36 | --mixed_prec True \ 37 | --optimizer adam \ 38 | --trainer ${ttype} \ 39 | --wqtype ${wqtype} \ 40 | --xqtype ${xqtype} \ 41 | --wbit ${wbit} \ 42 | --abit ${abit} \ 43 | --num_samples ${num_samples} \ 44 | --fine_tune \ 45 | --train_dir "/share/seo/imagenet/train/" \ 46 | --val_dir "/share/seo/imagenet/val/" \ 47 | --layer_trainer ${layer_train} \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-adaround-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_tiny 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq_token" 18 | wqtype="adaround" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-adaround-qdrop.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop" 18 | wqtype="adaround" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/vit_small/qdrop_adaround/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/vit_small/qdrop_adaround/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-minmax-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="lsq_token" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-minmax-qdrop.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="qdrop_token" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_base 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="minmax_token" 18 | wqtype="minmax_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/vit_base/minmax_token_minmax_channel/vit_base_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/vit_base/minmax_token_minmax_channel/vit_base_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ 43 | 44 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-adaround-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_tiny 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="lsq_token" 16 | wqtype="adaround" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/${model}/${xqtype}_${wqtype}/${model}_w${wbit}_a${abit}_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ 41 | 42 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-adaround-qdrop.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="qdrop" 16 | wqtype="adaround" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/vit_small/qdrop_adaround/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/vit_small/qdrop_adaround/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ 41 | 42 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-minmax-lsq.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="lsq_token" 16 | wqtype="minmax_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/vit_small/lsq_token_minmax_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/vit_small/lsq_token_minmax_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | 39 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-minmax-qdrop.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="qdrop_token" 16 | wqtype="minmax_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/vit_small/qdrop_token_minmax_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/vit_small/qdrop_token_minmax_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | 39 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-minmax.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_base 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="minmax_token" 16 | wqtype="minmax_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/vit_base/minmax_token_minmax_channel/vit_base_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/vit_base/minmax_token_minmax_channel/vit_base_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ 41 | 42 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-reload-smoothquant.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | batch_size=100 9 | loss=cross_entropy 10 | weight_decay=0.0005 11 | dataset="imagenet" 12 | log_file="reload.log" 13 | wbit=8 14 | abit=8 15 | xqtype="smooth" 16 | wqtype="smooth_channel" 17 | ttype=ptq 18 | 19 | save_path="./save/imagenet/vit_small/smooth_smooth_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 20 | pre_trained="./save/imagenet/vit_small/smooth_smooth_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/t2c_model.pth.tar" 21 | 22 | python3 -W ignore ./imagenet/reload.py \ 23 | --save_path ${save_path} \ 24 | --model ${model} \ 25 | --batch_size ${batch_size} \ 26 | --resume ${pre_trained} \ 27 | --log_file ${log_file} \ 28 | --fine_tune \ 29 | --wqtype ${wqtype} \ 30 | --xqtype ${xqtype} \ 31 | --wbit ${wbit} \ 32 | --abit ${abit} \ 33 | --dataset ${dataset} \ 34 | --train_dir "/share/seo/imagenet/train/" \ 35 | --val_dir "/share/seo/imagenet/val/" \ 36 | --evaluate \ 37 | --trainer qattn \ 38 | --swl 32 \ 39 | --sfl 26 \ 40 | --export_samples 1 \ 41 | 42 | -------------------------------------------------------------------------------- /deprecated/scripts/imagenet/vit-t2c-smoothquant.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "$DIRECTORY" ]; then 2 | mkdir ./save 3 | fi 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | model=vit_small 8 | epochs=200 9 | batch_size=100 10 | lr=0.1 11 | loss=cross_entropy 12 | weight_decay=0.0005 13 | dataset="imagenet" 14 | log_file="t2c.log" 15 | wbit=8 16 | abit=8 17 | xqtype="smooth" 18 | wqtype="smooth_channel" 19 | ttype=ptq 20 | 21 | save_path="./save/imagenet/vit_small/smooth_smooth_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/t2c/" 22 | pre_trained="./save/imagenet/vit_small/smooth_smooth_channel/vit_small_w8_a8_lr1e-4_batch100_cross_entropyloss_all/model_best.pth.tar" 23 | 24 | python3 -W ignore ./imagenet/t2c.py \ 25 | --save_path ${save_path} \ 26 | --model ${model} \ 27 | --batch_size ${batch_size} \ 28 | --resume ${pre_trained} \ 29 | --log_file ${log_file} \ 30 | --fine_tune \ 31 | --wqtype ${wqtype} \ 32 | --xqtype ${xqtype} \ 33 | --wbit ${wbit} \ 34 | --abit ${abit} \ 35 | --dataset ${dataset} \ 36 | --train_dir "/share/seo/imagenet/train/" \ 37 | --val_dir "/share/seo/imagenet/val/" \ 38 | --evaluate \ 39 | --trainer qattn \ 40 | --swl 32 \ 41 | --sfl 26 \ 42 | --export_samples 1 \ 43 | 44 | -------------------------------------------------------------------------------- /figs/DualPathDesign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/figs/DualPathDesign.png -------------------------------------------------------------------------------- /figs/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/figs/Figure1.png -------------------------------------------------------------------------------- /figs/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/figs/icon.png -------------------------------------------------------------------------------- /figs/torch2chip_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/figs/torch2chip_workflow.png -------------------------------------------------------------------------------- /llm/gsm8k.py: -------------------------------------------------------------------------------- 1 | """ 2 | Math Task (GSM8K, MATH) 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append("../torch2chip/") 8 | import argparse 9 | 10 | from transformers import AutoTokenizer 11 | from src.stage.base import Execute 12 | from src.trainer.llm.ptq import SmoothQuant 13 | from src.t2c.convert import Llama4Compress 14 | from src.trainer.llm.evaluator import GSM8K 15 | from src.t2c.t2c import T2C 16 | 17 | parser = argparse.ArgumentParser(description='LLM model evaluation against the GSM8K benchmark') 18 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 19 | args = parser.parse_args() 20 | 21 | class GSM8KEval(Execute): 22 | def __init__(self, config_dir): 23 | super().__init__(config_dir) 24 | model = self.create_model() 25 | self.tokenizer = self.prepare_tokenizer() 26 | 27 | wbit = self.config["quantization"]["wbit"] 28 | abit = self.config["quantization"]["abit"] 29 | converter = Llama4Compress(model, wbit=wbit, abit=abit) 30 | 31 | # convert model 32 | self.model = converter.convert() 33 | 34 | # initialize logging 35 | self.logger = self.initialize_logger() 36 | self.task = SmoothQuant(config_dir, self.model, self.tokenizer, self.logger) 37 | 38 | def register_run_dir(self): 39 | super().register_run_dir() 40 | 41 | self.t2c_dir = os.path.join(self.run_dir, "t2c") 42 | self.t2c_model_dir = os.path.join(self.t2c_dir , "t2c_model.pth.tar") 43 | 44 | self.tensors_dir = os.path.join(self.t2c_dir , "tensors") 45 | if not os.path.isdir(self.tensors_dir): 46 | os.makedirs(self.tensors_dir, exist_ok=True) 47 | 48 | def prepare_tokenizer(self): 49 | model_type = self.config["model"]["model_type"] 50 | tokenizer = AutoTokenizer.from_pretrained(model_type, trust_remote_code=True) 51 | 52 | if tokenizer.pad_token_id is None: 53 | if tokenizer.eos_token_id is not None: 54 | tokenizer.pad_token_id = tokenizer.eos_token_id 55 | else: 56 | tokenizer.pad_token_id = 0 57 | 58 | return tokenizer 59 | 60 | def ptq(self): 61 | fake_quantized_model = self.task.run() 62 | return fake_quantized_model 63 | 64 | def t2c(self, fake_quant_model): 65 | t2c = T2C(model=fake_quant_model, config=self.config) 66 | fused_model = t2c.fused_model() 67 | self.print_arch(fused_model, "fused_model") 68 | 69 | evaluator = GSM8K(self.config_dir, fused_model, self.tokenizer) 70 | evaluator.run() 71 | 72 | def run(self): 73 | fake_quant_model = self.ptq() 74 | fused_model = self.t2c(fake_quant_model) 75 | 76 | def starter(): 77 | executor = GSM8KEval(args.config_dir) 78 | executor.run() 79 | 80 | if __name__ == "__main__": 81 | starter() -------------------------------------------------------------------------------- /llm/mmlu.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM Evaluation with MMLU 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append("../torch2chip/") 8 | 9 | import argparse 10 | 11 | from transformers import AutoTokenizer 12 | from src.stage.base import Execute 13 | from src.trainer.llm.evaluator import MMLU 14 | from src.t2c.convert import Llama4Compress 15 | from src.trainer.llm.ptq import SmoothQuant 16 | from src.t2c.t2c import T2C 17 | 18 | parser = argparse.ArgumentParser(description='LLM model evaluation against the MMLU benchmark') 19 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 20 | args = parser.parse_args() 21 | 22 | class MMLUEval(Execute): 23 | def __init__(self, config_dir): 24 | super().__init__(config_dir) 25 | model = self.create_model() 26 | self.tokenizer = self.prepare_tokenizer() 27 | 28 | wbit = self.config["quantization"]["wbit"] 29 | abit = self.config["quantization"]["abit"] 30 | converter = Llama4Compress(model, wbit=wbit, abit=abit) 31 | 32 | # convert model 33 | self.model = converter.convert() 34 | 35 | # initialize logging 36 | self.logger = self.initialize_logger() 37 | self.task = SmoothQuant(config_dir, self.model, self.tokenizer, self.logger) 38 | 39 | def register_run_dir(self): 40 | super().register_run_dir() 41 | 42 | self.t2c_dir = os.path.join(self.run_dir, "t2c") 43 | self.t2c_model_dir = os.path.join(self.t2c_dir , "t2c_model.pth.tar") 44 | 45 | self.tensors_dir = os.path.join(self.t2c_dir , "tensors") 46 | if not os.path.isdir(self.tensors_dir): 47 | os.makedirs(self.tensors_dir, exist_ok=True) 48 | 49 | def prepare_tokenizer(self): 50 | model_type = self.config["model"]["model_type"] 51 | tokenizer = AutoTokenizer.from_pretrained(model_type, padding_side="left") 52 | 53 | tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id 54 | tokenizer.bos_token_id = 1 55 | 56 | return tokenizer 57 | 58 | def t2c(self, fake_quant_model): 59 | t2c = T2C(model=fake_quant_model, config=self.config) 60 | fused_model = t2c.fused_model() 61 | self.print_arch(fused_model, "fused_model") 62 | 63 | evaluator = MMLU(self.config_dir, fused_model, self.tokenizer) 64 | evaluator.run() 65 | 66 | def ptq(self): 67 | fake_quantized_model = self.task.run() 68 | return fake_quantized_model 69 | 70 | def run(self): 71 | fake_quant_model = self.ptq() 72 | fused_model = self.t2c(fake_quant_model) 73 | 74 | def starter(): 75 | executor = MMLUEval(args.config_dir) 76 | executor.run() 77 | 78 | if __name__ == "__main__": 79 | starter() 80 | -------------------------------------------------------------------------------- /llm/wikitext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compression of llama model series 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append("../torch2chip/") 8 | 9 | import torch 10 | import argparse 11 | 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer 14 | from src.stage.base import Execute 15 | from src.trainer.llm.ptq import SmoothQuant 16 | from src.t2c.convert import Llama4Compress 17 | from src.trainer.llm.evaluator import WikiText 18 | from src.t2c.t2c import T2C 19 | from src.utils.utils import gpufloat2cpuint 20 | from transformers import set_seed 21 | 22 | parser = argparse.ArgumentParser(description='LLM model evaluation against the WikiText benchmark') 23 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 24 | args = parser.parse_args() 25 | 26 | class CompressLlama(Execute): 27 | def __init__(self, config_dir): 28 | super().__init__(config_dir) 29 | set_seed(self.config["seed"]) 30 | 31 | model = self.create_model() 32 | self.tokenizer = self.prepare_tokenizer() 33 | 34 | wbit = self.config["quantization"]["wbit"] 35 | abit = self.config["quantization"]["abit"] 36 | converter = Llama4Compress(model, wbit=wbit, abit=abit) 37 | 38 | # convert model 39 | self.model = converter.convert() 40 | 41 | # define the target task 42 | self.task = SmoothQuant(config_dir, self.model, self.tokenizer, self.logger) 43 | 44 | def register_run_dir(self): 45 | super().register_run_dir() 46 | 47 | self.t2c_dir = os.path.join(self.run_dir, "t2c") 48 | self.t2c_model_dir = os.path.join(self.t2c_dir , "t2c_model.pth.tar") 49 | 50 | self.tensors_dir = os.path.join(self.t2c_dir , "tensors") 51 | if not os.path.isdir(self.tensors_dir): 52 | os.makedirs(self.tensors_dir, exist_ok=True) 53 | 54 | def prepare_tokenizer(self): 55 | model_type = self.config["model"]["model_type"] 56 | tokenizer = AutoTokenizer.from_pretrained(model_type, trust_remote_code=True) 57 | 58 | if tokenizer.pad_token_id is None: 59 | if tokenizer.eos_token_id is not None: 60 | tokenizer.pad_token_id = tokenizer.eos_token_id 61 | else: 62 | tokenizer.pad_token_id = 0 63 | 64 | return tokenizer 65 | 66 | def ptq(self): 67 | fake_quantized_model = self.task.run() 68 | evaluator = WikiText(self.config_dir, fake_quantized_model, self.tokenizer) 69 | evaluator.run() 70 | 71 | return fake_quantized_model 72 | 73 | def save(self, t2c:T2C): 74 | t2c_model = getattr(t2c, "model") 75 | torch.save(t2c_model.state_dict(), self.t2c_model_dir) 76 | 77 | for k, v in tqdm(t2c.node_dict.items()): 78 | x1, x2, y = v 79 | 80 | x1 = gpufloat2cpuint(x1, torch.int8) 81 | x2 = gpufloat2cpuint(x2, torch.int8) 82 | y = gpufloat2cpuint(y, torch.int32) 83 | 84 | torch.save(x1, os.path.join(self.tensors_dir, f"{k}_x1.pt")) 85 | torch.save(x2, os.path.join(self.tensors_dir, f"{k}_x2.pt")) 86 | torch.save(y, os.path.join(self.tensors_dir, f"{k}_y.pt")) 87 | 88 | 89 | def t2c(self, fake_quant_model): 90 | t2c = T2C(model=fake_quant_model, config=self.config) 91 | fused_model = t2c.fused_model() 92 | self.print_arch(fused_model, "fused_model") 93 | 94 | evaluator = WikiText(self.config_dir, fused_model, self.tokenizer) 95 | self.logger.info(f"\n Evaluating the fused model...") 96 | evaluator.run() 97 | 98 | export_samples = self.config["export"]["export_samples"] 99 | t2c.register_node() 100 | 101 | if export_samples > 0: 102 | evaluator.export_run(export_samples) 103 | self.save(t2c) 104 | 105 | return fused_model 106 | 107 | def run(self): 108 | fake_quant_model = self.ptq() 109 | fused_model = self.t2c(fake_quant_model) 110 | 111 | def starter(): 112 | executor = CompressLlama(args.config_dir) 113 | executor.run() 114 | 115 | if __name__ == "__main__": 116 | starter() -------------------------------------------------------------------------------- /prune/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Weight pruning with CNN 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append("../torch2chip/") 8 | 9 | import argparse 10 | from src.stage.base import Execute 11 | from src.data.vision.imagenet import ImageNet1K 12 | from src.t2c.convert import Vanilla4Compress 13 | from src.trainer.pruning import STrainer 14 | 15 | parser = argparse.ArgumentParser(description='Llama') 16 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 17 | args = parser.parse_args() 18 | 19 | class PruneResNet(Execute): 20 | def __init__(self, config_dir): 21 | super().__init__(config_dir) 22 | 23 | model = self.create_model() 24 | converter = Vanilla4Compress(model=model, wbit=32, abit=32) 25 | self.model = converter.convert() 26 | 27 | # prepare dataloaders 28 | trainloader, testloader = self.prepare_dataloader() 29 | 30 | # trainer 31 | self.trainer = STrainer( 32 | model=self.model, 33 | trainloader=trainloader, 34 | validloader=testloader, 35 | config=self.config, 36 | logger=self.logger 37 | ) 38 | 39 | def prepare_dataloader(self): 40 | data_gen = ImageNet1K(self.config_dir) 41 | 42 | trainloader, testloader = data_gen.run() 43 | return trainloader, testloader 44 | 45 | def run(self): 46 | self.trainer.fit() 47 | self.trainer.valid_epoch() 48 | self.logger.info("Baseline Model: Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 49 | 50 | 51 | if __name__ == "__main__": 52 | executor = PruneResNet(args.config_dir) 53 | executor.run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.3.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | certifi==2024.2.2 7 | charset-normalizer==3.3.2 8 | datasets==3.2.0 9 | dill==0.3.8 10 | filelock==3.14.0 11 | frozenlist==1.4.1 12 | fsspec==2024.3.1 13 | fxpmath==0.4.9 14 | huggingface-hub==0.28.1 15 | idna==3.7 16 | Jinja2==3.1.4 17 | MarkupSafe==2.1.5 18 | mpmath==1.3.0 19 | multidict==6.0.5 20 | multiprocess==0.70.16 21 | networkx==3.2.1 22 | ninja==1.11.1.3 23 | numpy==1.26.4 24 | nvidia-cublas-cu12==12.1.3.1 25 | nvidia-cuda-cupti-cu12==12.1.105 26 | nvidia-cuda-nvrtc-cu12==12.1.105 27 | nvidia-cuda-runtime-cu12==12.1.105 28 | nvidia-cudnn-cu12==8.9.2.26 29 | nvidia-cufft-cu12==11.0.2.54 30 | nvidia-curand-cu12==10.3.2.106 31 | nvidia-cusolver-cu12==11.4.5.107 32 | nvidia-cusparse-cu12==12.1.0.106 33 | nvidia-nccl-cu12==2.20.5 34 | nvidia-nvjitlink-cu12==12.5.40 35 | nvidia-nvtx-cu12==12.1.105 36 | packaging==24.0 37 | pandas==2.2.2 38 | pillow==10.3.0 39 | psutil==6.1.1 40 | pyarrow==16.1.0 41 | pyarrow-hotfix==0.6 42 | python-dateutil==2.9.0.post0 43 | pytz==2024.1 44 | PyYAML==6.0.1 45 | regex==2024.5.15 46 | requests==2.32.2 47 | safetensors==0.4.3 48 | six==1.16.0 49 | sympy==1.12 50 | tabulate==0.9.0 51 | timm==1.0.14 52 | tokenizers==0.21.0 53 | torch==2.3.0 54 | torchaudio==2.3.0 55 | torchvision==0.18.0 56 | tqdm==4.66.4 57 | transformers==4.48.2 58 | triton==2.3.0 59 | typing_extensions==4.11.0 60 | tzdata==2024.1 61 | urllib3==2.2.1 62 | xxhash==3.4.1 63 | yarl==1.9.4 64 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base dataset configuration 3 | """ 4 | 5 | import torch 6 | 7 | from src.stage.base import Execute 8 | 9 | # language model dataset mapping: path, package, function, separate path flag 10 | LANGUAGE_DATASET_MAP = { 11 | 'wikitext': ('wikitext-2-raw-v1', 'datasets', 'load_dataset', False), 12 | 'boolq': ('json', 'datasets', 'load_dataset', True), 13 | 'openbookqa': ('json', 'datasets', 'load_dataset', True), 14 | 'piqa': ('json', 'datasets', 'load_dataset', True), 15 | 'winogrande': ('json', 'datasets', 'load_dataset', True), 16 | 'commonsense_reasoning': ('json', 'datasets', 'load_dataset', True), 17 | } 18 | 19 | class DataStage(Execute): 20 | """ 21 | Base dataset stage for vision and language datasets 22 | """ 23 | def __init__(self, config_dir): 24 | super().__init__(config_dir) 25 | 26 | self.dataset_name = self.config["dataset"]["name"] 27 | self.data_split = self.config["dataset"]["split"] 28 | self.batch_size = self.config["train"]["batch_size"] 29 | 30 | # ddp flag 31 | self.is_ddp = torch.distributed.is_initialized() 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __name__(self): 37 | return "BaseDataStage" 38 | 39 | def load_dataset(self): 40 | return [] 41 | 42 | def prepare_transform(self): 43 | pass 44 | 45 | def prepare_loader(self): 46 | pass 47 | 48 | def run(self): 49 | print(f"Preparing dataset {self.dataset_name}") -------------------------------------------------------------------------------- /src/data/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for data loading, etc. 3 | """ 4 | 5 | import os 6 | import json 7 | 8 | def load_json_data(data_path, split) -> list: 9 | """ 10 | read data from dataset file 11 | Args: 12 | args: 13 | 14 | Returns: 15 | """ 16 | 17 | assert split in ["train", "test"], "Data split can only be train or test!" 18 | 19 | file_path = os.path.join(data_path, "test.json") 20 | if not os.path.exists(file_path): 21 | raise FileNotFoundError(f"can not find dataset file : {file_path}") 22 | json_data = json.load(open(file_path, 'r')) 23 | return json_data -------------------------------------------------------------------------------- /src/data/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf import * 2 | 3 | DATA_STAGE_MAP = { 4 | "wikitext": WikiText, 5 | "piqa": PiQA, 6 | "hellaswag": HellaSwag, 7 | "boolq": BoolQ, 8 | "openbookqa": OpenBookQA, 9 | "winogrande": WinoGrande, 10 | "ARC-Challenge": ARCc, 11 | "ARC-Easy": ARCe, 12 | "gsm8k": GSM8K, 13 | "mmlu": MMLU, 14 | "math": MATH500Step 15 | } -------------------------------------------------------------------------------- /src/data/llm/math/grader.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/data/llm/math/grader.py -------------------------------------------------------------------------------- /src/data/vision/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageNet data preparation 3 | """ 4 | 5 | import os 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | from torchvision import datasets 10 | from src.data.base import DataStage 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | 13 | class VisionData(DataStage): 14 | """ 15 | Basic stage for vision dataset 16 | """ 17 | def __init__(self, config_dir): 18 | super().__init__(config_dir) 19 | self.train_dir = self.config["dataset"]["train_dir"] 20 | self.test_dir = self.config["dataset"]["test_dir"] 21 | self.num_samples = self.config["dataset"]["samples"] 22 | self.mean = self.config["dataset"].get("mean", IMAGENET_DEFAULT_MEAN) 23 | self.std = self.config["dataset"].get("std", IMAGENET_DEFAULT_STD) 24 | 25 | def __len__(self): 26 | return self.num_samples 27 | 28 | class ImageNet1K(VisionData): 29 | def __init__(self, config_dir): 30 | super().__init__(config_dir) 31 | self.num_classes = 1000 32 | self.num_workers = self.config["dataset"]["num_workers"] 33 | 34 | def __name__(self): 35 | return "ImageNet-1K" 36 | 37 | def prepare_transform(self): 38 | train = transforms.Compose([ 39 | transforms.RandomResizedCrop(224), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize(self.mean, self.std) 43 | ]) 44 | 45 | test = transforms.Compose([ 46 | transforms.Resize(256), 47 | transforms.CenterCrop(224), 48 | transforms.ToTensor(), 49 | transforms.Normalize(self.mean, self.std) 50 | ]) 51 | 52 | return train, test 53 | 54 | def prepare_loader(self): 55 | trtf, tetf = self.prepare_transform() 56 | 57 | trainset = datasets.ImageFolder(self.train_dir, transform=trtf) 58 | testset = datasets.ImageFolder(self.test_dir, transform=tetf) 59 | 60 | if self.num_samples != -1: 61 | rand = torch.utils.data.RandomSampler(trainset, num_samples=self.num_samples) 62 | sampler = torch.utils.data.BatchSampler(rand, batch_size=self.batch_size, drop_last=False) 63 | else: 64 | sampler = None 65 | 66 | if self.is_ddp: 67 | sampler = torch.utils.data.distributed.DistributedSampler(trainset) 68 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, sampler=sampler) 69 | testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 70 | 71 | else: 72 | if self.num_samples != -1: 73 | trainloader = torch.utils.data.DataLoader( 74 | trainset, 75 | batch_sampler=sampler, 76 | num_workers=self.num_workers, 77 | pin_memory=True 78 | ) 79 | else: 80 | trainloader = torch.utils.data.DataLoader( 81 | trainset, 82 | batch_size=self.batch_size, 83 | num_workers=self.num_workers, 84 | pin_memory=True 85 | ) 86 | 87 | testloader = torch.utils.data.DataLoader( 88 | testset, 89 | batch_size=self.batch_size, 90 | shuffle=False, 91 | num_workers=self.num_workers, 92 | pin_memory=True 93 | ) 94 | 95 | return trainloader, testloader 96 | 97 | def run(self): 98 | self.logger.info("Preparing ImageNet-1K...") 99 | trainloader, testloader = self.prepare_loader() 100 | self.logger.info(f"Done | Train size: {len(trainloader)} | Test size: {len(testloader)}") 101 | 102 | return trainloader, testloader -------------------------------------------------------------------------------- /src/hardware/systolic_array.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/hardware/systolic_array.py -------------------------------------------------------------------------------- /src/hardware/writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export the MatMul Results to PE 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | class MatMulWriter(object): 10 | def __init__(self, pe:int=4, save_path:str=""): 11 | self.pe = pe 12 | self.save_path = save_path 13 | 14 | def matmul_shape(self, x:torch.Tensor): 15 | shape = list(x.size()) 16 | if len(shape) == 4: 17 | batch, channel, row, col = shape 18 | 19 | elif len(shape) == 3: 20 | batch, channel, row, col = shape[0], 1, shape[1], shape[2] 21 | x = x.view(batch, channel, row, col) 22 | 23 | elif len(shape) == 2: 24 | batch, channel, row, col = 1, 1, shape[0], shape[1] 25 | x = x.view(batch, channel, row, col) 26 | 27 | return x, batch, channel, row, col 28 | 29 | def round2pe(self, dim:int): 30 | if dim % self.pe != 0: 31 | ndiff = math.ceil(dim / self.pe) 32 | npad = abs(ndiff*self.pe - dim) 33 | else: 34 | npad = 0 35 | 36 | return npad 37 | 38 | def padxy(self, x:torch.Tensor, y:torch.Tensor): 39 | assert len(x.shape) == 2, "two-dim matrix only!" 40 | assert len(y.shape) == 2, "two-dim matrix only!" 41 | 42 | rx, cx = x.shape 43 | ry, cy = y.shape 44 | 45 | assert cx == ry, "inavlid matmul size!" 46 | 47 | rxpad = self.round2pe(rx) 48 | cxpad = self.round2pe(cx) 49 | cypad = self.round2pe(cy) 50 | 51 | px = [0, cxpad, 0, rxpad] 52 | py = [0, cypad, 0, cxpad] 53 | 54 | xpad = F.pad(x, px, value=0.0) 55 | ypad = F.pad(y, py, value=0.0) 56 | return xpad, ypad 57 | 58 | def mat2pe(self, mat:torch.Tensor): 59 | assert len(mat.shape) == 2, "two-dim matrix only!" 60 | r, c = mat.shape 61 | 62 | mat4d = mat.unsqueeze(0).unsqueeze(2) 63 | mat4d = mat.contiguous().view(int(r/self.pe), self.pe, int(c/self.pe), self.pe) 64 | 65 | return mat4d.permute(0,2,1,3) 66 | 67 | def pe_stats(self, tensor:torch.Tensor): 68 | return None 69 | 70 | def tensor2pe(self, x:torch.Tensor, y:torch.Tensor): 71 | x, bx, chx, rx, cx = self.matmul_shape(x) 72 | y, by, chy, ry, cy = self.matmul_shape(y) 73 | 74 | for h in range(chx): 75 | matx = x[0, h, ...] 76 | maty = y[0, h, ...] 77 | 78 | pmatx, pmaty = self.padxy(matx, maty) 79 | 80 | pmatx = self.mat2pe(pmatx) 81 | pmaty = self.mat2pe(pmaty) 82 | 83 | def write(self, x:torch.Tensor, y:torch.Tensor): 84 | self.tensor2pe(x, y) -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/auto_map.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load model architecture from different sources 3 | """ 4 | 5 | import torch 6 | import timm 7 | import torchvision 8 | import src.models as t2c_models 9 | 10 | from transformers import AutoModelForCausalLM 11 | from src.models.lm.retnet import RetNetForCausalLM 12 | 13 | # TODO: expand this list to support more model architectures 14 | MODEL_LIBRARY_MAP = { 15 | 'vit_tiny_patch16_224': ('timm', 'vision_transformer'), 16 | 'vit_small_patch16_224': ('timm', 'vision_transformer'), 17 | 'vit_base_patch16_224': ('timm', 'vision_transformer'), 18 | 'swin_tiny_patch4_window7_224': ('timm', 'swin_transformer'), 19 | 'swin_base_patch4_window7_224': ('timm', 'swin_transformer'), 20 | 'resnet18': ('torchvision', 'models'), 21 | 'resnet34': ('torchvision', 'models'), 22 | 'resnet50': ('torchvision', 'models'), 23 | 'vgg16_bn': ('torchvision', 'models'), 24 | 'Spiral-AI/Spiral-RetNet-3b-base': ('retnet', 'RetNetForCausalLM'), 25 | 'meta-llama/Llama-2-7b-hf': ('transformers', 'AutoModelForCausalLM'), 26 | 'meta-llama/Llama-3.2-1B-Instruct': ('transformers', 'AutoModelForCausalLM'), 27 | 'meta-llama/Llama-3.2-3B-Instruct': ('transformers', 'AutoModelForCausalLM'), 28 | 'meta-llama/Llama-3.1-8B-Instruct': ('transformers', 'AutoModelForCausalLM'), 29 | 'meta-llama/Llama-3.1-8B': ('transformers', 'AutoModelForCausalLM'), 30 | } 31 | 32 | TORCH_WEIGHTS_MAP = { 33 | 'resnet18': 'ResNet18_Weights', 34 | 'resnet34': 'ResNet34_Weights', 35 | 'resnet50': 'ResNet50_Weights', 36 | 'vgg16_bn': 'VGG16_BN_Weights', 37 | } 38 | 39 | class ModelMap: 40 | def __init__(self, model_name:str): 41 | self.model_name = model_name 42 | 43 | def fetch(self): 44 | if self.model_name not in MODEL_LIBRARY_MAP: 45 | raise ValueError(f"Model: {self.model_name} is unknown! Available models: {MODEL_LIBRARY_MAP.keys()}") 46 | 47 | lib_name, sub_name = MODEL_LIBRARY_MAP[self.model_name] 48 | 49 | if lib_name == "transformers": 50 | model = AutoModelForCausalLM.from_pretrained( 51 | self.model_name, 52 | load_in_8bit=False, 53 | torch_dtype=torch.float16, 54 | device_map="auto", 55 | trust_remote_code=True, 56 | ) 57 | 58 | elif lib_name == "timm": 59 | model_lib = getattr(timm, "models") 60 | sub_lib = getattr(model_lib, sub_name) 61 | model_func = getattr(sub_lib, self.model_name) 62 | 63 | model = model_func(pretrained=True) 64 | 65 | elif lib_name == "torchvision": 66 | model_func = getattr(torchvision.models, self.model_name) 67 | model_weights = getattr(torchvision.models, TORCH_WEIGHTS_MAP[self.model_name]) 68 | 69 | if hasattr(model_weights, "IMAGENET1K_V2"): 70 | model = model_func(weights=model_weights.IMAGENET1K_V2) 71 | else: 72 | model = model_func(weights=model_weights.IMAGENET1K_V1) 73 | 74 | elif lib_name == "t2c_models": 75 | if "RetNet" in self.model_name: 76 | model = RetNetForCausalLM.from_pretrained( 77 | self.model_name 78 | ) 79 | else: 80 | raise ValueError(f"Unknown model library {lib_name}") 81 | 82 | return model -------------------------------------------------------------------------------- /src/models/cifar/mobilenetv1.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/models/cifar/mobilenetv1.py -------------------------------------------------------------------------------- /src/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet with re-configured layer for CIFAR dataset 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["resnet18_cifar", "resnet50_cifar"] 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(BasicBlock, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = self.conv1(x) 31 | out = self.relu(self.bn1(out)) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | out += self.shortcut(x) 37 | out = self.relu(out) 38 | return out 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = self.relu(self.bn1(self.conv1(x))) 62 | out = self.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = self.relu(out) 66 | return out 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=10): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 64 72 | 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(64) 75 | self.relu = nn.ReLU(inplace=True) 76 | 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.linear = nn.Linear(512*block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1]*(num_blocks-1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x): 92 | out = self.conv1(x) 93 | 94 | out = self.relu(self.bn1(out)) 95 | out = self.layer1(out) 96 | out = self.layer2(out) 97 | out = self.layer3(out) 98 | out = self.layer4(out) 99 | 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | out = self.linear(out) 103 | return out 104 | 105 | def resnet18_cifar(num_classes=10): 106 | model = ResNet(block=BasicBlock, num_blocks=[2,2,2,2], num_classes=num_classes) 107 | return model 108 | 109 | def resnet50_cifar(num_classes=10): 110 | model = ResNet(block=Bottleneck, num_blocks=[3,4,6,3], num_classes=num_classes) 111 | return model -------------------------------------------------------------------------------- /src/models/cifar/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create a mini version of Vision Transformer based on timm 3 | """ 4 | 5 | import torch 6 | from timm.models.vision_transformer import VisionTransformer 7 | 8 | def vit_tiny_patch8_32(num_classes=10): 9 | model = VisionTransformer(img_size=32, 10 | patch_size=4, num_classes=num_classes, embed_dim=384, depth=7, num_heads=8, mlp_ratio=1.0) 11 | return model 12 | -------------------------------------------------------------------------------- /src/models/imagenet/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNet-V1 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | class Net(nn.Module): 9 | """ 10 | MobileNetV1 model 11 | """ 12 | def __init__(self, alpha=1.0, num_classes=1000): 13 | super(Net, self).__init__() 14 | self.alpha = alpha # width multiplier of the model 15 | 16 | def conv_bn(inp, oup, stride): 17 | layer = nn.Sequential( 18 | nn.Conv2d(inp, oup, 3, stride, padding=1, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU(inplace=True) 21 | ) 22 | return layer 23 | 24 | 25 | def conv_dw(inp, oup, stride): 26 | layer = nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.ReLU(inplace=True), 30 | 31 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 32 | nn.BatchNorm2d(oup), 33 | nn.ReLU(inplace=True) 34 | ) 35 | return layer 36 | 37 | self.model = nn.Sequential( 38 | conv_bn(3, int(32*self.alpha), 2), 39 | conv_dw(int(32*self.alpha), int(64*self.alpha), 1), 40 | conv_dw(int(64*self.alpha), int(128*self.alpha), 2), 41 | conv_dw(int(128*self.alpha), int(128*self.alpha), 1), 42 | conv_dw(int(128*self.alpha), int(256*self.alpha), 2), 43 | conv_dw(int(256*self.alpha), int(256*self.alpha), 1), 44 | conv_dw(int(256*self.alpha), int(512*self.alpha), 2), 45 | conv_dw(int(512*self.alpha), int(512*self.alpha), 1), 46 | conv_dw(int(512*self.alpha), int(512*self.alpha), 1), 47 | conv_dw(int(512*self.alpha), int(512*self.alpha), 1), 48 | conv_dw(int(512*self.alpha), int(512*self.alpha), 1), 49 | conv_dw(int(512*self.alpha), int(512*self.alpha), 1), 50 | conv_dw(int(512*self.alpha), int(1024*self.alpha), 2), 51 | conv_dw(int(1024*self.alpha), int(1024*self.alpha), 1), 52 | ) 53 | self.pool = nn.AvgPool2d(7) 54 | self.fc = nn.Linear(int(1024*self.alpha), num_classes) 55 | 56 | def forward(self, x): 57 | x = self.model(x) 58 | x = self.pool(x) 59 | x = x.view(-1, int(1024*self.alpha)) 60 | x = self.fc(x) 61 | return x 62 | 63 | def mobilenetv1(num_classes=1000): 64 | model = Net(num_classes=num_classes) 65 | return model -------------------------------------------------------------------------------- /src/models/lm/__init__.py: -------------------------------------------------------------------------------- 1 | from .retnet import * -------------------------------------------------------------------------------- /src/models/lm/configuration_retnet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | 6 | 7 | def load_config_from_json(config_file): 8 | with open(config_file, 'r') as f: 9 | config = json.load(f) 10 | config = RetNetConfig.from_dict(config) 11 | return config 12 | 13 | 14 | @dataclass 15 | class RetNetConfig(PretrainedConfig): 16 | model_type = "retnet" 17 | initializer_range: float = 0.02 18 | activation_fn: str = "gelu" 19 | dropout: float = 0.0 # dropout probability 20 | activation_dropout: float = 0.0 # dropout probability after activation in FFN. 21 | drop_path_rate: float = 0.0 22 | decoder_embed_dim: int = 768 # decoder embedding dimension 23 | decoder_value_embed_dim: int = 1280 # decoder value embedding dimension 24 | decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN 25 | decoder_layers: int = 12 # num decoder layers 26 | decoder_retention_heads: int = 3 # num decoder retention heads 27 | decoder_normalize_before: bool = True # apply layernorm before each decoder block 28 | layernorm_embedding: bool = False # add layernorm to embedding 29 | no_scale_embedding: bool = True # if True, dont scale embeddings 30 | recurrent_chunk_size: int = 512 31 | use_lm_decay: bool = False 32 | use_glu: bool = True # use GLU instead of FFN 33 | z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4 34 | deepnorm: bool = False 35 | subln: bool = True 36 | use_ffn_rms_norm: bool = False 37 | layernorm_eps: float = 1e-6 38 | tie_word_embeddings: bool = False 39 | 40 | def __init__( 41 | self, 42 | vocab_size: int = 50257, 43 | initializer_range: float = 0.02, 44 | is_decoder: bool = True, 45 | pad_token_id: int = 0, 46 | eos_token_id: int = 0, 47 | output_retentions: bool = False, 48 | use_cache: bool = True, 49 | forward_impl: str = 'parallel', 50 | activation_fn: str = "gelu", 51 | dropout: float = 0.0, # dropout probability 52 | activation_dropout: float = 0.0, # dropout probability after activation in FFN. 53 | drop_path_rate: float = 0.0, 54 | decoder_embed_dim: int = 768, # decoder embedding dimension 55 | decoder_value_embed_dim: int = 1280, # decoder value embedding dimension 56 | decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN 57 | decoder_layers: int = 12, # num decoder layers 58 | decoder_retention_heads: int = 3, # num decoder retention heads 59 | decoder_normalize_before: bool = True, # apply layernorm before each decoder block 60 | layernorm_embedding: bool = False, # add layernorm to embedding 61 | no_scale_embedding: bool = True, # if True, dont scale embeddings 62 | recurrent_chunk_size: int = 512, 63 | use_glu: bool = True, # use GLU instead of FFN 64 | z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4 65 | use_lm_decay: bool = False, 66 | deepnorm: bool = False, 67 | subln: bool = True, 68 | use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN 69 | layernorm_eps: float = 1e-6, 70 | tie_word_embeddings: bool = False, 71 | **kwargs): 72 | self.vocab_size = vocab_size 73 | self.initializer_range = initializer_range 74 | self.output_retentions = output_retentions 75 | self.use_lm_decay = use_lm_decay 76 | self.use_glu = use_glu 77 | self.z_loss_coeff = z_loss_coeff 78 | # size related 79 | self.decoder_embed_dim = decoder_embed_dim 80 | self.decoder_value_embed_dim = decoder_value_embed_dim 81 | self.decoder_retention_heads = decoder_retention_heads 82 | self.decoder_ffn_embed_dim = decoder_ffn_embed_dim 83 | self.decoder_layers = decoder_layers 84 | # normalization related 85 | self.decoder_normalize_before = decoder_normalize_before 86 | self.activation_fn = activation_fn 87 | self.dropout = dropout 88 | self.drop_path_rate = drop_path_rate 89 | self.activation_dropout = activation_dropout 90 | self.no_scale_embedding = no_scale_embedding 91 | self.layernorm_embedding = layernorm_embedding 92 | self.deepnorm = deepnorm 93 | self.subln = subln 94 | self.use_ffn_rms_norm = use_ffn_rms_norm 95 | self.layernorm_eps = layernorm_eps 96 | # Blockwise 97 | self.recurrent_chunk_size = recurrent_chunk_size 98 | self.forward_impl = forward_impl 99 | 100 | if self.deepnorm: 101 | self.decoder_normalize_before = False 102 | self.subln = False 103 | if self.subln: 104 | self.decoder_normalize_before = True 105 | self.deepnorm = False 106 | 107 | super().__init__(is_decoder=is_decoder, 108 | pad_token_id=pad_token_id, 109 | eos_token_id=eos_token_id, 110 | use_cache=use_cache, 111 | tie_word_embeddings=tie_word_embeddings, 112 | **kwargs) 113 | 114 | def override(self, args): 115 | for hp in self.__dict__.keys(): 116 | if getattr(args, hp, None) is not None: 117 | self.__dict__[hp] = getattr(args, hp, None) 118 | -------------------------------------------------------------------------------- /src/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/module/__init__.py -------------------------------------------------------------------------------- /src/module/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.module.base import _QBaseLinear 6 | from src.models.lm.retnet import get_activation_fn 7 | from transformers.activations import ACT2FN 8 | 9 | class QLlamaMLP(nn.Module): 10 | def __init__(self, config, rescale_out:bool=False): 11 | super().__init__() 12 | self.config = config 13 | self.hidden_size = config.hidden_size 14 | self.intermediate_size = config.intermediate_size 15 | self.gate_proj = _QBaseLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, rescale_out=rescale_out) 16 | self.up_proj = _QBaseLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, rescale_out=rescale_out) 17 | self.down_proj = _QBaseLinear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, rescale_out=rescale_out) 18 | self.act_fn = ACT2FN[config.hidden_act] 19 | 20 | def forward(self, x): 21 | if self.config.pretraining_tp > 1: 22 | slice = self.intermediate_size // self.config.pretraining_tp 23 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 24 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 25 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 26 | 27 | gate_proj = torch.cat( 28 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 29 | ) 30 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 31 | 32 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 33 | down_proj = [ 34 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 35 | ] 36 | down_proj = sum(down_proj) 37 | else: 38 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 39 | 40 | return down_proj 41 | 42 | class QGLU(nn.Module): 43 | def __init__( 44 | self, 45 | embed_dim, 46 | ffn_dim, 47 | activation_fn, 48 | dropout, 49 | activation_dropout, 50 | rescale_out:bool=False 51 | ): 52 | super().__init__() 53 | self.embed_dim = embed_dim 54 | self.activation_fn = get_activation_fn(activation=str(activation_fn)) 55 | self.activation_dropout_module = torch.nn.Dropout(activation_dropout) 56 | self.dropout_module = torch.nn.Dropout(dropout) 57 | self.fc1 = _QBaseLinear(self.embed_dim, ffn_dim, bias=False, rescale_out=rescale_out) 58 | self.fc2 = _QBaseLinear(ffn_dim, self.embed_dim, bias=False, rescale_out=rescale_out) 59 | self.gate = _QBaseLinear(self.embed_dim, ffn_dim, bias=False, rescale_out=rescale_out) 60 | 61 | def reset_parameters(self): 62 | self.fc1.reset_parameters() 63 | self.fc2.reset_parameters() 64 | self.gate.reset_parameters() 65 | 66 | def forward(self, x): 67 | x_shape = x.shape 68 | x = x.reshape(-1, x.size(-1)) 69 | g = self.gate(x) 70 | x = self.fc1(x) 71 | x = self.activation_fn(x.float()).type_as(x) * g 72 | x = self.activation_dropout_module(x) 73 | x = self.fc2(x) 74 | x = x.view(x_shape) 75 | x = self.dropout_module(x) 76 | return x 77 | -------------------------------------------------------------------------------- /src/module/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | import t2c_gemm 6 | INTMM = True 7 | except: 8 | print("Torch-gemm is not installed!") 9 | INTMM = False 10 | 11 | class BatchIntMatMul(nn.Module): 12 | def __init__(self, nbit:int): 13 | super().__init__() 14 | self.nbit = nbit 15 | 16 | def forward(self, x:torch.Tensor, y:torch.Tensor) -> torch.Tensor: 17 | x = x.to(torch.int8) 18 | y = y.to(torch.int8) 19 | 20 | z = t2c_gemm.bmm_int8(x, y, 1.0) 21 | return z 22 | 23 | class BatchHeadIntMatMul(nn.Module): 24 | def __init__(self, nbit:int): 25 | super().__init__() 26 | self.nbit = nbit 27 | 28 | def forward(self, x:torch.Tensor, y:torch.Tensor) -> torch.Tensor: 29 | x = x.to(torch.int8) 30 | y = y.to(torch.int8) 31 | 32 | z = t2c_gemm.bcmm_int8(x, y, 1.0) 33 | return z 34 | 35 | class IntActWeight(nn.Module): 36 | def __init__(self, nbit:int, dtype=torch.float32): 37 | super().__init__() 38 | self.register_buffer("scale", torch.ones(1, 1, 1, dtype=torch.float32)) 39 | self.nbit = nbit 40 | self.dtype = dtype 41 | 42 | def forward(self, x:torch.Tensor, y:torch.Tensor) -> torch.Tensor: 43 | x = x.to(torch.int8) 44 | y = y.to(torch.int8) 45 | z = t2c_gemm.bmw_int8(x, y, self.scale) 46 | return z.to(self.dtype) 47 | 48 | class FloatMatMul(nn.Module): 49 | def __init__(self, nbit:int): 50 | super().__init__() 51 | self.nbit = nbit 52 | 53 | def forward(self, x:torch.Tensor, y:torch.Tensor) -> torch.Tensor: 54 | z = torch.matmul(x, y) 55 | return z -------------------------------------------------------------------------------- /src/profiler/profiler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Profiler of FLOPs and MACs 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from typing import List 9 | from src.module.base import IntMatMul 10 | from src.module.fuse import MulShift, MulQuant 11 | 12 | class Profiler(object): 13 | def __init__(self, model:nn.Module): 14 | self.model = model 15 | self.profile = {} 16 | 17 | def get_row_col(self, shape:List): 18 | if len(shape) == 4: 19 | scale, row, col = shape[1], shape[2], shape[3] 20 | elif len(shape) == 3: 21 | scale, row, col = 1.0, shape[1], shape[2] 22 | elif len(shape) == 2: 23 | scale, row, col = 1.0, shape 24 | return scale, row, col 25 | 26 | def flops(self): 27 | total_flops = {} 28 | 29 | for n, m in self.model.named_modules(): 30 | if isinstance(m, IntMatMul): 31 | x_shape = m.x_shape.tolist() 32 | y_shape = m.y_shape.tolist() 33 | 34 | sx, rx, cx = self.get_row_col(x_shape) 35 | sy, ry, cy = self.get_row_col(y_shape) 36 | 37 | assert cx == ry, "ERROR: incorrect MatMul Shape" 38 | flops = (cx + (cx - 1)) * rx * cy 39 | 40 | total_flops[n] = int(flops) * sx 41 | 42 | elif isinstance(m, (MulShift, MulQuant)): 43 | pass 44 | 45 | 46 | return total_flops 47 | -------------------------------------------------------------------------------- /src/pruner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/pruner/__init__.py -------------------------------------------------------------------------------- /src/pruner/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base pruner 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | 9 | class CosineDecay(object): 10 | def __init__(self, prune_rate, T_max, eta_min=0.005, last_epoch=-1): 11 | self.sgd = optim.SGD(torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), lr=prune_rate) 12 | self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, T_max, eta_min, last_epoch) 13 | 14 | def step(self): 15 | self.cosine_stepper.step() 16 | 17 | def get_dr(self): 18 | return self.sgd.param_groups[0]['lr'] 19 | 20 | class Pruner(object): 21 | """ 22 | Base pruner 23 | """ 24 | def __init__(self, 25 | model:nn.Module, 26 | prune_ratio:float, 27 | warmup:int, 28 | final_epoch: int, 29 | dataloader, 30 | prune_freq:float, 31 | prune_decay=None, 32 | regrow:bool=True 33 | ): 34 | 35 | self.model = model 36 | self.prune_ratio = prune_ratio 37 | self.final_density = 1 - self.prune_ratio 38 | 39 | self.curr_pr = 0.0 40 | self.prune_rate_decay = prune_decay 41 | 42 | # mask buffer 43 | self.masks = {} 44 | 45 | # iterations 46 | self.steps = 0 47 | 48 | # loader and warmup 49 | self.iter_per_ep = len(dataloader) 50 | self.warmup = warmup 51 | self.final_epoch = final_epoch 52 | 53 | # pruning frequency 54 | self.prune_freq = prune_freq 55 | 56 | # regrow 57 | self.regrow = regrow 58 | 59 | @property 60 | def pr(self): 61 | return self.curr_pr 62 | 63 | @property 64 | def sparsity(self): 65 | self.compute_sparsity() 66 | return self.current_sparsity 67 | 68 | def init_schedule(self): 69 | self.final_step = int((self.final_epoch * self.iter_per_ep) / self.prune_freq) 70 | self.start_step = int((self.warmup * self.iter_per_ep) / self.prune_freq) 71 | self.total_step = self.final_step - self.start_step 72 | 73 | def _param_stats(self): 74 | total_params = 0 75 | spars_params = 0 76 | for n, m in self.model.named_modules(): 77 | if hasattr(m, "mask"): 78 | mask = m.mask.data 79 | total_params += mask.numel() 80 | spars_params += mask[mask.eq(0)].numel() 81 | return total_params, spars_params 82 | 83 | def compute_sparsity(self): 84 | total_params = 0 85 | ones = 0 86 | for n, m in self.model.named_modules(): 87 | if hasattr(m, "mask"): 88 | mask = m.mask.data 89 | total_params += mask.numel() 90 | ones += mask.sum() 91 | self.current_sparsity = 1 - ones / total_params 92 | 93 | def register_masks(self): 94 | for n, m in self.model.named_modules(): 95 | if hasattr(m, "mask"): 96 | self.masks[n] = m.mask 97 | 98 | def apply_masks(self): 99 | for n, m in self.model.named_modules(): 100 | if hasattr(m, "mask"): 101 | m.mask.data.copy_(self.masks[n]) 102 | 103 | def get_weight_grad(self, weight:torch.Tensor): 104 | grad = weight.grad.clone() 105 | return grad 106 | 107 | def step(self): 108 | # update the current prune_rate / probability 109 | self.prune_rate_decay.step() 110 | self.dr = self.prune_rate_decay.get_dr() 111 | 112 | # increment 113 | self.steps += 1 114 | 115 | if self.steps >= int(self.warmup * self.iter_per_ep) and self.steps % self.prune_freq == 0: 116 | if self.steps != 0: 117 | self.pruning() 118 | if self.regrow: 119 | self.prune_and_regrow() 120 | 121 | def pruning(self): 122 | """ 123 | Pruning method 124 | """ 125 | pass 126 | 127 | def prune_and_regrow(self): 128 | """ 129 | Plasticity of sparsity 130 | """ 131 | pass -------------------------------------------------------------------------------- /src/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/quantization/__init__.py -------------------------------------------------------------------------------- /src/quantization/adaround.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adaptive Round 3 | """ 4 | 5 | import torch 6 | from src.module.base import _QBase 7 | from src.quantization.observer import BaseObserver, lp_loss 8 | 9 | class AdaRoundObserver(BaseObserver): 10 | def __init__(self, nbit: int, unsigned: bool = True): 11 | super().__init__(nbit, unsigned) 12 | 13 | def quantize(self, x:torch.Tensor, xmin, xmax): 14 | delta = (xmax - xmin) / (self.qub - self.qlb) 15 | 16 | if self.unsigned: 17 | zero_point = self.qlb - torch.round(xmin / delta) 18 | else: 19 | zero_point = torch.tensor(0.0) 20 | 21 | xint = torch.round(x / delta) 22 | xq = torch.clamp(xint - zero_point, self.qlb, self.qub) 23 | xdq = (xq + zero_point) * delta 24 | return xdq, delta, zero_point 25 | 26 | def calculate_qparam(self, x: torch.Tensor): 27 | # update the quantization boundary 28 | self.get_bound(x) 29 | 30 | # quantization parameters 31 | scale, zero_point = torch.tensor(1.0), torch.tensor(0.0) 32 | 33 | best_loss = 1e+10 34 | for i in range(100): 35 | new_min = self.lb * (1.0 - (i * 0.01)) 36 | new_max = self.ub * (1.0 - (i * 0.01)) 37 | 38 | # quantize and dequantize for mse 39 | xdq, new_scale, new_zp = self.quantize(x, new_min, new_max) 40 | loss = lp_loss(xdq, x, p=2.4, reduction='all') 41 | 42 | if loss < best_loss: 43 | best_loss = loss 44 | scale, zero_point = new_scale, new_zp 45 | 46 | self.lb.data = new_min 47 | self.ub.data = new_max 48 | 49 | return scale, zero_point 50 | 51 | 52 | class AdaRound(_QBase): 53 | """ 54 | Weight quantizer AdaRound: "Up or Down? Adaptive Rounding for Post-Training Quantization" 55 | https://arxiv.org/abs/2004.10568 56 | """ 57 | def __init__(self, nbit: int, train_flag: bool = True, weights: torch.Tensor=None, unsigned=False): 58 | super().__init__(nbit, train_flag, unsigned) 59 | self.iter = 0 60 | 61 | # initialize the alpha 62 | self.init_flag = True 63 | 64 | # parameters 65 | self.gamma, self.zeta = -0.1, 1.1 66 | self.beta = 2/3 67 | 68 | # define the observer 69 | self.observer = AdaRoundObserver(nbit=self.nbit, unsigned=self.unsigned) 70 | 71 | # register the learnable parameters 72 | self.register_alpha(weights) 73 | 74 | def register_alpha(self, x:torch.Tensor): 75 | self.register_buffer("delta", torch.tensor(1.0)) 76 | 77 | delta, zp = self.observer(x) 78 | 79 | self.delta.copy_(delta) 80 | self.scale.copy_(delta) 81 | self.zero_point.copy_(zp) 82 | 83 | # find the optimal scaling factor first 84 | xfloor = x.div(self.delta).floor() 85 | 86 | # compute alpha 87 | diff = x.div(self.delta).sub(xfloor) 88 | alpha = -torch.log((self.zeta - self.gamma) / (diff - self.gamma) - 1) 89 | self.register_parameter("alpha", torch.nn.Parameter(alpha)) 90 | 91 | def h(self): 92 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 93 | 94 | def q(self, x:torch.Tensor): 95 | # quantization 96 | xfloor = x.div(self.scale).floor() 97 | soft_shift = self.h() 98 | 99 | # quantize 100 | if self.train_flag or self.training: 101 | xada = xfloor + soft_shift 102 | else: 103 | xada = xfloor + self.alpha.ge(0.0).float() 104 | 105 | xq = xada + self.zero_point 106 | # integer representation 107 | output = torch.clamp(xq, self.observer.qlb, self.observer.qub).sub(self.zero_point) 108 | 109 | # dequantize 110 | if self.dequantize: 111 | output = output.mul(self.scale) 112 | return output 113 | 114 | def trainFunc(self, input: torch.Tensor): 115 | xq = self.q(input) 116 | return xq 117 | 118 | def evalFunc(self, input: torch.Tensor): 119 | xq = self.q(input) 120 | return xq 121 | -------------------------------------------------------------------------------- /src/quantization/observer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Observer of the high precision floating point distribution 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | class BaseObserver(nn.Module): 9 | def __init__(self, nbit:int, unsigned:bool=True): 10 | super().__init__() 11 | 12 | self.nbit = nbit 13 | self.unsigned = unsigned 14 | self.initialize = True 15 | 16 | # quantization 17 | if self.unsigned: 18 | self.qlb = 0 19 | self.qub = 2 ** self.nbit - 1 20 | else: 21 | self.qlb = -2**(self.nbit-1) 22 | self.qub = 2**(self.nbit-1) - 1 23 | 24 | # initialize the floating point boundaries 25 | self.register_range() 26 | 27 | def register_range(self): 28 | # register buffer for the floating point range 29 | self.register_buffer("lb", torch.tensor(float("-inf"))) 30 | self.register_buffer("ub", torch.tensor(float("inf"))) 31 | 32 | def get_bound(self, x:torch.Tensor): 33 | min_val = x.min() 34 | max_val = x.max() 35 | 36 | if self.initialize: 37 | 38 | self.lb.data = min_val 39 | self.ub.data = max_val 40 | 41 | self.initialize = False 42 | else: 43 | lb = torch.min(self.lb, min_val) 44 | ub = torch.max(self.ub, max_val) 45 | 46 | # update bound 47 | self.lb.copy_(lb) 48 | self.ub.copy_(ub) 49 | 50 | def calculate_qparam(self, x:torch.Tensor): 51 | 52 | if self.unsigned: 53 | scale = (self.ub - self.lb) / (self.qub - self.qlb) 54 | zero_point = self.qlb - torch.round(self.lb / scale) 55 | else: 56 | max_val_pos = torch.max(-self.lb, self.ub) 57 | scale = max_val_pos / (float(self.qub - self.qlb) / 2) 58 | zero_point = torch.zeros(max_val_pos.size(), dtype=x.dtype, device=max_val_pos.device) 59 | 60 | return scale, zero_point 61 | 62 | def forward(self, x:torch.Tensor): 63 | 64 | if x.dtype != torch.float32: 65 | x = x.to(torch.float32) 66 | 67 | self.get_bound(x) 68 | scale, zero_point = self.calculate_qparam(x) 69 | return scale, zero_point 70 | 71 | 72 | class BaseChannelWiseObserver(BaseObserver): 73 | def __init__(self, nbit: int, unsigned: bool = True, num_channels:int=1): 74 | self.num_channels = num_channels 75 | super().__init__(nbit, unsigned) 76 | 77 | # register the upper and lower bound 78 | self.register_range() 79 | 80 | def register_range(self): 81 | # register buffer for the floating point range 82 | self.register_buffer("lb", torch.ones(self.num_channels, 1).mul(float("-inf"))) 83 | self.register_buffer("ub", torch.ones(self.num_channels, 1).mul(float("inf"))) 84 | 85 | def reshape(self, x): 86 | xr = x.reshape(-1, x.shape[-1]) 87 | return xr 88 | 89 | def get_bound(self, x:torch.Tensor): 90 | xr = self.reshape(x) 91 | 92 | min_val = xr.min(dim=1, keepdim=True)[0] 93 | max_val = xr.max(dim=1, keepdim=True)[0] 94 | 95 | if self.initialize: 96 | self.lb.data = min_val 97 | self.ub.data = max_val 98 | 99 | self.initialize = False 100 | else: 101 | lb = torch.min(self.lb, min_val) 102 | ub = torch.max(self.ub, max_val) 103 | 104 | # update bound 105 | self.lb.copy_(lb) 106 | self.ub.copy_(ub) 107 | 108 | 109 | class BaseTokenWiseObserver(BaseObserver): 110 | def __init__(self, nbit: int, unsigned: bool = True, num_tokens:int=197): 111 | # number of channels 112 | self.num_tokens = num_tokens 113 | super().__init__(nbit, unsigned) 114 | 115 | self.register_range() 116 | 117 | def register_range(self): 118 | # register buffer for the floating point range 119 | self.register_buffer("lb", torch.ones(1, self.num_tokens, 1).mul(float("-inf"))) 120 | self.register_buffer("ub", torch.ones(1, self.num_tokens, 1).mul(float("inf"))) 121 | 122 | def get_bound(self, x:torch.Tensor): 123 | x = x.reshape(x.size(1), -1) 124 | 125 | min_val = x.min(dim=1, keepdim=True)[0] 126 | max_val = x.max(dim=1, keepdim=True)[0] 127 | 128 | self.lb.data = min_val.unsqueeze(0) 129 | self.ub.data = max_val.unsqueeze(0) 130 | 131 | 132 | def lp_loss(pred, target, p=2.0, reduction='none'): 133 | """ 134 | loss function measured in lp norm 135 | """ 136 | if reduction == 'none': 137 | return (pred-target).abs().pow(p).sum(1).mean() 138 | else: 139 | return (pred-target).abs().pow(p).mean() -------------------------------------------------------------------------------- /src/quantization/qdrop.py: -------------------------------------------------------------------------------- 1 | """ 2 | T2C version of QDrop 3 | 4 | Paper: https://openreview.net/forum?id=ySQH0oDyp7 5 | """ 6 | 7 | import torch 8 | from src.quantization.lsq import LSQ, LSQTokenWise 9 | 10 | class QDrop(LSQ): 11 | def __init__(self, nbit: int = 8, train_flag: bool = True, unsigned: bool = True, drop_prob:float=0.5): 12 | super().__init__(nbit, train_flag, unsigned) 13 | self.drop_prob = drop_prob 14 | 15 | def forward(self, input:torch.Tensor): 16 | xorg = input 17 | y = super().forward(input) 18 | 19 | if self.drop_prob < 1.0 and self.training: 20 | x_prob = torch.where(torch.rand_like(input) < self.drop_prob, y, xorg) 21 | return x_prob 22 | return y 23 | 24 | class QDropTokenWise(LSQTokenWise): 25 | def __init__(self, nbit: int = 8, train_flag: bool = True, unsigned: bool = True, drop_prob:float=0.5): 26 | super().__init__(nbit, train_flag, unsigned) 27 | self.drop_prob=0.5 28 | 29 | def forward(self, input:torch.Tensor): 30 | xorg = input 31 | y = super().forward(input) 32 | 33 | if self.drop_prob < 1.0 and self.training: 34 | x_prob = torch.where(torch.rand_like(input) < self.drop_prob, y, xorg) 35 | return x_prob 36 | return y -------------------------------------------------------------------------------- /src/stage/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Register stage and import all the necessary packages / dependencies. 3 | 4 | Fetch arguments, define models, call necessary trainers, executors, etc. 5 | """ 6 | 7 | import os 8 | import yaml 9 | import logging 10 | import torch 11 | 12 | from src.models.auto_map import ModelMap 13 | 14 | class Execute: 15 | """ 16 | Costruct the starting point of all the executions for Torch2Chip, including model configuration and necessary module fetching (e.g., trainer) 17 | 18 | Args: 19 | config: configuration defined in an external .yaml file. 20 | pretrained_checkpoint: pre-trained checkpoint of the target model 21 | """ 22 | def __init__( 23 | self, 24 | config_dir, 25 | ): 26 | 27 | self.config_dir = config_dir 28 | self.config = self.prepare_config() 29 | self.run_dir = self.config["save"]["run_dir"] 30 | 31 | # detect device 32 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 33 | 34 | # run dir 35 | self.register_run_dir() 36 | 37 | # initialize logging 38 | self.logger = self.initialize_logger() 39 | 40 | def __name__(self): 41 | return "Execute" 42 | 43 | def register_run_dir(self): 44 | if not os.path.isdir(self.run_dir): 45 | os.makedirs(self.run_dir, exist_ok=True) 46 | 47 | def prepare_config(self): 48 | with open(self.config_dir, 'r') as f: 49 | config = yaml.full_load(f) 50 | return config 51 | 52 | def initialize_logger(self): 53 | logname = self.config["save"]["logger"] 54 | logpath = os.path.join(self.run_dir, logname) 55 | 56 | logger = logging.getLogger(logname) 57 | logger.setLevel(logging.DEBUG) 58 | 59 | if not logger.handlers: 60 | file_handler = logging.FileHandler(logpath, mode="w") 61 | console_handler = logging.StreamHandler() 62 | 63 | file_handler.setLevel(logging.DEBUG) 64 | console_handler.setLevel(logging.INFO) 65 | 66 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 67 | file_handler.setFormatter(formatter) 68 | console_handler.setFormatter(formatter) 69 | 70 | logger.addHandler(file_handler) 71 | logger.addHandler(console_handler) 72 | 73 | return logger 74 | 75 | def create_model(self): 76 | # mapper 77 | model_type = self.config["model"]["model_type"] 78 | print(f"Creating model {model_type}...") 79 | 80 | model_func = ModelMap(model_type) 81 | model = model_func.fetch() 82 | 83 | # map to device 84 | model.to(self.device) 85 | return model 86 | 87 | def output(self): 88 | """ 89 | Output of the execution stage. Default: Save model state_dict 90 | """ 91 | model_dict = self.model.state_dict() 92 | model_path = os.path.join(self.run_dir, "latest_model.pth.tar") 93 | 94 | torch.save(model_dict, model_path) 95 | 96 | def print_arch(self, model:torch.nn.Module, name:str): 97 | path = os.path.join(self.run_dir, name+".txt") 98 | 99 | with open(path, "w") as file: 100 | print(model, file=file) 101 | 102 | 103 | def run(self): 104 | """ 105 | Entrance of execution 106 | """ 107 | self.logger.info(f"Start stage {self.name}") -------------------------------------------------------------------------------- /src/stage/calib.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/stage/calib.py -------------------------------------------------------------------------------- /src/stage/hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | HuggingFace Stage 3 | """ 4 | 5 | from src.stage.base import Execute 6 | from transformers import AutoTokenizer 7 | 8 | class HFExecute(Execute): 9 | def __init__(self, config_dir): 10 | super().__init__(config_dir) 11 | 12 | # tokenizer 13 | self.tokenizer = AutoTokenizer.from_pretrained(self.config["model"]["model_type"]) 14 | 15 | def commonsense_qa(self): 16 | self.tokenizer.pad_token_id = (0) 17 | self.tokenizer.padding_side = "left" 18 | 19 | def gsm8k(self): 20 | if self.tokenizer.eos_token_id is not None: 21 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 22 | else: 23 | self.tokenizer.pad_token_id = 0 24 | -------------------------------------------------------------------------------- /src/t2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/t2c/__init__.py -------------------------------------------------------------------------------- /src/t2c/fusers/bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fuser of BERT 3 | """ 4 | 5 | import torch.nn as nn 6 | from src.module.attention import QBertSelfAttention 7 | from src.t2c.convert import get_parent_name 8 | from src.t2c.fusers.vit import ViTFuser 9 | from src.module.fuse import MulQuant 10 | from src.quantization.observer import BaseObserver, BaseTokenWiseObserver 11 | 12 | from transformers.models.bert.modeling_bert import BertSelfOutput 13 | 14 | class BERTFuser(ViTFuser): 15 | def __init__(self, model: nn.Module): 16 | super().__init__(model) 17 | 18 | def qkv_fuser(self, module: QBertSelfAttention): 19 | module.inference() 20 | 21 | # fetch modules 22 | xq = getattr(module, "xq") 23 | sq = module.qquery.scale 24 | sk = module.qkey.scale 25 | sv = module.qvalue.scale 26 | 27 | query = getattr(module, "query") 28 | key = getattr(module, "key") 29 | value = getattr(module, "value") 30 | 31 | sxq = self.quantizer_fuse(xq, query.wq) 32 | sxk = self.quantizer_fuse(xq, key.wq) 33 | sxv = self.quantizer_fuse(xq, value.wq) 34 | 35 | qquery = MulQuant(nbit=module.qquery.nbit) 36 | sxq = sq.mul(sxq) 37 | qbias = query.bias.mul(sq) 38 | 39 | setattr(qquery, "scale", sxq) 40 | setattr(qquery, "bias", qbias) 41 | setattr(qquery, "zero_point", module.qquery.zero_point) 42 | 43 | qkey = MulQuant(nbit=module.qkey.nbit) 44 | sxk = sk.mul(sxk) 45 | kbias = key.bias.mul(sk) 46 | 47 | setattr(qkey, "scale", sxk) 48 | setattr(qkey, "bias", kbias) 49 | setattr(qkey, "zero_point", module.qkey.zero_point) 50 | 51 | qvalue = MulQuant(nbit=module.qvalue.nbit) 52 | sxv = sv.mul(sxv) 53 | vbias = value.bias.mul(sv) 54 | 55 | setattr(qvalue, "scale", sxv) 56 | setattr(qvalue, "bias", vbias) 57 | setattr(qvalue, "zero_point", module.qvalue.zero_point) 58 | 59 | if isinstance(module.qquery.observer, BaseTokenWiseObserver): 60 | if isinstance(module.qkey.observer, BaseTokenWiseObserver): 61 | # [B, Head, Token, Token] 62 | qkscale = (sq @ sk.transpose(-1,-2)).unsqueeze(0) 63 | elif isinstance(module.qquery.observer, BaseObserver): 64 | qkscale = sq * sk 65 | 66 | # scale back after q @ k 67 | module.attn_scale.scale.data = 1 / (qkscale) * module.attn_scale.scale 68 | 69 | # scale back after attention @ v 70 | ssfmx = 1 / 255 71 | # module.qkv_deq.scale = 1 / sv * ssfmx 72 | module.qkv_deq.scale = 1 / sv * ssfmx 73 | 74 | # # update the module 75 | setattr(module, "qquery", qquery) 76 | setattr(module, "qqkey", qkey) 77 | setattr(module, "qvalue", qvalue) 78 | 79 | return module 80 | 81 | def output_fuser(self, module:BertSelfOutput): 82 | dense = getattr(module, "dense") 83 | 84 | fdense = self.fuse_linear(dense) 85 | setattr(module, "dense", fdense) 86 | return module 87 | 88 | def fuse(self): 89 | modules = dict(self.model.named_modules(remove_duplicate=True)) 90 | 91 | for n, m in self.model.named_modules(): 92 | if isinstance(m, QBertSelfAttention): 93 | print(f"Fusing {n}") 94 | parent_name, name = get_parent_name(n) 95 | 96 | module = self.qkv_fuser(m) 97 | setattr(modules[parent_name], name, module) 98 | 99 | elif isinstance(m, BertSelfOutput): 100 | print(f"Fusing {n}") 101 | parent_name, name = get_parent_name(n) 102 | 103 | module = self.output_fuser(m) 104 | setattr(modules[parent_name], name, module) 105 | 106 | return self.model 107 | 108 | 109 | -------------------------------------------------------------------------------- /src/t2c/fusers/fusers.py: -------------------------------------------------------------------------------- 1 | """ 2 | BatchNorm fusion with full observability 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from src.module.base import _QBase 8 | from typing import List, Union 9 | from src.module.fuse import QConvReLU, QConvBNReLU, _QBaseConv2d, _QBaseLinear, MulShift, MulQuant 10 | from src.quantization.observer import BaseChannelWiseObserver 11 | 12 | 13 | class LayerFuser(object): 14 | def __init__(self, model:nn.Module): 15 | self.model = model 16 | # flag 17 | self.flag = False 18 | 19 | # layers 20 | self.groups = [] 21 | 22 | # parameters 23 | self.xscales = [] 24 | self.xzps = [] 25 | 26 | # full precision conv layer 27 | self.fpl = 1 28 | 29 | # full precision classifier 30 | self.fpc = False 31 | 32 | def inference(self): 33 | """ 34 | Switch to inference mode 35 | """ 36 | for n, m in self.model.named_modules(): 37 | if hasattr(m, "inference"): 38 | m.inference() 39 | 40 | def fuse_linear(self, layer:_QBaseLinear): 41 | # switch 42 | layer.inference() 43 | 44 | if layer.bias is not None: 45 | bias = layer.bias.data 46 | 47 | scaler = MulShift(dtype=torch.float32) 48 | sq = 1 / (layer.wq.scale.data * layer.aq.scale.data) 49 | 50 | # scaler 51 | scaler.scale.data = sq 52 | 53 | # assign the scaling factor to the quantizer 54 | if isinstance(layer.wq.observer, BaseChannelWiseObserver): 55 | scaler.scale.data = sq.squeeze(1).unsqueeze(0) 56 | else: 57 | scaler.scale.data = sq 58 | 59 | scaler.bias.data = bias.unsqueeze(0) 60 | setattr(scaler, "yq", scaler) 61 | return layer 62 | 63 | def quantizer_bn_fuse(self, xq:_QBase, wq:_QBase, bn:Union[nn.BatchNorm1d, nn.BatchNorm2d]): 64 | sq = 1 / (wq.scale.data * xq.scale.data) 65 | 66 | # bn scaling 67 | std = torch.sqrt(bn.running_var.data + bn.eps) 68 | 69 | if isinstance(wq.observer, BaseChannelWiseObserver): 70 | sw = wq.scale.data.reshape(bn.weight.shape) 71 | else: 72 | sw = wq.scale.data 73 | 74 | # scaling 75 | sq = 1 / (sw * xq.scale) 76 | sbn = bn.weight.data.mul(sq) / std 77 | 78 | # bn bias 79 | bbn = bn.bias - bn.weight.mul(bn.running_mean.data).div(std) 80 | 81 | return sbn.unsqueeze(0).unsqueeze(2).unsqueeze(3), bbn.unsqueeze(0).unsqueeze(2).unsqueeze(3) 82 | 83 | 84 | def conv_bn_relu(self, cbr:List, l=-1.0, snxt:float=1.0, zpnxt:float=0.0, int_out:bool=False): 85 | assert len(cbr) == 3, "The input must include conv, bn, and relu modules" 86 | conv, bn, _ = cbr 87 | 88 | # fused layer 89 | tmp = QConvBNReLU(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, groups=conv.groups, 90 | wbit=conv.wbit, abit=conv.abit, train_flag=conv.train_flag, int_out=False) 91 | 92 | # assign modules 93 | setattr(tmp, "conv", cbr[0]) 94 | setattr(tmp, "relu", cbr[2]) 95 | tmp.conv.inference() 96 | 97 | sbn, bbn = self.quantizer_bn_fuse(tmp.conv.aq, tmp.conv.wq, bn) 98 | 99 | # scale and bias 100 | tmp.scaler.scale.data = sbn 101 | tmp.scaler.bias.data = bbn 102 | 103 | return tmp 104 | 105 | def conv_relu(self, cr:List, l=-1.0, snxt:float=1.0, zpnxt:float=0.0, int_out:bool=False): 106 | assert len(cr) == 2, "The input must include conv and relu modules" 107 | 108 | conv, relu = cr 109 | 110 | # fused layer 111 | tmp = QConvReLU(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, 112 | wbit=conv.wbit, abit=conv.abit, train_flag=False, int_out=int_out) 113 | 114 | # quantization scalers 115 | sq = 1 / (conv.wq.scale.data * conv.aq.scale.data) 116 | 117 | # scaled bias 118 | sb = conv.bias.data.div(sq) 119 | conv.bias.data = sb 120 | 121 | # assign modules 122 | setattr(tmp, "conv", conv) 123 | setattr(tmp, "relu", relu) 124 | 125 | # next layer scaler 126 | tmp.scaler.scale.data = sq.mul(snxt) 127 | 128 | if isinstance(tmp.scaler, MulQuant): 129 | tmp.scaler.zero_point.data = zpnxt 130 | 131 | # replace the activation quantizer by the Identity module 132 | if l > self.fpl-1: 133 | tmp.conv.aq = nn.Identity() 134 | 135 | return tmp 136 | 137 | def fuse(self): 138 | pass 139 | -------------------------------------------------------------------------------- /src/t2c/fusers/lm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fuse language model 3 | """ 4 | from src.t2c.convert import get_parent_name 5 | from src.module.base import _QBaseLinear, _QBase 6 | from src.quantization.observer import BaseChannelWiseObserver, BaseTokenWiseObserver 7 | from src.module.fuse import Add 8 | 9 | class LMFuser(object): 10 | """ 11 | Fuser of language model with the isolated matrix multiplication. 12 | """ 13 | def __init__(self, fake_quant_model, rescale_out:bool=False): 14 | self.model = fake_quant_model 15 | self.device = self.model.device 16 | self.rescale_out = rescale_out 17 | 18 | def inference(self): 19 | """ 20 | Switch to inference mode 21 | """ 22 | pass 23 | 24 | def quantizer_fuse(self, xq:_QBase, wq:_QBase): 25 | scale_x = xq.scale 26 | scale_w = wq.scale 27 | if isinstance(xq.observer, BaseTokenWiseObserver): 28 | if isinstance(wq.observer, BaseChannelWiseObserver): 29 | scale_w = scale_w.unsqueeze(0) 30 | sw = scale_x @ scale_w.transpose(1,2) 31 | else: 32 | sw = scale_x * scale_w 33 | else: 34 | if isinstance(wq.observer, BaseChannelWiseObserver): 35 | scale_w = scale_w.unsqueeze(0).transpose(1,2) 36 | sw = scale_x * scale_w 37 | 38 | return sw 39 | 40 | def fuse_linear(self, layer:_QBaseLinear): 41 | # switch to inference mode 42 | layer.inference() 43 | 44 | # NOTE: Make sure the performance is not affected by the infinitesimal scales 45 | scaler = Add() 46 | bias = getattr(layer, "bias") 47 | 48 | if bias is not None: 49 | scaler.bias.data = bias 50 | 51 | setattr(layer, "yq", scaler) 52 | setattr(layer, "rescale_out", self.rescale_out) 53 | return layer 54 | 55 | def fuse(self): 56 | modules = dict(self.model.named_modules(remove_duplicate=True)) 57 | 58 | for n, m in modules.items(): 59 | if isinstance(m, _QBaseLinear): 60 | parent_name, name = get_parent_name(n) 61 | new_layer = self.fuse_linear(m) 62 | new_layer = new_layer.to(self.device) 63 | setattr(modules[parent_name], name, new_layer) 64 | 65 | return self.model 66 | 67 | class LlamaFuser(LMFuser): 68 | def __init__(self, fake_quant_model, rescale_out:bool=False): 69 | super().__init__(fake_quant_model, rescale_out) 70 | -------------------------------------------------------------------------------- /src/t2c/fusers/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.t2c.fusers.fusers import LayerFuser 4 | from src.module.base import _QBaseConv2d, _QBaseLinear 5 | 6 | class MobileNetV1Fuser(LayerFuser): 7 | def __init__(self, model: nn.Module): 8 | super().__init__(model) 9 | 10 | def inference(self): 11 | """ 12 | Switch to inference mode 13 | """ 14 | for n, m in self.model.named_modules(): 15 | if hasattr(m, "inference"): 16 | m.inference() 17 | 18 | def layers(self): 19 | """ 20 | Fetch layer information from pretrained model 21 | """ 22 | conv_bn_relu = [] 23 | l = 0 24 | for n, m in self.model.named_modules(): 25 | if isinstance(m, nn.Conv2d) and not hasattr(m, "wbit"): 26 | self.fpl += 1 27 | 28 | elif isinstance(m, _QBaseConv2d): 29 | self.flag = True 30 | conv_bn_relu.append(m) 31 | 32 | # scales and boundaries 33 | self.xscales.append(m.aq.scale.data) 34 | self.xzps.append(m.aq.zero_point.data) 35 | l += 1 36 | 37 | elif isinstance(m, nn.BatchNorm2d) and self.flag: 38 | conv_bn_relu.append(m) 39 | 40 | elif isinstance(m, nn.ReLU) and self.flag: 41 | conv_bn_relu.append(m) 42 | self.groups.append(conv_bn_relu) 43 | 44 | # reset 45 | self.flag = False 46 | conv_bn_relu = [] 47 | 48 | elif isinstance(m, _QBaseLinear): 49 | self.fpc = False 50 | 51 | if not isinstance(m.aq, nn.Identity): 52 | # scales and boundaries 53 | self.xscales.append(m.aq.scale.data) 54 | self.xzps.append(m.aq.zero_point.data) 55 | l += 1 56 | 57 | def fuse(self): 58 | """ 59 | Fuse conv, layer, relu for MobileNet architecture 60 | """ 61 | l = 0 # layer counter 62 | 63 | fused_model = self.model 64 | 65 | # update the groups 66 | self.layers() 67 | 68 | for name, module in self.model.named_children(): 69 | if isinstance(module, (nn.AvgPool2d, nn.Linear)): 70 | continue 71 | else: 72 | # layers in the bottom level sequential 73 | for n, m in module.named_children(): 74 | assert len(m) > 0 75 | seq = [] 76 | for layer in m.modules(): 77 | if isinstance(layer, nn.Conv2d) and not hasattr(layer, "wbit"): 78 | seq.append(layer) 79 | 80 | elif isinstance(layer, _QBaseConv2d): 81 | # fetch the module 82 | conv_bn_relu = self.groups[l] 83 | self.flag = True 84 | l += 1 85 | 86 | if l < len(self.xscales)-1: 87 | snxt = self.xscales[l+1] 88 | zpnxt = self.xzps[l+1] 89 | int_out = True 90 | else: 91 | snxt = torch.tensor(1.0) 92 | zpnxt = torch.tensor(0.0) 93 | int_out = False 94 | 95 | tmp = self.conv_bn_relu(conv_bn_relu, l=l, snxt=snxt, zpnxt=zpnxt, int_out=int_out) 96 | seq.append(tmp) 97 | 98 | elif isinstance(layer, nn.BatchNorm2d): 99 | if l != 0: 100 | tmp = nn.Identity() 101 | seq.append(tmp) 102 | else: 103 | seq.append(layer) 104 | 105 | elif isinstance(layer, nn.ReLU): 106 | if l != 0: 107 | tmp = nn.Identity() 108 | seq.append(tmp) 109 | else: 110 | seq.append(layer) 111 | self.flag = False 112 | 113 | # reconstruct 114 | seq = nn.Sequential(*seq) 115 | setattr(module, n, seq) 116 | setattr(fused_model, name, module) 117 | 118 | # linear = fused_model.fc 119 | fused_linear = self.fuse_linear(fused_model.fc) 120 | setattr(fused_model, "fc", fused_linear) 121 | 122 | return fused_model -------------------------------------------------------------------------------- /src/t2c/fusers/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from src.t2c.fusers.fusers import LayerFuser 3 | 4 | class ResNet18Fuser(LayerFuser): 5 | def __init__(self, model: nn.Module): 6 | super().__init__(model) 7 | 8 | def fuse(self): 9 | for name, module in self.model.named_children(): 10 | if "layer" in name: 11 | for basic_block_name, basic_block in module.named_children(): 12 | cbr = [basic_block.conv1, basic_block.bn1, basic_block.relu] 13 | cb = [basic_block.conv2, basic_block.bn2, nn.Identity()] 14 | 15 | # get fused modules 16 | fm1 = self.conv_bn_relu(cbr) 17 | fm2 = self.conv_bn_relu(cb) 18 | 19 | # update modules 20 | basic_block.conv1 = fm1 21 | basic_block.conv2 = fm2 22 | 23 | # disable other modules 24 | basic_block.bn1 = nn.Identity() 25 | basic_block.bn2 = nn.Identity() 26 | 27 | for sub_block_name, sub_block in basic_block.named_children(): 28 | if "shortcut" in sub_block_name or "downsample" in sub_block_name: 29 | if len(sub_block) > 0: 30 | cbr = list(sub_block) 31 | cbr.append(nn.Identity()) 32 | fsc = self.conv_bn_relu(cbr) 33 | 34 | # update shortcut 35 | setattr(basic_block, sub_block_name, fsc) 36 | 37 | # special treatment on the first conv-bn-relu block 38 | elif "conv1" in name: 39 | cbr = [self.model.conv1, self.model.bn1, self.model.relu] 40 | 41 | # get fused modules 42 | fm1 = self.conv_bn_relu(cbr) 43 | 44 | # update the module 45 | self.model.conv1 = fm1 46 | 47 | # disable other modules 48 | self.model.bn1 = nn.Identity() 49 | self.model.relu = nn.Identity() 50 | 51 | elif "fc" in name: 52 | fm1 = self.fuse_linear(self.model.fc) 53 | self.model.fc = fm1 54 | 55 | return self.model 56 | 57 | class ResNet34Fuser(ResNet18Fuser): 58 | def __init__(self, model: nn.Module): 59 | super().__init__(model) 60 | 61 | class ResNet50Fuser(LayerFuser): 62 | def __init__(self, model: nn.Module): 63 | super().__init__(model) 64 | 65 | def fuse(self): 66 | for name, module in self.model.named_children(): 67 | if "layer" in name: 68 | for basic_block_name, basic_block in module.named_children(): 69 | cb0 = [basic_block.conv1, basic_block.bn1, basic_block.relu] 70 | cb1 = [basic_block.conv2, basic_block.bn2, basic_block.relu] 71 | cb2 = [basic_block.conv3, basic_block.bn3, nn.Identity()] 72 | 73 | # get fused modules 74 | fm0 = self.conv_bn_relu(cb0) 75 | fm1 = self.conv_bn_relu(cb1) 76 | fm2 = self.conv_bn_relu(cb2) 77 | 78 | # update modules 79 | basic_block.conv1 = fm0 80 | basic_block.conv2 = fm1 81 | basic_block.conv3 = fm2 82 | 83 | # disable other modules 84 | basic_block.bn1 = nn.Identity() 85 | basic_block.bn2 = nn.Identity() 86 | basic_block.bn3 = nn.Identity() 87 | 88 | for sub_block_name, sub_block in basic_block.named_children(): 89 | if "shortcut" in sub_block_name or "downsample" in sub_block_name: 90 | if len(sub_block) > 0: 91 | cbr = list(sub_block) 92 | cbr.append(nn.Identity()) 93 | fsc = self.conv_bn_relu(cbr) 94 | 95 | # update shortcut 96 | setattr(basic_block, sub_block_name, fsc) 97 | 98 | # special treatment on the first conv-bn-relu block 99 | elif "conv1" in name: 100 | cbr = [self.model.conv1, self.model.bn1, self.model.relu] 101 | 102 | # get fused modules 103 | fm1 = self.conv_bn_relu(cbr) 104 | 105 | # update the module 106 | self.model.conv1 = fm1 107 | 108 | # disable other modules 109 | self.model.bn1 = nn.Identity() 110 | self.model.relu = nn.Identity() 111 | 112 | elif "fc" in name: 113 | fm1 = self.fuse_linear(self.model.fc) 114 | self.model.fc = fm1 115 | 116 | return self.model -------------------------------------------------------------------------------- /src/t2c/fusers/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from src.t2c.fusers.fusers import LayerFuser 3 | from src.module.fuse import _QBaseConv2d, _QBaseLinear 4 | 5 | class VGGFuser(LayerFuser): 6 | def __init__(self, model: nn.Module): 7 | super().__init__(model) 8 | 9 | def layers(self): 10 | """ 11 | Fetch layer information from pretrained model 12 | """ 13 | conv_bn_relu = [] 14 | for i, m in enumerate(self.model.features): 15 | if isinstance(m, _QBaseConv2d): 16 | conv_bn_relu = [] 17 | conv_bn_relu.append(m) 18 | 19 | elif isinstance(m, nn.BatchNorm2d): 20 | conv_bn_relu.append(m) 21 | 22 | elif isinstance(m, nn.ReLU): 23 | conv_bn_relu.append(m) 24 | self.groups.append(conv_bn_relu) 25 | 26 | elif isinstance(m, (nn.MaxPool2d, nn.AvgPool2d)): 27 | self.groups.append([m]) 28 | 29 | def fuse(self): 30 | # update the groups 31 | self.layers() 32 | 33 | features = [] 34 | for cbr in self.groups: 35 | if len(cbr) == 3: 36 | new_layer = self.conv_bn_relu(cbr) 37 | features.append(new_layer) 38 | else: 39 | features.append(*cbr) 40 | 41 | self.model.features = nn.Sequential(*features) 42 | 43 | classifier = self.model.classifier 44 | for n, m in classifier.named_modules(): 45 | if isinstance(m, _QBaseLinear): 46 | new_layer = self.fuse_linear(m) 47 | classifier[int(n)] = new_layer 48 | 49 | self.model.classifier = classifier 50 | return self.model -------------------------------------------------------------------------------- /src/t2c/fusers/vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from src.t2c.convert import get_parent_name 4 | from src.module.attention import QAttention, QWindowAttention 5 | from src.module.base import _QBaseLinear, _QBase 6 | from src.module.fuse import MulQuant, MulShift 7 | from src.quantization.observer import BaseObserver, BaseTokenWiseObserver, BaseChannelWiseObserver 8 | 9 | from timm.layers import Mlp 10 | 11 | class ViTFuser(object): 12 | def __init__(self, model: nn.Module, rescale_out:bool=False): 13 | self.model = model.eval() 14 | self.rescale_out = rescale_out 15 | 16 | def inference(self): 17 | """ 18 | Switch to inference mode 19 | """ 20 | pass 21 | 22 | def layers(self): 23 | pass 24 | 25 | def quantizer_fuse(self, xq:_QBase, wq:_QBase): 26 | scale_x = xq.scale 27 | scale_w = wq.scale 28 | if isinstance(xq.observer, BaseTokenWiseObserver): 29 | if isinstance(wq.observer, BaseChannelWiseObserver): 30 | scale_w = scale_w.unsqueeze(0) 31 | sw = scale_x @ scale_w.transpose(1,2) 32 | else: 33 | sw = scale_x * scale_w 34 | else: 35 | if isinstance(wq.observer, BaseChannelWiseObserver): 36 | scale_w = scale_w.unsqueeze(0).transpose(1,2) 37 | sw = scale_x * scale_w 38 | 39 | return sw 40 | 41 | def fuse_linear(self, layer:_QBaseLinear): 42 | # switch to inference mode 43 | layer.inference() 44 | scaler = MulShift() 45 | 46 | # fetch the scaling factors 47 | sw = self.quantizer_fuse(layer.aq, layer.wq) 48 | bias = getattr(layer, "bias") 49 | 50 | # construct the scalers 51 | scaler.scale = sw 52 | scaler.bias.data = bias 53 | 54 | setattr(layer, "yq", scaler) 55 | return layer 56 | 57 | def qkv_fuser(self, module:QAttention): 58 | module.inference() 59 | 60 | # fuse the scaling factors 61 | sw = self.quantizer_fuse(module.qkv.aq, module.qkv.wq) 62 | sy = module.qkv.yq.scale 63 | 64 | # integer-only attention 65 | if isinstance(module.qkv.yq.observer, BaseTokenWiseObserver): 66 | qkscale = (sy @ sy.transpose(-2,-1)) 67 | elif isinstance(module.qkv.yq.observer, BaseObserver): 68 | qkscale = sy.pow(2) 69 | 70 | module.attn_scale.scale = qkscale.mul(module.attn_scale.scale) 71 | 72 | # replace the simple shifter to quantizer 73 | scaler = MulQuant(nbit=module.qkv.aq.nbit, unsigned=module.qkv.aq.unsigned) 74 | scaler.scale.data = sw.div(sy) 75 | scaler.bias.data = module.qkv.bias.div(sy) 76 | 77 | setattr(module.qkv, "yq", scaler) 78 | proj = self.fuse_linear(module.proj) 79 | proj.aq.scale.data.div_(sy) 80 | setattr(module, "proj", proj) 81 | 82 | return module 83 | 84 | def mlp_fuser(self, module:Mlp): 85 | fc1 = getattr(module, "fc1") 86 | fc2 = getattr(module, "fc2") 87 | 88 | ffc1 = self.fuse_linear(fc1) 89 | ffc2 = self.fuse_linear(fc2) 90 | 91 | setattr(module, "fc1", ffc1) 92 | setattr(module, "fc2", ffc2) 93 | return module 94 | 95 | def fuse(self): 96 | modules = dict(self.model.named_modules(remove_duplicate=True)) 97 | 98 | for n, m in self.model.named_modules(): 99 | if isinstance(m, QAttention): 100 | parent_name, name = get_parent_name(n) 101 | 102 | module = self.qkv_fuser(m) 103 | setattr(modules[parent_name], name, module) 104 | 105 | elif isinstance(m, Mlp): 106 | parent_name, name = get_parent_name(n) 107 | 108 | module = self.mlp_fuser(m) 109 | setattr(modules[parent_name], name, module) 110 | 111 | return self.model 112 | 113 | class SwinFuser(ViTFuser): 114 | def __init__(self, model: nn.Module, rescale_out:bool=False): 115 | super().__init__(model) 116 | self.rescale_out = rescale_out 117 | 118 | def qkv_fuser(self, module: QWindowAttention): 119 | module = super().qkv_fuser(module) 120 | 121 | return module 122 | 123 | def fuse(self): 124 | modules = dict(self.model.named_modules(remove_duplicate=True)) 125 | 126 | for n, m in self.model.named_modules(): 127 | if isinstance(m, QWindowAttention): 128 | parent_name, name = get_parent_name(n) 129 | 130 | module = self.qkv_fuser(m) 131 | setattr(modules[parent_name], name, module) 132 | elif isinstance(m, Mlp): 133 | parent_name, name = get_parent_name(n) 134 | 135 | module = self.mlp_fuser(m) 136 | setattr(modules[parent_name], name, module) 137 | 138 | return self.model -------------------------------------------------------------------------------- /src/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/trainer/__init__.py -------------------------------------------------------------------------------- /src/trainer/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/trainer/llm/__init__.py -------------------------------------------------------------------------------- /src/trainer/llm/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metric for different llm tasks 3 | """ 4 | 5 | import torch 6 | 7 | class Metric(object): 8 | def __init__(self): 9 | pass 10 | 11 | class Perplexity(Metric): 12 | def __init__(self, chunk_size:int=2048, n_samples:int=40): 13 | super().__init__() 14 | 15 | self.loss = [] 16 | self.chunk_size = chunk_size 17 | self.n_samples = n_samples 18 | 19 | def func(self, pred:torch.Tensor, target:torch.Tensor): 20 | target = target.long() 21 | 22 | loss_fn = torch.nn.CrossEntropyLoss() 23 | loss_val = loss_fn(pred.view(-1, pred.size(-1)), target.view(-1)) 24 | return loss_val 25 | 26 | def update(self, pred:torch.Tensor, target:torch.Tensor): 27 | loss_val = self.func(pred, target) 28 | neg_log_likelihood = loss_val.float() * self.chunk_size 29 | self.loss.append(neg_log_likelihood) 30 | 31 | def reduce(self): 32 | ppl = torch.exp(torch.stack(self.loss).sum() / (self.n_samples * self.chunk_size)) 33 | return ppl 34 | -------------------------------------------------------------------------------- /src/trainer/llm/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for LLM generation 3 | """ 4 | 5 | import torch 6 | import transformers 7 | from typing import List 8 | 9 | class MultiTokenEOSCriteria(transformers.StoppingCriteria): 10 | """ 11 | Criteria to stop on the specified multi-token sequence. 12 | 13 | Adopted from: lm-evaluation-harness: https://github.com/EleutherAI/lm-evaluation-harness 14 | """ 15 | 16 | def __init__( 17 | self, 18 | sequence: str, 19 | tokenizer: transformers.PreTrainedTokenizer, 20 | initial_decoder_input_length: int, 21 | batch_size: int, 22 | ) -> None: 23 | self.initial_decoder_input_length = initial_decoder_input_length 24 | self.done_tracker = [False] * batch_size 25 | self.sequence = sequence 26 | self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) 27 | # print(sequence, self.sequence_ids) 28 | # we look back for 2 more tokens than it takes to encode our stop sequence 29 | # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` 30 | # and we don't want to mistakenly not stop a generation because our 31 | # (string) stop sequence was output in a different tokenization 32 | 33 | # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, 34 | # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized 35 | # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described. 36 | self.sequence_id_len = len(self.sequence_ids) + 2 37 | self.tokenizer = tokenizer 38 | 39 | def __call__(self, input_ids, scores, **kwargs) -> bool: 40 | # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence 41 | lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] 42 | 43 | lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] 44 | 45 | lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) 46 | 47 | for i, done in enumerate(self.done_tracker): 48 | if not done: 49 | self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] 50 | return False not in self.done_tracker 51 | 52 | 53 | def stop_sequences_criteria( 54 | tokenizer: transformers.PreTrainedTokenizer, 55 | stop_sequences: List[str], 56 | initial_decoder_input_length: int, 57 | batch_size: int, 58 | ) -> transformers.StoppingCriteriaList: 59 | """ 60 | Create a sequence list for LLM generation 61 | 62 | Adopted from: lm-evaluation-harness: https://github.com/EleutherAI/lm-evaluation-harness 63 | """ 64 | return transformers.StoppingCriteriaList( 65 | [ 66 | *[ 67 | MultiTokenEOSCriteria( 68 | sequence, tokenizer, initial_decoder_input_length, batch_size 69 | ) 70 | for sequence in stop_sequences 71 | ], 72 | ] 73 | ) -------------------------------------------------------------------------------- /src/trainer/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions 3 | """ 4 | 5 | import torch 6 | from src.module.base import _QBaseConv2d, _QBaseLinear 7 | from src.module.attention import QAttention 8 | from timm.layers.mlp import Mlp 9 | 10 | from typing import Union 11 | 12 | def lp_loss(pred, tgt, p=2.0): 13 | """ 14 | loss function 15 | """ 16 | return (pred - tgt).abs().pow(p).sum(1).mean() 17 | 18 | class LinearTempDecay: 19 | def __init__(self, t_max=20000, warm_up=0.2, start_b=20, end_b=2): 20 | self.t_max = t_max 21 | self.start_decay = warm_up * t_max 22 | self.start_b = start_b 23 | self.end_b = end_b 24 | 25 | def __call__(self, t): 26 | if t < self.start_decay: 27 | return self.start_b 28 | elif t > self.t_max: 29 | return self.end_b 30 | else: 31 | rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) 32 | return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) 33 | 34 | class AdaRoundLoss: 35 | def __init__(self, 36 | module:Union[_QBaseConv2d, _QBaseLinear, QAttention, Mlp], 37 | weight:float = 1e-5, 38 | iters: int = 100, 39 | b_range: tuple = (4, 2), 40 | warmup:float = 0.2, 41 | p: float = 2. 42 | ): 43 | 44 | self.module = module 45 | self.weight = weight 46 | self.p = p 47 | 48 | # temperature decay 49 | self.temp_decay = LinearTempDecay(iters, warm_up=warmup, start_b=b_range[0], end_b=b_range[1]) 50 | self.start = int(warmup * iters) 51 | 52 | self.steps = 0 53 | self.b = b_range[0] 54 | 55 | def step(self): 56 | self.b = self.temp_decay(self.steps) 57 | 58 | def attention(self, module:QAttention): 59 | rqkv = self.weight * (1 - ((module.qkv.wq.h() - .5).abs() * 2).pow(self.b)).sum() 60 | rproj = self.weight * (1 - ((module.proj.wq.h() - .5).abs() * 2).pow(self.b)).sum() 61 | return rqkv + rproj 62 | 63 | def mlp(self, module:Mlp): 64 | rfc1 = self.weight * (1 - ((module.fc1.wq.h() - .5).abs() * 2).pow(self.b)).sum() 65 | rfc2 = self.weight * (1 - ((module.fc2.wq.h() - .5).abs() * 2).pow(self.b)).sum() 66 | return rfc1 + rfc2 67 | 68 | def conv_linear(self, module: Union[_QBaseConv2d, _QBaseLinear]): 69 | rw = self.weight * (1 - ((module.wq.h() - .5).abs() * 2).pow(self.b)).sum() 70 | return rw 71 | 72 | def __call__(self, pred, target): 73 | # update b 74 | self.step() 75 | 76 | rec_loss = lp_loss(pred, target, p=self.p) 77 | 78 | if self.steps > self.start: 79 | if isinstance(self.module, QAttention): 80 | round_loss = self.attention(self.module) 81 | elif isinstance(self.module, Mlp): 82 | round_loss = self.mlp(self.module) 83 | elif isinstance(self.module, (_QBaseConv2d, _QBaseLinear)): 84 | round_loss = self.conv_linear(self.module) 85 | 86 | return rec_loss + round_loss 87 | else: 88 | return rec_loss 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/trainer/pruning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sparse Trainer 3 | """ 4 | 5 | from torch.nn.modules import Module 6 | from src.trainer.base import Trainer 7 | from src.pruner.base import CosineDecay 8 | from src.pruner.element import ElementPrune 9 | from src.pruner.nm import NMPruner 10 | 11 | PRUNER = { 12 | "element": ElementPrune, 13 | "nm": NMPruner 14 | } 15 | 16 | class STrainer(Trainer): 17 | def __init__(self, model: Module, trainloader, validloader, config, logger): 18 | super().__init__(model, trainloader, validloader, config, logger) 19 | 20 | prune_config = self.config["prune"] 21 | pruner = prune_config["type"] 22 | 23 | # decay 24 | pr_decay = CosineDecay(prune_config["drate"], T_max=int(len(trainloader)*self.epochs)) 25 | 26 | # pruner 27 | self.pruner = PRUNER[str(pruner)]( 28 | model=model, 29 | prune_ratio=prune_config["prune_ratio"], 30 | warmup=prune_config["warmup"], 31 | final_epoch=prune_config["final_epoch"], 32 | dataloader=trainloader, 33 | prune_freq=prune_config["prune_freq"], 34 | prune_decay=pr_decay, 35 | ) 36 | 37 | if str(pruner) == "nm": 38 | self.pruner.M = prune_config.get("M", 4) 39 | self.pruner.N = prune_config.get("N", 2) 40 | 41 | def train_step(self, inputs, target): 42 | out, loss = super().train_step(inputs, target) 43 | self.pruner.step() 44 | return out, loss 45 | 46 | def train_epoch(self): 47 | super().train_epoch() 48 | self.logger_dict["sparsity"] = self.pruner.sparsity 49 | self.logger_dict["pr"] = self.pruner.pr 50 | self.logger_dict["dr"] = self.pruner.dr 51 | -------------------------------------------------------------------------------- /src/trainer/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/trainer/vision/__init__.py -------------------------------------------------------------------------------- /src/trainer/vision/smoothquant.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post Training Quantizer of SmoothQuant 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from typing import List, Union 8 | 9 | from src.module.attention import QAttention, QWindowAttention 10 | from src.module.base import _QBaseLinear 11 | from src.trainer.vision.ptq import PTQViT 12 | from timm.layers.mlp import Mlp 13 | 14 | class SmoothQuantPTQViT(PTQViT): 15 | def __init__(self, model: nn.Module, loss_type: str, trainloader, validloader, args, logger): 16 | super().__init__(model, loss_type, trainloader, validloader, args, logger) 17 | 18 | # smooth coefficient 19 | self.alpha = self.args.alpha 20 | 21 | # smooth scaling 22 | self.sscale = None 23 | 24 | # calibration only 25 | self.layer_train = False 26 | 27 | def layer_stat(self, cached_data:List): 28 | xmax = torch.tensor(0.) 29 | 30 | for batch in cached_data: 31 | x, _ = batch 32 | xmax = torch.max(xmax, x.max()) 33 | 34 | return xmax 35 | 36 | def collect_scale(self): 37 | sscale = {} 38 | self.logger.info(f"Start Fetching the smooth factor!") 39 | for n, m in self.model.named_modules(): 40 | if isinstance(m, _QBaseLinear): 41 | cached_data = self.fetch_layer_data_all(m) 42 | 43 | xmax = self.layer_stat(cached_data) 44 | wmax = m.weight.abs().max() 45 | 46 | scale = (xmax.pow(self.alpha) / wmax.pow(1 - self.alpha)).clamp(1e-5) 47 | 48 | sscale[n] = scale 49 | 50 | del cached_data 51 | self.logger.info(f"Done!\n") 52 | return sscale 53 | 54 | def update_attn(self, layer: Union[QAttention, QWindowAttention], name=None): 55 | layer = super().update_attn(layer, name) 56 | 57 | # load the smooth factor 58 | layer.xq.smoother.scale.data.copy_(1 / self.sscale[name+".qkv"]) 59 | layer.qkv.wq.smoother.scale.data.copy_(self.sscale[name+".qkv"]) 60 | 61 | layer.qproj.smoother.scale.data.copy_(1 / self.sscale[name+".proj"]) 62 | layer.proj.wq.smoother.scale.data.copy_(self.sscale[name+".proj"]) 63 | 64 | return layer 65 | 66 | def update_mlp(self, layer: Mlp, name=None): 67 | layer = super().update_mlp(layer, name) 68 | 69 | # load the smooth factor 70 | layer.fc1.wq.smoother.scale.data.copy_(self.sscale[name+".fc1"]) 71 | layer.fc1.aq.smoother.scale.data.copy_(1 / self.sscale[name+".fc1"]) 72 | 73 | layer.fc2.wq.smoother.scale.data.copy_(self.sscale[name+".fc2"]) 74 | layer.fc2.aq.smoother.scale.data.copy_(1 / self.sscale[name+".fc2"]) 75 | 76 | return layer 77 | 78 | def fit(self): 79 | if self.sscale is None: 80 | self.sscale = self.collect_scale() 81 | 82 | super().fit() 83 | 84 | 85 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeoLabCornell/torch2chip/34090538b86fcbabe8ae1cd770cbd7757c81edfc/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities 3 | """ 4 | import shutil 5 | import torch 6 | import tabulate 7 | import argparse 8 | from collections import OrderedDict 9 | 10 | def str2bool(v): 11 | if isinstance(v, bool): 12 | return v 13 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 14 | return True 15 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 16 | return False 17 | else: 18 | raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | def accuracy(output, target, topk=(1,)): 39 | """Computes the accuracy over the k top predictions for the specified values of k""" 40 | with torch.no_grad(): 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in topk: 50 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | def print_table(values, columns, epoch, logger): 55 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.6f') 56 | if epoch == 0: 57 | table = table.split('\n') 58 | table = '\n'.join([table[1]] + table) 59 | else: 60 | table = table.split('\n')[2] 61 | logger.info(table) 62 | 63 | def lr_schedule(epoch): 64 | if epoch >= 100: 65 | factor = 0.1 66 | if epoch >= 150: 67 | factor = 0.01 68 | else: 69 | factor = 1.0 70 | return factor 71 | 72 | def convert_secs2time(epoch_time): 73 | need_hour = int(epoch_time / 3600) 74 | need_mins = int((epoch_time - 3600*need_hour) / 60) 75 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 76 | return need_hour, need_mins, need_secs 77 | 78 | def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar'): 79 | torch.save(state, save_path+filename) 80 | if is_best: 81 | shutil.copyfile(save_path+filename, save_path+'model_best.pth.tar') 82 | 83 | def load_checkpoint(ckpt, state=None): 84 | checkpoint = torch.load(ckpt) 85 | 86 | if "state_dict" in checkpoint.keys(): 87 | sdict = checkpoint['state_dict'] 88 | else: 89 | sdict = checkpoint 90 | 91 | new_state_dict = OrderedDict() 92 | 93 | for k, v in sdict.items(): 94 | name = k 95 | new_state_dict[name] = v 96 | 97 | if state is not None: 98 | state.update(new_state_dict) 99 | 100 | return state 101 | 102 | def load_ddp_checkpoint(ckpt, state): 103 | checkpoint = torch.load(ckpt) 104 | sdict = checkpoint['state_dict'] 105 | 106 | new_state_dict = OrderedDict() 107 | 108 | for k, v in sdict.items(): 109 | name = k[7:] 110 | new_state_dict[name] = v 111 | 112 | state.update(new_state_dict) 113 | return state 114 | 115 | def gpufloat2cpuint(tensor:torch.Tensor, dtype=torch.int32): 116 | return tensor.to(dtype).cpu().detach() -------------------------------------------------------------------------------- /vision/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet-50 3 | """ 4 | 5 | import sys 6 | sys.path.append("../torch2chip/") 7 | 8 | import argparse 9 | from src.trainer.vision.ptq import PTQ 10 | from src.stage.base import Execute 11 | from src.t2c.convert import Vanilla4Compress 12 | from src.data.vision.imagenet import ImageNet1K 13 | from src.t2c.t2c import T2C 14 | 15 | parser = argparse.ArgumentParser(description='Llama') 16 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 17 | args = parser.parse_args() 18 | 19 | class CompressResNet(Execute): 20 | def __init__(self, config_dir): 21 | super().__init__(config_dir) 22 | 23 | model = self.create_model() 24 | converter = Vanilla4Compress(model=model, wbit=32, abit=32) 25 | self.model = converter.convert() 26 | 27 | # prepare dataloaders 28 | trainloader, self.testloader = self.prepare_dataloader() 29 | 30 | # quantizer 31 | self.trainer = PTQ( 32 | model=self.model, 33 | trainloader=trainloader, 34 | testloader=self.testloader, 35 | config=self.config, 36 | logger=self.logger 37 | ) 38 | 39 | def prepare_dataloader(self): 40 | data_gen = ImageNet1K(self.config_dir) 41 | 42 | trainloader, testloader = data_gen.run() 43 | return trainloader, testloader 44 | 45 | def ptq(self): 46 | wqtype = self.config["quantization"]["wqtype"] 47 | xqtype = self.config["quantization"]["xqtype"] 48 | method = f"w{wqtype}_a{xqtype}" 49 | 50 | self.logger.info(f"PTQ start! {method}") 51 | self.trainer.fit() 52 | 53 | fake_quant_model = getattr(self.trainer, "model") 54 | self.print_arch(fake_quant_model, "fake_quantized_model") 55 | 56 | self.trainer.valid_epoch() 57 | self.logger.info("[After PTQ] Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 58 | 59 | return fake_quant_model 60 | 61 | def t2c(self, fake_quant_model): 62 | t2c = T2C(model=fake_quant_model, config=self.config) 63 | fused_model = t2c.fused_model() 64 | 65 | assert hasattr(self, "trainer"), "Trainer must be defined before running T2C fusion!" 66 | setattr(self.trainer, "model", fused_model.to(self.device)) 67 | 68 | self.trainer.valid_epoch() 69 | self.logger.info("[After fusing]: Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 70 | 71 | fused_model = getattr(self.trainer, "model") 72 | self.print_arch(fused_model, "fused_model") 73 | 74 | # export the files 75 | t2c.export(self.testloader, path=self.run_dir, export_samples=1) 76 | 77 | return fused_model 78 | 79 | def run(self): 80 | fake_quant_model = self.ptq() 81 | fused_model = self.t2c(fake_quant_model) 82 | 83 | 84 | if __name__ == "__main__": 85 | executor = CompressResNet(args.config_dir) 86 | executor.run() -------------------------------------------------------------------------------- /vision/swin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compress Swin Transformer 3 | """ 4 | import sys 5 | sys.path.append("../torch2chip/") 6 | 7 | import argparse 8 | from src.stage.base import Execute 9 | from src.t2c.convert import ViTV4C 10 | from src.data.vision.imagenet import ImageNet1K 11 | from src.trainer.vision.ptq import PTQViT 12 | from src.trainer.vision.smoothquant import SmoothQuantPTQViT 13 | from src.t2c.t2c import T2C 14 | 15 | parser = argparse.ArgumentParser(description='Llama') 16 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 17 | args = parser.parse_args() 18 | 19 | class CompressSwin(Execute): 20 | def __init__(self, config_dir): 21 | super().__init__(config_dir) 22 | 23 | model = self.create_model() 24 | 25 | qconfig = self.config["quantization"] 26 | wbit = qconfig["wbit"] 27 | abit = qconfig["abit"] 28 | smooth = qconfig.get("smooth", None) 29 | 30 | converter = ViTV4C(model=model, wbit=wbit, abit=abit) 31 | self.model = converter.convert() 32 | 33 | # prepare dataloaders 34 | trainloader, testloader = self.prepare_dataloader() 35 | 36 | if not smooth: 37 | self.trainer = PTQViT( 38 | model=self.model, 39 | trainloader=trainloader, 40 | testloader=testloader, 41 | logger=self.logger, 42 | config=self.config 43 | ) 44 | 45 | else: 46 | pass 47 | 48 | def prepare_dataloader(self): 49 | data_gen = ImageNet1K(self.config_dir) 50 | 51 | trainloader, testloader = data_gen.run() 52 | return trainloader, testloader 53 | 54 | def ptq(self): 55 | wqtype = self.config["quantization"]["wqtype"] 56 | xqtype = self.config["quantization"]["xqtype"] 57 | method = f"w{wqtype}_a{xqtype}" 58 | self.logger.info(f"PTQ start! {method}") 59 | self.trainer.fit() 60 | 61 | fake_quant_model = getattr(self.trainer, "model") 62 | self.print_arch(fake_quant_model, "fake_quantized_model") 63 | 64 | self.trainer.valid_epoch() 65 | self.logger.info("[After PTQ] Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 66 | 67 | return fake_quant_model 68 | 69 | def t2c(self, fake_quant_model): 70 | t2c = T2C(model=fake_quant_model, config=self.config) 71 | fused_model = t2c.fused_model() 72 | 73 | assert hasattr(self, "trainer"), "Trainer must be defined before running T2C fusion!" 74 | setattr(self.trainer, "model", fused_model.to(self.device)) 75 | 76 | self.trainer.valid_epoch() 77 | self.logger.info("[After fusing]: Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 78 | 79 | fused_model = getattr(self.trainer, "model") 80 | self.print_arch(fused_model, "fused_model") 81 | 82 | return fused_model 83 | 84 | def run(self): 85 | fake_quant_model = self.ptq() 86 | fused_model = self.t2c(fake_quant_model) 87 | 88 | if __name__ == "__main__": 89 | executor = CompressSwin(args.config_dir) 90 | executor.run() 91 | -------------------------------------------------------------------------------- /vision/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vision model stage for compression 3 | """ 4 | 5 | import sys 6 | sys.path.append("../torch2chip/") 7 | 8 | import argparse 9 | from src.stage.base import Execute 10 | from src.t2c.convert import ViTV4C 11 | from src.trainer.vision.ptq import PTQViT 12 | from src.data.vision.imagenet import ImageNet1K 13 | from src.t2c.t2c import T2C 14 | 15 | parser = argparse.ArgumentParser(description='Compress Vision Transformers') 16 | parser.add_argument('--config_dir', type=str, default=None, help="Path to the configuration file (.yaml)") 17 | args = parser.parse_args() 18 | 19 | class CompressViT(Execute): 20 | def __init__(self, config_dir): 21 | super().__init__(config_dir) 22 | 23 | model = self.create_model() 24 | 25 | qconfig = self.config["quantization"] 26 | wbit = qconfig["wbit"] 27 | abit = qconfig["abit"] 28 | 29 | smooth = qconfig.get("smooth", None) 30 | 31 | converter = ViTV4C(model=model, wbit=wbit, abit=abit) 32 | self.model = converter.convert() 33 | 34 | # prepare dataloaders 35 | trainloader, testloader = self.prepare_dataloader() 36 | 37 | if not smooth: 38 | self.trainer = PTQViT( 39 | model=self.model, 40 | trainloader=trainloader, 41 | testloader=testloader, 42 | logger=self.logger, 43 | config=self.config 44 | ) 45 | else: 46 | pass 47 | 48 | def prepare_dataloader(self): 49 | data_gen = ImageNet1K(self.config_dir) 50 | 51 | trainloader, testloader = data_gen.run() 52 | return trainloader, testloader 53 | 54 | def ptq(self): 55 | wqtype = self.config["quantization"]["wqtype"] 56 | xqtype = self.config["quantization"]["xqtype"] 57 | method = f"w{wqtype}_a{xqtype}" 58 | self.logger.info(f"PTQ start! {method}") 59 | self.trainer.fit() 60 | 61 | fake_quant_model = getattr(self.trainer, "model") 62 | self.print_arch(fake_quant_model, "fake_quantized_model") 63 | 64 | self.trainer.valid_epoch() 65 | self.logger.info("[After PTQ] Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 66 | 67 | return fake_quant_model 68 | 69 | def t2c(self, fake_quant_model): 70 | t2c = T2C(model=fake_quant_model, config=self.config) 71 | fused_model = t2c.fused_model() 72 | 73 | assert hasattr(self, "trainer"), "Trainer must be defined before running T2C fusion!" 74 | setattr(self.trainer, "model", fused_model.to(self.device)) 75 | 76 | self.trainer.valid_epoch() 77 | self.logger.info("[After fusing]: Test accuracy = {:.2f}".format(self.trainer.logger_dict["valid_top1"])) 78 | 79 | fused_model = getattr(self.trainer, "model") 80 | self.print_arch(fused_model, "fused_model") 81 | 82 | return fused_model 83 | 84 | def run(self): 85 | fake_quant_model = self.ptq() 86 | fused_model = self.t2c(fake_quant_model) 87 | 88 | 89 | if __name__ == "__main__": 90 | executor = CompressViT(args.config_dir) 91 | executor.run() 92 | --------------------------------------------------------------------------------