├── LICENSE ├── README.md ├── cga.py ├── configs ├── deit_default_imagent.attn_q.yml ├── ours_imagenet_recipe.attn_q.yml └── swin_t_imagenet.attn_q.yml ├── eval.py ├── eval_scripts ├── deit_s │ ├── w2a2.sh │ ├── w3a3.sh │ └── w4a4.sh ├── deit_t │ ├── w2a2.sh │ ├── w3a3.sh │ └── w4a4.sh └── swin_t │ ├── w2a2.sh │ ├── w3a3.sh │ └── w4a4.sh ├── imgs ├── ofq_vit.png └── w2a2_deit_s_cga_vis_.png ├── src ├── __init__.py ├── deit.py ├── deit_vision_transformer.py ├── quantization │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── qbias.py │ │ ├── qlinear.py │ │ ├── swin_attention_and_mlp.py │ │ └── utils.py │ ├── quantizer │ │ ├── __init__.py │ │ ├── lsq.py │ │ └── statsq.py │ └── utils.py ├── registry.py ├── swin.py └── utils │ ├── __init__.py │ ├── embedder.py │ ├── helpers.py │ ├── stochastic_depth.py │ ├── tokenizer.py │ ├── transformers.py │ └── utils.py ├── timm_fix_imagenet_loading_bugs └── dataset_factory.py ├── train.py └── train_scripts ├── deit_s ├── w2a2_deit_s.sh ├── w3a3_deit_s.sh └── w4a4_deit_s.sh ├── deit_t ├── w2a2_deit_t.sh ├── w3a3_deit_t.sh └── w4a4_deit_t.sh └── swin_t ├── w2a2_swin_t.sh ├── w3a3_swin_t.sh └── w4a4_swin_t.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LIU, Shih-yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OFQ: Oscillation-free Quantization for Low-bit Vision Transformers 2 | 3 | This repository contains the training code of ViT introduced in our work: "[Oscillation-free Quantization for Low-bit Vision Transformers](https://arxiv.org/abs/2302.02210)" which has been accepted for ICML 2023. Please consider starring the repo if you find our work useful, thanks! 4 | 5 | In this work, we discusses the issue of weight oscillation in quantization-aware training and how it negatively affects model performance. The learnable scaling factor, commonly used in quantization, was found to worsen weight oscillation. The study proposes three techniques to address this issue: statistical weight quantization (StatsQ), confidence-guided annealing (CGA), and query-key reparameterization (QKR). These techniques were tested on the ViT model and were found to improve quantization robustness and accuracy. The proposed 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively. 6 | 7 | 8 |
9 | 10 |
Fig.1 - Trajectory of statistical scaling factors (StatsQ) from the 10th transformer blocks in a 2-bit DeiT-S throughout CGA with 4 different boundary ranges [BR_0.003, BR_0.005, BR_0.007, BR_0.01]. The y-axis represents the value of scaling factorsß
11 |
12 | 13 | ## Run 14 | 15 | ### 1. Requirements: 16 | * numpy==1.22.3 17 | * torch==2.0.0 18 | * torchvision==0.15.1 19 | * timm=0.5.4 20 | * pyyaml 21 | 22 | Please replace "/your/miniconda3/envs/ofq/lib/python3.8/site-packages/timm/data/dataset_factory.py" with "timm_fix_imagenet_loading_bugs/dataset_factory.py" as with the original code there is a "TypeError: __init__() got an unexpected keyword argument 'download'" error. 23 | 24 | ### 2. Data: 25 | * Download [ILSVRC12 ImageNet classification dataset](https://www.image-net.org/download.php) 26 | 27 | ### 3. Pretrained models: 28 | * Pretrained models will be automatically downloaded for you if set args.pretrained to True. 29 | 30 | ### 4. Steps to run: 31 | * Examples of training scripts, finetuning scripts (CGA) are provided under "train_scripts/" and evaluation scripts are under "eval_scripts/" (please use the exact same batch size (batch_size * world_size) as provided in the evaluation scripts to reproduce the results reported in the paper). 32 | 33 | * Please modified the data path to your own dataset address 34 | 35 | ## Models 36 | ### 1. ImageNet1K dataset 37 | 38 | | Models | #Bits | Top-1 Accuracy (Model Link)| eval script | 39 | | --- | --- | --- | ------- | 40 | | DeiT-T | 32-32 | 72.02 | ------- | 41 | | OFQ DeiT-T | 2-2 | [**64.33**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ETacXA4nZxtBuJ2RSVO6GToBqe6M8vnL0hXo75msagLKDw?e=RdConK)| eval_scripts/deit_t/w2a2.sh | 42 | | OFQ DeiT-T | 3-3 | [**72.72**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQB7riAo9y1GpX5Tr3bChqIBYetCcIHfR2xdxJACqvAuLw?e=VeHOm6) | eval_scripts/deit_t/w3a3.sh | 43 | | OFQ DeiT-T | 4-4 | [**75.46**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/Ebo0I7E5_4xMqfThmh3zpgkB77BWrWSyx5NzJqm-mDsGVg?e=23S3Bz) | eval_scripts/deit_t/w4a4.sh | 44 | | DeiT-S | 32-32 | 79.9 | ------- | 45 | | OFQ DeiT-S | 2-2 | [**75.72**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQEHW7tDmYpNjULhtUFpf2cB9EfIQrsN2LPF846TVNe2cg?e=eBqu6S) | eval_scripts/deit_s/w2a2.sh | 46 | | OFQ DeiT-S | 3-3 | [**79.57**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ESMKhGXDKT1HjFzCKrSYOvgB9SfXQ5u_46HnIt0LIMS5MQ?e=4cBdYU) | eval_scripts/deit_s/w3a3.sh | 47 | | OFQ DeiT-S | 4-4 | [**81.10**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQ-KvPjXiI5GiZB-r03hcm8B_Q0M0cZ4b0Rj59RWiQ8ZtA?e=Lk4GjR) | eval_scripts/deit_s/w4a4.sh | 48 | | Swin-T | 32-32 | 81.2 | ------- | 49 | | OFQ Swin-T | 2-2 | [**78.52**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EWl2BPKk2qpCv0uZ1VGYa2QBGog9ZmpckGco-A1aNLrxhA?e=Eh2X1X) | eval_scripts/swin_t/w2a2.sh | 50 | | OFQ Swin-T | 3-3 | [**81.09**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ETnNiYGeG9JOnLgJAIzbXFIBlaovFZ-RoZkC95-FrLP9yA?e=gB77Vv) | eval_scripts/swin_t/w3a3.sh | 51 | | OFQ Swin-T | 4-4 | [**81.88**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EX9WmBLUQsBNqWscdm1cIcQBWe6Yt5EARN0xoD12IOrE-g?e=cmuFle) | eval_scripts/swin_t/w4a4.sh | 52 | 53 | ## Acknowledgement 54 | 55 | The original code is borrowed from [DeiT](https://github.com/facebookresearch/deit). 56 | 57 | ## Citation 58 | 59 | If you find our code useful for your research, please consider citing: 60 | ```bibtex 61 | @InProceedings{pmlr-v202-liu23w, 62 | title = {Oscillation-free Quantization for Low-bit Vision Transformers}, 63 | author = {Liu, Shih-Yang and Liu, Zechun and Cheng, Kwang-Ting}, 64 | booktitle = {Proceedings of the 40th International Conference on Machine Learning}, 65 | pages = {21813--21824}, 66 | year = {2023}, 67 | editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan}, 68 | volume = {202}, 69 | series = {Proceedings of Machine Learning Research}, 70 | month = {23--29 Jul}, 71 | publisher = {PMLR}, 72 | pdf = {https://proceedings.mlr.press/v202/liu23w/liu23w.pdf}, 73 | url = {https://proceedings.mlr.press/v202/liu23w.html}, 74 | abstract = {Weight oscillation is a by-product of quantization-aware training, in which quantized weights frequently jump between two quantized levels, resulting in training instability and a sub-optimal final model. We discover that the learnable scaling factor, a widely-used $\textit{de facto}$ setting in quantization aggravates weight oscillation. In this work, we investigate the connection between learnable scaling factor and quantized weight oscillation using ViT, and we additionally find that the interdependence between quantized weights in $\textit{query}$ and $\textit{key}$ of a self-attention layer also makes ViT vulnerable to oscillation. We propose three techniques correspondingly: statistical weight quantization ($\rm StatsQ$) to improve quantization robustness compared to the prevalent learnable-scale-based method; confidence-guided annealing ($\rm CGA$) that freezes the weights with $\textit{high confidence}$ and calms the oscillating weights; and $\textit{query}$-$\textit{key}$ reparameterization ($\rm QKR$) to resolve the query-key intertwined oscillation and mitigate the resulting gradient misestimation. Extensive experiments demonstrate that our algorithms successfully abate weight oscillation and consistently achieve substantial accuracy improvement on ImageNet. Specifically, our 2-bit DeiT-T/DeiT-S surpass the previous state-of-the-art by 9.8% and 7.7%, respectively. The code is included in the supplementary material and will be released.} 75 | } 76 | ``` 77 | 78 | ## Contact 79 | 80 | Shih-Yang Liu, HKUST (sliuau at connect.ust.hk) 81 | 82 | 83 | -------------------------------------------------------------------------------- /configs/deit_default_imagent.attn_q.yml: -------------------------------------------------------------------------------- 1 | dataset: imagenet 2 | num_classes: 1000 3 | img_size: 224 4 | mean: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | std: 9 | - 0.229 10 | - 0.224 11 | - 0.225 12 | crop_pct: 0.9 13 | scale: 14 | - 0.08 15 | - 1.0 16 | interpolation: bicubic 17 | train_interpolation: random 18 | aa: rand-m9-mstd0.5-inc1 19 | mixup: 0.8 20 | mixup_off_epoch: 0 21 | mixup_prob: 1.0 22 | mixup_mode: batch 23 | mixup_switch_prob: 0.5 24 | cutmix: 1.0 25 | reprob: 0.25 26 | remode: pixel 27 | amp: False 28 | model_ema: False 29 | batch_size: 128 30 | lr: 5e-4 31 | min_lr: 1e-5 32 | sched: cosine 33 | weight_decay: 5e-2 34 | epochs: 300 35 | cooldown_epochs: 10 36 | warmup_epochs: 5 37 | warmup_lr: 0.00001 38 | opt: adamw 39 | smoothing: 0.1 40 | num_aug_repeats: 3 41 | workers: 4 42 | qmodules: 43 | - "patch_embed.proj" 44 | - "blocks.0.attn" 45 | - "blocks.0.mlp" 46 | - "blocks.1.attn" 47 | - "blocks.1.mlp" 48 | - "blocks.2.attn" 49 | - "blocks.2.mlp" 50 | - "blocks.3.attn" 51 | - "blocks.3.mlp" 52 | - "blocks.4.attn" 53 | - "blocks.4.mlp" 54 | - "blocks.5.attn" 55 | - "blocks.5.mlp" 56 | - "blocks.6.attn" 57 | - "blocks.6.mlp" 58 | - "blocks.7.attn" 59 | - "blocks.7.mlp" 60 | - "blocks.8.attn" 61 | - "blocks.8.mlp" 62 | - "blocks.9.attn" 63 | - "blocks.9.mlp" 64 | - "blocks.10.attn" 65 | - "blocks.10.mlp" 66 | - "blocks.11.attn" 67 | - "blocks.11.mlp" 68 | - "head" 69 | - "head_dist" 70 | -------------------------------------------------------------------------------- /configs/ours_imagenet_recipe.attn_q.yml: -------------------------------------------------------------------------------- 1 | dataset: imagenet 2 | num_classes: 1000 3 | img_size: 224 4 | mean: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | std: 9 | - 0.229 10 | - 0.224 11 | - 0.225 12 | crop_pct: 0.9 13 | scale: 14 | - 0.08 15 | - 1.0 16 | interpolation: bicubic 17 | train_interpolation: random 18 | aa: rand-m9-mstd0.5-inc1 19 | mixup: 0.8 20 | mixup_off_epoch: 0 21 | mixup_prob: 1.0 22 | mixup_mode: batch 23 | mixup_switch_prob: 0.5 24 | cutmix: 1.0 25 | reprob: 0.25 26 | remode: pixel 27 | amp: False 28 | model: deit_tiny_patch16_224 29 | model_ema: False 30 | batch_size: 128 31 | lr: 5e-4 32 | min_lr: 1e-5 33 | sched: cosine 34 | weight_decay: 5e-2 35 | epochs: 128 36 | cooldown_epochs: 10 37 | warmup_epochs: 10 38 | warmup_lr: 0.00001 39 | opt: adamw 40 | smoothing: 0.1 41 | workers: 4 42 | wq_enable: True 43 | wq_bitw: 2 44 | aq_enable: True 45 | aq_bitw: 2 46 | num_aug_repeats: 0 47 | qmodules: 48 | - "patch_embed.proj" 49 | - "blocks.0.attn" 50 | - "blocks.0.mlp" 51 | - "blocks.1.attn" 52 | - "blocks.1.mlp" 53 | - "blocks.2.attn" 54 | - "blocks.2.mlp" 55 | - "blocks.3.attn" 56 | - "blocks.3.mlp" 57 | - "blocks.4.attn" 58 | - "blocks.4.mlp" 59 | - "blocks.5.attn" 60 | - "blocks.5.mlp" 61 | - "blocks.6.attn" 62 | - "blocks.6.mlp" 63 | - "blocks.7.attn" 64 | - "blocks.7.mlp" 65 | - "blocks.8.attn" 66 | - "blocks.8.mlp" 67 | - "blocks.9.attn" 68 | - "blocks.9.mlp" 69 | - "blocks.10.attn" 70 | - "blocks.10.mlp" 71 | - "blocks.11.attn" 72 | - "blocks.11.mlp" 73 | - "head" 74 | - "head_dist" -------------------------------------------------------------------------------- /configs/swin_t_imagenet.attn_q.yml: -------------------------------------------------------------------------------- 1 | dataset: imagenet 2 | num_classes: 1000 3 | img_size: 224 4 | mean: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | std: 9 | - 0.229 10 | - 0.224 11 | - 0.225 12 | crop_pct: 0.9 13 | scale: 14 | - 0.08 15 | - 1.0 16 | interpolation: bicubic 17 | train_interpolation: random 18 | aa: rand-m9-mstd0.5-inc1 19 | mixup: 0.8 20 | mixup_off_epoch: 0 21 | mixup_prob: 1.0 22 | mixup_mode: batch 23 | mixup_switch_prob: 0.5 24 | cutmix: 1.0 25 | reprob: 0.25 26 | remode: pixel 27 | amp: False 28 | model: swin_t 29 | model_ema: False 30 | batch_size: 512 31 | lr: 2e-4 32 | min_lr: 1e-5 33 | sched: cosine 34 | weight_decay: 0.0 35 | epochs: 300 36 | cooldown_epochs: 10 37 | warmup_epochs: 0 38 | warmup_lr: 0.00001 39 | opt: adamw 40 | smoothing: 0.1 41 | num_aug_repeats: 0 42 | workers: 1 43 | drop_path: 0.0 44 | qmodules: 45 | - "features.0.0" 46 | - "features.1.0.attn" 47 | - "features.1.0.mlp" 48 | - "features.1.1.attn" 49 | - "features.1.1.mlp" 50 | - "features.2.reduction" 51 | - "features.3.0.attn" 52 | - "features.3.0.mlp" 53 | - "features.3.1.attn" 54 | - "features.3.1.mlp" 55 | - "features.4.reduction" 56 | - "features.5.0.attn" 57 | - "features.5.0.mlp" 58 | - "features.5.1.attn" 59 | - "features.5.1.mlp" 60 | - "features.5.2.attn" 61 | - "features.5.2.mlp" 62 | - "features.5.3.attn" 63 | - "features.5.3.mlp" 64 | - "features.5.4.attn" 65 | - "features.5.4.mlp" 66 | - "features.5.5.attn" 67 | - "features.5.5.mlp" 68 | - "features.6.reduction" 69 | - "features.7.0.attn" 70 | - "features.7.0.mlp" 71 | - "features.7.1.attn" 72 | - "features.7.1.mlp" 73 | - "head" 74 | -------------------------------------------------------------------------------- /eval_scripts/deit_s/w2a2.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 200 \ 6 | --weight-decay 0.05 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 5.47e-4 \ 9 | --warmup-epochs 5 \ 10 | --mixup 0.0 --cutmix 0.0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 2 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 2 \ 19 | --wq-mode statsq \ 20 | --model_type deit \ 21 | --quantized \ 22 | --pretrained \ 23 | --pretrained_initialized \ 24 | --use-kd --teacher deit_small_distilled_patch16_224 \ 25 | --kd_hard_and_soft 1 \ 26 | --qk_reparam \ 27 | --qk_reparam_type 0 \ 28 | --teacher_pretrained \ 29 | --resume your_path/model_saved/deit_s/w2a2/w2a2_deit_s_qkr_cga.pth.tar \ 30 | --output ./outputs/w2a2_deit_s_qkreparam_eval/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '36969' -------------------------------------------------------------------------------- /eval_scripts/deit_s/w3a3.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 200 \ 6 | --weight-decay 0.0 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 3.2e-4 \ 9 | --warmup-epochs 0 \ 10 | --aq-enable \ 11 | --aq-mode lsq \ 12 | --aq-per-channel \ 13 | --aq_clip_learnable \ 14 | --aq-bitw 3 \ 15 | --wq-enable \ 16 | --wq-per-channel \ 17 | --wq-bitw 3 \ 18 | --wq-mode statsq \ 19 | --model_type deit \ 20 | --quantized \ 21 | --pretrained \ 22 | --pretrained_initialized \ 23 | --use-kd --teacher deit_small_distilled_patch16_224 \ 24 | --kd_hard_and_soft 1 \ 25 | --qk_reparam \ 26 | --qk_reparam_type 0 \ 27 | --teacher_pretrained \ 28 | --resume your_path/model_saved/deit_s/w3a3/w3a3_deit_s_qkr_cga.pth.tar \ 29 | --output ./outputs/w3a3_deit_s_qkreparam_eval/ \ 30 | --visible_gpu '0,1,2,4' \ 31 | --world_size '4' \ 32 | --tcp_port '36969' 33 | 34 | -------------------------------------------------------------------------------- /eval_scripts/deit_s/w4a4.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 200 \ 6 | --weight-decay 0.0 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 3.2e-4 \ 9 | --warmup-epochs 0 \ 10 | --aq-enable \ 11 | --aq-mode lsq \ 12 | --aq-per-channel \ 13 | --aq_clip_learnable \ 14 | --aq-bitw 4 \ 15 | --wq-enable \ 16 | --wq-per-channel \ 17 | --wq-bitw 4 \ 18 | --wq-mode statsq \ 19 | --model_type deit \ 20 | --quantized \ 21 | --pretrained \ 22 | --pretrained_initialized \ 23 | --use-kd --teacher deit_small_distilled_patch16_224 \ 24 | --kd_hard_and_soft 1 \ 25 | --qk_reparam \ 26 | --qk_reparam_type 0 \ 27 | --teacher_pretrained \ 28 | --resume your_path/model_saved/deit_s/w4a4/w4a4_deit_s_qkr_cga.pth.tar \ 29 | --output ./outputs/w4a4_deit_s_qkreparam_eval/ \ 30 | --visible_gpu '0,1,2,4' \ 31 | --world_size '4' \ 32 | --tcp_port '36969' -------------------------------------------------------------------------------- /eval_scripts/deit_t/w2a2.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 140 \ 6 | --aq-enable \ 7 | --aq-mode lsq \ 8 | --aq-per-channel \ 9 | --aq_clip_learnable \ 10 | --aq-bitw 2 \ 11 | --wq-enable \ 12 | --wq-per-channel \ 13 | --wq-bitw 2 \ 14 | --wq-mode statsq \ 15 | --model_type deit \ 16 | --quantized \ 17 | --pretrained \ 18 | --pretrained_initialized \ 19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 20 | --kd_hard_and_soft 1 \ 21 | --qk_reparam \ 22 | --qk_reparam_type 0 \ 23 | --teacher_pretrained \ 24 | --resume your_path/model_saved/deit_t/w2a2/w2a2_deit_t_qkr_cga.pth.tar \ 25 | --output ./outputs/w2a2_deit_t_qkreparam_eval/ \ 26 | --visible_gpu '0,1,2,4' \ 27 | --world_size '4' \ 28 | --tcp_port '36969' -------------------------------------------------------------------------------- /eval_scripts/deit_t/w3a3.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 140 \ 6 | --aq-enable \ 7 | --aq-mode lsq \ 8 | --aq-per-channel \ 9 | --aq_clip_learnable \ 10 | --aq-bitw 3 \ 11 | --wq-enable \ 12 | --wq-per-channel \ 13 | --wq-bitw 3 \ 14 | --wq-mode statsq \ 15 | --model_type deit \ 16 | --quantized \ 17 | --pretrained \ 18 | --pretrained_initialized \ 19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 20 | --kd_hard_and_soft 1 \ 21 | --qk_reparam \ 22 | --qk_reparam_type 0 \ 23 | --teacher_pretrained \ 24 | --resume your_path/model_saved/deit_t/w3a3/w3a3_deit_t_qkr_cga.pth.tar \ 25 | --output ./outputs/w3a3_deit_t_qkreparam_eval/ \ 26 | --visible_gpu '0,1,2,4' \ 27 | --world_size '4' \ 28 | --tcp_port '36969' -------------------------------------------------------------------------------- /eval_scripts/deit_t/w4a4.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 140 \ 6 | --aq-enable \ 7 | --aq-mode lsq \ 8 | --aq-per-channel \ 9 | --aq_clip_learnable \ 10 | --aq-bitw 4 \ 11 | --wq-enable \ 12 | --wq-per-channel \ 13 | --wq-bitw 4 \ 14 | --wq-mode statsq \ 15 | --model_type deit \ 16 | --quantized \ 17 | --pretrained \ 18 | --pretrained_initialized \ 19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 20 | --kd_hard_and_soft 1 \ 21 | --qk_reparam \ 22 | --qk_reparam_type 0 \ 23 | --teacher_pretrained \ 24 | --resume your_path/model_saved/deit_t/w4a4/w4a4_deit_t_qkr_cga.pth.tar \ 25 | --output ./outputs/w4a4_deit_t_qkreparam_eval/ \ 26 | --visible_gpu '0,1,2,4' \ 27 | --world_size '4' \ 28 | --tcp_port '36969' -------------------------------------------------------------------------------- /eval_scripts/swin_t/w2a2.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 128 \ 6 | --weight-decay 0.05 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 5.0e-4 \ 9 | --warmup-epochs 5 \ 10 | --mixup 0.0 --cutmix 0.0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 2 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 2 \ 19 | --wq-mode statsq \ 20 | --model_type swin \ 21 | --teacher_type swin \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher swin_t \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --resume your_path/model_saved/swin_t/w2a2/w2a2_swin_t_qkr_cga.pth.tar \ 31 | --output ./outputs/w2a2_swin_t_qkreparam/ \ 32 | --visible_gpu '0,1,2,4' \ 33 | --world_size '4' \ 34 | --tcp_port '12345' -------------------------------------------------------------------------------- /eval_scripts/swin_t/w3a3.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 128 \ 6 | --weight-decay 0.0 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 2.0e-4 \ 9 | --warmup-epochs 0 \ 10 | --aq-enable \ 11 | --aq-mode lsq \ 12 | --aq-per-channel \ 13 | --aq_clip_learnable \ 14 | --aq-bitw 3 \ 15 | --wq-enable \ 16 | --wq-per-channel \ 17 | --wq-bitw 3 \ 18 | --wq-mode statsq \ 19 | --model_type swin \ 20 | --teacher_type swin \ 21 | --quantized \ 22 | --pretrained \ 23 | --pretrained_initialized \ 24 | --use-kd --teacher swin_t \ 25 | --kd_hard_and_soft 1 \ 26 | --qk_reparam \ 27 | --qk_reparam_type 0 \ 28 | --teacher_pretrained \ 29 | --resume your_path/model_saved/swin_t/w3a3/w3a3_swin_t_qkr_cga.pth.tar \ 30 | --output ./outputs/w3a3_swin_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '12345' -------------------------------------------------------------------------------- /eval_scripts/swin_t/w4a4.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 2 | your_path/dataset/imagenet-1k/imagenet \ 3 | --dataset 'torch/imagenet' \ 4 | --epochs 300 \ 5 | --batch-size 128 \ 6 | --weight-decay 0.0 \ 7 | --warmup-lr 1.0e-6 \ 8 | --lr 2.0e-4 \ 9 | --warmup-epochs 0 \ 10 | --aq-enable \ 11 | --aq-mode lsq \ 12 | --aq-per-channel \ 13 | --aq_clip_learnable \ 14 | --aq-bitw 4 \ 15 | --wq-enable \ 16 | --wq-per-channel \ 17 | --wq-bitw 4 \ 18 | --wq-mode statsq \ 19 | --model_type swin \ 20 | --teacher_type swin \ 21 | --quantized \ 22 | --pretrained \ 23 | --pretrained_initialized \ 24 | --use-kd --teacher swin_t \ 25 | --kd_hard_and_soft 1 \ 26 | --qk_reparam \ 27 | --qk_reparam_type 0 \ 28 | --teacher_pretrained \ 29 | --resume your_path/model_saved/swin_t/w4a4/w4a4_swin_t_qkr_cga.pth.tar \ 30 | --output ./outputs/w4a4_swin_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '12345' -------------------------------------------------------------------------------- /imgs/ofq_vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nbasyl/OFQ/7ed37d1dd33d39395edbf49fcbbc52f678ecf961/imgs/ofq_vit.png -------------------------------------------------------------------------------- /imgs/w2a2_deit_s_cga_vis_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nbasyl/OFQ/7ed37d1dd33d39395edbf49fcbbc52f678ecf961/imgs/w2a2_deit_s_cga_vis_.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from .deit_vision_transformer import * 3 | from .deit import * 4 | from .swin import * 5 | from .quantization import * 6 | from .utils import * 7 | 8 | -------------------------------------------------------------------------------- /src/deit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from operator import is_ 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | 8 | from .deit_vision_transformer import VisionTransformer, _cfg 9 | 10 | # from timm.models.vision_transformer import VisionTransformer, _cfg 11 | from timm.models.registry import register_model 12 | from timm.models.layers import trunc_normal_ 13 | 14 | 15 | __all__ = [ 16 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224' 17 | ] 18 | 19 | 20 | class DistilledVisionTransformer(VisionTransformer): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 24 | num_patches = self.patch_embed.num_patches 25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 27 | 28 | trunc_normal_(self.dist_token, std=.02) 29 | trunc_normal_(self.pos_embed, std=.02) 30 | self.head_dist.apply(self._init_weights) 31 | 32 | def forward_features(self, x): 33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 34 | x = self.patch_embed(x) 35 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 36 | if self.dist_token is None: 37 | x = torch.cat((cls_token, x), dim=1) 38 | else: 39 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 40 | x = self.pos_drop(x + self.pos_embed) 41 | 42 | attn_matrixs = [] 43 | intermediate_features = [] 44 | for block in self.blocks: 45 | x, attn_matrix = block(x) 46 | attn_matrixs.append(attn_matrix) 47 | intermediate_features.append(x) 48 | 49 | x = self.norm(x) 50 | if self.dist_token is None: 51 | return self.pre_logits(x[:, 0]), attn_matrixs, intermediate_features 52 | else: 53 | return x[:, 0], x[:, 1], attn_matrixs, intermediate_features 54 | 55 | 56 | def forward(self, x): 57 | x = self.forward_features(x) 58 | 59 | if self.head_dist is not None: 60 | cls_x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 61 | if self.training and not torch.jit.is_scripting(): 62 | return (cls_x, x_dist), x[2] 63 | else: 64 | return (cls_x + x_dist) / 2, x[2] 65 | else: 66 | cls_x = self.head(x[0]) 67 | return cls_x, x[1] # 68 | 69 | 70 | 71 | 72 | @register_model 73 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 74 | model = DistilledVisionTransformer( 75 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 76 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer= nn.GELU, **kwargs) 77 | model.default_cfg = _cfg() 78 | if pretrained and kwargs['num_classes'] == 1000: 79 | checkpoint = torch.hub.load_state_dict_from_url( 80 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 81 | map_location="cpu", check_hash=True 82 | ) 83 | print("load pretrained") 84 | model.load_state_dict(checkpoint["model"]) 85 | elif pretrained and kwargs['num_classes'] == 100: 86 | raise ValueError('No trained model provided') 87 | return model 88 | 89 | @register_model 90 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 91 | model = DistilledVisionTransformer( 92 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 93 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer= nn.GELU, **kwargs) 94 | model.default_cfg = _cfg() 95 | if pretrained and kwargs['num_classes'] == 1000: 96 | checkpoint = torch.hub.load_state_dict_from_url( 97 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 98 | map_location="cpu", check_hash=True 99 | ) 100 | print("load pretrained") 101 | model.load_state_dict(checkpoint["model"]) 102 | elif pretrained and kwargs['num_classes'] == 100: 103 | raise ValueError('No trained model provided') 104 | return model 105 | 106 | -------------------------------------------------------------------------------- /src/deit_vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2020, Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 37 | from timm.models.layers import PatchEmbed, DropPath, trunc_normal_, lecun_normal_, to_2tuple 38 | from .registry import register_model 39 | 40 | _logger = logging.getLogger(__name__) 41 | 42 | 43 | def _cfg(url='', **kwargs): 44 | return { 45 | 'url': url, 46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 50 | **kwargs 51 | } 52 | 53 | class Mlp(nn.Module): 54 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 55 | """ 56 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 57 | super().__init__() 58 | self.in_features = in_features 59 | self.hidden_features = hidden_features 60 | self.out_features = out_features 61 | self.drop = drop 62 | 63 | out_features = out_features or in_features 64 | hidden_features = hidden_features or in_features 65 | drop_probs = to_2tuple(drop) 66 | 67 | self.fc1 = nn.Linear(in_features, hidden_features) 68 | self.act = act_layer() 69 | 70 | 71 | 72 | self.drop1 = nn.Dropout(drop_probs[0]) 73 | self.fc2 = nn.Linear(hidden_features, out_features) 74 | self.drop2 = nn.Dropout(drop_probs[1]) 75 | 76 | 77 | def forward(self, x): 78 | x = self.fc1(x) 79 | x = self.act(x) 80 | x = self.drop1(x) 81 | x = self.fc2(x) 82 | x = self.drop2(x) 83 | return x 84 | 85 | class Attention(nn.Module): 86 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., qqkkvv = False): 87 | super().__init__() 88 | self.dim = dim 89 | self.num_heads = num_heads 90 | self.head_dim = dim // num_heads 91 | self.scale = self.head_dim ** -0.5 92 | 93 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 94 | self.attn_drop = nn.Dropout(attn_drop) 95 | self.proj = nn.Linear(dim, dim) 96 | self.proj_drop = nn.Dropout(proj_drop) 97 | 98 | self.qqkkvv = qqkkvv 99 | 100 | def forward(self, x): 101 | B, N, C = x.shape 102 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 103 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 104 | 105 | if self.qqkkvv: 106 | q_score = torch.matmul(q, q.transpose(-1, -2)) 107 | q_score = q_score / math.sqrt(self.head_dim) 108 | k_score = torch.matmul(k, k.transpose(-1, -2)) 109 | k_score = k_score / math.sqrt(self.head_dim) 110 | v_score = torch.matmul(v, v.transpose(-1, -2)) 111 | v_score = v_score / math.sqrt(self.head_dim) 112 | 113 | attn = (q @ k.transpose(-2, -1)) * self.scale 114 | attn_matrix = attn.softmax(dim=-1) 115 | attn = self.attn_drop(attn_matrix) 116 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 117 | x = self.proj(x) 118 | x = self.proj_drop(x) 119 | return x, (attn_matrix, q_score, k_score, v_score) 120 | else: 121 | 122 | 123 | attn = (q @ k.transpose(-2, -1)) * self.scale 124 | attn_matrix = attn.softmax(dim=-1) 125 | attn = self.attn_drop(attn_matrix) 126 | 127 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 128 | x = self.proj(x) 129 | x = self.proj_drop(x) 130 | return x, None 131 | 132 | class Block(nn.Module): 133 | 134 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 135 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, v_normalized = False, qk_normalized=False, qqkkvv = False, LN_affine = True): 136 | super().__init__() 137 | if norm_layer != None: 138 | self.norm1 = norm_layer(dim, elementwise_affine= LN_affine) 139 | else: 140 | self.norm1 = torch.nn.Identity() 141 | 142 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, qqkkvv = qqkkvv) 143 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 144 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 145 | if norm_layer != None: 146 | self.norm2 = norm_layer(dim, elementwise_affine= LN_affine) 147 | else: 148 | self.norm2 = torch.nn.Identity() 149 | 150 | mlp_hidden_dim = int(dim * mlp_ratio) 151 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 152 | self.qqkkvv = qqkkvv 153 | 154 | def forward(self, x): 155 | if self.qqkkvv: 156 | temp_x, attn_mtrx = self.attn(self.norm1(x)) 157 | x = x + self.drop_path(temp_x) 158 | x = x + self.drop_path(self.mlp(self.norm2(x))) 159 | return x, attn_mtrx 160 | else: 161 | temp_x, _ = self.attn(self.norm1(x)) 162 | x = x + self.drop_path(temp_x) 163 | x = x + self.drop_path(self.mlp(self.norm2(x))) 164 | return x, None 165 | 166 | 167 | 168 | class VisionTransformer(nn.Module): 169 | """ Vision Transformer 170 | 171 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 172 | - https://arxiv.org/abs/2010.11929 173 | 174 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 175 | - https://arxiv.org/abs/2012.12877 176 | """ 177 | 178 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 179 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 180 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 181 | act_layer=None, weight_init='', qqkkvv = False, LN_affine = True): 182 | """ 183 | Args: 184 | img_size (int, tuple): input image size 185 | patch_size (int, tuple): patch size 186 | in_chans (int): number of input channels 187 | num_classes (int): number of classes for classification head 188 | embed_dim (int): embedding dimension 189 | depth (int): depth of transformer 190 | num_heads (int): number of attention heads 191 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 192 | qkv_bias (bool): enable bias for qkv if True 193 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 194 | distilled (bool): model includes a distillation token and head as in DeiT models 195 | drop_rate (float): dropout rate 196 | attn_drop_rate (float): attention dropout rate 197 | drop_path_rate (float): stochastic depth rate 198 | embed_layer (nn.Module): patch embedding layer 199 | norm_layer: (nn.Module): normalization layer 200 | weight_init: (str): weight init scheme 201 | """ 202 | super().__init__() 203 | self.num_classes = num_classes 204 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 205 | self.num_tokens = 2 if distilled else 1 206 | # norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 207 | # act_layer = act_layer or nn.GELU 208 | norm_layer = norm_layer 209 | act_layer = act_layer 210 | self.qqkkvv = qqkkvv 211 | 212 | self.patch_embed = embed_layer( 213 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 214 | num_patches = self.patch_embed.num_patches 215 | 216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 219 | self.pos_drop = nn.Dropout(p=drop_rate) 220 | 221 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 222 | self.blocks = nn.Sequential(*[ 223 | Block( 224 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 225 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, qqkkvv =self.qqkkvv, LN_affine = LN_affine) 226 | for i in range(depth)]) 227 | 228 | if norm_layer == None: 229 | self.norm = torch.nn.Identity() 230 | else: 231 | self.norm = norm_layer(embed_dim) 232 | 233 | # Representation layer 234 | if representation_size and not distilled: 235 | self.num_features = representation_size 236 | self.pre_logits = nn.Sequential(OrderedDict([ 237 | ('fc', nn.Linear(embed_dim, representation_size)), 238 | ('act', nn.Tanh()) 239 | ])) 240 | else: 241 | self.pre_logits = nn.Identity() 242 | 243 | # Classifier head(s) 244 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 245 | self.head_dist = None 246 | if distilled: 247 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 248 | 249 | self.init_weights(weight_init) 250 | 251 | def init_weights(self, mode=''): 252 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 253 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 254 | trunc_normal_(self.pos_embed, std=.02) 255 | if self.dist_token is not None: 256 | trunc_normal_(self.dist_token, std=.02) 257 | if mode.startswith('jax'): 258 | # leave cls token as zeros to match jax impl 259 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 260 | else: 261 | trunc_normal_(self.cls_token, std=.02) 262 | self.apply(_init_vit_weights) 263 | 264 | def _init_weights(self, m): 265 | # this fn left here for compat with downstream users 266 | _init_vit_weights(m) 267 | 268 | @torch.jit.ignore() 269 | def load_pretrained(self, checkpoint_path, prefix=''): 270 | _load_weights(self, checkpoint_path, prefix) 271 | 272 | @torch.jit.ignore 273 | def no_weight_decay(self): 274 | return {'pos_embed', 'cls_token', 'dist_token'} 275 | 276 | def get_classifier(self): 277 | if self.dist_token is None: 278 | return self.head 279 | else: 280 | return self.head, self.head_dist 281 | 282 | def reset_classifier(self, num_classes, global_pool=''): 283 | self.num_classes = num_classes 284 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 285 | if self.num_tokens == 2: 286 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 287 | 288 | def forward_features(self, x): 289 | x = self.patch_embed(x) 290 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 291 | if self.dist_token is None: 292 | x = torch.cat((cls_token, x), dim=1) 293 | else: 294 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 295 | x = self.pos_drop(x + self.pos_embed) 296 | 297 | attn_matrixs = [] 298 | intermediate_features = [] 299 | for block in self.blocks: 300 | x, attn_matrix = block(x) 301 | attn_matrixs.append(attn_matrix) 302 | intermediate_features.append(x) 303 | 304 | x = self.norm(x) 305 | if self.dist_token is None: 306 | return self.pre_logits(x[:, 0]), attn_matrixs, intermediate_features 307 | else: 308 | return x[:, 0], x[:, 1], attn_matrixs, intermediate_features 309 | 310 | def get_attn_matrix_and_intermediate_features(self, x): 311 | 312 | x = self.forward_features(x) 313 | if self.dist_token is None: 314 | return x[1], x[2] 315 | else: 316 | return x[2], x[3] 317 | 318 | def forward(self, x): 319 | x = self.forward_features(x) 320 | if self.head_dist is not None: 321 | cls_x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 322 | if self.training and not torch.jit.is_scripting(): 323 | # during inference, return the average of both classifier predictions 324 | return (cls_x, x_dist), x[2] 325 | else: 326 | return (cls_x + x_dist) / 2, x[2] 327 | else: 328 | cls_x = self.head(x[0]) 329 | 330 | return cls_x, x[1] 331 | 332 | 333 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 334 | """ ViT weight initialization 335 | * When called without n, head_bias, jax_impl args it will behave exactly the same 336 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 337 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 338 | """ 339 | if isinstance(module, nn.Linear): 340 | if name.startswith('head'): 341 | nn.init.zeros_(module.weight) 342 | nn.init.constant_(module.bias, head_bias) 343 | elif name.startswith('pre_logits'): 344 | lecun_normal_(module.weight) 345 | nn.init.zeros_(module.bias) 346 | else: 347 | if jax_impl: 348 | nn.init.xavier_uniform_(module.weight) 349 | if module.bias is not None: 350 | if 'mlp' in name: 351 | nn.init.normal_(module.bias, std=1e-6) 352 | else: 353 | nn.init.zeros_(module.bias) 354 | else: 355 | trunc_normal_(module.weight, std=.02) 356 | if module.bias is not None: 357 | nn.init.zeros_(module.bias) 358 | elif jax_impl and isinstance(module, nn.Conv2d): 359 | # NOTE conv was left to pytorch default in my original init 360 | lecun_normal_(module.weight) 361 | if module.bias is not None: 362 | nn.init.zeros_(module.bias) 363 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 364 | if module.elementwise_affine == True: 365 | nn.init.zeros_(module.bias) 366 | nn.init.ones_(module.weight) 367 | 368 | 369 | @torch.no_grad() 370 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 371 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 372 | """ 373 | import numpy as np 374 | 375 | def _n2p(w, t=True): 376 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 377 | w = w.flatten() 378 | if t: 379 | if w.ndim == 4: 380 | w = w.transpose([3, 2, 0, 1]) 381 | elif w.ndim == 3: 382 | w = w.transpose([2, 0, 1]) 383 | elif w.ndim == 2: 384 | w = w.transpose([1, 0]) 385 | return torch.from_numpy(w) 386 | 387 | w = np.load(checkpoint_path) 388 | if not prefix and 'opt/target/embedding/kernel' in w: 389 | prefix = 'opt/target/' 390 | 391 | if hasattr(model.patch_embed, 'backbone'): 392 | # hybrid 393 | backbone = model.patch_embed.backbone 394 | stem_only = not hasattr(backbone, 'stem') 395 | stem = backbone if stem_only else backbone.stem 396 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 397 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 398 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 399 | if not stem_only: 400 | for i, stage in enumerate(backbone.stages): 401 | for j, block in enumerate(stage.blocks): 402 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 403 | for r in range(3): 404 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 405 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 406 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 407 | if block.downsample is not None: 408 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 409 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 410 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 411 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 412 | else: 413 | embed_conv_w = adapt_input_conv( 414 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 415 | model.patch_embed.proj.weight.copy_(embed_conv_w) 416 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 417 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 418 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 419 | if pos_embed_w.shape != model.pos_embed.shape: 420 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 421 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 422 | model.pos_embed.copy_(pos_embed_w) 423 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 424 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 425 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 426 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 427 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 428 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 429 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 430 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 431 | for i, block in enumerate(model.blocks.children()): 432 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 433 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 434 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 435 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 436 | block.attn.qkv.weight.copy_(torch.cat([ 437 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 438 | block.attn.qkv.bias.copy_(torch.cat([ 439 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 440 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 441 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 442 | for r in range(2): 443 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 444 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 445 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 446 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 447 | 448 | 449 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 450 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 451 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 452 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 453 | ntok_new = posemb_new.shape[1] 454 | if num_tokens: 455 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 456 | ntok_new -= num_tokens 457 | else: 458 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 459 | gs_old = int(math.sqrt(len(posemb_grid))) 460 | if not len(gs_new): # backwards compatibility 461 | gs_new = [int(math.sqrt(ntok_new))] * 2 462 | assert len(gs_new) >= 2 463 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 464 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 465 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 466 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 467 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 468 | return posemb 469 | 470 | 471 | def checkpoint_filter_fn(state_dict, model): 472 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 473 | out_dict = {} 474 | if 'model' in state_dict: 475 | # For deit models 476 | state_dict = state_dict['model'] 477 | for k, v in state_dict.items(): 478 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 479 | # For old models that I trained prior to conv based patchification 480 | O, I, H, W = model.patch_embed.proj.weight.shape 481 | v = v.reshape(O, -1, H, W) 482 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 483 | # To resize pos embedding when using model at different size from pretrained weights 484 | v = resize_pos_embed( 485 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 486 | out_dict[k] = v 487 | return out_dict 488 | 489 | 490 | 491 | 492 | 493 | -------------------------------------------------------------------------------- /src/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .quantizer import * 3 | from .utils import adaptive_clip_grad, KLLossSoft, KLTokenMSELoss, KDLossSoftandHard, KDLossSoftandHard_qk, KDLossSoftandHard_qkv, KDLossSoftandHard_dampening, KDLossSoftandSoftTargetCE 4 | -------------------------------------------------------------------------------- /src/quantization/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import replace_module_by_qmodule_deit, replace_module_by_qmodule_swin 2 | -------------------------------------------------------------------------------- /src/quantization/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .qlinear import LSQ_input 6 | from src.deit_vision_transformer import Attention as deit_attention 7 | from ..modules.qbias import LearnableBias 8 | from .qlinear import QLinear, LSQ_w_and_act_QLinear 9 | from ..quantizer.lsq import LsqQuantizer, LsqQuantizer4v 10 | from ..quantizer.statsq import StatsQuantizer, StatsQuantizer_specific_4_qkreparam_cga 11 | 12 | class QAttention(deit_attention): 13 | def __init__(self, m: deit_attention, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 14 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 15 | pretrained_initialized = False, 16 | **kwargs): 17 | assert type(m) == deit_attention 18 | super().__init__( 19 | dim = m.qkv.in_features, 20 | num_heads=m.num_heads, 21 | attn_drop=m.attn_drop.p, 22 | proj_drop=m.proj_drop.p, 23 | qqkkvv= m.qqkkvv 24 | ) 25 | self.weight_bits = weight_bits 26 | self.input_bits = input_bits 27 | self.input_channelwise = input_channelwise 28 | 29 | self.qkv = QLinear( 30 | m = self.qkv, 31 | weight_bits = weight_bits, 32 | input_bits = input_bits, 33 | weight_channelwise = weight_channelwise, 34 | input_channelwise = input_channelwise, 35 | weight_quant_method = weight_quant_method, 36 | input_quant_method = input_quant_method, 37 | aq_learnable = aq_learnable, ## act 38 | wq_learnable = wq_learnable,## weight 39 | symmetric = True, ## act 40 | pretrained_initialized = pretrained_initialized 41 | ) 42 | self.proj = QLinear( 43 | m = self.proj, 44 | weight_bits = weight_bits, 45 | input_bits = input_bits, 46 | weight_channelwise = weight_channelwise, 47 | input_channelwise = input_channelwise, 48 | weight_quant_method = weight_quant_method, 49 | input_quant_method = input_quant_method, 50 | aq_learnable = aq_learnable, ## act 51 | wq_learnable = wq_learnable,## weight 52 | symmetric = True, ## act 53 | pretrained_initialized = pretrained_initialized 54 | ) 55 | 56 | self.quan_a_q_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 57 | self.quan_a_k_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 58 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 59 | 60 | self.move_qkv_b4 = LearnableBias(m.qkv.in_features*3) 61 | self.move_q_aft = LearnableBias(m.qkv.in_features) 62 | self.move_k_aft = LearnableBias(m.qkv.in_features) 63 | self.move_v_aft = LearnableBias(m.qkv.in_features) 64 | 65 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable) 66 | 67 | def forward(self, x): 68 | B, N, C = x.shape 69 | qkv = self.qkv(x) # B, N, 3*C 70 | if self.input_bits < 32: 71 | qkv = self.move_qkv_b4(qkv) 72 | qkv = qkv.reshape( 73 | B, N, 3, self.num_heads, C // self.num_heads 74 | ).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C // self.num_heads 75 | q, k, v = qkv[0], qkv[1], qkv[2] # B, num_heads, N, C//self.num_heads 76 | 77 | q = self.quan_a_q_fn(q) # quantize along N 78 | k = self.quan_a_k_fn(k) # quantize along N 79 | 80 | v = v.permute(0,2,1,3).reshape(B,N,C) 81 | v = self.quan_a_v_fn(v) # quantize along C 82 | v = v.reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3) 83 | if self.input_bits < 32: 84 | 85 | q = q.permute(0, 2, 1, 3).reshape(B, N, C) 86 | k = k.permute(0, 2, 1, 3).reshape(B, N, C) 87 | v = v.permute(0, 2, 1, 3).reshape(B, N, C) 88 | q = self.move_q_aft(q) 89 | k = self.move_k_aft(k) 90 | v = self.move_v_aft(v) 91 | 92 | q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // self.num_heads 93 | k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 94 | v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 95 | 96 | attn_weights = (q @ k.transpose(-2, -1).contiguous()) * self.scale 97 | attn_prob = F.softmax(attn_weights, dim=-1) 98 | 99 | attn_prob = self.quan_a_softmax_fn(attn_prob) 100 | attn_prob = self.attn_drop(attn_prob) 101 | 102 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C) 103 | x = self.proj(x) 104 | x = self.proj_drop(x) 105 | return x, None 106 | 107 | class QAttention_qkreparam(deit_attention): 108 | def __init__(self, m: deit_attention, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 109 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 110 | pretrained_initialized = False, 111 | **kwargs): 112 | assert type(m) == deit_attention 113 | super().__init__( 114 | dim = m.qkv.in_features, 115 | num_heads=m.num_heads, 116 | attn_drop=m.attn_drop.p, 117 | proj_drop=m.proj_drop.p, 118 | qqkkvv= m.qqkkvv 119 | ) 120 | self.weight_bits = weight_bits 121 | self.input_bits = input_bits 122 | self.input_channelwise = input_channelwise 123 | 124 | self.quant_x_4_qkv = LSQ_input(bit = input_bits, all_positive= False, learnable= aq_learnable, learanbaleBiasdim=m.qkv.in_features) 125 | 126 | self.q = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias=False) 127 | self.k = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias= False) 128 | self.v = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features) 129 | 130 | if pretrained_initialized: 131 | with torch.no_grad(): 132 | q_k_v_dim = int(m.qkv.weight.shape[0]/3) 133 | copy_weight = m.qkv.weight.detach() 134 | copy_bias = m.qkv.bias.detach() 135 | self.q.weight.copy_(copy_weight[:q_k_v_dim*1,:]) 136 | self.k.weight.copy_(copy_weight[q_k_v_dim*1:q_k_v_dim*2,:]) 137 | self.v.weight.copy_(copy_weight[q_k_v_dim*2:q_k_v_dim*3,:]) 138 | self.v.bias.copy_(copy_bias[q_k_v_dim*2:q_k_v_dim*3]) 139 | 140 | self.qk_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable) # num_heads*in_features, in_features 141 | self.v_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable)#.to(m.weight.device) 142 | 143 | self.proj = QLinear( 144 | m = self.proj, 145 | weight_bits = weight_bits, 146 | input_bits = input_bits, 147 | weight_channelwise = weight_channelwise, 148 | input_channelwise = input_channelwise, 149 | weight_quant_method = weight_quant_method, 150 | input_quant_method = input_quant_method, 151 | aq_learnable = aq_learnable, ## act 152 | wq_learnable = wq_learnable,## weight 153 | symmetric = True, ## act 154 | pretrained_initialized = pretrained_initialized 155 | ) 156 | 157 | ## THIS IS FOR QK PART 158 | self.quan_a_qkx_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) # for W_q*W_k*X -> B, num_heads, N, C//numheads 159 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 160 | 161 | # for W_q*W_k*X 162 | self.move_qkx_b4 = LearnableBias(self.num_heads * m.qkv.in_features) 163 | self.move_qkx_aft = LearnableBias(self.num_heads * m.qkv.in_features) 164 | 165 | # for v 166 | self.move_v_b4 = LearnableBias(m.qkv.in_features) 167 | self.move_v_aft = LearnableBias(m.qkv.in_features) 168 | 169 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable) 170 | 171 | ## no longer need qkv 172 | del self.qkv 173 | 174 | def forward(self, x): 175 | B, N, C = x.shape 176 | ## first quantize input x 177 | quant_x = self.quant_x_4_qkv(x) 178 | ## V 179 | quant_v_weight = self.v_quant(self.v.weight) 180 | v_out = nn.functional.linear(quant_x, quant_v_weight) 181 | v_out += self.v.bias.view(1, -1).expand_as(v_out) # B,N,C 182 | 183 | ## TO MULTI_HEAD V 184 | v_out = self.move_v_b4(v_out) 185 | v_out = self.quan_a_v_fn(v_out) 186 | v_out = self.move_v_aft(v_out) # B, N, C 187 | v = v_out.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // num_heads 188 | 189 | ## TO MULTI_HEAD QK 190 | multi_head_q_weight = self.q.weight.reshape(self.num_heads,self.q.out_features // self.num_heads, self.q.in_features) # out_feat , in_feat -> num_heads, out_feat // num_heads, in_feat 191 | multi_head_k_weight = self.k.weight.reshape(self.num_heads,self.k.out_features // self.num_heads, self.k.in_features) 192 | 193 | multi_head_qk = multi_head_q_weight.transpose(-2, -1).contiguous() @ multi_head_k_weight # num_heads, in_features, in_features 194 | multi_head_qk = multi_head_qk.reshape(self.num_heads*self.q.out_features,self.q.in_features)# num_heads*in_features, in_features 195 | multi_head_qk_qunat = self.qk_quant(multi_head_qk) 196 | multi_head_qk_qunat = multi_head_qk_qunat.reshape(self.num_heads, self.q.in_features,self.q.in_features) #num_heads, in_features, in_features 197 | 198 | ## W_qk@X^T torch.einsum('BNC,BACD -> BAND',quant_x,quant_qkx) 199 | # quant_x: B,C,N 200 | qkx = torch.einsum('HDC, BCN -> BHDN', multi_head_qk_qunat, quant_x.transpose(-2, -1).contiguous() ) # B, num_heads, in_features, N 201 | qkx = qkx.permute(0,3,1,2).reshape(B,N, self.num_heads * C) #B, N, num_heads*in_features 202 | qkx = self.move_qkx_b4(qkx) 203 | qkx = qkx.reshape(B,N*self.num_heads, C) 204 | quant_qkx = self.quan_a_qkx_fn(qkx) # B, num_heads*N, in_features 205 | quant_qkx = quant_qkx.reshape(B,N,self.num_heads * C) #B, N, num_heads*in_features 206 | quant_qkx = self.move_qkx_aft(quant_qkx) #B, N, num_heads*in_features 207 | quant_qkx = quant_qkx.reshape(B, N, self.num_heads, -1).permute(0, 2, 3, 1) # B, num_heads, in_features, N 208 | 209 | ## x@W_qk@X^T, quant_x: B,N,C and quant_qkx: B,num_heads, C, N 210 | xqkx = torch.einsum('BNC,BHCD -> BHND',quant_x,quant_qkx) # B, num_heads, N, N 211 | # B, num_heads, N, C // self.num_heads 212 | value_4_softmax = xqkx 213 | attn_weights = (value_4_softmax) * self.scale 214 | attn_prob = F.softmax(attn_weights, dim=-1) 215 | 216 | attn_prob = self.quan_a_softmax_fn(attn_prob) 217 | attn_prob = self.attn_drop(attn_prob) 218 | 219 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C) 220 | x = self.proj(x) 221 | x = self.proj_drop(x) 222 | return x, None 223 | 224 | class QAttention_qkreparam_4_cga(deit_attention): 225 | def __init__(self, m: deit_attention, clip_val=2.5, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 226 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 227 | pretrained_initialized = False, boundaryRange = 0.005, 228 | **kwargs): 229 | assert type(m) == deit_attention 230 | super().__init__( 231 | dim = m.qkv.in_features, 232 | num_heads=m.num_heads, 233 | attn_drop=m.attn_drop.p, 234 | proj_drop=m.proj_drop.p, 235 | qqkkvv= m.qqkkvv 236 | ) 237 | self.weight_bits = weight_bits 238 | self.input_bits = input_bits 239 | self.input_channelwise = input_channelwise 240 | 241 | self.quant_x_4_qkv = LSQ_input(bit = input_bits, all_positive= False, learnable= aq_learnable, learanbaleBiasdim=m.qkv.in_features) 242 | 243 | self.q = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias=False) 244 | self.k = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias= False) 245 | self.v = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features) 246 | 247 | if pretrained_initialized: 248 | with torch.no_grad(): 249 | q_k_v_dim = int(m.qkv.weight.shape[0]/3) 250 | copy_weight = m.qkv.weight.detach() 251 | copy_bias = m.qkv.bias.detach() 252 | self.q.weight.copy_(copy_weight[:q_k_v_dim*1,:]) 253 | self.k.weight.copy_(copy_weight[q_k_v_dim*1:q_k_v_dim*2,:]) 254 | self.v.weight.copy_(copy_weight[q_k_v_dim*2:q_k_v_dim*3,:]) 255 | self.v.bias.copy_(copy_bias[q_k_v_dim*2:q_k_v_dim*3]) 256 | 257 | self.qk_quant = StatsQuantizer_specific_4_qkreparam_cga(num_bits=self.weight_bits, clip_learnable=wq_learnable,boundaryRange=boundaryRange) # num_heads*in_features, in_features 258 | self.v_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable)#.to(m.weight.device) 259 | 260 | self.proj = QLinear( 261 | m = self.proj, 262 | weight_bits = weight_bits, 263 | input_bits = input_bits, 264 | weight_channelwise = weight_channelwise, 265 | input_channelwise = input_channelwise, 266 | weight_quant_method = weight_quant_method, 267 | input_quant_method = input_quant_method, 268 | aq_learnable = aq_learnable, ## act 269 | wq_learnable = wq_learnable,## weight 270 | symmetric = True, ## act 271 | pretrained_initialized = pretrained_initialized 272 | ) 273 | 274 | ## THIS IS FOR QK PART 275 | self.quan_a_qkx_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) # for W_q*W_k*X -> B, num_heads, N, C//numheads 276 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 277 | 278 | # for W_q*W_k*X 279 | self.move_qkx_b4 = LearnableBias(self.num_heads * m.qkv.in_features) 280 | self.move_qkx_aft = LearnableBias(self.num_heads * m.qkv.in_features) 281 | 282 | # for v 283 | self.move_v_b4 = LearnableBias(m.qkv.in_features) 284 | self.move_v_aft = LearnableBias(m.qkv.in_features) 285 | 286 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable) 287 | 288 | ## no longer need qkv 289 | del self.qkv 290 | 291 | def forward(self, x): 292 | B, N, C = x.shape 293 | ## first quantize input x 294 | quant_x = self.quant_x_4_qkv(x) 295 | ## V 296 | quant_v_weight = self.v_quant(self.v.weight) 297 | v_out = nn.functional.linear(quant_x, quant_v_weight) 298 | v_out += self.v.bias.view(1, -1).expand_as(v_out) # B,N,C 299 | 300 | ## TO MULTI_HEAD V 301 | v_out = self.move_v_b4(v_out) 302 | v_out = self.quan_a_v_fn(v_out) 303 | v_out = self.move_v_aft(v_out) # B, N, C 304 | v = v_out.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // num_heads 305 | 306 | ## TO MULTI_HEAD QK 307 | multi_head_q_weight = self.q.weight.reshape(self.num_heads,self.q.out_features // self.num_heads, self.q.in_features) # out_feat , in_feat -> num_heads, out_feat // num_heads, in_feat 308 | multi_head_k_weight = self.k.weight.reshape(self.num_heads,self.k.out_features // self.num_heads, self.k.in_features) 309 | 310 | multi_head_qk = multi_head_q_weight.transpose(-2, -1).contiguous() @ multi_head_k_weight # num_heads, in_features, in_features 311 | multi_head_qk = multi_head_qk.reshape(self.num_heads*self.q.out_features,self.q.in_features)# num_heads*in_features, in_features 312 | multi_head_qk_qunat = self.qk_quant(multi_head_qk) 313 | multi_head_qk_qunat = multi_head_qk_qunat.reshape(self.num_heads, self.q.in_features,self.q.in_features) #num_heads, in_features, in_features 314 | 315 | ## W_qk@X^T torch.einsum('BNC,BACD -> BAND',quant_x,quant_qkx) 316 | # quant_x: B,C,N 317 | qkx = torch.einsum('HDC, BCN -> BHDN', multi_head_qk_qunat, quant_x.transpose(-2, -1).contiguous() ) # B, num_heads, in_features, N 318 | qkx = qkx.permute(0,3,1,2).reshape(B,N, self.num_heads * C) #B, N, num_heads*in_features 319 | qkx = self.move_qkx_b4(qkx) 320 | qkx = qkx.reshape(B,N*self.num_heads, C) 321 | quant_qkx = self.quan_a_qkx_fn(qkx) # B, num_heads*N, in_features 322 | quant_qkx = quant_qkx.reshape(B,N,self.num_heads * C) #B, N, num_heads*in_features 323 | quant_qkx = self.move_qkx_aft(quant_qkx) #B, N, num_heads*in_features 324 | quant_qkx = quant_qkx.reshape(B, N, self.num_heads, -1).permute(0, 2, 3, 1) # B, num_heads, in_features, N 325 | 326 | ## x@W_qk@X^T, quant_x: B,N,C and quant_qkx: B,num_heads, C, N 327 | xqkx = torch.einsum('BNC,BHCD -> BHND',quant_x,quant_qkx) # B, num_heads, N, N 328 | # B, num_heads, N, C // self.num_heads 329 | value_4_softmax = xqkx 330 | attn_weights = (value_4_softmax) * self.scale 331 | attn_prob = F.softmax(attn_weights, dim=-1) 332 | 333 | attn_prob = self.quan_a_softmax_fn(attn_prob) 334 | attn_prob = self.attn_drop(attn_prob) 335 | 336 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C) 337 | x = self.proj(x) 338 | x = self.proj_drop(x) 339 | return x, None 340 | 341 | class QAttention_lsq(deit_attention): 342 | def __init__(self, m: deit_attention, clip_val=2.5, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 343 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="lsq", input_quant_method="lsq", 344 | pretrained_initialized = False, 345 | **kwargs): 346 | assert type(m) == deit_attention 347 | super().__init__( 348 | dim = m.qkv.in_features, 349 | num_heads=m.num_heads, 350 | attn_drop=m.attn_drop.p, 351 | proj_drop=m.proj_drop.p, 352 | qqkkvv= m.qqkkvv 353 | ) 354 | self.weight_bits = weight_bits 355 | self.input_bits = input_bits 356 | self.input_channelwise = input_channelwise 357 | 358 | self.qkv = LSQ_w_and_act_QLinear( 359 | m = self.qkv, 360 | weight_bits = weight_bits, 361 | input_bits = input_bits, 362 | weight_channelwise = weight_channelwise, 363 | input_channelwise = input_channelwise, 364 | weight_quant_method = weight_quant_method, 365 | input_quant_method = input_quant_method, 366 | aq_learnable = aq_learnable, ## act 367 | wq_learnable = wq_learnable,## weight 368 | symmetric = True, ## act 369 | pretrained_initialized = pretrained_initialized 370 | ) 371 | self.proj = LSQ_w_and_act_QLinear( 372 | m = self.proj, 373 | weight_bits = weight_bits, 374 | input_bits = input_bits, 375 | weight_channelwise = weight_channelwise, 376 | input_channelwise = input_channelwise, 377 | weight_quant_method = weight_quant_method, 378 | input_quant_method = input_quant_method, 379 | aq_learnable = aq_learnable, ## act 380 | wq_learnable = wq_learnable,## weight 381 | symmetric = True, ## act 382 | pretrained_initialized = pretrained_initialized 383 | ) 384 | 385 | self.quan_a_q_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 386 | self.quan_a_k_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 387 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) 388 | 389 | self.move_qkv_b4 = LearnableBias(m.qkv.in_features*3) 390 | self.move_q_aft = LearnableBias(m.qkv.in_features) 391 | self.move_k_aft = LearnableBias(m.qkv.in_features) 392 | self.move_v_aft = LearnableBias(m.qkv.in_features) 393 | 394 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable) 395 | 396 | def forward(self, x): 397 | B, N, C = x.shape 398 | qkv = self.qkv(x) # B, N, 3*C 399 | if self.input_bits < 32: 400 | qkv = self.move_qkv_b4(qkv) 401 | qkv = qkv.reshape( 402 | B, N, 3, self.num_heads, C // self.num_heads 403 | ).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C // self.num_heads 404 | q, k, v = qkv[0], qkv[1], qkv[2] 405 | 406 | q = self.quan_a_q_fn(q) 407 | k = self.quan_a_k_fn(k) 408 | v = v.permute(0,2,1,3).reshape(B,N,C) 409 | v = self.quan_a_v_fn(v) # quantize along C 410 | v = v.reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3) 411 | 412 | 413 | 414 | if self.input_bits < 32: 415 | 416 | q = q.permute(0, 2, 1, 3).reshape(B, N, C) 417 | k = k.permute(0, 2, 1, 3).reshape(B, N, C) 418 | v = v.permute(0, 2, 1, 3).reshape(B, N, C) 419 | q = self.move_q_aft(q) 420 | k = self.move_k_aft(k) 421 | v = self.move_v_aft(v) 422 | 423 | q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // self.num_heads 424 | k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 425 | v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 426 | 427 | 428 | attn_weights = (q @ k.transpose(-2, -1).contiguous()) * self.scale 429 | attn_prob = F.softmax(attn_weights, dim=-1) 430 | 431 | attn_prob = self.quan_a_softmax_fn(attn_prob) 432 | attn_prob = self.attn_drop(attn_prob) 433 | 434 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C) 435 | x = self.proj(x) 436 | x = self.proj_drop(x) 437 | 438 | return x, None 439 | 440 | -------------------------------------------------------------------------------- /src/quantization/modules/qbias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LearnableBias(nn.Module): 6 | def __init__(self, out_chn): 7 | super(LearnableBias, self).__init__() 8 | self.bias = nn.Parameter(torch.zeros(out_chn), requires_grad=True) 9 | 10 | def forward(self, x): 11 | out = x + self.bias.expand_as(x) 12 | 13 | return out 14 | 15 | class LearnableBias4img(nn.Module): 16 | def __init__(self, out_chn): 17 | super(LearnableBias4img, self).__init__() 18 | self.bias = nn.Parameter(torch.zeros(out_chn), requires_grad=True) 19 | 20 | def forward(self, x): 21 | out = x + self.bias.reshape(x.shape[-1],x.shape[-2]).expand_as(x) 22 | 23 | return out 24 | 25 | -------------------------------------------------------------------------------- /src/quantization/modules/qlinear.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from numpy import clip 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .qbias import LearnableBias, LearnableBias4img 7 | from ..quantizer.statsq import StatsQuantizer 8 | from ...deit_vision_transformer import Mlp 9 | from timm.models.layers import to_2tuple 10 | from ..quantizer.lsq import LsqQuantizer, LsqQuantizerWeight, LsqQuantizer4img, LsqQuantizer4Conv2d, LsqQuantizer4head_input 11 | 12 | class LSQ_input(nn.Module): 13 | def __init__(self, bit=2,all_positive=False, learnable = True, learanbaleBiasdim = 192): 14 | super().__init__() 15 | self.input_bits = bit 16 | self.all_positive = all_positive 17 | self.learnable = learnable 18 | self.input_quant_fn = LsqQuantizer(bit=bit,all_positive=all_positive, learnable = learnable) 19 | self.move_b4 = LearnableBias(learanbaleBiasdim) 20 | self.move_aft = LearnableBias(learanbaleBiasdim) 21 | def forward(self, input): 22 | 23 | input = self.move_b4(input) 24 | input = self.input_quant_fn(input) 25 | input = self.move_aft(input) 26 | return input 27 | 28 | class QLinear(nn.Linear): 29 | 30 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 31 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 32 | pretrained_initialized = False, 33 | **kwargs): 34 | super(QLinear, self).__init__(m.in_features, m.out_features,bias=True) 35 | self.weight_bits = weight_bits 36 | self.input_bits = input_bits 37 | self.aq_learnable = aq_learnable 38 | self.wq_learnable = wq_learnable 39 | self.symmetric = symmetric 40 | self.weight_channelwise = weight_channelwise # not gonna used atm 41 | self.input_channelwise = input_channelwise 42 | self.weight_quant_method = weight_quant_method 43 | self.input_quant_method = input_quant_method 44 | self.input_quant_fn = LsqQuantizer(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable) 45 | self.pretrained_initialized = pretrained_initialized 46 | if pretrained_initialized != False: 47 | self.weight = torch.nn.Parameter(m.weight.detach()) 48 | if m.bias is not None: 49 | self.bias = torch.nn.Parameter(m.bias.detach()) 50 | if weight_quant_method == 'statsq': 51 | self.statsq_fn = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable).to(m.weight.device) 52 | else: 53 | raise ValueError("Unknown quant_method") 54 | 55 | self.move_b4 = LearnableBias(self.weight.shape[1]) 56 | self.move_aft = LearnableBias(self.weight.shape[1]) 57 | 58 | def forward(self, input): 59 | 60 | # quantize weight 61 | if self.weight_quant_method == 'statsq': 62 | weight = self.statsq_fn(self.weight) 63 | else: 64 | raise ValueError("Unknown quant_method") 65 | # quantize input 66 | input = self.move_b4(input) 67 | input = self.input_quant_fn(input) 68 | input = self.move_aft(input) 69 | out = nn.functional.linear(input, weight) 70 | if not self.bias is None: 71 | out += self.bias.view(1, -1).expand_as(out) 72 | 73 | return out 74 | 75 | def extra_repr(self): 76 | return ( 77 | f"act_bit={self.input_bits}, " 78 | f"weight_bit={self.weight_bits}, " 79 | f"act_all_positive={not self.symmetric}, " 80 | f"wq_learnable={self.wq_learnable}, " 81 | f"aq_learnable={self.aq_learnable}, " 82 | f"weight_channelwise ={self.weight_channelwise}, " 83 | f"input_channelwise ={self.input_channelwise}, " 84 | f"weight_quant_method={self.weight_quant_method}, " 85 | f"activation_quant_method={self.input_quant_method}, " 86 | f"pretrained_initialized = {self.pretrained_initialized}" 87 | ) 88 | 89 | class QMLP(Mlp): 90 | def __init__(self, *kargs, m: Mlp, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", act_layer=nn.GELU, 91 | pretrained_initialized = False, 92 | **kwargs): 93 | super().__init__( 94 | in_features = m.in_features, 95 | hidden_features = m.hidden_features, 96 | out_features = m.out_features, 97 | drop = m.drop 98 | ) 99 | 100 | out_features = m.out_features or m.in_features 101 | hidden_features = m.hidden_features or m.in_features 102 | drop_probs = to_2tuple(self.drop) 103 | 104 | self.fc1 = QLinear(m = m.fc1,weight_bits=weight_bits,input_bits=input_bits, 105 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=True,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise, 106 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized) 107 | 108 | self.act_layer = act_layer 109 | 110 | if act_layer != 'rprelu': 111 | if act_layer != 'None': 112 | self.act = act_layer() 113 | else: 114 | self.act = nn.Identity() 115 | 116 | 117 | self.drop1 = nn.Dropout(drop_probs[0]) 118 | self.fc2 = QLinear(m = m.fc2,weight_bits=weight_bits,input_bits=input_bits, 119 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=False,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise, 120 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized) 121 | self.drop2 = nn.Dropout(drop_probs[1]) 122 | 123 | def forward(self, x): 124 | 125 | x = self.fc1(x) 126 | if self.act_layer != 'rprelu': 127 | x = self.act(x) 128 | else: 129 | x = self.move1(x) 130 | x = self.act(x) 131 | x = self.move2(x) 132 | 133 | x = self.drop1(x) 134 | x = self.fc2(x) 135 | x = self.drop2(x) 136 | return x 137 | 138 | class LSQ_QConv2d(nn.Conv2d): 139 | def __init__(self, *kargs, m: torch.nn.Conv2d, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 140 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="lsq", input_quant_method="lsq", 141 | pretrained_initialized = False, 142 | **kwargs): 143 | super(LSQ_QConv2d, self).__init__(in_channels = m.in_channels, out_channels = m.out_channels, kernel_size = m.kernel_size 144 | , stride=m.stride, padding=m.padding, dilation=m.dilation, groups=m.groups, bias=True) 145 | self.weight_bits = weight_bits 146 | self.input_bits = input_bits 147 | self.aq_learnable = aq_learnable 148 | self.wq_learnable = wq_learnable 149 | self.symmetric = symmetric 150 | self.weight_channelwise = weight_channelwise # not gonna used atm 151 | self.input_channelwise = input_channelwise 152 | self.weight_quant_method = weight_quant_method 153 | self.input_quant_method = input_quant_method 154 | self.input_quant_fn = LsqQuantizer4img(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable) 155 | self.pretrained_initialized = pretrained_initialized 156 | if pretrained_initialized != False: 157 | self.weight = torch.nn.Parameter(m.weight.detach()) 158 | if m.bias is not None: 159 | self.bias = torch.nn.Parameter(m.bias.detach()) 160 | 161 | self.lsqw_fn = LsqQuantizer4Conv2d(bit=self.weight_bits, learnable=aq_learnable).to(m.weight.device) 162 | 163 | self.move_b4 = LearnableBias4img(224*224) # 3x224x224 164 | self.move_aft = LearnableBias4img(224*224) # 3x224x224 165 | 166 | def forward(self, input): 167 | # quantize weight 168 | weight = self.lsqw_fn(self.weight) 169 | 170 | # quantize input 171 | input = self.move_b4(input) 172 | input = self.input_quant_fn(input) 173 | input = self.move_aft(input) 174 | out = nn.functional.conv2d(input, weight, self.bias, self.stride, 175 | self.padding, self.dilation, self.groups) 176 | 177 | return out 178 | 179 | def extra_repr(self): 180 | return ( 181 | f"act_bit={self.input_bits}, " 182 | f"weight_bit={self.weight_bits}, " 183 | f"act_all_positive={not self.symmetric}, " 184 | f"wq_learnable={self.wq_learnable}, " 185 | f"aq_learnable={self.aq_learnable}, " 186 | f"weight_channelwise ={self.weight_channelwise}, " 187 | f"input_channelwise ={self.input_channelwise}, " 188 | f"weight_quant_method={self.weight_quant_method}, " 189 | f"activation_quant_method={self.input_quant_method}, " 190 | f"pretrained_initialized = {self.pretrained_initialized}" 191 | ) 192 | 193 | class LSQ_QLinear4head(nn.Linear): 194 | 195 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 196 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 197 | pretrained_initialized = False, 198 | **kwargs): 199 | super(LSQ_QLinear4head, self).__init__(m.in_features, m.out_features,bias=True) 200 | self.weight_bits = weight_bits 201 | self.input_bits = input_bits 202 | self.aq_learnable = aq_learnable 203 | self.wq_learnable = wq_learnable 204 | self.symmetric = symmetric 205 | self.weight_channelwise = weight_channelwise # not gonna used atm 206 | self.input_channelwise = input_channelwise 207 | self.weight_quant_method = weight_quant_method 208 | self.input_quant_method = input_quant_method 209 | self.input_quant_fn = LsqQuantizer4head_input(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable) 210 | self.pretrained_initialized = pretrained_initialized 211 | if pretrained_initialized != False: 212 | self.weight = torch.nn.Parameter(m.weight.detach()) 213 | if m.bias is not None: 214 | self.bias = torch.nn.Parameter(m.bias.detach()) 215 | if weight_quant_method == 'lsq': 216 | self.lsqw_fn = LsqQuantizerWeight(bit=self.weight_bits, per_channel=weight_channelwise ,learnable=wq_learnable).to(m.weight.device) 217 | else: 218 | raise ValueError("Unknown quant_method") 219 | 220 | self.move_b4 = LearnableBias(self.weight.shape[1]) 221 | self.move_aft = LearnableBias(self.weight.shape[1]) 222 | 223 | def forward(self, input): 224 | 225 | # quantize weight 226 | if self.weight_quant_method == 'lsq': 227 | weight = self.lsqw_fn(self.weight) 228 | else: 229 | raise ValueError("Unknown quant_method") 230 | # quantize input 231 | input = self.move_b4(input) 232 | input = self.input_quant_fn(input) 233 | input = self.move_aft(input) 234 | out = nn.functional.linear(input, weight) 235 | if not self.bias is None: 236 | out += self.bias.view(1, -1).expand_as(out) 237 | 238 | return out 239 | 240 | def extra_repr(self): 241 | return ( 242 | f"act_bit={self.input_bits}, " 243 | f"weight_bit={self.weight_bits}, " 244 | f"act_all_positive={not self.symmetric}, " 245 | f"wq_learnable={self.wq_learnable}, " 246 | f"aq_learnable={self.aq_learnable}, " 247 | f"weight_channelwise ={self.weight_channelwise}, " 248 | f"input_channelwise ={self.input_channelwise}, " 249 | f"weight_quant_method={self.weight_quant_method}, " 250 | f"activation_quant_method={self.input_quant_method}, " 251 | f"pretrained_initialized = {self.pretrained_initialized}" 252 | ) 253 | 254 | class LSQ_w_and_act_QLinear(nn.Linear): 255 | 256 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 257 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", 258 | pretrained_initialized = False, 259 | **kwargs): 260 | super(LSQ_w_and_act_QLinear, self).__init__(m.in_features, m.out_features,bias=True) 261 | self.weight_bits = weight_bits 262 | self.input_bits = input_bits 263 | self.aq_learnable = aq_learnable 264 | self.wq_learnable = wq_learnable 265 | self.symmetric = symmetric 266 | self.weight_channelwise = weight_channelwise # not gonna used atm 267 | self.input_channelwise = input_channelwise 268 | self.weight_quant_method = weight_quant_method 269 | self.input_quant_method = input_quant_method 270 | self.input_quant_fn = LsqQuantizer(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable) 271 | self.pretrained_initialized = pretrained_initialized 272 | if pretrained_initialized != False: 273 | self.weight = torch.nn.Parameter(m.weight.detach()) 274 | if m.bias is not None: 275 | self.bias = torch.nn.Parameter(m.bias.detach()) 276 | if weight_quant_method == 'lsq': 277 | self.lsqw_fn = LsqQuantizerWeight(bit=self.weight_bits, per_channel=weight_channelwise ,learnable=wq_learnable).to(m.weight.device) 278 | else: 279 | raise ValueError("Unknown quant_method") 280 | 281 | self.move_b4 = LearnableBias(self.weight.shape[1]) 282 | self.move_aft = LearnableBias(self.weight.shape[1]) 283 | 284 | def forward(self, input): 285 | 286 | # quantize weight 287 | if self.weight_quant_method == 'lsq': 288 | weight = self.lsqw_fn(self.weight) 289 | else: 290 | raise ValueError("Unknown quant_method") 291 | # quantize input 292 | input = self.move_b4(input) 293 | input = self.input_quant_fn(input) 294 | input = self.move_aft(input) 295 | out = nn.functional.linear(input, weight) 296 | if not self.bias is None: 297 | out += self.bias.view(1, -1).expand_as(out) 298 | 299 | return out 300 | 301 | def extra_repr(self): 302 | return ( 303 | f"act_bit={self.input_bits}, " 304 | f"weight_bit={self.weight_bits}, " 305 | f"act_all_positive={not self.symmetric}, " 306 | f"wq_learnable={self.wq_learnable}, " 307 | f"aq_learnable={self.aq_learnable}, " 308 | f"weight_channelwise ={self.weight_channelwise}, " 309 | f"input_channelwise ={self.input_channelwise}, " 310 | f"weight_quant_method={self.weight_quant_method}, " 311 | f"activation_quant_method={self.input_quant_method}, " 312 | f"pretrained_initialized = {self.pretrained_initialized}" 313 | ) 314 | 315 | class LSQ_w_and_act_QMLP(Mlp): 316 | def __init__(self, *kargs, m: Mlp, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, 317 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", act_layer=nn.GELU, 318 | pretrained_initialized = False, 319 | **kwargs): 320 | super().__init__( 321 | in_features = m.in_features, 322 | hidden_features = m.hidden_features, 323 | out_features = m.out_features, 324 | drop = m.drop 325 | ) 326 | 327 | out_features = m.out_features or m.in_features 328 | hidden_features = m.hidden_features or m.in_features 329 | drop_probs = to_2tuple(self.drop) 330 | 331 | self.fc1 = LSQ_w_and_act_QLinear(m = m.fc1,weight_bits=weight_bits,input_bits=input_bits, 332 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=True,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise, 333 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized) 334 | 335 | self.act_layer = act_layer 336 | 337 | if act_layer != 'rprelu': 338 | if act_layer != 'None': 339 | self.act = act_layer() 340 | else: 341 | self.act = nn.Identity() 342 | 343 | 344 | self.drop1 = nn.Dropout(drop_probs[0]) 345 | self.fc2 = LSQ_w_and_act_QLinear(m = m.fc2,weight_bits=weight_bits,input_bits=input_bits, 346 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=False,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise, 347 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized) 348 | self.drop2 = nn.Dropout(drop_probs[1]) 349 | 350 | def forward(self, x): 351 | 352 | x = self.fc1(x) 353 | if self.act_layer != 'rprelu': 354 | x = self.act(x) 355 | else: 356 | x = self.move1(x) 357 | x = self.act(x) 358 | x = self.move2(x) 359 | 360 | x = self.drop1(x) 361 | x = self.fc2(x) 362 | x = self.drop2(x) 363 | return x 364 | 365 | 366 | 367 | 368 | 369 | -------------------------------------------------------------------------------- /src/quantization/modules/utils.py: -------------------------------------------------------------------------------- 1 | from email.policy import default 2 | import os 3 | import torch 4 | # from .conv import QConv2d, QConvBn2d 5 | # from .linear import QLinear 6 | from .qlinear import LSQ_QConv2d 7 | from .qlinear import QLinear, QMLP, LSQ_w_and_act_QMLP, LSQ_w_and_act_QLinear, LSQ_QLinear4head 8 | from .attention import QAttention, QAttention_lsq 9 | from .attention import QAttention_qkreparam_4_cga 10 | from .attention import QAttention_qkreparam 11 | 12 | # from src.utils import Attention 13 | from src.deit_vision_transformer import Attention as deit_attention 14 | from src.deit_vision_transformer import Mlp 15 | from src.swin import ShiftedWindowAttention 16 | from torchvision.ops.misc import MLP as swin_MLP 17 | 18 | from .swin_attention_and_mlp import QAttention_swin, QMLP_swin, QAttention_swin_qkreparam, QAttention_swin_qkreparam_4_cga 19 | 20 | 21 | QMODULE_MAPPINGS = { 22 | torch.nn.Linear: QLinear, 23 | deit_attention: QAttention, 24 | Mlp: QMLP 25 | } 26 | 27 | ## 0: QAttention_qkreparam, 1: QAttention_qkreparam_4_cga 28 | QMODULE_MAPPINGS_QK_REPARAM = [ 29 | { 30 | torch.nn.Linear: QLinear, 31 | deit_attention: QAttention_qkreparam, 32 | Mlp: QMLP 33 | }, 34 | { 35 | torch.nn.Linear: QLinear, 36 | deit_attention: QAttention_qkreparam_4_cga, 37 | Mlp: QMLP 38 | } 39 | ] 40 | QMODULE_MAPPINGS_W_AND_ACT = { 41 | torch.nn.Linear: LSQ_w_and_act_QLinear, 42 | deit_attention: QAttention_lsq, 43 | Mlp: LSQ_w_and_act_QMLP 44 | } 45 | def get_module_by_name(model, module_name): 46 | names = module_name.split(".") 47 | module = model 48 | for name in names: 49 | module = getattr(module, name) 50 | return module 51 | 52 | 53 | def set_module_by_name(model, module_name, module): 54 | if module_name == 'head' or module_name == 'head_dist': 55 | setattr(model, module_name, module) 56 | else: 57 | names = module_name.split(".") 58 | parent = get_module_by_name(model, ".".join(names[:-1])) 59 | setattr(parent, names[-1], module) 60 | 61 | 62 | def replace_module_by_qmodule_deit(model, qconfigs, pretrained_initialized = False, 63 | qk_reparam = False, qk_reparam_type = 0, boundaryRange = 0.005): 64 | 65 | if qconfigs[list(qconfigs.keys())[0]]["weight"]["mode"] == 'lsq' and qconfigs[list(qconfigs.keys())[0]]["act"]["mode"] == 'lsq': 66 | 67 | for name, cfg in qconfigs.items(): 68 | if name == "patch_embed.proj": 69 | module = get_module_by_name(model, name) 70 | 71 | qmodule = LSQ_QConv2d( 72 | m = module, 73 | weight_bits = 8, 74 | input_bits = 8, 75 | weight_channelwise = True, 76 | input_channelwise = True, 77 | weight_quant_method = 'lsq', 78 | input_quant_method = 'lsq', 79 | aq_learnable = True, ## act 80 | wq_learnable = True,## weight 81 | act_layer = cfg["act_layer"], 82 | pretrained_initialized = pretrained_initialized 83 | ) 84 | set_module_by_name(model, name, qmodule) 85 | elif name == "head" or name == "head_dist": 86 | module = get_module_by_name(model, name) 87 | qmodule = LSQ_QLinear4head( 88 | m = module, 89 | weight_bits = 8, 90 | input_bits = 8, 91 | weight_channelwise = True, 92 | input_channelwise = True, 93 | weight_quant_method = 'lsq', 94 | input_quant_method = 'lsq', 95 | aq_learnable = True, ## act 96 | wq_learnable = True,## weight 97 | symmetric = True, ## ac 98 | act_layer = cfg["act_layer"], 99 | pretrained_initialized = pretrained_initialized 100 | ) 101 | set_module_by_name(model, name, qmodule) 102 | else: 103 | module = get_module_by_name(model, name) 104 | 105 | qmodule = QMODULE_MAPPINGS_W_AND_ACT[type(module)]( 106 | m = module, 107 | weight_bits = cfg["weight"]['bit'], 108 | input_bits = cfg["act"]['bit'], 109 | weight_channelwise = cfg["weight"]["per_channel"], 110 | input_channelwise = cfg["act"]["per_channel"], 111 | weight_quant_method = cfg["weight"]["mode"], 112 | input_quant_method = cfg["act"]["mode"], 113 | aq_learnable = cfg["act"]["learnable"], ## act 114 | wq_learnable = cfg["weight"]["learnable"],## weight 115 | symmetric = not cfg["act"]['all_positive'], ## ac 116 | act_layer = cfg["act_layer"], 117 | pretrained_initialized = pretrained_initialized 118 | ) 119 | set_module_by_name(model, name, qmodule) 120 | 121 | elif qk_reparam: 122 | if qk_reparam_type == 0: 123 | for name, cfg in qconfigs.items(): 124 | if name == "patch_embed.proj": 125 | module = get_module_by_name(model, name) 126 | qmodule = LSQ_QConv2d( 127 | m = module, 128 | weight_bits = 8, 129 | input_bits = 8, 130 | weight_channelwise = True, 131 | input_channelwise = True, 132 | weight_quant_method = 'lsq', 133 | input_quant_method = 'lsq', 134 | aq_learnable = True, ## act 135 | wq_learnable = True,## weight 136 | act_layer = cfg["act_layer"], 137 | pretrained_initialized = pretrained_initialized 138 | ) 139 | set_module_by_name(model, name, qmodule) 140 | elif name == "head" or name == "head_dist": 141 | module = get_module_by_name(model, name) 142 | qmodule = LSQ_QLinear4head( 143 | m = module, 144 | weight_bits = 8, 145 | input_bits = 8, 146 | weight_channelwise = True, 147 | input_channelwise = True, 148 | weight_quant_method = 'lsq', 149 | input_quant_method = 'lsq', 150 | aq_learnable = True, ## act 151 | wq_learnable = True,## weight 152 | symmetric = True, ## ac 153 | act_layer = cfg["act_layer"], 154 | pretrained_initialized = pretrained_initialized 155 | ) 156 | set_module_by_name(model, name, qmodule) 157 | else: 158 | module = get_module_by_name(model, name) 159 | qmodule = QMODULE_MAPPINGS_QK_REPARAM[qk_reparam_type][type(module)]( 160 | m = module, 161 | weight_bits = cfg["weight"]['bit'], 162 | input_bits = cfg["act"]['bit'], 163 | weight_channelwise = cfg["weight"]["per_channel"], 164 | input_channelwise = cfg["act"]["per_channel"], 165 | weight_quant_method = cfg["weight"]["mode"], 166 | input_quant_method = cfg["act"]["mode"], 167 | aq_learnable = cfg["act"]["learnable"], ## act 168 | wq_learnable = cfg["weight"]["learnable"],## weight 169 | # symmetric = not cfg["act"]['all_positive'], ## ac 170 | act_layer = cfg["act_layer"], 171 | pretrained_initialized = pretrained_initialized 172 | ) 173 | set_module_by_name(model, name, qmodule) 174 | elif qk_reparam_type == 1: 175 | for name, cfg in qconfigs.items(): 176 | if name == "patch_embed.proj": 177 | module = get_module_by_name(model, name) 178 | qmodule = LSQ_QConv2d( 179 | m = module, 180 | weight_bits = 8, 181 | input_bits = 8, 182 | weight_channelwise = True, 183 | input_channelwise = True, 184 | weight_quant_method = 'lsq', 185 | input_quant_method = 'lsq', 186 | aq_learnable = True, ## act 187 | wq_learnable = True,## weight 188 | act_layer = cfg["act_layer"], 189 | pretrained_initialized = pretrained_initialized 190 | ) 191 | set_module_by_name(model, name, qmodule) 192 | elif name == "head" or name == "head_dist": 193 | module = get_module_by_name(model, name) 194 | qmodule = LSQ_QLinear4head( 195 | m = module, 196 | weight_bits = 8, 197 | input_bits = 8, 198 | weight_channelwise = True, 199 | input_channelwise = True, 200 | weight_quant_method = 'lsq', 201 | input_quant_method = 'lsq', 202 | aq_learnable = True, ## act 203 | wq_learnable = True,## weight 204 | symmetric = True, ## ac 205 | act_layer = cfg["act_layer"], 206 | pretrained_initialized = pretrained_initialized 207 | ) 208 | set_module_by_name(model, name, qmodule) 209 | else: 210 | module = get_module_by_name(model, name) 211 | qmodule = QMODULE_MAPPINGS_QK_REPARAM[qk_reparam_type][type(module)]( 212 | m = module, 213 | weight_bits = cfg["weight"]['bit'], 214 | input_bits = cfg["act"]['bit'], 215 | weight_channelwise = cfg["weight"]["per_channel"], 216 | input_channelwise = cfg["act"]["per_channel"], 217 | weight_quant_method = cfg["weight"]["mode"], 218 | input_quant_method = cfg["act"]["mode"], 219 | aq_learnable = cfg["act"]["learnable"], ## act 220 | wq_learnable = cfg["weight"]["learnable"],## weight 221 | act_layer = cfg["act_layer"], 222 | pretrained_initialized = pretrained_initialized, 223 | boundaryRange = boundaryRange 224 | ) 225 | set_module_by_name(model, name, qmodule) 226 | 227 | else: ## statsq w quant 228 | for name, cfg in qconfigs.items(): 229 | if name == "patch_embed.proj": 230 | module = get_module_by_name(model, name) 231 | 232 | qmodule = LSQ_QConv2d( 233 | m = module, 234 | weight_bits = 8, 235 | input_bits = 8, 236 | weight_channelwise = True, 237 | input_channelwise = True, 238 | weight_quant_method = 'lsq', 239 | input_quant_method = 'lsq', 240 | aq_learnable = True, ## act 241 | wq_learnable = True,## weight 242 | act_layer = cfg["act_layer"], 243 | pretrained_initialized = pretrained_initialized 244 | ) 245 | set_module_by_name(model, name, qmodule) 246 | elif name == "head" or name == "head_dist": 247 | module = get_module_by_name(model, name) 248 | qmodule = LSQ_QLinear4head( 249 | m = module, 250 | weight_bits = 8, 251 | input_bits = 8, 252 | weight_channelwise = True, 253 | input_channelwise = True, 254 | weight_quant_method = 'lsq', 255 | input_quant_method = 'lsq', 256 | aq_learnable = True, ## act 257 | wq_learnable = True,## weight 258 | symmetric = True, ## ac 259 | act_layer = cfg["act_layer"], 260 | pretrained_initialized = pretrained_initialized 261 | ) 262 | set_module_by_name(model, name, qmodule) 263 | else: 264 | module = get_module_by_name(model, name) 265 | 266 | qmodule = QMODULE_MAPPINGS[type(module)]( 267 | m = module, 268 | weight_bits = cfg["weight"]['bit'], 269 | input_bits = cfg["act"]['bit'], 270 | weight_channelwise = cfg["weight"]["per_channel"], 271 | input_channelwise = cfg["act"]["per_channel"], 272 | weight_quant_method = cfg["weight"]["mode"], 273 | input_quant_method = cfg["act"]["mode"], 274 | aq_learnable = cfg["act"]["learnable"], ## act 275 | wq_learnable = cfg["weight"]["learnable"],## weight 276 | # symmetric = not cfg["act"]['all_positive'], ## ac 277 | act_layer = cfg["act_layer"], 278 | pretrained_initialized = pretrained_initialized 279 | ) 280 | set_module_by_name(model, name, qmodule) 281 | 282 | return model 283 | 284 | 285 | 286 | QMODULE_MAPPINGS_SWIN = { 287 | torch.nn.Linear: QLinear, 288 | ShiftedWindowAttention: QAttention_swin, 289 | swin_MLP: QMLP_swin 290 | } 291 | 292 | QMODULE_MAPPINGS_QK_REPARAM_SWIN = [ 293 | { 294 | torch.nn.Linear: QLinear, 295 | ShiftedWindowAttention: QAttention_swin_qkreparam, 296 | swin_MLP: QMLP_swin 297 | }, 298 | { 299 | torch.nn.Linear: QLinear, 300 | ShiftedWindowAttention: QAttention_swin_qkreparam_4_cga, 301 | swin_MLP: QMLP_swin 302 | } 303 | ] 304 | 305 | def replace_module_by_qmodule_swin(model, qconfigs, pretrained_initialized = False, qk_reparam = False, qk_reparam_type = 0, boundaryRange = 0.005): 306 | if qk_reparam: 307 | for name, cfg in qconfigs.items(): 308 | if name == "features.0.0": 309 | module = get_module_by_name(model, name) 310 | qmodule = LSQ_QConv2d( 311 | m = module, 312 | weight_bits = 8, 313 | input_bits = 8, 314 | weight_channelwise = True, 315 | input_channelwise = True, 316 | weight_quant_method = 'lsq', 317 | input_quant_method = 'lsq', 318 | aq_learnable = True, ## act 319 | wq_learnable = True,## weight 320 | act_layer = cfg["act_layer"], 321 | pretrained_initialized = pretrained_initialized 322 | ) 323 | set_module_by_name(model, name, qmodule) 324 | elif name == "head": 325 | module = get_module_by_name(model, name) 326 | qmodule = LSQ_QLinear4head( 327 | m = module, 328 | weight_bits = 8, 329 | input_bits = 8, 330 | weight_channelwise = True, 331 | input_channelwise = True, 332 | weight_quant_method = 'lsq', 333 | input_quant_method = 'lsq', 334 | aq_learnable = True, ## act 335 | wq_learnable = True,## weight 336 | symmetric = True, ## ac 337 | act_layer = cfg["act_layer"], 338 | pretrained_initialized = pretrained_initialized 339 | ) 340 | set_module_by_name(model, name, qmodule) 341 | else: 342 | module = get_module_by_name(model, name) 343 | qmodule = QMODULE_MAPPINGS_QK_REPARAM_SWIN[qk_reparam_type][type(module)]( 344 | m = module, 345 | weight_bits = cfg["weight"]['bit'], 346 | input_bits = cfg["act"]['bit'], 347 | weight_channelwise = cfg["weight"]["per_channel"], 348 | input_channelwise = cfg["act"]["per_channel"], 349 | weight_quant_method = cfg["weight"]["mode"], 350 | input_quant_method = cfg["act"]["mode"], 351 | aq_learnable = cfg["act"]["learnable"], ## act 352 | wq_learnable = cfg["weight"]["learnable"],## weight 353 | # symmetric = not cfg["act"]['all_positive'], ## ac 354 | act_layer = cfg["act_layer"], 355 | pretrained_initialized = pretrained_initialized 356 | ) 357 | set_module_by_name(model, name, qmodule) 358 | 359 | else: ## statsq w quant 360 | for name, cfg in qconfigs.items(): 361 | if name == "features.0.0": 362 | module = get_module_by_name(model, name) 363 | qmodule = LSQ_QConv2d( 364 | m = module, 365 | weight_bits = 8, 366 | input_bits = 8, 367 | weight_channelwise = True, 368 | input_channelwise = True, 369 | weight_quant_method = 'lsq', 370 | input_quant_method = 'lsq', 371 | aq_learnable = True, ## act 372 | wq_learnable = True,## weight 373 | act_layer = cfg["act_layer"], 374 | pretrained_initialized = pretrained_initialized 375 | ) 376 | set_module_by_name(model, name, qmodule) 377 | elif name == "head": 378 | module = get_module_by_name(model, name) 379 | qmodule = LSQ_QLinear4head( 380 | m = module, 381 | weight_bits = 8, 382 | input_bits = 8, 383 | weight_channelwise = True, 384 | input_channelwise = True, 385 | weight_quant_method = 'lsq', 386 | input_quant_method = 'lsq', 387 | aq_learnable = True, ## act 388 | wq_learnable = True,## weight 389 | symmetric = True, ## ac 390 | act_layer = cfg["act_layer"], 391 | pretrained_initialized = pretrained_initialized 392 | ) 393 | set_module_by_name(model, name, qmodule) 394 | else: 395 | module = get_module_by_name(model, name) 396 | 397 | qmodule = QMODULE_MAPPINGS_SWIN[type(module)]( 398 | m = module, 399 | weight_bits = cfg["weight"]['bit'], 400 | input_bits = cfg["act"]['bit'], 401 | weight_channelwise = cfg["weight"]["per_channel"], 402 | input_channelwise = cfg["act"]["per_channel"], 403 | weight_quant_method = cfg["weight"]["mode"], 404 | input_quant_method = cfg["act"]["mode"], 405 | aq_learnable = cfg["act"]["learnable"], ## act 406 | wq_learnable = cfg["weight"]["learnable"],## weight 407 | # symmetric = not cfg["act"]['all_positive'], ## ac 408 | act_layer = cfg["act_layer"], 409 | pretrained_initialized = pretrained_initialized 410 | ) 411 | set_module_by_name(model, name, qmodule) 412 | 413 | return model 414 | -------------------------------------------------------------------------------- /src/quantization/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from .lsq import LsqQuantizer, LsqQuantizerWeight 4 | from .statsq import StatsQuantizer 5 | 6 | -------------------------------------------------------------------------------- /src/quantization/quantizer/statsq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | 6 | ## create 1D mask 7 | def create_mask(s2, prob): 8 | raw = torch.zeros((s2,)) 9 | raw[:int((1-prob) * s2)] = 1.0/(1.0-prob) # set EXACTLY 30% of the pixels in the mask 10 | ridx = torch.randperm(s2) # a random permutation of the entries 11 | return raw[ridx] 12 | 13 | def round_pass(x): 14 | y = x.round() 15 | y_grad = x 16 | return (y - y_grad).detach() + y_grad 17 | 18 | def grad_scale(x, scale): 19 | y = x 20 | y_grad = x * scale 21 | return (y - y_grad).detach() + y_grad 22 | 23 | def clip(x, eps): 24 | x_clip = torch.where(x > eps, x, eps) 25 | return x - x.detach() + x_clip.detach() 26 | 27 | def modify_grad(x, freeze_inds): 28 | x = x * freeze_inds * 0 + x * (1-freeze_inds) 29 | return x 30 | 31 | 32 | class TrackOscillation(nn.Module): 33 | """ 34 | This is a wrapper of the int_forward function of a quantizer. 35 | It tracks the oscillations in integer domain. 36 | """ 37 | 38 | def __init__(self, momentum=0.01, freeze_threshold=0, use_ema_x_int=True): 39 | super(TrackOscillation, self).__init__() 40 | self.momentum = momentum 41 | 42 | self.prev_x_int = None 43 | self.prev_switch_dir = None 44 | 45 | # Statistics to log 46 | self.ema_oscillation = None 47 | self.oscillated_sum = None 48 | self.total_oscillation = None 49 | self.iters_since_reset = 0 50 | 51 | # Extra variables for weight freezing 52 | self.freeze_threshold = freeze_threshold # This should be at least 2-3x the momentum value. 53 | self.use_ema_x_int = use_ema_x_int 54 | self.frozen = None 55 | self.frozen_x_int = None 56 | self.ema_x_int = None 57 | 58 | def __call__(self, x_int, skip_tracking=False, *args, **kwargs): 59 | 60 | # Apply weight freezing 61 | if self.frozen is not None: 62 | x_int = ~self.frozen * x_int + self.frozen * self.frozen_x_int 63 | 64 | if skip_tracking: 65 | return x_int 66 | 67 | with torch.no_grad(): 68 | # Check if everything is correctly initialized, otherwise do so 69 | self.check_init(x_int) 70 | 71 | # detect difference in x_int NB we round to avoid int inaccuracies 72 | delta_x_int = torch.round(self.prev_x_int - x_int).detach() # should be {-1, 0, 1} 73 | switch_dir = torch.sign(delta_x_int) # This is {-1, 0, 1} as sign(0) is mapped to 0 74 | # binary mask for switching 75 | switched = delta_x_int != 0 76 | 77 | oscillated = (self.prev_switch_dir * switch_dir) == -1 78 | self.ema_oscillation = ( 79 | self.momentum * oscillated + (1 - self.momentum) * self.ema_oscillation 80 | ) 81 | 82 | # Update prev_switch_dir for the switch variables 83 | self.prev_switch_dir[switched] = switch_dir[switched] 84 | self.prev_x_int = x_int 85 | self.oscillated_sum = oscillated.sum() 86 | self.total_oscillation += oscillated 87 | self.iters_since_reset += 1 88 | 89 | # Freeze some weights 90 | if self.freeze_threshold > 0: 91 | freeze_weights = self.ema_oscillation > self.freeze_threshold 92 | self.frozen[freeze_weights] = True # Set them to frozen 93 | if self.use_ema_x_int: 94 | self.frozen_x_int[freeze_weights] = torch.round(self.ema_x_int[freeze_weights]) 95 | # Update x_int EMA which can be used for freezing 96 | self.ema_x_int = self.momentum * x_int + (1 - self.momentum) * self.ema_x_int 97 | else: 98 | self.frozen_x_int[freeze_weights] = x_int[freeze_weights] 99 | 100 | return x_int 101 | 102 | def check_init(self, x_int): 103 | if self.prev_x_int is None: 104 | # Init prev switch dir to 0 105 | self.prev_switch_dir = torch.zeros_like(x_int) 106 | self.prev_x_int = x_int.detach() # Not sure if needed, don't think so 107 | self.ema_oscillation = torch.zeros_like(x_int) 108 | self.oscillated_sum = 0 109 | self.total_oscillation = torch.zeros_like(x_int) 110 | else: 111 | assert ( 112 | self.prev_x_int.shape == x_int.shape 113 | ), "Tracking shape does not match current tensor shape." 114 | 115 | # For weight freezing 116 | if self.frozen is None and self.freeze_threshold > 0: 117 | self.frozen = torch.zeros_like(x_int, dtype=torch.bool) 118 | self.frozen_x_int = torch.zeros_like(x_int) 119 | if self.use_ema_x_int: 120 | self.ema_x_int = x_int.detach().clone() 121 | 122 | class StatsQuantizer(nn.Module): 123 | def __init__(self, num_bits, clip_learnable): 124 | super(StatsQuantizer, self).__init__() 125 | self.num_bits = num_bits 126 | init_act_clip_val = 2.0 127 | 128 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False) 129 | 130 | self.s = None 131 | 132 | 133 | def forward(self, weight): 134 | 135 | real_weights = weight 136 | 137 | if len(weight.shape) == 2: 138 | scaling_factor = 2 * torch.mean(abs(real_weights),dim=1,keepdim=True) # dim, 1 139 | elif len(weight.shape) == 3: 140 | scaling_factor = 2 * torch.mean(torch.mean(abs(real_weights),dim=-1,keepdim=True),dim=0,keepdim=True) # 1, dim, 1 141 | 142 | scaling_factor = scaling_factor.detach() 143 | self.s = scaling_factor.squeeze().cpu() 144 | scaled_weights = real_weights/scaling_factor 145 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6) 146 | n = float(2 ** (self.num_bits - 1)) 147 | quan_weights_no_grad = scaling_factor * ((torch.round((cliped_weights) * n - 0.5 ) + 0.5) / n) 148 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights 149 | 150 | return quan_weights 151 | 152 | 153 | 154 | class StatsQuantizer_specific_4_qkreparam_cga(nn.Module): 155 | def __init__(self, num_bits, clip_learnable, boundaryRange = 0.005): 156 | super(StatsQuantizer_specific_4_qkreparam_cga, self).__init__() 157 | 158 | self.num_bits = num_bits 159 | init_act_clip_val = 2.0 160 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False) 161 | self.s = None 162 | self.boundaryRange = boundaryRange 163 | 164 | def forward(self, weight): 165 | 166 | real_weights = weight 167 | 168 | if len(weight.shape) == 2: 169 | scaling_factor = 2 * torch.mean(abs(real_weights),dim=1,keepdim=True) # dim, 1 170 | elif len(weight.shape) == 3: 171 | scaling_factor = 2 * torch.mean(torch.mean(abs(real_weights),dim=-1,keepdim=True),dim=0,keepdim=True) # 1, dim, 1 172 | 173 | 174 | scaling_factor = scaling_factor.detach() 175 | self.s = scaling_factor.squeeze().cpu() 176 | scaled_weights = real_weights/scaling_factor 177 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6) 178 | n = float(2 ** (self.num_bits - 1)) 179 | b4_round = (cliped_weights) * n - 0.5 180 | 181 | if self.training: 182 | not_freeze_idx = torch.zeros_like(real_weights).cuda() 183 | for i in np.arange(start=-(2**(self.num_bits - 1)),stop=(2**(self.num_bits - 1) - 1)): # 0.5 - boundaryRange < x < 0.5 + boundaryRange 184 | within_boundary = ((b4_round - i) <= (0.5 + self.boundaryRange)) * ((b4_round - i) >= (0.5 - self.boundaryRange)) #idx of # 0.5 - boundaryRange < x < 0.5 + boundaryRange 185 | not_freeze_idx = not_freeze_idx + within_boundary.float() 186 | 187 | freeze_idx = 1.0-not_freeze_idx 188 | b4_round = b4_round.detach() * freeze_idx + b4_round* (1-freeze_idx) 189 | 190 | quan_weights_no_grad = scaling_factor * ((torch.round( b4_round ) + 0.5) / n) 191 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights 192 | 193 | return quan_weights 194 | 195 | 196 | class StatsQuantizer_4d(nn.Module): # B, num_heads, N, in_features 197 | def __init__(self, num_bits, clip_learnable): 198 | super(StatsQuantizer_4d, self).__init__() 199 | 200 | self.num_bits = num_bits 201 | init_act_clip_val = 2.0 202 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False) 203 | self.s = None 204 | 205 | def forward(self, weight): 206 | 207 | real_weights = weight 208 | 209 | scaling_factor = 2 * torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=1,keepdim=True),dim=0,keepdim=True) 210 | 211 | scaling_factor = scaling_factor.detach() 212 | self.s = scaling_factor.squeeze().cpu() 213 | scaled_weights = real_weights/scaling_factor 214 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6) 215 | n = float(2 ** (self.num_bits - 1)) 216 | quan_weights_no_grad = scaling_factor * ((torch.round((cliped_weights) * n - 0.5 ) + 0.5) / n) 217 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights 218 | 219 | return quan_weights 220 | 221 | 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /src/quantization/utils.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .modules.qlinear import QLinear, LSQ_w_and_act_QLinear 6 | from timm.loss import SoftTargetCrossEntropy 7 | 8 | def unitwise_norm(x, norm_type=2.0): 9 | if x.ndim <= 1: 10 | return x.norm(norm_type) 11 | else: 12 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 13 | 14 | 15 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 16 | if isinstance(parameters, torch.Tensor): 17 | parameters = [parameters] 18 | for p in parameters: 19 | if p.grad is None: 20 | continue 21 | p_data = p.detach() 22 | g_data = p.grad.detach() 23 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 24 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 25 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 26 | new_grad = torch.where(grad_norm < max_norm, g_data, clipped_grad) 27 | p.grad.detach().copy_(new_grad) 28 | 29 | class Multi_KLLossSoft(torch.nn.modules.loss._Loss): 30 | def forward(self, output, target, T=1.0): 31 | output = output[0] if isinstance(output, tuple) else output 32 | target = target[0] if isinstance(target, tuple) else target 33 | output, target = output / T, target / T 34 | target_prob = F.softmax(target, dim=1) 35 | output_log_prob = F.log_softmax(output, dim=1) 36 | loss = - torch.sum(target_prob * output_log_prob, dim=1) 37 | if self.reduction == "mean": 38 | return loss.mean() 39 | elif self.reduction == "sum": 40 | return loss.sum() 41 | else: 42 | return loss 43 | 44 | class KLLossSoft(torch.nn.modules.loss._Loss): 45 | def forward(self, output, target, T=1.0): 46 | output = output[0] if isinstance(output, tuple) else output 47 | target = target[0] if isinstance(target, tuple) else target 48 | output, target = output / T, target / T 49 | target_prob = F.softmax(target, dim=1) 50 | output_log_prob = F.log_softmax(output, dim=1) 51 | loss = - torch.sum(target_prob * output_log_prob, dim=1) 52 | if self.reduction == "mean": 53 | return loss.mean() 54 | elif self.reduction == "sum": 55 | return loss.sum() 56 | else: 57 | return loss 58 | 59 | class KDLossSoftandHard(torch.nn.Module): 60 | def __init__(self) -> None: 61 | super().__init__() 62 | self.KLSoft = KLLossSoft() 63 | self.Hard = nn.CrossEntropyLoss() 64 | 65 | def forward(self, output, hard_target, soft_target): 66 | 67 | if isinstance(output, tuple): 68 | cls_output = output[0] 69 | dist_output = output[1] 70 | soft_loss = self.KLSoft(dist_output,soft_target) 71 | hard_loss = self.Hard(cls_output, hard_target) 72 | 73 | else: 74 | soft_loss = self.KLSoft(output,soft_target) 75 | hard_loss = self.Hard(output, hard_target) 76 | 77 | return soft_loss + hard_loss 78 | 79 | def dampening_loss(w_fp, w_q, x_min, x_max): 80 | # L &= (s*w_{int} - w)^2 81 | # We also need to add clipping for both cases, we can do so by using the forward 82 | # this is also clipped and our target 83 | # clamp w in FP32 domain to not change range learning (min(max) is needed for per-channel) 84 | w_fp_clip = torch.min(torch.max(w_fp, x_min), x_max) 85 | loss = (w_q - w_fp_clip) ** 2 86 | 87 | return loss.sum() 88 | 89 | class DampeningLoss(torch.nn.Module): 90 | def __init__(self, weighting=1.0, weight_quant_method = 'nu2u') -> None: 91 | super().__init__() 92 | """ 93 | Calculates the dampening loss for all weights in a given quantized model. It is 94 | expected that all quantized weights are in a Hijacker module. 95 | 96 | """ 97 | self.weighting = weighting 98 | self.weight_quant_method = weight_quant_method 99 | 100 | def forward(self, model): 101 | total_bin_loss = 0 102 | for name, module in model.named_modules(): 103 | if isinstance(module, QLinear) or isinstance(module, LSQ_w_and_act_QLinear) : 104 | # print(name,"calculate dampening loss") 105 | # FP32 weight tensor, potential folded but before quantization 106 | weight = module.weight 107 | # The matching weight quantizer (not manager, direct quantizer class) 108 | if self.weight_quant_method == 'lsq': 109 | weight_q = module.lsqw_fn(weight).detach() 110 | weight_q_min = (module.lsqw_fn.thd_neg * module.lsqw_fn.s).unsqueeze(dim=-1) 111 | weight_q_max = (module.lsqw_fn.thd_pos * module.lsqw_fn.s).unsqueeze(dim=-1) 112 | 113 | elif self.weight_quant_method == 'nu2u': 114 | weight_q = module.nu2u_fn(weight).detach() 115 | weight_q_min, _ = torch.min(weight_q, 1) 116 | weight_q_min = weight_q_min.unsqueeze(dim=-1) 117 | weight_q_max, _ = torch.max(weight_q, 1) 118 | weight_q_max = weight_q_max.unsqueeze(dim=-1) 119 | 120 | total_bin_loss += dampening_loss(weight, weight_q, weight_q_min, weight_q_max) 121 | return total_bin_loss * self.weighting 122 | 123 | class KDLossSoftandHard_dampening(torch.nn.Module): 124 | def __init__(self, weight_quant_method) -> None: 125 | super().__init__() 126 | self.KLSoft = KLLossSoft() 127 | self.Hard = nn.CrossEntropyLoss() 128 | self.dampening_loss = DampeningLoss(weighting=0, weight_quant_method=weight_quant_method) 129 | 130 | def forward(self, output, hard_target, soft_target, model): 131 | 132 | if isinstance(output, tuple): 133 | cls_output = output[0] 134 | dist_output = output[1] 135 | soft_loss = self.KLSoft(dist_output,soft_target) 136 | hard_loss = self.Hard(cls_output, hard_target) 137 | 138 | else: 139 | soft_loss = self.KLSoft(output,soft_target) 140 | hard_loss = self.Hard(output, hard_target) 141 | 142 | dampening_loss = self.dampening_loss(model) ## as for now only works for LSQ_QLinea AND LSQ_w_and_act_QLinear 143 | 144 | return soft_loss + hard_loss + dampening_loss 145 | 146 | class KDLossSoftandSoftTargetCE(torch.nn.Module): 147 | def __init__(self) -> None: 148 | super().__init__() 149 | self.KLSoft = KLLossSoft() 150 | self.Hard = SoftTargetCrossEntropy() 151 | 152 | def forward(self, output, hard_target, soft_target): 153 | 154 | if isinstance(output, tuple): 155 | cls_output = output[0] 156 | dist_output = output[1] 157 | soft_loss = self.KLSoft(dist_output,soft_target) 158 | hard_loss = self.Hard(cls_output, hard_target) 159 | 160 | else: 161 | soft_loss = self.KLSoft(output,soft_target) 162 | hard_loss = self.Hard(output, hard_target) 163 | 164 | return soft_loss + hard_loss 165 | 166 | def att_loss_r2b(Q_s, Q_t): 167 | Q_s_norm = Q_s / torch.norm(Q_s, p=2) 168 | Q_t_norm = Q_t / torch.norm(Q_t, p=2) 169 | tmp = Q_s_norm - Q_t_norm 170 | loss = torch.norm(tmp, p=2) 171 | return loss 172 | 173 | def direction_matching_distillation(student_scores, teacher_scores): 174 | tmp_loss = 0. 175 | # new_teacher_scores = [teacher_scores[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] 176 | for student_score, teacher_score in zip(student_scores, teacher_scores): 177 | student_score = torch.where(student_score <= -1e2, 178 | torch.zeros_like(student_score).to(student_scores[0].device), 179 | student_score) 180 | teacher_score = torch.where(teacher_score <= -1e2, 181 | torch.zeros_like(teacher_score).to(student_scores[0].device), 182 | teacher_score) 183 | tmp_loss += att_loss_r2b(student_score, teacher_score) 184 | return tmp_loss 185 | 186 | class KDLossSoftandHard_qk(torch.nn.Module): 187 | def __init__(self) -> None: 188 | super().__init__() 189 | self.KLSoft = KLLossSoft() 190 | self.Hard = nn.CrossEntropyLoss() 191 | 192 | def forward(self, student_logit, student_attn_info ,target, teacher_logit, teacher_attn_info): 193 | 194 | student_q = [] 195 | teacher_q = [] 196 | student_k = [] 197 | teacher_k = [] 198 | for layer in range(len(student_attn_info)): 199 | student_q.append(student_attn_info[layer][1]) 200 | teacher_q.append(teacher_attn_info[layer][1]) 201 | student_k.append(student_attn_info[layer][2]) 202 | teacher_k.append(teacher_attn_info[layer][2]) 203 | 204 | if isinstance(student_logit, tuple): 205 | cls_output = student_logit[0] 206 | dist_output = student_logit[1] 207 | soft_loss = self.KLSoft(dist_output, teacher_logit) 208 | hard_loss = self.Hard(cls_output, target) 209 | 210 | else: 211 | soft_loss = self.KLSoft(student_logit,teacher_logit) 212 | hard_loss = self.Hard(student_logit, target) 213 | 214 | 215 | q_loss = direction_matching_distillation(student_q, teacher_q) 216 | k_loss = direction_matching_distillation(student_k, teacher_k) 217 | 218 | 219 | return soft_loss + hard_loss + q_loss + k_loss 220 | 221 | class KDLossSoftandHard_qkv(torch.nn.Module): 222 | def __init__(self) -> None: 223 | super().__init__() 224 | self.KLSoft = KLLossSoft() 225 | self.Hard = nn.CrossEntropyLoss() 226 | 227 | def forward(self, student_logit, student_attn_info ,target, teacher_logit, teacher_attn_info): 228 | 229 | student_q = [] 230 | teacher_q = [] 231 | student_k = [] 232 | teacher_k = [] 233 | student_v = [] 234 | teacher_v = [] 235 | for layer in range(len(student_attn_info)): 236 | student_q.append(student_attn_info[layer][1]) 237 | teacher_q.append(teacher_attn_info[layer][1]) 238 | student_k.append(student_attn_info[layer][2]) 239 | teacher_k.append(teacher_attn_info[layer][2]) 240 | student_v.append(student_attn_info[layer][3]) 241 | teacher_v.append(teacher_attn_info[layer][3]) 242 | 243 | if isinstance(student_logit, tuple): 244 | cls_output = student_logit[0] 245 | dist_output = student_logit[1] 246 | soft_loss = self.KLSoft(dist_output, teacher_logit) 247 | hard_loss = self.Hard(cls_output, target) 248 | 249 | else: 250 | soft_loss = self.KLSoft(student_logit,teacher_logit) 251 | hard_loss = self.Hard(student_logit, target) 252 | 253 | 254 | q_loss = direction_matching_distillation(student_q, teacher_q) 255 | k_loss = direction_matching_distillation(student_k, teacher_k) 256 | v_loss = direction_matching_distillation(student_v, teacher_v) 257 | 258 | return soft_loss + hard_loss + q_loss + k_loss + v_loss 259 | 260 | class KLTokenMSELoss(torch.nn.Module): 261 | def __init__( 262 | self, 263 | alpha: float = 0.5, 264 | kd_type: str = "last", 265 | reduction: str = "mean", 266 | ): 267 | super().__init__() 268 | self.reduction = reduction 269 | self.alpha = alpha 270 | self.kl_loss = KLLossSoft(reduction=reduction) 271 | self.mse_loss = nn.MSELoss(reduction=reduction) 272 | self.kd_type = kd_type 273 | 274 | def _kl_loss(self, output, target): 275 | return self.kl_loss(output, target) 276 | 277 | def _mse_loss(self, output, target): 278 | mse_loss = 0 279 | if self.kd_type == "last": 280 | if isinstance(output, torch.Tensor): 281 | _, N, _ = target.size() 282 | mse_loss = self.mse_loss(output[:, -N:], target) 283 | else: 284 | _, N, _ = target[-1].size() 285 | mse_loss = self.mse_loss(output[-1][:, -N:], target[-1]) 286 | elif self.kd_type == "all": 287 | if isinstance(output, torch.Tensor): 288 | _, N, _ = target.size() 289 | mse_loss = self.mse_loss(output[:, -N:], target) 290 | else: 291 | assert len(output) == len(target) 292 | for i in range(len(output)): 293 | _, N, _ = target[i].size() 294 | mse_loss += self.mse_loss(output[i][:, -N:], target[i]) 295 | mse_loss = mse_loss / len(output) 296 | else: 297 | raise NotImplementedError 298 | return mse_loss 299 | 300 | def forward(self, output, target): 301 | assert len(output) == len(target) 302 | 303 | kl_loss = self.kl_loss(output[0], target[0]) 304 | mse_loss = self._mse_loss(output[1], target[1]) 305 | loss = kl_loss + self.alpha * mse_loss 306 | # print(f"KL loss {kl_loss}, MSE loss {mse_loss}, total loss {loss}") 307 | 308 | return loss 309 | 310 | -------------------------------------------------------------------------------- /src/registry.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def register_model(func): 4 | """ 5 | Fallback wrapper in case timm isn't installed 6 | """ 7 | return func 8 | -------------------------------------------------------------------------------- /src/swin.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, Callable, List, Any 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | 8 | from torchvision.ops.misc import MLP, Permute 9 | from torchvision.ops.stochastic_depth import StochasticDepth 10 | from torchvision.transforms._presets import ImageClassification, InterpolationMode 11 | from torchvision.utils import _log_api_usage_once 12 | from torchvision.models._api import WeightsEnum, Weights 13 | from torchvision.models._meta import _IMAGENET_CATEGORIES 14 | from torchvision.models._utils import _ovewrite_named_param 15 | import math 16 | 17 | from timm.models.registry import register_model 18 | 19 | __all__ = [ 20 | "SwinTransformer", 21 | "Swin_T_Weights", 22 | "swin_t" 23 | ] 24 | 25 | 26 | class PatchMerging(nn.Module): 27 | """Patch Merging Layer. 28 | Args: 29 | dim (int): Number of input channels. 30 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 31 | """ 32 | 33 | def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): 34 | super().__init__() 35 | _log_api_usage_once(self) 36 | self.dim = dim 37 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 38 | self.norm = norm_layer(4 * dim) 39 | 40 | def forward(self, x): 41 | """ 42 | Args: 43 | x (Tensor): input tensor with expected layout of [..., H, W, C] 44 | Returns: 45 | Tensor with layout of [..., H/2, W/2, 2*C] 46 | """ 47 | features_x, attn_x = x 48 | H, W, _ = features_x.shape[-3:] 49 | features_x = F.pad(features_x, (0, 0, 0, W % 2, 0, H % 2)) 50 | 51 | features_x0 = features_x[..., 0::2, 0::2, :] # ... H/2 W/2 C 52 | features_x1 = features_x[..., 1::2, 0::2, :] # ... H/2 W/2 C 53 | features_x2 = features_x[..., 0::2, 1::2, :] # ... H/2 W/2 C 54 | features_x3 = features_x[..., 1::2, 1::2, :] # ... H/2 W/2 C 55 | features_x = torch.cat([features_x0, features_x1, features_x2, features_x3], -1) # ... H/2 W/2 4*C 56 | 57 | features_x = self.norm(features_x) 58 | features_x = self.reduction(features_x) # ... H/2 W/2 2*C 59 | return features_x, attn_x 60 | 61 | 62 | def shifted_window_attention( 63 | input: Tensor, 64 | qkv_weight: Tensor, 65 | proj_weight: Tensor, 66 | relative_position_bias: Tensor, 67 | window_size: List[int], 68 | num_heads: int, 69 | shift_size: List[int], 70 | attention_dropout: float = 0.0, 71 | dropout: float = 0.0, 72 | qkv_bias: Optional[Tensor] = None, 73 | proj_bias: Optional[Tensor] = None, 74 | qqkkvv: bool = False 75 | ): 76 | """ 77 | Window based multi-head self attention (W-MSA) module with relative position bias. 78 | It supports both of shifted and non-shifted window. 79 | Args: 80 | input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. 81 | qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. 82 | proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. 83 | relative_position_bias (Tensor): The learned relative position bias added to attention. 84 | window_size (List[int]): Window size. 85 | num_heads (int): Number of attention heads. 86 | shift_size (List[int]): Shift size for shifted window attention. 87 | attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. 88 | dropout (float): Dropout ratio of output. Default: 0.0. 89 | qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. 90 | proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. 91 | Returns: 92 | Tensor[N, H, W, C]: The output tensor after shifted window attention. 93 | """ 94 | B, H, W, C = input.shape 95 | # pad feature maps to multiples of window size 96 | pad_r = (window_size[1] - W % window_size[1]) % window_size[1] 97 | pad_b = (window_size[0] - H % window_size[0]) % window_size[0] 98 | x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) 99 | _, pad_H, pad_W, _ = x.shape 100 | 101 | # If window size is larger than feature size, there is no need to shift window 102 | if window_size[0] >= pad_H: 103 | shift_size[0] = 0 104 | if window_size[1] >= pad_W: 105 | shift_size[1] = 0 106 | 107 | # cyclic shift 108 | if sum(shift_size) > 0: 109 | x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) 110 | 111 | # partition windows 112 | num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) 113 | x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) 114 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C 115 | 116 | # multi-head attention 117 | qkv = F.linear(x, qkv_weight, qkv_bias) 118 | qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) 119 | q, k, v = qkv[0], qkv[1], qkv[2] 120 | 121 | attn = q.matmul(k.transpose(-2, -1)) * (C // num_heads) ** -0.5 122 | # add relative position bias 123 | attn = attn + relative_position_bias 124 | 125 | if sum(shift_size) > 0: 126 | # generate attention mask 127 | attn_mask = x.new_zeros((pad_H, pad_W)) 128 | h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) 129 | w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) 130 | count = 0 131 | for h in h_slices: 132 | for w in w_slices: 133 | attn_mask[h[0] : h[1], w[0] : w[1]] = count 134 | count += 1 135 | attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) 136 | attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) 137 | attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) 138 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 139 | attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) 140 | attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) 141 | attn = attn.view(-1, num_heads, x.size(1), x.size(1)) 142 | 143 | attn = F.softmax(attn, dim=-1) 144 | attn = F.dropout(attn, p=attention_dropout) 145 | 146 | x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) 147 | x = F.linear(x, proj_weight, proj_bias) 148 | x = F.dropout(x, p=dropout) 149 | 150 | # reverse windows 151 | x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) 152 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) 153 | 154 | # reverse cyclic shift 155 | if sum(shift_size) > 0: 156 | x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) 157 | 158 | # unpad features 159 | x = x[:, :H, :W, :].contiguous() 160 | if qqkkvv: 161 | q_score = torch.matmul(q, q.transpose(-1, -2)) 162 | q_score = q_score / math.sqrt(C // num_heads) 163 | k_score = torch.matmul(k, k.transpose(-1, -2)) 164 | k_score = k_score / math.sqrt(C // num_heads) 165 | v_score = torch.matmul(v, v.transpose(-1, -2)) 166 | v_score = v_score / math.sqrt(C // num_heads) 167 | 168 | return x, (attn, q_score, k_score, v_score) 169 | else: 170 | return x, None 171 | 172 | 173 | torch.fx.wrap("shifted_window_attention") 174 | 175 | 176 | class ShiftedWindowAttention(nn.Module): 177 | """ 178 | See :func:`shifted_window_attention`. 179 | """ 180 | 181 | def __init__( 182 | self, 183 | dim: int, 184 | window_size: List[int], 185 | shift_size: List[int], 186 | num_heads: int, 187 | qkv_bias: bool = True, 188 | proj_bias: bool = True, 189 | attention_dropout: float = 0.0, 190 | dropout: float = 0.0, 191 | qqkkvv: bool = False, 192 | ): 193 | super().__init__() 194 | if len(window_size) != 2 or len(shift_size) != 2: 195 | raise ValueError("window_size and shift_size must be of length 2") 196 | self.dim = dim 197 | self.window_size = window_size 198 | self.shift_size = shift_size 199 | self.num_heads = num_heads 200 | self.attention_dropout = attention_dropout 201 | self.dropout = dropout 202 | self.qqkkvv = qqkkvv 203 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 204 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 205 | 206 | # define a parameter table of relative position bias 207 | self.relative_position_bias_table = nn.Parameter( 208 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) 209 | ) # 2*Wh-1 * 2*Ww-1, nH 210 | 211 | # get pair-wise relative position index for each token inside the window 212 | coords_h = torch.arange(self.window_size[0]) 213 | coords_w = torch.arange(self.window_size[1]) 214 | coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww 215 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 216 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 217 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 218 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 219 | relative_coords[:, :, 1] += self.window_size[1] - 1 220 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 221 | relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww 222 | self.register_buffer("relative_position_index", relative_position_index) 223 | 224 | nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) 225 | 226 | def forward(self, x: Tensor): 227 | """ 228 | Args: 229 | x (Tensor): Tensor with layout of [B, H, W, C] 230 | Returns: 231 | Tensor with same layout as input, i.e. [B, H, W, C] 232 | """ 233 | 234 | N = self.window_size[0] * self.window_size[1] 235 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] 236 | relative_position_bias = relative_position_bias.view(N, N, -1) 237 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) 238 | 239 | return shifted_window_attention( 240 | x, 241 | self.qkv.weight, 242 | self.proj.weight, 243 | relative_position_bias, 244 | self.window_size, 245 | self.num_heads, 246 | shift_size=self.shift_size, 247 | attention_dropout=self.attention_dropout, 248 | dropout=self.dropout, 249 | qkv_bias=self.qkv.bias, 250 | proj_bias=self.proj.bias, 251 | qqkkvv = self.qqkkvv 252 | ) 253 | 254 | 255 | class SwinTransformerBlock(nn.Module): 256 | """ 257 | Swin Transformer Block. 258 | Args: 259 | dim (int): Number of input channels. 260 | num_heads (int): Number of attention heads. 261 | window_size (List[int]): Window size. 262 | shift_size (List[int]): Shift size for shifted window attention. 263 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. 264 | dropout (float): Dropout rate. Default: 0.0. 265 | attention_dropout (float): Attention dropout rate. Default: 0.0. 266 | stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. 267 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 268 | attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention 269 | """ 270 | 271 | def __init__( 272 | self, 273 | dim: int, 274 | num_heads: int, 275 | window_size: List[int], 276 | shift_size: List[int], 277 | mlp_ratio: float = 4.0, 278 | dropout: float = 0.0, 279 | qqkkvv: bool = False, 280 | attention_dropout: float = 0.0, 281 | stochastic_depth_prob: float = 0.0, 282 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 283 | attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, 284 | ): 285 | super().__init__() 286 | _log_api_usage_once(self) 287 | 288 | self.norm1 = norm_layer(dim) 289 | self.attn = attn_layer( 290 | dim, 291 | window_size, 292 | shift_size, 293 | num_heads, 294 | attention_dropout=attention_dropout, 295 | dropout=dropout, 296 | qqkkvv=qqkkvv 297 | ) 298 | self.qqkkvv = qqkkvv 299 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") 300 | self.norm2 = norm_layer(dim) 301 | self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) 302 | 303 | for m in self.mlp.modules(): 304 | if isinstance(m, nn.Linear): 305 | nn.init.xavier_uniform_(m.weight) 306 | if m.bias is not None: 307 | nn.init.normal_(m.bias, std=1e-6) 308 | 309 | def forward(self, x): 310 | 311 | x = x[0] 312 | 313 | if self.qqkkvv: 314 | temp_x, attn_mtrx = self.attn(self.norm1(x)) 315 | x = x + self.stochastic_depth(temp_x) 316 | x = x + self.stochastic_depth(self.mlp(self.norm2(x))) 317 | return x, attn_mtrx 318 | else: 319 | temp_x, _ = self.attn(self.norm1(x)) 320 | x = x + self.stochastic_depth(temp_x) 321 | x = x + self.stochastic_depth(self.mlp(self.norm2(x))) 322 | return x, None 323 | 324 | 325 | class SwinTransformer(nn.Module): 326 | """ 327 | Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using 328 | Shifted Windows" `_ paper. 329 | Args: 330 | patch_size (List[int]): Patch size. 331 | embed_dim (int): Patch embedding dimension. 332 | depths (List(int)): Depth of each Swin Transformer layer. 333 | num_heads (List(int)): Number of attention heads in different layers. 334 | window_size (List[int]): Window size. 335 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. 336 | dropout (float): Dropout rate. Default: 0.0. 337 | attention_dropout (float): Attention dropout rate. Default: 0.0. 338 | stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. 339 | num_classes (int): Number of classes for classification head. Default: 1000. 340 | block (nn.Module, optional): SwinTransformer Block. Default: None. 341 | norm_layer (nn.Module, optional): Normalization layer. Default: None. 342 | """ 343 | 344 | def __init__( 345 | self, 346 | patch_size: List[int], 347 | embed_dim: int, 348 | depths: List[int], 349 | num_heads: List[int], 350 | window_size: List[int], 351 | mlp_ratio: float = 4.0, 352 | dropout: float = 0.0, 353 | qqkkvv: bool = False, 354 | attention_dropout: float = 0.0, 355 | stochastic_depth_prob: float = 0.0, 356 | num_classes: int = 1000, 357 | norm_layer: Optional[Callable[..., nn.Module]] = None, 358 | block: Optional[Callable[..., nn.Module]] = None, 359 | ): 360 | super().__init__() 361 | _log_api_usage_once(self) 362 | self.num_classes = num_classes 363 | 364 | if block is None: 365 | block = SwinTransformerBlock 366 | 367 | if norm_layer is None: 368 | norm_layer = partial(nn.LayerNorm, eps=1e-5) 369 | 370 | layers: List[nn.Module] = [] 371 | # split image into non-overlapping patches 372 | layers.append( 373 | nn.Sequential( 374 | nn.Conv2d( 375 | 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) 376 | ), 377 | Permute([0, 2, 3, 1]), 378 | norm_layer(embed_dim), 379 | ) 380 | ) 381 | 382 | 383 | total_stage_blocks = sum(depths) 384 | stage_block_id = 0 385 | # build SwinTransformer blocks 386 | for i_stage in range(len(depths)): 387 | stage: List[nn.Module] = [] 388 | dim = embed_dim * 2 ** i_stage 389 | for i_layer in range(depths[i_stage]): 390 | # adjust stochastic depth probability based on the depth of the stage block 391 | sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) 392 | stage.append( 393 | block( 394 | dim, 395 | num_heads[i_stage], 396 | window_size=window_size, 397 | shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], 398 | mlp_ratio=mlp_ratio, 399 | dropout=dropout, 400 | qqkkvv=qqkkvv, 401 | attention_dropout=attention_dropout, 402 | stochastic_depth_prob=sd_prob, 403 | norm_layer=norm_layer, 404 | ) 405 | ) 406 | stage_block_id += 1 407 | layers.append(nn.Sequential(*stage)) 408 | # add patch merging layer 409 | if i_stage < (len(depths) - 1): 410 | layers.append(PatchMerging(dim, norm_layer)) 411 | self.features = nn.Sequential(*layers) 412 | self.qqkkvv = qqkkvv 413 | num_features = embed_dim * 2 ** (len(depths) - 1) 414 | self.norm = norm_layer(num_features) 415 | self.avgpool = nn.AdaptiveAvgPool2d(1) 416 | self.head = nn.Linear(num_features, num_classes) 417 | 418 | for m in self.modules(): 419 | if isinstance(m, nn.Linear): 420 | nn.init.trunc_normal_(m.weight, std=0.02) 421 | if m.bias is not None: 422 | nn.init.zeros_(m.bias) 423 | 424 | 425 | def forward_features(self, x): 426 | 427 | ## patch embeddings 428 | x = self.features[0](x) 429 | x = (x, None) 430 | ## blocks 431 | attn_matrixs = [] 432 | for block in self.features[1:]: 433 | x, attn_matrix = block(x) 434 | x = (x, None) 435 | attn_matrixs.append(attn_matrix) 436 | 437 | return x[0], attn_matrixs 438 | 439 | 440 | 441 | def forward(self, x): 442 | x, attn_matrixs = self.forward_features(x) 443 | x = self.norm(x) 444 | x = x.permute(0, 3, 1, 2) 445 | x = self.avgpool(x) 446 | x = torch.flatten(x, 1) 447 | x = self.head(x) 448 | return x, attn_matrixs 449 | 450 | 451 | def _swin_transformer( 452 | patch_size: List[int], 453 | embed_dim: int, 454 | depths: List[int], 455 | num_heads: List[int], 456 | window_size: List[int], 457 | stochastic_depth_prob: float, 458 | weights: Optional[WeightsEnum], 459 | progress: bool, 460 | **kwargs: Any, 461 | ) -> SwinTransformer: 462 | if weights is not None: 463 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 464 | 465 | model = SwinTransformer( 466 | patch_size=patch_size, 467 | embed_dim=embed_dim, 468 | depths=depths, 469 | num_heads=num_heads, 470 | window_size=window_size, 471 | stochastic_depth_prob=stochastic_depth_prob, 472 | **kwargs, 473 | ) 474 | 475 | if weights is not None: 476 | model.load_state_dict(weights.get_state_dict(progress=progress)) 477 | 478 | return model 479 | 480 | 481 | 482 | 483 | _COMMON_META = { 484 | "categories": _IMAGENET_CATEGORIES, 485 | } 486 | 487 | 488 | class Swin_T_Weights(WeightsEnum): 489 | IMAGENET1K_V1 = Weights( 490 | url="https://download.pytorch.org/models/swin_t-704ceda3.pth", 491 | transforms=partial( 492 | ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC 493 | ), 494 | meta={ 495 | **_COMMON_META, 496 | "num_params": 28288354, 497 | "min_size": (224, 224), 498 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", 499 | "_metrics": { 500 | "ImageNet-1K": { 501 | "acc@1": 81.474, 502 | "acc@5": 95.776, 503 | } 504 | }, 505 | "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", 506 | }, 507 | ) 508 | DEFAULT = IMAGENET1K_V1 509 | 510 | 511 | @register_model 512 | def swin_t(*, drop_path = 0.2, pretrained = False, progress: bool = True, **kwargs: Any) -> SwinTransformer: 513 | """ 514 | Constructs a swin_tiny architecture from 515 | `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. 516 | 517 | Args: 518 | weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The 519 | pretrained weights to use. See 520 | :class:`~torchvision.models.Swin_T_Weights` below for 521 | more details, and possible values. By default, no pre-trained 522 | weights are used. 523 | progress (bool, optional): If True, displays a progress bar of the 524 | download to stderr. Default is True. 525 | **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` 526 | base class. Please refer to the `source code 527 | `_ 528 | for more details about this class. 529 | 530 | .. autoclass:: torchvision.models.Swin_T_Weights 531 | :members: 532 | """ 533 | 534 | model = _swin_transformer( 535 | patch_size=[4, 4], 536 | embed_dim=96, 537 | depths=[2, 2, 6, 2], 538 | num_heads=[3, 6, 12, 24], 539 | window_size=[7, 7], 540 | stochastic_depth_prob=drop_path, 541 | weights=None, 542 | progress=progress, 543 | **kwargs, 544 | ) 545 | if pretrained and kwargs['num_classes'] == 1000: 546 | print("load pretrained") 547 | checkpoint = torch.hub.load_state_dict_from_url( 548 | url="https://download.pytorch.org/models/swin_t-704ceda3.pth", 549 | map_location="cpu", check_hash=True 550 | ) 551 | model.load_state_dict(checkpoint) 552 | 553 | return model 554 | 555 | 556 | 557 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .transformers import * 3 | -------------------------------------------------------------------------------- /src/utils/embedder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Embedder(nn.Module): 5 | def __init__(self, 6 | word_embedding_dim=300, 7 | vocab_size=100000, 8 | padding_idx=1, 9 | pretrained_weight=None, 10 | embed_freeze=False, 11 | *args, **kwargs): 12 | super(Embedder, self).__init__() 13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \ 14 | if pretrained_weight is not None else \ 15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) 16 | self.embeddings.weight.requires_grad = not embed_freeze 17 | 18 | def forward_mask(self, mask): 19 | bsz, seq_len = mask.shape 20 | new_mask = mask.view(bsz, seq_len, 1) 21 | new_mask = new_mask.sum(-1) 22 | new_mask = (new_mask > 0) 23 | return new_mask 24 | 25 | def forward(self, x, mask=None): 26 | embed = self.embeddings(x) 27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float() 28 | return embed, mask 29 | 30 | @staticmethod 31 | def init_weight(m): 32 | if isinstance(m, nn.Linear): 33 | nn.init.trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | else: 37 | nn.init.normal_(m.weight) 38 | -------------------------------------------------------------------------------- /src/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import logging 5 | 6 | _logger = logging.getLogger('train') 7 | 8 | 9 | def resize_pos_embed(posemb, posemb_new, num_tokens=1): 10 | # Copied from `timm` by Ross Wightman: 11 | # github.com/rwightman/pytorch-image-models 12 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 13 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 14 | ntok_new = posemb_new.shape[1] 15 | if num_tokens: 16 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 17 | ntok_new -= num_tokens 18 | else: 19 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 20 | gs_old = int(math.sqrt(len(posemb_grid))) 21 | gs_new = int(math.sqrt(ntok_new)) 22 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 23 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') 24 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 25 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 26 | return posemb 27 | 28 | 29 | def pe_check(model, state_dict, pe_key='classifier.positional_emb'): 30 | if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys(): 31 | if model.state_dict()[pe_key].shape != state_dict[pe_key].shape: 32 | state_dict[pe_key] = resize_pos_embed(state_dict[pe_key], 33 | model.state_dict()[pe_key], 34 | num_tokens=model.classifier.num_tokens) 35 | return state_dict 36 | 37 | 38 | def fc_check(model, state_dict, fc_key='classifier.fc'): 39 | for key in [f'{fc_key}.weight', f'{fc_key}.bias']: 40 | if key is not None and key in state_dict.keys() and key in model.state_dict().keys(): 41 | if model.state_dict()[key].shape != state_dict[key].shape: 42 | _logger.warning(f'Removing {key}, number of classes has changed.') 43 | state_dict[key] = model.state_dict()[key] 44 | return state_dict 45 | -------------------------------------------------------------------------------- /src/utils/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # Thanks to rwightman's timm package 2 | # github.com:rwightman/pytorch-image-models 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def drop_path(x, drop_prob: float = 0., training: bool = False): 9 | """ 10 | Obtained from: github.com:rwightman/pytorch-image-models 11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | """ 18 | if drop_prob == 0. or not training: 19 | return x 20 | keep_prob = 1 - drop_prob 21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 23 | random_tensor.floor_() # binarize 24 | output = x.div(keep_prob) * random_tensor 25 | return output 26 | 27 | 28 | class DropPath(nn.Module): 29 | """ 30 | Obtained from: github.com:rwightman/pytorch-image-models 31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 32 | """ 33 | 34 | def __init__(self, drop_prob=None): 35 | super(DropPath, self).__init__() 36 | self.drop_prob = drop_prob 37 | 38 | def forward(self, x): 39 | return drop_path(x, self.drop_prob, self.training) 40 | -------------------------------------------------------------------------------- /src/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Tokenizer(nn.Module): 7 | def __init__(self, 8 | kernel_size, stride, padding, 9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 10 | n_conv_layers=1, 11 | n_input_channels=3, 12 | n_output_channels=64, 13 | in_planes=64, 14 | activation=None, 15 | max_pool=True, 16 | conv_bias=False): 17 | super(Tokenizer, self).__init__() 18 | 19 | n_filter_list = [n_input_channels] + \ 20 | [in_planes for _ in range(n_conv_layers - 1)] + \ 21 | [n_output_channels] 22 | 23 | self.conv_layers = nn.Sequential( 24 | *[nn.Sequential( 25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1], 26 | kernel_size=(kernel_size, kernel_size), 27 | stride=(stride, stride), 28 | padding=(padding, padding), bias=conv_bias), 29 | nn.Identity() if activation is None else activation(), 30 | nn.MaxPool2d(kernel_size=pooling_kernel_size, 31 | stride=pooling_stride, 32 | padding=pooling_padding) if max_pool else nn.Identity() 33 | ) 34 | for i in range(n_conv_layers) 35 | ]) 36 | # print(self.conv_layers) 37 | self.flattener = nn.Flatten(2, 3) 38 | self.apply(self.init_weight) 39 | 40 | def sequence_length(self, n_channels=3, height=224, width=224): 41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] 42 | 43 | def forward(self, x): 44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1) 45 | 46 | @staticmethod 47 | def init_weight(m): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight) 50 | 51 | 52 | class TextTokenizer(nn.Module): 53 | def __init__(self, 54 | kernel_size, stride, padding, 55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1, 56 | embedding_dim=300, 57 | n_output_channels=128, 58 | activation=None, 59 | max_pool=True, 60 | *args, **kwargs): 61 | super(TextTokenizer, self).__init__() 62 | 63 | self.max_pool = max_pool 64 | self.conv_layers = nn.Sequential( 65 | nn.Conv2d(1, n_output_channels, 66 | kernel_size=(kernel_size, embedding_dim), 67 | stride=(stride, 1), 68 | padding=(padding, 0), bias=False), 69 | nn.Identity() if activation is None else activation(), 70 | nn.MaxPool2d( 71 | kernel_size=(pooling_kernel_size, 1), 72 | stride=(pooling_stride, 1), 73 | padding=(pooling_padding, 0) 74 | ) if max_pool else nn.Identity() 75 | ) 76 | 77 | self.apply(self.init_weight) 78 | 79 | def seq_len(self, seq_len=32, embed_dim=300): 80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] 81 | 82 | def forward_mask(self, mask): 83 | new_mask = mask.unsqueeze(1).float() 84 | cnn_weight = torch.ones( 85 | (1, 1, self.conv_layers[0].kernel_size[0]), 86 | device=mask.device, 87 | dtype=torch.float) 88 | new_mask = F.conv1d( 89 | new_mask, cnn_weight, None, 90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1) 91 | if self.max_pool: 92 | new_mask = F.max_pool1d( 93 | new_mask, self.conv_layers[2].kernel_size[0], 94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False) 95 | new_mask = new_mask.squeeze(1) 96 | new_mask = (new_mask > 0) 97 | return new_mask 98 | 99 | def forward(self, x, mask=None): 100 | x = x.unsqueeze(1) 101 | x = self.conv_layers(x) 102 | x = x.transpose(1, 3).squeeze(1) 103 | if mask is not None: 104 | mask = self.forward_mask(mask).unsqueeze(-1).float() 105 | x = x * mask 106 | return x, mask 107 | 108 | @staticmethod 109 | def init_weight(m): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight) 112 | -------------------------------------------------------------------------------- /src/utils/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init 5 | import torch.nn.functional as F 6 | from .stochastic_depth import DropPath 7 | 8 | 9 | class Attention(Module): 10 | """ 11 | Obtained from timm: github.com:rwightman/pytorch-image-models 12 | """ 13 | 14 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1, use_skip=False): 15 | super().__init__() 16 | self.dim = dim 17 | self.num_heads = num_heads 18 | self.use_skip = use_skip 19 | self.attention_dropout = attention_dropout 20 | self.projection_dropout = projection_dropout 21 | 22 | head_dim = dim // self.num_heads 23 | self.scale = head_dim ** -0.5 24 | 25 | self.qkv = Linear(dim, dim * 3, bias=False) 26 | self.attn_drop = Dropout(attention_dropout) 27 | self.proj = Linear(dim, dim) 28 | self.proj_drop = Dropout(projection_dropout) 29 | 30 | def forward(self, x): 31 | B, N, C = x.shape 32 | # qkv dimension: 3, B, num_heads, N, C // num_heads 33 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 34 | q, k, v = qkv[0], qkv[1], qkv[2] 35 | 36 | if self.use_skip: 37 | res = x.reshape( 38 | B, N, self.num_heads, C // self.num_heads 39 | ).permute(0, 2, 1, 3) 40 | q += res 41 | k += res 42 | v += res 43 | 44 | attn = (q @ k.transpose(-2, -1)) * self.scale 45 | attn = attn.softmax(dim=-1) 46 | attn = self.attn_drop(attn) 47 | 48 | 49 | 50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 51 | if not self.use_skip: 52 | x = self.proj(x) 53 | else: 54 | x = x + self.proj(x) 55 | x = self.proj_drop(x) 56 | return x 57 | 58 | 59 | class MaskedAttention(Module): 60 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): 61 | super().__init__() 62 | self.num_heads = num_heads 63 | head_dim = dim // self.num_heads 64 | self.scale = head_dim ** -0.5 65 | 66 | self.qkv = Linear(dim, dim * 3, bias=False) 67 | self.attn_drop = Dropout(attention_dropout) 68 | self.proj = Linear(dim, dim) 69 | self.proj_drop = Dropout(projection_dropout) 70 | 71 | def forward(self, x, mask=None): 72 | B, N, C = x.shape 73 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | q, k, v = qkv[0], qkv[1], qkv[2] 75 | 76 | attn = (q @ k.transpose(-2, -1)) * self.scale 77 | 78 | if mask is not None: 79 | mask_value = -torch.finfo(attn.dtype).max 80 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions' 81 | mask = mask[:, None, :] * mask[:, :, None] 82 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) 83 | attn.masked_fill_(~mask, mask_value) 84 | 85 | attn = attn.softmax(dim=-1) 86 | attn = self.attn_drop(attn) 87 | 88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | return x 92 | 93 | 94 | class TransformerEncoderLayer(Module): 95 | """ 96 | Inspired by torch.nn.TransformerEncoderLayer and timm. 97 | """ 98 | 99 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 100 | attention_dropout=0.1, drop_path_rate=0.1, 101 | use_layer_scale=False, use_skip=False, use_relu=False): 102 | super(TransformerEncoderLayer, self).__init__() 103 | self.pre_norm = LayerNorm(d_model) 104 | self.self_attn = Attention(dim=d_model, num_heads=nhead, 105 | attention_dropout=attention_dropout, projection_dropout=dropout, use_skip=use_skip) 106 | 107 | self.linear1 = Linear(d_model, dim_feedforward) 108 | self.dropout1 = Dropout(dropout) 109 | self.norm1 = LayerNorm(d_model) 110 | self.linear2 = Linear(dim_feedforward, d_model) 111 | self.dropout2 = Dropout(dropout) 112 | 113 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 114 | 115 | # self.activation = F.gelu if not use_relu else F.relu 116 | self.activation = nn.GELU() if not use_relu else nn.ReLU(inplace=True) 117 | 118 | init_values = 1e-4 119 | self.use_layer_scale = use_layer_scale 120 | 121 | if self.use_layer_scale: 122 | self.gamma_1 = torch.nn.Parameter(init_values * torch.ones(d_model), requires_grad=True) 123 | self.gamma_2 = torch.nn.Parameter(init_values * torch.ones(d_model), requires_grad=True) 124 | 125 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: 126 | src2 = self.self_attn(self.pre_norm(src)) 127 | src2 = self.gamma_1 * src2 if self.use_layer_scale else src2 128 | src = src + self.drop_path(src2) 129 | src = self.norm1(src) 130 | gelu_output = self.activation(self.linear1(src)) 131 | src2 = self.linear2(self.dropout1(gelu_output)) 132 | src2 = self.gamma_2 * src2 if self.use_layer_scale else src2 133 | src = src + self.drop_path(self.dropout2(src2)) 134 | return src 135 | 136 | 137 | class MaskedTransformerEncoderLayer(Module): 138 | """ 139 | Inspired by torch.nn.TransformerEncoderLayer and timm. 140 | """ 141 | 142 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 143 | attention_dropout=0.1, drop_path_rate=0.1): 144 | super(MaskedTransformerEncoderLayer, self).__init__() 145 | self.pre_norm = LayerNorm(d_model) 146 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead, 147 | attention_dropout=attention_dropout, projection_dropout=dropout) 148 | 149 | self.linear1 = Linear(d_model, dim_feedforward) 150 | self.dropout1 = Dropout(dropout) 151 | self.norm1 = LayerNorm(d_model) 152 | self.linear2 = Linear(dim_feedforward, d_model) 153 | self.dropout2 = Dropout(dropout) 154 | 155 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() 156 | 157 | self.activation = F.gelu 158 | 159 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor: 160 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask)) 161 | src = self.norm1(src) 162 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) 163 | src = src + self.drop_path(self.dropout2(src2)) 164 | return src 165 | 166 | 167 | class TransformerClassifier(Module): 168 | def __init__(self, 169 | seq_pool=True, 170 | embedding_dim=768, 171 | num_layers=12, 172 | num_heads=12, 173 | mlp_ratio=4.0, 174 | num_classes=1000, 175 | dropout=0.1, 176 | attention_dropout=0.1, 177 | stochastic_depth=0.1, 178 | positional_embedding='learnable', 179 | sequence_length=None, 180 | use_layer_scale=False, 181 | use_skip=False, 182 | use_relu=False): 183 | super().__init__() 184 | 185 | positional_embedding = positional_embedding if \ 186 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 187 | dim_feedforward = int(embedding_dim * mlp_ratio) 188 | self.embedding_dim = embedding_dim 189 | self.sequence_length = sequence_length 190 | self.seq_pool = seq_pool 191 | self.num_tokens = 0 192 | 193 | assert sequence_length is not None or positional_embedding == 'none', \ 194 | f"Positional embedding is set to {positional_embedding} and" \ 195 | f" the sequence length was not specified." 196 | 197 | if not seq_pool: 198 | sequence_length += 1 199 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 200 | requires_grad=True) 201 | self.num_tokens = 1 202 | else: 203 | self.attention_pool = Linear(self.embedding_dim, 1) 204 | 205 | if positional_embedding != 'none': 206 | if positional_embedding == 'learnable': 207 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), 208 | requires_grad=True) 209 | init.trunc_normal_(self.positional_emb, std=0.2) 210 | else: 211 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), 212 | requires_grad=False) 213 | else: 214 | self.positional_emb = None 215 | 216 | self.dropout = Dropout(p=dropout) 217 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] 218 | self.blocks = ModuleList([ 219 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 220 | dim_feedforward=dim_feedforward, dropout=dropout, 221 | attention_dropout=attention_dropout, 222 | drop_path_rate=dpr[i], use_layer_scale=use_layer_scale, 223 | use_skip=use_skip, use_relu=use_relu) 224 | for i in range(num_layers)]) 225 | self.norm = LayerNorm(embedding_dim) 226 | 227 | self.fc = Linear(embedding_dim, num_classes) 228 | self.apply(self.init_weight) 229 | 230 | def forward(self, x): 231 | if self.positional_emb is None and x.size(1) < self.sequence_length: 232 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 233 | 234 | features = [] 235 | 236 | if not self.seq_pool: 237 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 238 | x = torch.cat((cls_token, x), dim=1) 239 | 240 | if self.positional_emb is not None: 241 | x += self.positional_emb 242 | 243 | x = self.dropout(x) 244 | 245 | for blk in self.blocks: 246 | features.append(x) 247 | x = blk(x) 248 | x = self.norm(x) 249 | 250 | features.append(x) 251 | if self.seq_pool: 252 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 253 | else: 254 | x = x[:, 0] 255 | 256 | x = self.fc(x) 257 | return x, features 258 | 259 | @staticmethod 260 | def init_weight(m): 261 | if isinstance(m, Linear): 262 | init.trunc_normal_(m.weight, std=.02) 263 | if isinstance(m, Linear) and m.bias is not None: 264 | init.constant_(m.bias, 0) 265 | elif isinstance(m, LayerNorm): 266 | init.constant_(m.bias, 0) 267 | init.constant_(m.weight, 1.0) 268 | 269 | @staticmethod 270 | def sinusoidal_embedding(n_channels, dim): 271 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 272 | for p in range(n_channels)]) 273 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 274 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 275 | return pe.unsqueeze(0) 276 | 277 | 278 | class MaskedTransformerClassifier(Module): 279 | def __init__(self, 280 | seq_pool=True, 281 | embedding_dim=768, 282 | num_layers=12, 283 | num_heads=12, 284 | mlp_ratio=4.0, 285 | num_classes=1000, 286 | dropout=0.1, 287 | attention_dropout=0.1, 288 | stochastic_depth=0.1, 289 | positional_embedding='sine', 290 | seq_len=None, 291 | *args, **kwargs): 292 | super().__init__() 293 | positional_embedding = positional_embedding if \ 294 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine' 295 | dim_feedforward = int(embedding_dim * mlp_ratio) 296 | self.embedding_dim = embedding_dim 297 | self.seq_len = seq_len 298 | self.seq_pool = seq_pool 299 | self.num_tokens = 0 300 | 301 | assert seq_len is not None or positional_embedding == 'none', \ 302 | f"Positional embedding is set to {positional_embedding} and" \ 303 | f" the sequence length was not specified." 304 | 305 | if not seq_pool: 306 | seq_len += 1 307 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), 308 | requires_grad=True) 309 | self.num_tokens = 1 310 | else: 311 | self.attention_pool = Linear(self.embedding_dim, 1) 312 | 313 | if positional_embedding != 'none': 314 | if positional_embedding == 'learnable': 315 | seq_len += 1 # padding idx 316 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim), 317 | requires_grad=True) 318 | init.trunc_normal_(self.positional_emb, std=0.2) 319 | else: 320 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len, 321 | embedding_dim, 322 | padding_idx=True), 323 | requires_grad=False) 324 | else: 325 | self.positional_emb = None 326 | 327 | self.dropout = Dropout(p=dropout) 328 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] 329 | self.blocks = ModuleList([ 330 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, 331 | dim_feedforward=dim_feedforward, dropout=dropout, 332 | attention_dropout=attention_dropout, drop_path_rate=dpr[i]) 333 | for i in range(num_layers)]) 334 | self.norm = LayerNorm(embedding_dim) 335 | 336 | self.fc = Linear(embedding_dim, num_classes) 337 | self.apply(self.init_weight) 338 | 339 | def forward(self, x, mask=None): 340 | if self.positional_emb is None and x.size(1) < self.seq_len: 341 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) 342 | 343 | if not self.seq_pool: 344 | cls_token = self.class_emb.expand(x.shape[0], -1, -1) 345 | x = torch.cat((cls_token, x), dim=1) 346 | if mask is not None: 347 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1) 348 | mask = (mask > 0) 349 | 350 | if self.positional_emb is not None: 351 | x += self.positional_emb 352 | 353 | x = self.dropout(x) 354 | 355 | for blk in self.blocks: 356 | x = blk(x, mask=mask) 357 | x = self.norm(x) 358 | 359 | if self.seq_pool: 360 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) 361 | else: 362 | x = x[:, 0] 363 | 364 | x = self.fc(x) 365 | return x 366 | 367 | @staticmethod 368 | def init_weight(m): 369 | if isinstance(m, Linear): 370 | init.trunc_normal_(m.weight, std=.02) 371 | if isinstance(m, Linear) and m.bias is not None: 372 | init.constant_(m.bias, 0) 373 | elif isinstance(m, LayerNorm): 374 | init.constant_(m.bias, 0) 375 | init.constant_(m.weight, 1.0) 376 | 377 | @staticmethod 378 | def sinusoidal_embedding(n_channels, dim, padding_idx=False): 379 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] 380 | for p in range(n_channels)]) 381 | pe[:, 0::2] = torch.sin(pe[:, 0::2]) 382 | pe[:, 1::2] = torch.cos(pe[:, 1::2]) 383 | pe = pe.unsqueeze(0) 384 | if padding_idx: 385 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1) 386 | return pe 387 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BatchNorm(nn.modules.batchnorm._BatchNorm): 6 | def __init__( 7 | self, 8 | num_features, 9 | eps=1e-5, 10 | momentum=0.1, 11 | affine=True, 12 | track_running_stats=True, 13 | device=None, 14 | dtype=None, 15 | transpose=False, 16 | ): 17 | factory_kwargs = {'device': device, 'dtype': dtype} 18 | super().__init__( 19 | num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 20 | ) 21 | self.transpose = transpose 22 | 23 | def _check_input_dim(self, input): 24 | if input.dim() < 2: 25 | raise ValueError( 26 | "expected at least 2D input (got {}D input)".format(input.dim()) 27 | ) 28 | 29 | def forward(self, input: torch.Tensor) -> torch.Tensor: 30 | if self.transpose: 31 | dim = input.ndim 32 | input = input.transpose(dim-1, dim-2) 33 | output = super().forward(input) 34 | if self.transpose: 35 | return output.transpose(dim-1, dim-2) 36 | else: 37 | return output 38 | 39 | 40 | def build_bn_from_ln(ln, transpose=False): 41 | assert isinstance(ln.normalized_shape, int) or len(ln.normalized_shape) == 1 42 | num_features = ( 43 | ln.normalized_shape 44 | if isinstance(ln.normalized_shape, int) 45 | else ln.normalized_shape[0] 46 | ) 47 | return BatchNorm( 48 | num_features=num_features, 49 | affine=ln.elementwise_affine, 50 | device=ln.weight.device, 51 | transpose=transpose, 52 | ) 53 | 54 | 55 | def replace_module_by_module(module, m_to_replace, build_fn): 56 | module_output = module 57 | if isinstance(module, m_to_replace): 58 | module_output = build_fn(module) 59 | for name, child in module.named_children(): 60 | module_output.add_module( 61 | name, replace_module_by_module(child, m_to_replace, build_fn) 62 | ) 63 | del module 64 | return module_output 65 | 66 | 67 | def replace_ln_by_bn2d(module): 68 | return replace_module_by_module( 69 | module, 70 | nn.LayerNorm, 71 | lambda x: build_bn_from_ln(ln=x, tranpose=False), 72 | ) 73 | 74 | 75 | def replace_ln_by_bn1d(module): 76 | return replace_module_by_module( 77 | module, 78 | nn.LayerNorm, 79 | lambda x: build_bn_from_ln(ln=x, transpose=True), 80 | ) 81 | 82 | -------------------------------------------------------------------------------- /timm_fix_imagenet_loading_bugs/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """ Dataset Factory 2 | 3 | Hacked together by / Copyright 2021, Ross Wightman 4 | """ 5 | import os 6 | 7 | from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder 8 | try: 9 | from torchvision.datasets import Places365 10 | has_places365 = True 11 | except ImportError: 12 | has_places365 = False 13 | try: 14 | from torchvision.datasets import INaturalist 15 | has_inaturalist = True 16 | except ImportError: 17 | has_inaturalist = False 18 | 19 | from .dataset import IterableImageDataset, ImageDataset 20 | 21 | _TORCH_BASIC_DS = dict( 22 | cifar10=CIFAR10, 23 | cifar100=CIFAR100, 24 | mnist=MNIST, 25 | qmist=QMNIST, 26 | kmnist=KMNIST, 27 | fashion_mnist=FashionMNIST, 28 | ) 29 | _TRAIN_SYNONYM = {'train', 'training'} 30 | _EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} 31 | 32 | 33 | def _search_split(root, split): 34 | # look for sub-folder with name of split in root and use that if it exists 35 | split_name = split.split('[')[0] 36 | try_root = os.path.join(root, split_name) 37 | if os.path.exists(try_root): 38 | return try_root 39 | 40 | def _try(syn): 41 | for s in syn: 42 | try_root = os.path.join(root, s) 43 | if os.path.exists(try_root): 44 | return try_root 45 | return root 46 | if split_name in _TRAIN_SYNONYM: 47 | root = _try(_TRAIN_SYNONYM) 48 | elif split_name in _EVAL_SYNONYM: 49 | root = _try(_EVAL_SYNONYM) 50 | return root 51 | 52 | 53 | def create_dataset( 54 | name, 55 | root, 56 | split='validation', 57 | search_split=True, 58 | class_map=None, 59 | load_bytes=False, 60 | is_training=False, 61 | download=False, 62 | batch_size=None, 63 | repeats=0, 64 | **kwargs 65 | ): 66 | """ Dataset factory method 67 | 68 | In parenthesis after each arg are the type of dataset supported for each arg, one of: 69 | * folder - default, timm folder (or tar) based ImageDataset 70 | * torch - torchvision based datasets 71 | * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset 72 | * all - any of the above 73 | 74 | Args: 75 | name: dataset name, empty is okay for folder based datasets 76 | root: root folder of dataset (all) 77 | split: dataset split (all) 78 | search_split: search for split specific child fold from root so one can specify 79 | `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) 80 | class_map: specify class -> index mapping via text file or dict (folder) 81 | load_bytes: load data, return images as undecoded bytes (folder) 82 | download: download dataset if not present and supported (TFDS, torch) 83 | is_training: create dataset in train mode, this is different from the split. 84 | For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) 85 | batch_size: batch size hint for (TFDS) 86 | repeats: dataset repeats per iteration i.e. epoch (TFDS) 87 | **kwargs: other args to pass to dataset 88 | 89 | Returns: 90 | Dataset object 91 | """ 92 | name = name.lower() 93 | if name.startswith('torch/'): 94 | name = name.split('/', 2)[-1] 95 | torch_kwargs = dict(root=root, download=download, **kwargs) 96 | if name in _TORCH_BASIC_DS: 97 | ds_class = _TORCH_BASIC_DS[name] 98 | use_train = split in _TRAIN_SYNONYM 99 | ds = ds_class(train=use_train, **torch_kwargs) 100 | elif name == 'inaturalist' or name == 'inat': 101 | assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' 102 | target_type = 'full' 103 | split_split = split.split('/') 104 | if len(split_split) > 1: 105 | target_type = split_split[0].split('_') 106 | if len(target_type) == 1: 107 | target_type = target_type[0] 108 | split = split_split[-1] 109 | if split in _TRAIN_SYNONYM: 110 | split = '2021_train' 111 | elif split in _EVAL_SYNONYM: 112 | split = '2021_valid' 113 | ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) 114 | elif name == 'places365': 115 | assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' 116 | if split in _TRAIN_SYNONYM: 117 | split = 'train-standard' 118 | elif split in _EVAL_SYNONYM: 119 | split = 'val' 120 | ds = Places365(split=split, **torch_kwargs) 121 | elif name == 'imagenet': 122 | if split in _EVAL_SYNONYM: 123 | split = 'val' 124 | print('imagenet torch') 125 | torch_kwargs = dict(root=root, **kwargs) 126 | ds = ImageNet(split=split, **torch_kwargs) 127 | elif name == 'image_folder' or name == 'folder': 128 | # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason 129 | if search_split and os.path.isdir(root): 130 | # look for split specific sub-folder in root 131 | root = _search_split(root, split) 132 | ds = ImageFolder(root, **kwargs) 133 | else: 134 | assert False, f"Unknown torchvision dataset {name}" 135 | elif name.startswith('tfds/'): 136 | ds = IterableImageDataset( 137 | root, parser=name, split=split, is_training=is_training, 138 | download=download, batch_size=batch_size, repeats=repeats, **kwargs) 139 | else: 140 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future 141 | if search_split and os.path.isdir(root): 142 | # look for split specific sub-folder in root 143 | root = _search_split(root, split) 144 | ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) 145 | return ds 146 | -------------------------------------------------------------------------------- /train_scripts/deit_s/w2a2_deit_s.sh: -------------------------------------------------------------------------------- 1 | ## traing of 2-bit DeiT-S 2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 140 \ 7 | --weight-decay 0.05 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 5.47e-4 \ 10 | --warmup-epochs 5 \ 11 | --mixup 0.0 --cutmix 0.0 \ 12 | --aq-enable \ 13 | --aq-mode lsq \ 14 | --aq-per-channel \ 15 | --aq_clip_learnable \ 16 | --aq-bitw 2 \ 17 | --wq-enable \ 18 | --wq-per-channel \ 19 | --wq-bitw 2 \ 20 | --wq-mode statsq \ 21 | --model_type deit \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher deit_small_distilled_patch16_224 \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w2a2_deit_s_qkreparam/ \ 31 | --visible_gpu '4,5,6,7' \ 32 | --world_size '4' \ 33 | --tcp_port '36969' 34 | 35 | ## Finetune 2-bit DeiT-S with CGA 36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 140 \ 41 | --weight-decay 0.05 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 5.47e-4 \ 44 | --warmup-epochs 5 \ 45 | --mixup 0.0 --cutmix 0.0 \ 46 | --aq-enable \ 47 | --aq-mode lsq \ 48 | --aq-per-channel \ 49 | --aq_clip_learnable \ 50 | --aq-bitw 2 \ 51 | --wq-enable \ 52 | --wq-per-channel \ 53 | --wq-bitw 2 \ 54 | --wq-mode statsq \ 55 | --model_type deit \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher deit_small_distilled_patch16_224 \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w2a2_deit_s_qkreparam_cga_0005/ \ 68 | --visible_gpu '4,5,6,7' \ 69 | --world_size '4' \ 70 | --tcp_port '36969' 71 | -------------------------------------------------------------------------------- /train_scripts/deit_s/w3a3_deit_s.sh: -------------------------------------------------------------------------------- 1 | ## traing of 3-bit DeiT-S 2 | python3 train.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 100 \ 7 | --weight-decay 0.0 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 3.2e-4 \ 10 | --warmup-epochs 0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 3 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 3 \ 19 | --wq-mode statsq \ 20 | --model_type deit \ 21 | --quantized \ 22 | --pretrained \ 23 | --pretrained_initialized \ 24 | --use-kd --teacher deit_small_distilled_patch16_224 \ 25 | --kd_hard_and_soft 1 \ 26 | --qk_reparam \ 27 | --qk_reparam_type 0 \ 28 | --teacher_pretrained \ 29 | --output ./outputs/w3a3_deit_s_qkreparam/ \ 30 | --visible_gpu '0,1,2,3,4,5,6,7' \ 31 | --world_size '8' \ 32 | --tcp_port '36969' 33 | 34 | ## Finetune 3-bit DeiT-S with CGA 35 | python3 cga.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 36 | your_path/dataset/imagenet-1k/imagenet \ 37 | --dataset 'torch/imagenet' \ 38 | --epochs 300 \ 39 | --batch-size 100 \ 40 | --weight-decay 0.0 \ 41 | --warmup-lr 1.0e-6 \ 42 | --lr 3.2e-4 \ 43 | --warmup-epochs 0 \ 44 | --aq-enable \ 45 | --aq-mode lsq \ 46 | --aq-per-channel \ 47 | --aq_clip_learnable \ 48 | --aq-bitw 3 \ 49 | --wq-enable \ 50 | --wq-per-channel \ 51 | --wq-bitw 3 \ 52 | --wq-mode statsq \ 53 | --model_type deit \ 54 | --quantized \ 55 | --pretrained \ 56 | --pretrained_initialized \ 57 | --use-kd --teacher deit_small_distilled_patch16_224 \ 58 | --kd_hard_and_soft 1 \ 59 | --qk_reparam \ 60 | --qk_reparam_type 1 \ 61 | --boundaryRange 0.005 \ 62 | --freeze_for_n_epochs 30 \ 63 | --teacher_pretrained \ 64 | --resume put the model you wish to finetune here \ 65 | --output ./outputs/w3a3_deit_s_qkreparam_cga_0005/ \ 66 | --visible_gpu '0,1,2,3,4,5,6,7' \ 67 | --world_size '8' \ 68 | --tcp_port '36969' 69 | 70 | 71 | -------------------------------------------------------------------------------- /train_scripts/deit_s/w4a4_deit_s.sh: -------------------------------------------------------------------------------- 1 | ## traing of 4-bit DeiT-S 2 | python3 train.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 100 \ 7 | --weight-decay 0.0 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 3.2e-4 \ 10 | --warmup-epochs 0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 4 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 4 \ 19 | --wq-mode statsq \ 20 | --model_type deit \ 21 | --quantized \ 22 | --pretrained \ 23 | --pretrained_initialized \ 24 | --use-kd --teacher deit_small_distilled_patch16_224 \ 25 | --kd_hard_and_soft 1 \ 26 | --qk_reparam \ 27 | --qk_reparam_type 0 \ 28 | --teacher_pretrained \ 29 | --output ./outputs/w4a4_deit_s_qkreparam/ \ 30 | --visible_gpu '0,1,2,3,4,5,6,7' \ 31 | --world_size '8' \ 32 | --tcp_port '36969' 33 | 34 | ## Finetune 4-bit DeiT-S with CGA 35 | python3 cga.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \ 36 | your_path/dataset/imagenet-1k/imagenet \ 37 | --dataset 'torch/imagenet' \ 38 | --epochs 300 \ 39 | --batch-size 100 \ 40 | --weight-decay 0.0 \ 41 | --warmup-lr 1.0e-6 \ 42 | --lr 3.2e-4 \ 43 | --warmup-epochs 0 \ 44 | --aq-enable \ 45 | --aq-mode lsq \ 46 | --aq-per-channel \ 47 | --aq_clip_learnable \ 48 | --aq-bitw 4 \ 49 | --wq-enable \ 50 | --wq-per-channel \ 51 | --wq-bitw 4 \ 52 | --wq-mode statsq \ 53 | --model_type deit \ 54 | --quantized \ 55 | --pretrained \ 56 | --pretrained_initialized \ 57 | --use-kd --teacher deit_small_distilled_patch16_224 \ 58 | --kd_hard_and_soft 1 \ 59 | --qk_reparam \ 60 | --qk_reparam_type 1 \ 61 | --boundaryRange 0.005 \ 62 | --freeze_for_n_epochs 30 \ 63 | --teacher_pretrained \ 64 | --resume put the model you wish to finetune here \ 65 | --output ./outputs/w4a4_deit_s_qkreparam_cga_0005/ \ 66 | --visible_gpu '0,1,2,3,4,5,6,7' \ 67 | --world_size '8' \ 68 | --tcp_port '36969' 69 | 70 | 71 | -------------------------------------------------------------------------------- /train_scripts/deit_t/w2a2_deit_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 2-bit DeiT-T 2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 140 \ 7 | --weight-decay 0.05 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 5.47e-4 \ 10 | --warmup-epochs 5 \ 11 | --mixup 0.0 --cutmix 0.0 \ 12 | --aq-enable \ 13 | --aq-mode lsq \ 14 | --aq-per-channel \ 15 | --aq_clip_learnable \ 16 | --aq-bitw 2 \ 17 | --wq-enable \ 18 | --wq-per-channel \ 19 | --wq-bitw 2 \ 20 | --wq-mode statsq \ 21 | --model_type deit \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w2a2_deit_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '36969' 34 | 35 | ## Finetune 2-bit DeiT-T with CGA your_path/model_saved/w2a2_deit_t_lsq_auto_8_bits_fl_qkreparam_correct/model_best_modified.pth.tar 36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 140 \ 41 | --weight-decay 0.05 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 5.47e-4 \ 44 | --warmup-epochs 5 \ 45 | --mixup 0.0 --cutmix 0.0 \ 46 | --aq-enable \ 47 | --aq-mode lsq \ 48 | --aq-per-channel \ 49 | --aq_clip_learnable \ 50 | --aq-bitw 2 \ 51 | --wq-enable \ 52 | --wq-per-channel \ 53 | --wq-bitw 2 \ 54 | --wq-mode statsq \ 55 | --model_type deit \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w2a2_deit_t_qkreparam_cga_0005/ \ 68 | --visible_gpu '4,5,6,7' \ 69 | --world_size '4' \ 70 | --tcp_port '36969' 71 | 72 | -------------------------------------------------------------------------------- /train_scripts/deit_t/w3a3_deit_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 3-bit DeiT-T 2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 140 \ 7 | --weight-decay 0.05 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 5.47e-4 \ 10 | --warmup-epochs 5 \ 11 | --mixup 0.0 --cutmix 0.0 \ 12 | --aq-enable \ 13 | --aq-mode lsq \ 14 | --aq-per-channel \ 15 | --aq_clip_learnable \ 16 | --aq-bitw 3 \ 17 | --wq-enable \ 18 | --wq-per-channel \ 19 | --wq-bitw 3 \ 20 | --wq-mode statsq \ 21 | --model_type deit \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w3a3_deit_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '36969' 34 | 35 | ## Finetune 3-bit DeiT-T with CGA 36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 140 \ 41 | --weight-decay 0.05 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 5.47e-4 \ 44 | --warmup-epochs 5 \ 45 | --mixup 0.0 --cutmix 0.0 \ 46 | --aq-enable \ 47 | --aq-mode lsq \ 48 | --aq-per-channel \ 49 | --aq_clip_learnable \ 50 | --aq-bitw 3 \ 51 | --wq-enable \ 52 | --wq-per-channel \ 53 | --wq-bitw 3 \ 54 | --wq-mode statsq \ 55 | --model_type deit \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w3a3_deit_t_qkreparam_cga_0005/ \ 68 | --visible_gpu '4,5,6,7' \ 69 | --world_size '4' \ 70 | --tcp_port '36969' 71 | 72 | -------------------------------------------------------------------------------- /train_scripts/deit_t/w4a4_deit_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 4-bit DeiT-T 2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 140 \ 7 | --weight-decay 0.05 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 5.47e-4 \ 10 | --warmup-epochs 5 \ 11 | --mixup 0.0 --cutmix 0.0 \ 12 | --aq-enable \ 13 | --aq-mode lsq \ 14 | --aq-per-channel \ 15 | --aq_clip_learnable \ 16 | --aq-bitw 4 \ 17 | --wq-enable \ 18 | --wq-per-channel \ 19 | --wq-bitw 4 \ 20 | --wq-mode statsq \ 21 | --model_type deit \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w4a4_deit_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,4' \ 32 | --world_size '4' \ 33 | --tcp_port '36969' 34 | 35 | ## Finetune 4-bit DeiT-T with CGA 36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 140 \ 41 | --weight-decay 0.05 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 5.47e-4 \ 44 | --warmup-epochs 5 \ 45 | --mixup 0.0 --cutmix 0.0 \ 46 | --aq-enable \ 47 | --aq-mode lsq \ 48 | --aq-per-channel \ 49 | --aq_clip_learnable \ 50 | --aq-bitw 4 \ 51 | --wq-enable \ 52 | --wq-per-channel \ 53 | --wq-bitw 4 \ 54 | --wq-mode statsq \ 55 | --model_type deit \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w4a4_deit_t_qkreparam_cga_0005/ \ 68 | --visible_gpu '4,5,6,7' \ 69 | --world_size '4' \ 70 | --tcp_port '36969' 71 | 72 | -------------------------------------------------------------------------------- /train_scripts/swin_t/w2a2_swin_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 2-bit Swin-T 2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 64 \ 7 | --weight-decay 0.05 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 5.0e-4 \ 10 | --warmup-epochs 5 \ 11 | --mixup 0.0 --cutmix 0.0 \ 12 | --aq-enable \ 13 | --aq-mode lsq \ 14 | --aq-per-channel \ 15 | --aq_clip_learnable \ 16 | --aq-bitw 2 \ 17 | --wq-enable \ 18 | --wq-per-channel \ 19 | --wq-bitw 2 \ 20 | --wq-mode statsq \ 21 | --model_type swin \ 22 | --teacher_type swin \ 23 | --quantized \ 24 | --pretrained \ 25 | --pretrained_initialized \ 26 | --use-kd --teacher swin_t \ 27 | --kd_hard_and_soft 1 \ 28 | --qk_reparam \ 29 | --qk_reparam_type 0 \ 30 | --teacher_pretrained \ 31 | --output ./outputs/w2a2_swin_t_qkreparam/ \ 32 | --visible_gpu '0,1,2,3,4,5,6,7' \ 33 | --world_size '8' \ 34 | --tcp_port '12345' 35 | 36 | ## Fine tune with CGA 37 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 38 | your_path/dataset/imagenet-1k/imagenet \ 39 | --dataset 'torch/imagenet' \ 40 | --epochs 300 \ 41 | --batch-size 64 \ 42 | --weight-decay 0.05 \ 43 | --warmup-lr 1.0e-6 \ 44 | --lr 5.0e-4 \ 45 | --warmup-epochs 5 \ 46 | --mixup 0.0 --cutmix 0.0 \ 47 | --aq-enable \ 48 | --aq-mode lsq \ 49 | --aq-per-channel \ 50 | --aq_clip_learnable \ 51 | --aq-bitw 2 \ 52 | --wq-enable \ 53 | --wq-per-channel \ 54 | --wq-bitw 2 \ 55 | --wq-mode statsq \ 56 | --model_type swin \ 57 | --teacher_type swin \ 58 | --quantized \ 59 | --pretrained \ 60 | --pretrained_initialized \ 61 | --use-kd --teacher swin_t \ 62 | --kd_hard_and_soft 1 \ 63 | --qk_reparam \ 64 | --qk_reparam_type 1 \ 65 | --boundaryRange 0.005 \ 66 | --freeze_for_n_epochs 30 \ 67 | --teacher_pretrained \ 68 | --resume put the model you wish to finetune here \ 69 | --output ./outputs/w2a2_swin_t_qkreparam/ \ 70 | --visible_gpu '0,1,2,3,4,5,6,7' \ 71 | --world_size '8' \ 72 | --tcp_port '12345' -------------------------------------------------------------------------------- /train_scripts/swin_t/w3a3_swin_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 3-bit Swin-T 2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 64 \ 7 | --weight-decay 0.0 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 2.0e-4 \ 10 | --warmup-epochs 0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 3 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 3 \ 19 | --wq-mode statsq \ 20 | --model_type swin \ 21 | --teacher_type swin \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher swin_t \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w3a3_swin_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,3,4,5,6,7' \ 32 | --world_size '8' \ 33 | --tcp_port '12345' 34 | 35 | ## Fine tune with CGA 36 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 64 \ 41 | --weight-decay 0.0 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 2.0e-4 \ 44 | --warmup-epochs 0 \ 45 | --aq-enable \ 46 | --aq-mode lsq \ 47 | --aq-per-channel \ 48 | --aq_clip_learnable \ 49 | --aq-bitw 3 \ 50 | --wq-enable \ 51 | --wq-per-channel \ 52 | --wq-bitw 3 \ 53 | --wq-mode statsq \ 54 | --model_type swin \ 55 | --teacher_type swin \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher swin_t \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w3a3_swin_t_qkreparam/ \ 68 | --visible_gpu '0,1,2,3,4,5,6,7' \ 69 | --world_size '8' \ 70 | --tcp_port '12345' -------------------------------------------------------------------------------- /train_scripts/swin_t/w4a4_swin_t.sh: -------------------------------------------------------------------------------- 1 | ## traing of 4-bit Swin-T 2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 3 | your_path/dataset/imagenet-1k/imagenet \ 4 | --dataset 'torch/imagenet' \ 5 | --epochs 300 \ 6 | --batch-size 64 \ 7 | --weight-decay 0.0 \ 8 | --warmup-lr 1.0e-6 \ 9 | --lr 2.0e-4 \ 10 | --warmup-epochs 0 \ 11 | --aq-enable \ 12 | --aq-mode lsq \ 13 | --aq-per-channel \ 14 | --aq_clip_learnable \ 15 | --aq-bitw 4 \ 16 | --wq-enable \ 17 | --wq-per-channel \ 18 | --wq-bitw 4 \ 19 | --wq-mode statsq \ 20 | --model_type swin \ 21 | --teacher_type swin \ 22 | --quantized \ 23 | --pretrained \ 24 | --pretrained_initialized \ 25 | --use-kd --teacher swin_t \ 26 | --kd_hard_and_soft 1 \ 27 | --qk_reparam \ 28 | --qk_reparam_type 0 \ 29 | --teacher_pretrained \ 30 | --output ./outputs/w4a4_swin_t_qkreparam/ \ 31 | --visible_gpu '0,1,2,3,4,5,6,7' \ 32 | --world_size '8' \ 33 | --tcp_port '12345' 34 | 35 | ## Fine tune with CGA 36 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \ 37 | your_path/dataset/imagenet-1k/imagenet \ 38 | --dataset 'torch/imagenet' \ 39 | --epochs 300 \ 40 | --batch-size 64 \ 41 | --weight-decay 0.0 \ 42 | --warmup-lr 1.0e-6 \ 43 | --lr 2.0e-4 \ 44 | --warmup-epochs 0 \ 45 | --aq-enable \ 46 | --aq-mode lsq \ 47 | --aq-per-channel \ 48 | --aq_clip_learnable \ 49 | --aq-bitw 4 \ 50 | --wq-enable \ 51 | --wq-per-channel \ 52 | --wq-bitw 4 \ 53 | --wq-mode statsq \ 54 | --model_type swin \ 55 | --teacher_type swin \ 56 | --quantized \ 57 | --pretrained \ 58 | --pretrained_initialized \ 59 | --use-kd --teacher swin_t \ 60 | --kd_hard_and_soft 1 \ 61 | --qk_reparam \ 62 | --qk_reparam_type 1 \ 63 | --boundaryRange 0.005 \ 64 | --freeze_for_n_epochs 30 \ 65 | --teacher_pretrained \ 66 | --resume put the model you wish to finetune here \ 67 | --output ./outputs/w4a4_swin_t_qkreparam/ \ 68 | --visible_gpu '0,1,2,3,4,5,6,7' \ 69 | --world_size '8' \ 70 | --tcp_port '12345' --------------------------------------------------------------------------------