├── LICENSE
├── README.md
├── cga.py
├── configs
├── deit_default_imagent.attn_q.yml
├── ours_imagenet_recipe.attn_q.yml
└── swin_t_imagenet.attn_q.yml
├── eval.py
├── eval_scripts
├── deit_s
│ ├── w2a2.sh
│ ├── w3a3.sh
│ └── w4a4.sh
├── deit_t
│ ├── w2a2.sh
│ ├── w3a3.sh
│ └── w4a4.sh
└── swin_t
│ ├── w2a2.sh
│ ├── w3a3.sh
│ └── w4a4.sh
├── imgs
├── ofq_vit.png
└── w2a2_deit_s_cga_vis_.png
├── src
├── __init__.py
├── deit.py
├── deit_vision_transformer.py
├── quantization
│ ├── __init__.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── qbias.py
│ │ ├── qlinear.py
│ │ ├── swin_attention_and_mlp.py
│ │ └── utils.py
│ ├── quantizer
│ │ ├── __init__.py
│ │ ├── lsq.py
│ │ └── statsq.py
│ └── utils.py
├── registry.py
├── swin.py
└── utils
│ ├── __init__.py
│ ├── embedder.py
│ ├── helpers.py
│ ├── stochastic_depth.py
│ ├── tokenizer.py
│ ├── transformers.py
│ └── utils.py
├── timm_fix_imagenet_loading_bugs
└── dataset_factory.py
├── train.py
└── train_scripts
├── deit_s
├── w2a2_deit_s.sh
├── w3a3_deit_s.sh
└── w4a4_deit_s.sh
├── deit_t
├── w2a2_deit_t.sh
├── w3a3_deit_t.sh
└── w4a4_deit_t.sh
└── swin_t
├── w2a2_swin_t.sh
├── w3a3_swin_t.sh
└── w4a4_swin_t.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 LIU, Shih-yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OFQ: Oscillation-free Quantization for Low-bit Vision Transformers
2 |
3 | This repository contains the training code of ViT introduced in our work: "[Oscillation-free Quantization for Low-bit Vision Transformers](https://arxiv.org/abs/2302.02210)" which has been accepted for ICML 2023. Please consider starring the repo if you find our work useful, thanks!
4 |
5 | In this work, we discusses the issue of weight oscillation in quantization-aware training and how it negatively affects model performance. The learnable scaling factor, commonly used in quantization, was found to worsen weight oscillation. The study proposes three techniques to address this issue: statistical weight quantization (StatsQ), confidence-guided annealing (CGA), and query-key reparameterization (QKR). These techniques were tested on the ViT model and were found to improve quantization robustness and accuracy. The proposed 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively.
6 |
7 |
8 |
9 |

10 |
Fig.1 - Trajectory of statistical scaling factors (StatsQ) from the 10th transformer blocks in a 2-bit DeiT-S throughout CGA with 4 different boundary ranges [BR_0.003, BR_0.005, BR_0.007, BR_0.01]. The y-axis represents the value of scaling factorsß
11 |
12 |
13 | ## Run
14 |
15 | ### 1. Requirements:
16 | * numpy==1.22.3
17 | * torch==2.0.0
18 | * torchvision==0.15.1
19 | * timm=0.5.4
20 | * pyyaml
21 |
22 | Please replace "/your/miniconda3/envs/ofq/lib/python3.8/site-packages/timm/data/dataset_factory.py" with "timm_fix_imagenet_loading_bugs/dataset_factory.py" as with the original code there is a "TypeError: __init__() got an unexpected keyword argument 'download'" error.
23 |
24 | ### 2. Data:
25 | * Download [ILSVRC12 ImageNet classification dataset](https://www.image-net.org/download.php)
26 |
27 | ### 3. Pretrained models:
28 | * Pretrained models will be automatically downloaded for you if set args.pretrained to True.
29 |
30 | ### 4. Steps to run:
31 | * Examples of training scripts, finetuning scripts (CGA) are provided under "train_scripts/" and evaluation scripts are under "eval_scripts/" (please use the exact same batch size (batch_size * world_size) as provided in the evaluation scripts to reproduce the results reported in the paper).
32 |
33 | * Please modified the data path to your own dataset address
34 |
35 | ## Models
36 | ### 1. ImageNet1K dataset
37 |
38 | | Models | #Bits | Top-1 Accuracy (Model Link)| eval script |
39 | | --- | --- | --- | ------- |
40 | | DeiT-T | 32-32 | 72.02 | ------- |
41 | | OFQ DeiT-T | 2-2 | [**64.33**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ETacXA4nZxtBuJ2RSVO6GToBqe6M8vnL0hXo75msagLKDw?e=RdConK)| eval_scripts/deit_t/w2a2.sh |
42 | | OFQ DeiT-T | 3-3 | [**72.72**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQB7riAo9y1GpX5Tr3bChqIBYetCcIHfR2xdxJACqvAuLw?e=VeHOm6) | eval_scripts/deit_t/w3a3.sh |
43 | | OFQ DeiT-T | 4-4 | [**75.46**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/Ebo0I7E5_4xMqfThmh3zpgkB77BWrWSyx5NzJqm-mDsGVg?e=23S3Bz) | eval_scripts/deit_t/w4a4.sh |
44 | | DeiT-S | 32-32 | 79.9 | ------- |
45 | | OFQ DeiT-S | 2-2 | [**75.72**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQEHW7tDmYpNjULhtUFpf2cB9EfIQrsN2LPF846TVNe2cg?e=eBqu6S) | eval_scripts/deit_s/w2a2.sh |
46 | | OFQ DeiT-S | 3-3 | [**79.57**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ESMKhGXDKT1HjFzCKrSYOvgB9SfXQ5u_46HnIt0LIMS5MQ?e=4cBdYU) | eval_scripts/deit_s/w3a3.sh |
47 | | OFQ DeiT-S | 4-4 | [**81.10**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EQ-KvPjXiI5GiZB-r03hcm8B_Q0M0cZ4b0Rj59RWiQ8ZtA?e=Lk4GjR) | eval_scripts/deit_s/w4a4.sh |
48 | | Swin-T | 32-32 | 81.2 | ------- |
49 | | OFQ Swin-T | 2-2 | [**78.52**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EWl2BPKk2qpCv0uZ1VGYa2QBGog9ZmpckGco-A1aNLrxhA?e=Eh2X1X) | eval_scripts/swin_t/w2a2.sh |
50 | | OFQ Swin-T | 3-3 | [**81.09**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/ETnNiYGeG9JOnLgJAIzbXFIBlaovFZ-RoZkC95-FrLP9yA?e=gB77Vv) | eval_scripts/swin_t/w3a3.sh |
51 | | OFQ Swin-T | 4-4 | [**81.88**](https://hkustconnect-my.sharepoint.com/:u:/g/personal/sliuau_connect_ust_hk/EX9WmBLUQsBNqWscdm1cIcQBWe6Yt5EARN0xoD12IOrE-g?e=cmuFle) | eval_scripts/swin_t/w4a4.sh |
52 |
53 | ## Acknowledgement
54 |
55 | The original code is borrowed from [DeiT](https://github.com/facebookresearch/deit).
56 |
57 | ## Citation
58 |
59 | If you find our code useful for your research, please consider citing:
60 | ```bibtex
61 | @InProceedings{pmlr-v202-liu23w,
62 | title = {Oscillation-free Quantization for Low-bit Vision Transformers},
63 | author = {Liu, Shih-Yang and Liu, Zechun and Cheng, Kwang-Ting},
64 | booktitle = {Proceedings of the 40th International Conference on Machine Learning},
65 | pages = {21813--21824},
66 | year = {2023},
67 | editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
68 | volume = {202},
69 | series = {Proceedings of Machine Learning Research},
70 | month = {23--29 Jul},
71 | publisher = {PMLR},
72 | pdf = {https://proceedings.mlr.press/v202/liu23w/liu23w.pdf},
73 | url = {https://proceedings.mlr.press/v202/liu23w.html},
74 | abstract = {Weight oscillation is a by-product of quantization-aware training, in which quantized weights frequently jump between two quantized levels, resulting in training instability and a sub-optimal final model. We discover that the learnable scaling factor, a widely-used $\textit{de facto}$ setting in quantization aggravates weight oscillation. In this work, we investigate the connection between learnable scaling factor and quantized weight oscillation using ViT, and we additionally find that the interdependence between quantized weights in $\textit{query}$ and $\textit{key}$ of a self-attention layer also makes ViT vulnerable to oscillation. We propose three techniques correspondingly: statistical weight quantization ($\rm StatsQ$) to improve quantization robustness compared to the prevalent learnable-scale-based method; confidence-guided annealing ($\rm CGA$) that freezes the weights with $\textit{high confidence}$ and calms the oscillating weights; and $\textit{query}$-$\textit{key}$ reparameterization ($\rm QKR$) to resolve the query-key intertwined oscillation and mitigate the resulting gradient misestimation. Extensive experiments demonstrate that our algorithms successfully abate weight oscillation and consistently achieve substantial accuracy improvement on ImageNet. Specifically, our 2-bit DeiT-T/DeiT-S surpass the previous state-of-the-art by 9.8% and 7.7%, respectively. The code is included in the supplementary material and will be released.}
75 | }
76 | ```
77 |
78 | ## Contact
79 |
80 | Shih-Yang Liu, HKUST (sliuau at connect.ust.hk)
81 |
82 |
83 |
--------------------------------------------------------------------------------
/configs/deit_default_imagent.attn_q.yml:
--------------------------------------------------------------------------------
1 | dataset: imagenet
2 | num_classes: 1000
3 | img_size: 224
4 | mean:
5 | - 0.485
6 | - 0.456
7 | - 0.406
8 | std:
9 | - 0.229
10 | - 0.224
11 | - 0.225
12 | crop_pct: 0.9
13 | scale:
14 | - 0.08
15 | - 1.0
16 | interpolation: bicubic
17 | train_interpolation: random
18 | aa: rand-m9-mstd0.5-inc1
19 | mixup: 0.8
20 | mixup_off_epoch: 0
21 | mixup_prob: 1.0
22 | mixup_mode: batch
23 | mixup_switch_prob: 0.5
24 | cutmix: 1.0
25 | reprob: 0.25
26 | remode: pixel
27 | amp: False
28 | model_ema: False
29 | batch_size: 128
30 | lr: 5e-4
31 | min_lr: 1e-5
32 | sched: cosine
33 | weight_decay: 5e-2
34 | epochs: 300
35 | cooldown_epochs: 10
36 | warmup_epochs: 5
37 | warmup_lr: 0.00001
38 | opt: adamw
39 | smoothing: 0.1
40 | num_aug_repeats: 3
41 | workers: 4
42 | qmodules:
43 | - "patch_embed.proj"
44 | - "blocks.0.attn"
45 | - "blocks.0.mlp"
46 | - "blocks.1.attn"
47 | - "blocks.1.mlp"
48 | - "blocks.2.attn"
49 | - "blocks.2.mlp"
50 | - "blocks.3.attn"
51 | - "blocks.3.mlp"
52 | - "blocks.4.attn"
53 | - "blocks.4.mlp"
54 | - "blocks.5.attn"
55 | - "blocks.5.mlp"
56 | - "blocks.6.attn"
57 | - "blocks.6.mlp"
58 | - "blocks.7.attn"
59 | - "blocks.7.mlp"
60 | - "blocks.8.attn"
61 | - "blocks.8.mlp"
62 | - "blocks.9.attn"
63 | - "blocks.9.mlp"
64 | - "blocks.10.attn"
65 | - "blocks.10.mlp"
66 | - "blocks.11.attn"
67 | - "blocks.11.mlp"
68 | - "head"
69 | - "head_dist"
70 |
--------------------------------------------------------------------------------
/configs/ours_imagenet_recipe.attn_q.yml:
--------------------------------------------------------------------------------
1 | dataset: imagenet
2 | num_classes: 1000
3 | img_size: 224
4 | mean:
5 | - 0.485
6 | - 0.456
7 | - 0.406
8 | std:
9 | - 0.229
10 | - 0.224
11 | - 0.225
12 | crop_pct: 0.9
13 | scale:
14 | - 0.08
15 | - 1.0
16 | interpolation: bicubic
17 | train_interpolation: random
18 | aa: rand-m9-mstd0.5-inc1
19 | mixup: 0.8
20 | mixup_off_epoch: 0
21 | mixup_prob: 1.0
22 | mixup_mode: batch
23 | mixup_switch_prob: 0.5
24 | cutmix: 1.0
25 | reprob: 0.25
26 | remode: pixel
27 | amp: False
28 | model: deit_tiny_patch16_224
29 | model_ema: False
30 | batch_size: 128
31 | lr: 5e-4
32 | min_lr: 1e-5
33 | sched: cosine
34 | weight_decay: 5e-2
35 | epochs: 128
36 | cooldown_epochs: 10
37 | warmup_epochs: 10
38 | warmup_lr: 0.00001
39 | opt: adamw
40 | smoothing: 0.1
41 | workers: 4
42 | wq_enable: True
43 | wq_bitw: 2
44 | aq_enable: True
45 | aq_bitw: 2
46 | num_aug_repeats: 0
47 | qmodules:
48 | - "patch_embed.proj"
49 | - "blocks.0.attn"
50 | - "blocks.0.mlp"
51 | - "blocks.1.attn"
52 | - "blocks.1.mlp"
53 | - "blocks.2.attn"
54 | - "blocks.2.mlp"
55 | - "blocks.3.attn"
56 | - "blocks.3.mlp"
57 | - "blocks.4.attn"
58 | - "blocks.4.mlp"
59 | - "blocks.5.attn"
60 | - "blocks.5.mlp"
61 | - "blocks.6.attn"
62 | - "blocks.6.mlp"
63 | - "blocks.7.attn"
64 | - "blocks.7.mlp"
65 | - "blocks.8.attn"
66 | - "blocks.8.mlp"
67 | - "blocks.9.attn"
68 | - "blocks.9.mlp"
69 | - "blocks.10.attn"
70 | - "blocks.10.mlp"
71 | - "blocks.11.attn"
72 | - "blocks.11.mlp"
73 | - "head"
74 | - "head_dist"
--------------------------------------------------------------------------------
/configs/swin_t_imagenet.attn_q.yml:
--------------------------------------------------------------------------------
1 | dataset: imagenet
2 | num_classes: 1000
3 | img_size: 224
4 | mean:
5 | - 0.485
6 | - 0.456
7 | - 0.406
8 | std:
9 | - 0.229
10 | - 0.224
11 | - 0.225
12 | crop_pct: 0.9
13 | scale:
14 | - 0.08
15 | - 1.0
16 | interpolation: bicubic
17 | train_interpolation: random
18 | aa: rand-m9-mstd0.5-inc1
19 | mixup: 0.8
20 | mixup_off_epoch: 0
21 | mixup_prob: 1.0
22 | mixup_mode: batch
23 | mixup_switch_prob: 0.5
24 | cutmix: 1.0
25 | reprob: 0.25
26 | remode: pixel
27 | amp: False
28 | model: swin_t
29 | model_ema: False
30 | batch_size: 512
31 | lr: 2e-4
32 | min_lr: 1e-5
33 | sched: cosine
34 | weight_decay: 0.0
35 | epochs: 300
36 | cooldown_epochs: 10
37 | warmup_epochs: 0
38 | warmup_lr: 0.00001
39 | opt: adamw
40 | smoothing: 0.1
41 | num_aug_repeats: 0
42 | workers: 1
43 | drop_path: 0.0
44 | qmodules:
45 | - "features.0.0"
46 | - "features.1.0.attn"
47 | - "features.1.0.mlp"
48 | - "features.1.1.attn"
49 | - "features.1.1.mlp"
50 | - "features.2.reduction"
51 | - "features.3.0.attn"
52 | - "features.3.0.mlp"
53 | - "features.3.1.attn"
54 | - "features.3.1.mlp"
55 | - "features.4.reduction"
56 | - "features.5.0.attn"
57 | - "features.5.0.mlp"
58 | - "features.5.1.attn"
59 | - "features.5.1.mlp"
60 | - "features.5.2.attn"
61 | - "features.5.2.mlp"
62 | - "features.5.3.attn"
63 | - "features.5.3.mlp"
64 | - "features.5.4.attn"
65 | - "features.5.4.mlp"
66 | - "features.5.5.attn"
67 | - "features.5.5.mlp"
68 | - "features.6.reduction"
69 | - "features.7.0.attn"
70 | - "features.7.0.mlp"
71 | - "features.7.1.attn"
72 | - "features.7.1.mlp"
73 | - "head"
74 |
--------------------------------------------------------------------------------
/eval_scripts/deit_s/w2a2.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 200 \
6 | --weight-decay 0.05 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 5.47e-4 \
9 | --warmup-epochs 5 \
10 | --mixup 0.0 --cutmix 0.0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 2 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 2 \
19 | --wq-mode statsq \
20 | --model_type deit \
21 | --quantized \
22 | --pretrained \
23 | --pretrained_initialized \
24 | --use-kd --teacher deit_small_distilled_patch16_224 \
25 | --kd_hard_and_soft 1 \
26 | --qk_reparam \
27 | --qk_reparam_type 0 \
28 | --teacher_pretrained \
29 | --resume your_path/model_saved/deit_s/w2a2/w2a2_deit_s_qkr_cga.pth.tar \
30 | --output ./outputs/w2a2_deit_s_qkreparam_eval/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '36969'
--------------------------------------------------------------------------------
/eval_scripts/deit_s/w3a3.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 200 \
6 | --weight-decay 0.0 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 3.2e-4 \
9 | --warmup-epochs 0 \
10 | --aq-enable \
11 | --aq-mode lsq \
12 | --aq-per-channel \
13 | --aq_clip_learnable \
14 | --aq-bitw 3 \
15 | --wq-enable \
16 | --wq-per-channel \
17 | --wq-bitw 3 \
18 | --wq-mode statsq \
19 | --model_type deit \
20 | --quantized \
21 | --pretrained \
22 | --pretrained_initialized \
23 | --use-kd --teacher deit_small_distilled_patch16_224 \
24 | --kd_hard_and_soft 1 \
25 | --qk_reparam \
26 | --qk_reparam_type 0 \
27 | --teacher_pretrained \
28 | --resume your_path/model_saved/deit_s/w3a3/w3a3_deit_s_qkr_cga.pth.tar \
29 | --output ./outputs/w3a3_deit_s_qkreparam_eval/ \
30 | --visible_gpu '0,1,2,4' \
31 | --world_size '4' \
32 | --tcp_port '36969'
33 |
34 |
--------------------------------------------------------------------------------
/eval_scripts/deit_s/w4a4.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 200 \
6 | --weight-decay 0.0 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 3.2e-4 \
9 | --warmup-epochs 0 \
10 | --aq-enable \
11 | --aq-mode lsq \
12 | --aq-per-channel \
13 | --aq_clip_learnable \
14 | --aq-bitw 4 \
15 | --wq-enable \
16 | --wq-per-channel \
17 | --wq-bitw 4 \
18 | --wq-mode statsq \
19 | --model_type deit \
20 | --quantized \
21 | --pretrained \
22 | --pretrained_initialized \
23 | --use-kd --teacher deit_small_distilled_patch16_224 \
24 | --kd_hard_and_soft 1 \
25 | --qk_reparam \
26 | --qk_reparam_type 0 \
27 | --teacher_pretrained \
28 | --resume your_path/model_saved/deit_s/w4a4/w4a4_deit_s_qkr_cga.pth.tar \
29 | --output ./outputs/w4a4_deit_s_qkreparam_eval/ \
30 | --visible_gpu '0,1,2,4' \
31 | --world_size '4' \
32 | --tcp_port '36969'
--------------------------------------------------------------------------------
/eval_scripts/deit_t/w2a2.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 140 \
6 | --aq-enable \
7 | --aq-mode lsq \
8 | --aq-per-channel \
9 | --aq_clip_learnable \
10 | --aq-bitw 2 \
11 | --wq-enable \
12 | --wq-per-channel \
13 | --wq-bitw 2 \
14 | --wq-mode statsq \
15 | --model_type deit \
16 | --quantized \
17 | --pretrained \
18 | --pretrained_initialized \
19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
20 | --kd_hard_and_soft 1 \
21 | --qk_reparam \
22 | --qk_reparam_type 0 \
23 | --teacher_pretrained \
24 | --resume your_path/model_saved/deit_t/w2a2/w2a2_deit_t_qkr_cga.pth.tar \
25 | --output ./outputs/w2a2_deit_t_qkreparam_eval/ \
26 | --visible_gpu '0,1,2,4' \
27 | --world_size '4' \
28 | --tcp_port '36969'
--------------------------------------------------------------------------------
/eval_scripts/deit_t/w3a3.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 140 \
6 | --aq-enable \
7 | --aq-mode lsq \
8 | --aq-per-channel \
9 | --aq_clip_learnable \
10 | --aq-bitw 3 \
11 | --wq-enable \
12 | --wq-per-channel \
13 | --wq-bitw 3 \
14 | --wq-mode statsq \
15 | --model_type deit \
16 | --quantized \
17 | --pretrained \
18 | --pretrained_initialized \
19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
20 | --kd_hard_and_soft 1 \
21 | --qk_reparam \
22 | --qk_reparam_type 0 \
23 | --teacher_pretrained \
24 | --resume your_path/model_saved/deit_t/w3a3/w3a3_deit_t_qkr_cga.pth.tar \
25 | --output ./outputs/w3a3_deit_t_qkreparam_eval/ \
26 | --visible_gpu '0,1,2,4' \
27 | --world_size '4' \
28 | --tcp_port '36969'
--------------------------------------------------------------------------------
/eval_scripts/deit_t/w4a4.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 140 \
6 | --aq-enable \
7 | --aq-mode lsq \
8 | --aq-per-channel \
9 | --aq_clip_learnable \
10 | --aq-bitw 4 \
11 | --wq-enable \
12 | --wq-per-channel \
13 | --wq-bitw 4 \
14 | --wq-mode statsq \
15 | --model_type deit \
16 | --quantized \
17 | --pretrained \
18 | --pretrained_initialized \
19 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
20 | --kd_hard_and_soft 1 \
21 | --qk_reparam \
22 | --qk_reparam_type 0 \
23 | --teacher_pretrained \
24 | --resume your_path/model_saved/deit_t/w4a4/w4a4_deit_t_qkr_cga.pth.tar \
25 | --output ./outputs/w4a4_deit_t_qkreparam_eval/ \
26 | --visible_gpu '0,1,2,4' \
27 | --world_size '4' \
28 | --tcp_port '36969'
--------------------------------------------------------------------------------
/eval_scripts/swin_t/w2a2.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 128 \
6 | --weight-decay 0.05 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 5.0e-4 \
9 | --warmup-epochs 5 \
10 | --mixup 0.0 --cutmix 0.0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 2 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 2 \
19 | --wq-mode statsq \
20 | --model_type swin \
21 | --teacher_type swin \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher swin_t \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --resume your_path/model_saved/swin_t/w2a2/w2a2_swin_t_qkr_cga.pth.tar \
31 | --output ./outputs/w2a2_swin_t_qkreparam/ \
32 | --visible_gpu '0,1,2,4' \
33 | --world_size '4' \
34 | --tcp_port '12345'
--------------------------------------------------------------------------------
/eval_scripts/swin_t/w3a3.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 128 \
6 | --weight-decay 0.0 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 2.0e-4 \
9 | --warmup-epochs 0 \
10 | --aq-enable \
11 | --aq-mode lsq \
12 | --aq-per-channel \
13 | --aq_clip_learnable \
14 | --aq-bitw 3 \
15 | --wq-enable \
16 | --wq-per-channel \
17 | --wq-bitw 3 \
18 | --wq-mode statsq \
19 | --model_type swin \
20 | --teacher_type swin \
21 | --quantized \
22 | --pretrained \
23 | --pretrained_initialized \
24 | --use-kd --teacher swin_t \
25 | --kd_hard_and_soft 1 \
26 | --qk_reparam \
27 | --qk_reparam_type 0 \
28 | --teacher_pretrained \
29 | --resume your_path/model_saved/swin_t/w3a3/w3a3_swin_t_qkr_cga.pth.tar \
30 | --output ./outputs/w3a3_swin_t_qkreparam/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '12345'
--------------------------------------------------------------------------------
/eval_scripts/swin_t/w4a4.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
2 | your_path/dataset/imagenet-1k/imagenet \
3 | --dataset 'torch/imagenet' \
4 | --epochs 300 \
5 | --batch-size 128 \
6 | --weight-decay 0.0 \
7 | --warmup-lr 1.0e-6 \
8 | --lr 2.0e-4 \
9 | --warmup-epochs 0 \
10 | --aq-enable \
11 | --aq-mode lsq \
12 | --aq-per-channel \
13 | --aq_clip_learnable \
14 | --aq-bitw 4 \
15 | --wq-enable \
16 | --wq-per-channel \
17 | --wq-bitw 4 \
18 | --wq-mode statsq \
19 | --model_type swin \
20 | --teacher_type swin \
21 | --quantized \
22 | --pretrained \
23 | --pretrained_initialized \
24 | --use-kd --teacher swin_t \
25 | --kd_hard_and_soft 1 \
26 | --qk_reparam \
27 | --qk_reparam_type 0 \
28 | --teacher_pretrained \
29 | --resume your_path/model_saved/swin_t/w4a4/w4a4_swin_t_qkr_cga.pth.tar \
30 | --output ./outputs/w4a4_swin_t_qkreparam/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '12345'
--------------------------------------------------------------------------------
/imgs/ofq_vit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nbasyl/OFQ/7ed37d1dd33d39395edbf49fcbbc52f678ecf961/imgs/ofq_vit.png
--------------------------------------------------------------------------------
/imgs/w2a2_deit_s_cga_vis_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nbasyl/OFQ/7ed37d1dd33d39395edbf49fcbbc52f678ecf961/imgs/w2a2_deit_s_cga_vis_.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import imp
2 | from .deit_vision_transformer import *
3 | from .deit import *
4 | from .swin import *
5 | from .quantization import *
6 | from .utils import *
7 |
8 |
--------------------------------------------------------------------------------
/src/deit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | from operator import is_
4 | import torch
5 | import torch.nn as nn
6 | from functools import partial
7 |
8 | from .deit_vision_transformer import VisionTransformer, _cfg
9 |
10 | # from timm.models.vision_transformer import VisionTransformer, _cfg
11 | from timm.models.registry import register_model
12 | from timm.models.layers import trunc_normal_
13 |
14 |
15 | __all__ = [
16 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224'
17 | ]
18 |
19 |
20 | class DistilledVisionTransformer(VisionTransformer):
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
24 | num_patches = self.patch_embed.num_patches
25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
27 |
28 | trunc_normal_(self.dist_token, std=.02)
29 | trunc_normal_(self.pos_embed, std=.02)
30 | self.head_dist.apply(self._init_weights)
31 |
32 | def forward_features(self, x):
33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
34 | x = self.patch_embed(x)
35 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
36 | if self.dist_token is None:
37 | x = torch.cat((cls_token, x), dim=1)
38 | else:
39 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
40 | x = self.pos_drop(x + self.pos_embed)
41 |
42 | attn_matrixs = []
43 | intermediate_features = []
44 | for block in self.blocks:
45 | x, attn_matrix = block(x)
46 | attn_matrixs.append(attn_matrix)
47 | intermediate_features.append(x)
48 |
49 | x = self.norm(x)
50 | if self.dist_token is None:
51 | return self.pre_logits(x[:, 0]), attn_matrixs, intermediate_features
52 | else:
53 | return x[:, 0], x[:, 1], attn_matrixs, intermediate_features
54 |
55 |
56 | def forward(self, x):
57 | x = self.forward_features(x)
58 |
59 | if self.head_dist is not None:
60 | cls_x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
61 | if self.training and not torch.jit.is_scripting():
62 | return (cls_x, x_dist), x[2]
63 | else:
64 | return (cls_x + x_dist) / 2, x[2]
65 | else:
66 | cls_x = self.head(x[0])
67 | return cls_x, x[1] #
68 |
69 |
70 |
71 |
72 | @register_model
73 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
74 | model = DistilledVisionTransformer(
75 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
76 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer= nn.GELU, **kwargs)
77 | model.default_cfg = _cfg()
78 | if pretrained and kwargs['num_classes'] == 1000:
79 | checkpoint = torch.hub.load_state_dict_from_url(
80 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
81 | map_location="cpu", check_hash=True
82 | )
83 | print("load pretrained")
84 | model.load_state_dict(checkpoint["model"])
85 | elif pretrained and kwargs['num_classes'] == 100:
86 | raise ValueError('No trained model provided')
87 | return model
88 |
89 | @register_model
90 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
91 | model = DistilledVisionTransformer(
92 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
93 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer= nn.GELU, **kwargs)
94 | model.default_cfg = _cfg()
95 | if pretrained and kwargs['num_classes'] == 1000:
96 | checkpoint = torch.hub.load_state_dict_from_url(
97 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
98 | map_location="cpu", check_hash=True
99 | )
100 | print("load pretrained")
101 | model.load_state_dict(checkpoint["model"])
102 | elif pretrained and kwargs['num_classes'] == 100:
103 | raise ValueError('No trained model provided')
104 | return model
105 |
106 |
--------------------------------------------------------------------------------
/src/deit_vision_transformer.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 |
3 | A PyTorch implement of Vision Transformers as described in:
4 |
5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6 | - https://arxiv.org/abs/2010.11929
7 |
8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9 | - https://arxiv.org/abs/2106.10270
10 |
11 | The official jax code is released and available at https://github.com/google-research/vision_transformer
12 |
13 | DeiT model defs and weights from https://github.com/facebookresearch/deit,
14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15 |
16 | Acknowledgments:
17 | * The paper authors for releasing code and weights, thanks!
18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19 | for some einops/einsum fun
20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22 |
23 | Hacked together by / Copyright 2020, Ross Wightman
24 | """
25 | import math
26 | import logging
27 | from functools import partial
28 | from collections import OrderedDict
29 | from copy import deepcopy
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
37 | from timm.models.layers import PatchEmbed, DropPath, trunc_normal_, lecun_normal_, to_2tuple
38 | from .registry import register_model
39 |
40 | _logger = logging.getLogger(__name__)
41 |
42 |
43 | def _cfg(url='', **kwargs):
44 | return {
45 | 'url': url,
46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
50 | **kwargs
51 | }
52 |
53 | class Mlp(nn.Module):
54 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
55 | """
56 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
57 | super().__init__()
58 | self.in_features = in_features
59 | self.hidden_features = hidden_features
60 | self.out_features = out_features
61 | self.drop = drop
62 |
63 | out_features = out_features or in_features
64 | hidden_features = hidden_features or in_features
65 | drop_probs = to_2tuple(drop)
66 |
67 | self.fc1 = nn.Linear(in_features, hidden_features)
68 | self.act = act_layer()
69 |
70 |
71 |
72 | self.drop1 = nn.Dropout(drop_probs[0])
73 | self.fc2 = nn.Linear(hidden_features, out_features)
74 | self.drop2 = nn.Dropout(drop_probs[1])
75 |
76 |
77 | def forward(self, x):
78 | x = self.fc1(x)
79 | x = self.act(x)
80 | x = self.drop1(x)
81 | x = self.fc2(x)
82 | x = self.drop2(x)
83 | return x
84 |
85 | class Attention(nn.Module):
86 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., qqkkvv = False):
87 | super().__init__()
88 | self.dim = dim
89 | self.num_heads = num_heads
90 | self.head_dim = dim // num_heads
91 | self.scale = self.head_dim ** -0.5
92 |
93 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94 | self.attn_drop = nn.Dropout(attn_drop)
95 | self.proj = nn.Linear(dim, dim)
96 | self.proj_drop = nn.Dropout(proj_drop)
97 |
98 | self.qqkkvv = qqkkvv
99 |
100 | def forward(self, x):
101 | B, N, C = x.shape
102 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
103 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
104 |
105 | if self.qqkkvv:
106 | q_score = torch.matmul(q, q.transpose(-1, -2))
107 | q_score = q_score / math.sqrt(self.head_dim)
108 | k_score = torch.matmul(k, k.transpose(-1, -2))
109 | k_score = k_score / math.sqrt(self.head_dim)
110 | v_score = torch.matmul(v, v.transpose(-1, -2))
111 | v_score = v_score / math.sqrt(self.head_dim)
112 |
113 | attn = (q @ k.transpose(-2, -1)) * self.scale
114 | attn_matrix = attn.softmax(dim=-1)
115 | attn = self.attn_drop(attn_matrix)
116 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
117 | x = self.proj(x)
118 | x = self.proj_drop(x)
119 | return x, (attn_matrix, q_score, k_score, v_score)
120 | else:
121 |
122 |
123 | attn = (q @ k.transpose(-2, -1)) * self.scale
124 | attn_matrix = attn.softmax(dim=-1)
125 | attn = self.attn_drop(attn_matrix)
126 |
127 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
128 | x = self.proj(x)
129 | x = self.proj_drop(x)
130 | return x, None
131 |
132 | class Block(nn.Module):
133 |
134 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
135 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, v_normalized = False, qk_normalized=False, qqkkvv = False, LN_affine = True):
136 | super().__init__()
137 | if norm_layer != None:
138 | self.norm1 = norm_layer(dim, elementwise_affine= LN_affine)
139 | else:
140 | self.norm1 = torch.nn.Identity()
141 |
142 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, qqkkvv = qqkkvv)
143 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
144 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
145 | if norm_layer != None:
146 | self.norm2 = norm_layer(dim, elementwise_affine= LN_affine)
147 | else:
148 | self.norm2 = torch.nn.Identity()
149 |
150 | mlp_hidden_dim = int(dim * mlp_ratio)
151 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
152 | self.qqkkvv = qqkkvv
153 |
154 | def forward(self, x):
155 | if self.qqkkvv:
156 | temp_x, attn_mtrx = self.attn(self.norm1(x))
157 | x = x + self.drop_path(temp_x)
158 | x = x + self.drop_path(self.mlp(self.norm2(x)))
159 | return x, attn_mtrx
160 | else:
161 | temp_x, _ = self.attn(self.norm1(x))
162 | x = x + self.drop_path(temp_x)
163 | x = x + self.drop_path(self.mlp(self.norm2(x)))
164 | return x, None
165 |
166 |
167 |
168 | class VisionTransformer(nn.Module):
169 | """ Vision Transformer
170 |
171 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
172 | - https://arxiv.org/abs/2010.11929
173 |
174 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
175 | - https://arxiv.org/abs/2012.12877
176 | """
177 |
178 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
179 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
180 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
181 | act_layer=None, weight_init='', qqkkvv = False, LN_affine = True):
182 | """
183 | Args:
184 | img_size (int, tuple): input image size
185 | patch_size (int, tuple): patch size
186 | in_chans (int): number of input channels
187 | num_classes (int): number of classes for classification head
188 | embed_dim (int): embedding dimension
189 | depth (int): depth of transformer
190 | num_heads (int): number of attention heads
191 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
192 | qkv_bias (bool): enable bias for qkv if True
193 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
194 | distilled (bool): model includes a distillation token and head as in DeiT models
195 | drop_rate (float): dropout rate
196 | attn_drop_rate (float): attention dropout rate
197 | drop_path_rate (float): stochastic depth rate
198 | embed_layer (nn.Module): patch embedding layer
199 | norm_layer: (nn.Module): normalization layer
200 | weight_init: (str): weight init scheme
201 | """
202 | super().__init__()
203 | self.num_classes = num_classes
204 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
205 | self.num_tokens = 2 if distilled else 1
206 | # norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
207 | # act_layer = act_layer or nn.GELU
208 | norm_layer = norm_layer
209 | act_layer = act_layer
210 | self.qqkkvv = qqkkvv
211 |
212 | self.patch_embed = embed_layer(
213 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
214 | num_patches = self.patch_embed.num_patches
215 |
216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
217 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
219 | self.pos_drop = nn.Dropout(p=drop_rate)
220 |
221 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
222 | self.blocks = nn.Sequential(*[
223 | Block(
224 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
225 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, qqkkvv =self.qqkkvv, LN_affine = LN_affine)
226 | for i in range(depth)])
227 |
228 | if norm_layer == None:
229 | self.norm = torch.nn.Identity()
230 | else:
231 | self.norm = norm_layer(embed_dim)
232 |
233 | # Representation layer
234 | if representation_size and not distilled:
235 | self.num_features = representation_size
236 | self.pre_logits = nn.Sequential(OrderedDict([
237 | ('fc', nn.Linear(embed_dim, representation_size)),
238 | ('act', nn.Tanh())
239 | ]))
240 | else:
241 | self.pre_logits = nn.Identity()
242 |
243 | # Classifier head(s)
244 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
245 | self.head_dist = None
246 | if distilled:
247 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
248 |
249 | self.init_weights(weight_init)
250 |
251 | def init_weights(self, mode=''):
252 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
253 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
254 | trunc_normal_(self.pos_embed, std=.02)
255 | if self.dist_token is not None:
256 | trunc_normal_(self.dist_token, std=.02)
257 | if mode.startswith('jax'):
258 | # leave cls token as zeros to match jax impl
259 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
260 | else:
261 | trunc_normal_(self.cls_token, std=.02)
262 | self.apply(_init_vit_weights)
263 |
264 | def _init_weights(self, m):
265 | # this fn left here for compat with downstream users
266 | _init_vit_weights(m)
267 |
268 | @torch.jit.ignore()
269 | def load_pretrained(self, checkpoint_path, prefix=''):
270 | _load_weights(self, checkpoint_path, prefix)
271 |
272 | @torch.jit.ignore
273 | def no_weight_decay(self):
274 | return {'pos_embed', 'cls_token', 'dist_token'}
275 |
276 | def get_classifier(self):
277 | if self.dist_token is None:
278 | return self.head
279 | else:
280 | return self.head, self.head_dist
281 |
282 | def reset_classifier(self, num_classes, global_pool=''):
283 | self.num_classes = num_classes
284 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
285 | if self.num_tokens == 2:
286 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
287 |
288 | def forward_features(self, x):
289 | x = self.patch_embed(x)
290 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
291 | if self.dist_token is None:
292 | x = torch.cat((cls_token, x), dim=1)
293 | else:
294 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
295 | x = self.pos_drop(x + self.pos_embed)
296 |
297 | attn_matrixs = []
298 | intermediate_features = []
299 | for block in self.blocks:
300 | x, attn_matrix = block(x)
301 | attn_matrixs.append(attn_matrix)
302 | intermediate_features.append(x)
303 |
304 | x = self.norm(x)
305 | if self.dist_token is None:
306 | return self.pre_logits(x[:, 0]), attn_matrixs, intermediate_features
307 | else:
308 | return x[:, 0], x[:, 1], attn_matrixs, intermediate_features
309 |
310 | def get_attn_matrix_and_intermediate_features(self, x):
311 |
312 | x = self.forward_features(x)
313 | if self.dist_token is None:
314 | return x[1], x[2]
315 | else:
316 | return x[2], x[3]
317 |
318 | def forward(self, x):
319 | x = self.forward_features(x)
320 | if self.head_dist is not None:
321 | cls_x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
322 | if self.training and not torch.jit.is_scripting():
323 | # during inference, return the average of both classifier predictions
324 | return (cls_x, x_dist), x[2]
325 | else:
326 | return (cls_x + x_dist) / 2, x[2]
327 | else:
328 | cls_x = self.head(x[0])
329 |
330 | return cls_x, x[1]
331 |
332 |
333 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
334 | """ ViT weight initialization
335 | * When called without n, head_bias, jax_impl args it will behave exactly the same
336 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
337 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
338 | """
339 | if isinstance(module, nn.Linear):
340 | if name.startswith('head'):
341 | nn.init.zeros_(module.weight)
342 | nn.init.constant_(module.bias, head_bias)
343 | elif name.startswith('pre_logits'):
344 | lecun_normal_(module.weight)
345 | nn.init.zeros_(module.bias)
346 | else:
347 | if jax_impl:
348 | nn.init.xavier_uniform_(module.weight)
349 | if module.bias is not None:
350 | if 'mlp' in name:
351 | nn.init.normal_(module.bias, std=1e-6)
352 | else:
353 | nn.init.zeros_(module.bias)
354 | else:
355 | trunc_normal_(module.weight, std=.02)
356 | if module.bias is not None:
357 | nn.init.zeros_(module.bias)
358 | elif jax_impl and isinstance(module, nn.Conv2d):
359 | # NOTE conv was left to pytorch default in my original init
360 | lecun_normal_(module.weight)
361 | if module.bias is not None:
362 | nn.init.zeros_(module.bias)
363 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
364 | if module.elementwise_affine == True:
365 | nn.init.zeros_(module.bias)
366 | nn.init.ones_(module.weight)
367 |
368 |
369 | @torch.no_grad()
370 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
371 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
372 | """
373 | import numpy as np
374 |
375 | def _n2p(w, t=True):
376 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
377 | w = w.flatten()
378 | if t:
379 | if w.ndim == 4:
380 | w = w.transpose([3, 2, 0, 1])
381 | elif w.ndim == 3:
382 | w = w.transpose([2, 0, 1])
383 | elif w.ndim == 2:
384 | w = w.transpose([1, 0])
385 | return torch.from_numpy(w)
386 |
387 | w = np.load(checkpoint_path)
388 | if not prefix and 'opt/target/embedding/kernel' in w:
389 | prefix = 'opt/target/'
390 |
391 | if hasattr(model.patch_embed, 'backbone'):
392 | # hybrid
393 | backbone = model.patch_embed.backbone
394 | stem_only = not hasattr(backbone, 'stem')
395 | stem = backbone if stem_only else backbone.stem
396 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
397 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
398 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
399 | if not stem_only:
400 | for i, stage in enumerate(backbone.stages):
401 | for j, block in enumerate(stage.blocks):
402 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
403 | for r in range(3):
404 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
405 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
406 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
407 | if block.downsample is not None:
408 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
409 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
410 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
411 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
412 | else:
413 | embed_conv_w = adapt_input_conv(
414 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
415 | model.patch_embed.proj.weight.copy_(embed_conv_w)
416 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
417 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
418 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
419 | if pos_embed_w.shape != model.pos_embed.shape:
420 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
421 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
422 | model.pos_embed.copy_(pos_embed_w)
423 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
424 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
425 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
426 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
427 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
428 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
429 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
430 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
431 | for i, block in enumerate(model.blocks.children()):
432 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
433 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
434 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
435 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
436 | block.attn.qkv.weight.copy_(torch.cat([
437 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
438 | block.attn.qkv.bias.copy_(torch.cat([
439 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
440 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
441 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
442 | for r in range(2):
443 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
444 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
445 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
446 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
447 |
448 |
449 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
450 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
451 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
452 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
453 | ntok_new = posemb_new.shape[1]
454 | if num_tokens:
455 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
456 | ntok_new -= num_tokens
457 | else:
458 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
459 | gs_old = int(math.sqrt(len(posemb_grid)))
460 | if not len(gs_new): # backwards compatibility
461 | gs_new = [int(math.sqrt(ntok_new))] * 2
462 | assert len(gs_new) >= 2
463 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
464 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
465 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
466 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
467 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
468 | return posemb
469 |
470 |
471 | def checkpoint_filter_fn(state_dict, model):
472 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
473 | out_dict = {}
474 | if 'model' in state_dict:
475 | # For deit models
476 | state_dict = state_dict['model']
477 | for k, v in state_dict.items():
478 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
479 | # For old models that I trained prior to conv based patchification
480 | O, I, H, W = model.patch_embed.proj.weight.shape
481 | v = v.reshape(O, -1, H, W)
482 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
483 | # To resize pos embedding when using model at different size from pretrained weights
484 | v = resize_pos_embed(
485 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
486 | out_dict[k] = v
487 | return out_dict
488 |
489 |
490 |
491 |
492 |
493 |
--------------------------------------------------------------------------------
/src/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .quantizer import *
3 | from .utils import adaptive_clip_grad, KLLossSoft, KLTokenMSELoss, KDLossSoftandHard, KDLossSoftandHard_qk, KDLossSoftandHard_qkv, KDLossSoftandHard_dampening, KDLossSoftandSoftTargetCE
4 |
--------------------------------------------------------------------------------
/src/quantization/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import replace_module_by_qmodule_deit, replace_module_by_qmodule_swin
2 |
--------------------------------------------------------------------------------
/src/quantization/modules/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .qlinear import LSQ_input
6 | from src.deit_vision_transformer import Attention as deit_attention
7 | from ..modules.qbias import LearnableBias
8 | from .qlinear import QLinear, LSQ_w_and_act_QLinear
9 | from ..quantizer.lsq import LsqQuantizer, LsqQuantizer4v
10 | from ..quantizer.statsq import StatsQuantizer, StatsQuantizer_specific_4_qkreparam_cga
11 |
12 | class QAttention(deit_attention):
13 | def __init__(self, m: deit_attention, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
14 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
15 | pretrained_initialized = False,
16 | **kwargs):
17 | assert type(m) == deit_attention
18 | super().__init__(
19 | dim = m.qkv.in_features,
20 | num_heads=m.num_heads,
21 | attn_drop=m.attn_drop.p,
22 | proj_drop=m.proj_drop.p,
23 | qqkkvv= m.qqkkvv
24 | )
25 | self.weight_bits = weight_bits
26 | self.input_bits = input_bits
27 | self.input_channelwise = input_channelwise
28 |
29 | self.qkv = QLinear(
30 | m = self.qkv,
31 | weight_bits = weight_bits,
32 | input_bits = input_bits,
33 | weight_channelwise = weight_channelwise,
34 | input_channelwise = input_channelwise,
35 | weight_quant_method = weight_quant_method,
36 | input_quant_method = input_quant_method,
37 | aq_learnable = aq_learnable, ## act
38 | wq_learnable = wq_learnable,## weight
39 | symmetric = True, ## act
40 | pretrained_initialized = pretrained_initialized
41 | )
42 | self.proj = QLinear(
43 | m = self.proj,
44 | weight_bits = weight_bits,
45 | input_bits = input_bits,
46 | weight_channelwise = weight_channelwise,
47 | input_channelwise = input_channelwise,
48 | weight_quant_method = weight_quant_method,
49 | input_quant_method = input_quant_method,
50 | aq_learnable = aq_learnable, ## act
51 | wq_learnable = wq_learnable,## weight
52 | symmetric = True, ## act
53 | pretrained_initialized = pretrained_initialized
54 | )
55 |
56 | self.quan_a_q_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
57 | self.quan_a_k_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
58 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
59 |
60 | self.move_qkv_b4 = LearnableBias(m.qkv.in_features*3)
61 | self.move_q_aft = LearnableBias(m.qkv.in_features)
62 | self.move_k_aft = LearnableBias(m.qkv.in_features)
63 | self.move_v_aft = LearnableBias(m.qkv.in_features)
64 |
65 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable)
66 |
67 | def forward(self, x):
68 | B, N, C = x.shape
69 | qkv = self.qkv(x) # B, N, 3*C
70 | if self.input_bits < 32:
71 | qkv = self.move_qkv_b4(qkv)
72 | qkv = qkv.reshape(
73 | B, N, 3, self.num_heads, C // self.num_heads
74 | ).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C // self.num_heads
75 | q, k, v = qkv[0], qkv[1], qkv[2] # B, num_heads, N, C//self.num_heads
76 |
77 | q = self.quan_a_q_fn(q) # quantize along N
78 | k = self.quan_a_k_fn(k) # quantize along N
79 |
80 | v = v.permute(0,2,1,3).reshape(B,N,C)
81 | v = self.quan_a_v_fn(v) # quantize along C
82 | v = v.reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)
83 | if self.input_bits < 32:
84 |
85 | q = q.permute(0, 2, 1, 3).reshape(B, N, C)
86 | k = k.permute(0, 2, 1, 3).reshape(B, N, C)
87 | v = v.permute(0, 2, 1, 3).reshape(B, N, C)
88 | q = self.move_q_aft(q)
89 | k = self.move_k_aft(k)
90 | v = self.move_v_aft(v)
91 |
92 | q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // self.num_heads
93 | k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
94 | v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
95 |
96 | attn_weights = (q @ k.transpose(-2, -1).contiguous()) * self.scale
97 | attn_prob = F.softmax(attn_weights, dim=-1)
98 |
99 | attn_prob = self.quan_a_softmax_fn(attn_prob)
100 | attn_prob = self.attn_drop(attn_prob)
101 |
102 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C)
103 | x = self.proj(x)
104 | x = self.proj_drop(x)
105 | return x, None
106 |
107 | class QAttention_qkreparam(deit_attention):
108 | def __init__(self, m: deit_attention, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
109 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
110 | pretrained_initialized = False,
111 | **kwargs):
112 | assert type(m) == deit_attention
113 | super().__init__(
114 | dim = m.qkv.in_features,
115 | num_heads=m.num_heads,
116 | attn_drop=m.attn_drop.p,
117 | proj_drop=m.proj_drop.p,
118 | qqkkvv= m.qqkkvv
119 | )
120 | self.weight_bits = weight_bits
121 | self.input_bits = input_bits
122 | self.input_channelwise = input_channelwise
123 |
124 | self.quant_x_4_qkv = LSQ_input(bit = input_bits, all_positive= False, learnable= aq_learnable, learanbaleBiasdim=m.qkv.in_features)
125 |
126 | self.q = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias=False)
127 | self.k = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias= False)
128 | self.v = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features)
129 |
130 | if pretrained_initialized:
131 | with torch.no_grad():
132 | q_k_v_dim = int(m.qkv.weight.shape[0]/3)
133 | copy_weight = m.qkv.weight.detach()
134 | copy_bias = m.qkv.bias.detach()
135 | self.q.weight.copy_(copy_weight[:q_k_v_dim*1,:])
136 | self.k.weight.copy_(copy_weight[q_k_v_dim*1:q_k_v_dim*2,:])
137 | self.v.weight.copy_(copy_weight[q_k_v_dim*2:q_k_v_dim*3,:])
138 | self.v.bias.copy_(copy_bias[q_k_v_dim*2:q_k_v_dim*3])
139 |
140 | self.qk_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable) # num_heads*in_features, in_features
141 | self.v_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable)#.to(m.weight.device)
142 |
143 | self.proj = QLinear(
144 | m = self.proj,
145 | weight_bits = weight_bits,
146 | input_bits = input_bits,
147 | weight_channelwise = weight_channelwise,
148 | input_channelwise = input_channelwise,
149 | weight_quant_method = weight_quant_method,
150 | input_quant_method = input_quant_method,
151 | aq_learnable = aq_learnable, ## act
152 | wq_learnable = wq_learnable,## weight
153 | symmetric = True, ## act
154 | pretrained_initialized = pretrained_initialized
155 | )
156 |
157 | ## THIS IS FOR QK PART
158 | self.quan_a_qkx_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) # for W_q*W_k*X -> B, num_heads, N, C//numheads
159 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
160 |
161 | # for W_q*W_k*X
162 | self.move_qkx_b4 = LearnableBias(self.num_heads * m.qkv.in_features)
163 | self.move_qkx_aft = LearnableBias(self.num_heads * m.qkv.in_features)
164 |
165 | # for v
166 | self.move_v_b4 = LearnableBias(m.qkv.in_features)
167 | self.move_v_aft = LearnableBias(m.qkv.in_features)
168 |
169 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable)
170 |
171 | ## no longer need qkv
172 | del self.qkv
173 |
174 | def forward(self, x):
175 | B, N, C = x.shape
176 | ## first quantize input x
177 | quant_x = self.quant_x_4_qkv(x)
178 | ## V
179 | quant_v_weight = self.v_quant(self.v.weight)
180 | v_out = nn.functional.linear(quant_x, quant_v_weight)
181 | v_out += self.v.bias.view(1, -1).expand_as(v_out) # B,N,C
182 |
183 | ## TO MULTI_HEAD V
184 | v_out = self.move_v_b4(v_out)
185 | v_out = self.quan_a_v_fn(v_out)
186 | v_out = self.move_v_aft(v_out) # B, N, C
187 | v = v_out.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // num_heads
188 |
189 | ## TO MULTI_HEAD QK
190 | multi_head_q_weight = self.q.weight.reshape(self.num_heads,self.q.out_features // self.num_heads, self.q.in_features) # out_feat , in_feat -> num_heads, out_feat // num_heads, in_feat
191 | multi_head_k_weight = self.k.weight.reshape(self.num_heads,self.k.out_features // self.num_heads, self.k.in_features)
192 |
193 | multi_head_qk = multi_head_q_weight.transpose(-2, -1).contiguous() @ multi_head_k_weight # num_heads, in_features, in_features
194 | multi_head_qk = multi_head_qk.reshape(self.num_heads*self.q.out_features,self.q.in_features)# num_heads*in_features, in_features
195 | multi_head_qk_qunat = self.qk_quant(multi_head_qk)
196 | multi_head_qk_qunat = multi_head_qk_qunat.reshape(self.num_heads, self.q.in_features,self.q.in_features) #num_heads, in_features, in_features
197 |
198 | ## W_qk@X^T torch.einsum('BNC,BACD -> BAND',quant_x,quant_qkx)
199 | # quant_x: B,C,N
200 | qkx = torch.einsum('HDC, BCN -> BHDN', multi_head_qk_qunat, quant_x.transpose(-2, -1).contiguous() ) # B, num_heads, in_features, N
201 | qkx = qkx.permute(0,3,1,2).reshape(B,N, self.num_heads * C) #B, N, num_heads*in_features
202 | qkx = self.move_qkx_b4(qkx)
203 | qkx = qkx.reshape(B,N*self.num_heads, C)
204 | quant_qkx = self.quan_a_qkx_fn(qkx) # B, num_heads*N, in_features
205 | quant_qkx = quant_qkx.reshape(B,N,self.num_heads * C) #B, N, num_heads*in_features
206 | quant_qkx = self.move_qkx_aft(quant_qkx) #B, N, num_heads*in_features
207 | quant_qkx = quant_qkx.reshape(B, N, self.num_heads, -1).permute(0, 2, 3, 1) # B, num_heads, in_features, N
208 |
209 | ## x@W_qk@X^T, quant_x: B,N,C and quant_qkx: B,num_heads, C, N
210 | xqkx = torch.einsum('BNC,BHCD -> BHND',quant_x,quant_qkx) # B, num_heads, N, N
211 | # B, num_heads, N, C // self.num_heads
212 | value_4_softmax = xqkx
213 | attn_weights = (value_4_softmax) * self.scale
214 | attn_prob = F.softmax(attn_weights, dim=-1)
215 |
216 | attn_prob = self.quan_a_softmax_fn(attn_prob)
217 | attn_prob = self.attn_drop(attn_prob)
218 |
219 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C)
220 | x = self.proj(x)
221 | x = self.proj_drop(x)
222 | return x, None
223 |
224 | class QAttention_qkreparam_4_cga(deit_attention):
225 | def __init__(self, m: deit_attention, clip_val=2.5, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
226 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
227 | pretrained_initialized = False, boundaryRange = 0.005,
228 | **kwargs):
229 | assert type(m) == deit_attention
230 | super().__init__(
231 | dim = m.qkv.in_features,
232 | num_heads=m.num_heads,
233 | attn_drop=m.attn_drop.p,
234 | proj_drop=m.proj_drop.p,
235 | qqkkvv= m.qqkkvv
236 | )
237 | self.weight_bits = weight_bits
238 | self.input_bits = input_bits
239 | self.input_channelwise = input_channelwise
240 |
241 | self.quant_x_4_qkv = LSQ_input(bit = input_bits, all_positive= False, learnable= aq_learnable, learanbaleBiasdim=m.qkv.in_features)
242 |
243 | self.q = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias=False)
244 | self.k = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features, bias= False)
245 | self.v = nn.Linear(in_features=m.qkv.in_features,out_features=m.qkv.in_features)
246 |
247 | if pretrained_initialized:
248 | with torch.no_grad():
249 | q_k_v_dim = int(m.qkv.weight.shape[0]/3)
250 | copy_weight = m.qkv.weight.detach()
251 | copy_bias = m.qkv.bias.detach()
252 | self.q.weight.copy_(copy_weight[:q_k_v_dim*1,:])
253 | self.k.weight.copy_(copy_weight[q_k_v_dim*1:q_k_v_dim*2,:])
254 | self.v.weight.copy_(copy_weight[q_k_v_dim*2:q_k_v_dim*3,:])
255 | self.v.bias.copy_(copy_bias[q_k_v_dim*2:q_k_v_dim*3])
256 |
257 | self.qk_quant = StatsQuantizer_specific_4_qkreparam_cga(num_bits=self.weight_bits, clip_learnable=wq_learnable,boundaryRange=boundaryRange) # num_heads*in_features, in_features
258 | self.v_quant = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable)#.to(m.weight.device)
259 |
260 | self.proj = QLinear(
261 | m = self.proj,
262 | weight_bits = weight_bits,
263 | input_bits = input_bits,
264 | weight_channelwise = weight_channelwise,
265 | input_channelwise = input_channelwise,
266 | weight_quant_method = weight_quant_method,
267 | input_quant_method = input_quant_method,
268 | aq_learnable = aq_learnable, ## act
269 | wq_learnable = wq_learnable,## weight
270 | symmetric = True, ## act
271 | pretrained_initialized = pretrained_initialized
272 | )
273 |
274 | ## THIS IS FOR QK PART
275 | self.quan_a_qkx_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable) # for W_q*W_k*X -> B, num_heads, N, C//numheads
276 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
277 |
278 | # for W_q*W_k*X
279 | self.move_qkx_b4 = LearnableBias(self.num_heads * m.qkv.in_features)
280 | self.move_qkx_aft = LearnableBias(self.num_heads * m.qkv.in_features)
281 |
282 | # for v
283 | self.move_v_b4 = LearnableBias(m.qkv.in_features)
284 | self.move_v_aft = LearnableBias(m.qkv.in_features)
285 |
286 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable)
287 |
288 | ## no longer need qkv
289 | del self.qkv
290 |
291 | def forward(self, x):
292 | B, N, C = x.shape
293 | ## first quantize input x
294 | quant_x = self.quant_x_4_qkv(x)
295 | ## V
296 | quant_v_weight = self.v_quant(self.v.weight)
297 | v_out = nn.functional.linear(quant_x, quant_v_weight)
298 | v_out += self.v.bias.view(1, -1).expand_as(v_out) # B,N,C
299 |
300 | ## TO MULTI_HEAD V
301 | v_out = self.move_v_b4(v_out)
302 | v_out = self.quan_a_v_fn(v_out)
303 | v_out = self.move_v_aft(v_out) # B, N, C
304 | v = v_out.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // num_heads
305 |
306 | ## TO MULTI_HEAD QK
307 | multi_head_q_weight = self.q.weight.reshape(self.num_heads,self.q.out_features // self.num_heads, self.q.in_features) # out_feat , in_feat -> num_heads, out_feat // num_heads, in_feat
308 | multi_head_k_weight = self.k.weight.reshape(self.num_heads,self.k.out_features // self.num_heads, self.k.in_features)
309 |
310 | multi_head_qk = multi_head_q_weight.transpose(-2, -1).contiguous() @ multi_head_k_weight # num_heads, in_features, in_features
311 | multi_head_qk = multi_head_qk.reshape(self.num_heads*self.q.out_features,self.q.in_features)# num_heads*in_features, in_features
312 | multi_head_qk_qunat = self.qk_quant(multi_head_qk)
313 | multi_head_qk_qunat = multi_head_qk_qunat.reshape(self.num_heads, self.q.in_features,self.q.in_features) #num_heads, in_features, in_features
314 |
315 | ## W_qk@X^T torch.einsum('BNC,BACD -> BAND',quant_x,quant_qkx)
316 | # quant_x: B,C,N
317 | qkx = torch.einsum('HDC, BCN -> BHDN', multi_head_qk_qunat, quant_x.transpose(-2, -1).contiguous() ) # B, num_heads, in_features, N
318 | qkx = qkx.permute(0,3,1,2).reshape(B,N, self.num_heads * C) #B, N, num_heads*in_features
319 | qkx = self.move_qkx_b4(qkx)
320 | qkx = qkx.reshape(B,N*self.num_heads, C)
321 | quant_qkx = self.quan_a_qkx_fn(qkx) # B, num_heads*N, in_features
322 | quant_qkx = quant_qkx.reshape(B,N,self.num_heads * C) #B, N, num_heads*in_features
323 | quant_qkx = self.move_qkx_aft(quant_qkx) #B, N, num_heads*in_features
324 | quant_qkx = quant_qkx.reshape(B, N, self.num_heads, -1).permute(0, 2, 3, 1) # B, num_heads, in_features, N
325 |
326 | ## x@W_qk@X^T, quant_x: B,N,C and quant_qkx: B,num_heads, C, N
327 | xqkx = torch.einsum('BNC,BHCD -> BHND',quant_x,quant_qkx) # B, num_heads, N, N
328 | # B, num_heads, N, C // self.num_heads
329 | value_4_softmax = xqkx
330 | attn_weights = (value_4_softmax) * self.scale
331 | attn_prob = F.softmax(attn_weights, dim=-1)
332 |
333 | attn_prob = self.quan_a_softmax_fn(attn_prob)
334 | attn_prob = self.attn_drop(attn_prob)
335 |
336 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C)
337 | x = self.proj(x)
338 | x = self.proj_drop(x)
339 | return x, None
340 |
341 | class QAttention_lsq(deit_attention):
342 | def __init__(self, m: deit_attention, clip_val=2.5, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
343 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="lsq", input_quant_method="lsq",
344 | pretrained_initialized = False,
345 | **kwargs):
346 | assert type(m) == deit_attention
347 | super().__init__(
348 | dim = m.qkv.in_features,
349 | num_heads=m.num_heads,
350 | attn_drop=m.attn_drop.p,
351 | proj_drop=m.proj_drop.p,
352 | qqkkvv= m.qqkkvv
353 | )
354 | self.weight_bits = weight_bits
355 | self.input_bits = input_bits
356 | self.input_channelwise = input_channelwise
357 |
358 | self.qkv = LSQ_w_and_act_QLinear(
359 | m = self.qkv,
360 | weight_bits = weight_bits,
361 | input_bits = input_bits,
362 | weight_channelwise = weight_channelwise,
363 | input_channelwise = input_channelwise,
364 | weight_quant_method = weight_quant_method,
365 | input_quant_method = input_quant_method,
366 | aq_learnable = aq_learnable, ## act
367 | wq_learnable = wq_learnable,## weight
368 | symmetric = True, ## act
369 | pretrained_initialized = pretrained_initialized
370 | )
371 | self.proj = LSQ_w_and_act_QLinear(
372 | m = self.proj,
373 | weight_bits = weight_bits,
374 | input_bits = input_bits,
375 | weight_channelwise = weight_channelwise,
376 | input_channelwise = input_channelwise,
377 | weight_quant_method = weight_quant_method,
378 | input_quant_method = input_quant_method,
379 | aq_learnable = aq_learnable, ## act
380 | wq_learnable = wq_learnable,## weight
381 | symmetric = True, ## act
382 | pretrained_initialized = pretrained_initialized
383 | )
384 |
385 | self.quan_a_q_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
386 | self.quan_a_k_fn = LsqQuantizer(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
387 | self.quan_a_v_fn = LsqQuantizer4v(bit=input_bits,all_positive=False,per_channel=True, learnable = aq_learnable)
388 |
389 | self.move_qkv_b4 = LearnableBias(m.qkv.in_features*3)
390 | self.move_q_aft = LearnableBias(m.qkv.in_features)
391 | self.move_k_aft = LearnableBias(m.qkv.in_features)
392 | self.move_v_aft = LearnableBias(m.qkv.in_features)
393 |
394 | self.quan_a_softmax_fn = LsqQuantizer(bit=input_bits,all_positive=True,per_channel=True, learnable = aq_learnable)
395 |
396 | def forward(self, x):
397 | B, N, C = x.shape
398 | qkv = self.qkv(x) # B, N, 3*C
399 | if self.input_bits < 32:
400 | qkv = self.move_qkv_b4(qkv)
401 | qkv = qkv.reshape(
402 | B, N, 3, self.num_heads, C // self.num_heads
403 | ).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C // self.num_heads
404 | q, k, v = qkv[0], qkv[1], qkv[2]
405 |
406 | q = self.quan_a_q_fn(q)
407 | k = self.quan_a_k_fn(k)
408 | v = v.permute(0,2,1,3).reshape(B,N,C)
409 | v = self.quan_a_v_fn(v) # quantize along C
410 | v = v.reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)
411 |
412 |
413 |
414 | if self.input_bits < 32:
415 |
416 | q = q.permute(0, 2, 1, 3).reshape(B, N, C)
417 | k = k.permute(0, 2, 1, 3).reshape(B, N, C)
418 | v = v.permute(0, 2, 1, 3).reshape(B, N, C)
419 | q = self.move_q_aft(q)
420 | k = self.move_k_aft(k)
421 | v = self.move_v_aft(v)
422 |
423 | q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, num_heads, N, C // self.num_heads
424 | k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
425 | v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
426 |
427 |
428 | attn_weights = (q @ k.transpose(-2, -1).contiguous()) * self.scale
429 | attn_prob = F.softmax(attn_weights, dim=-1)
430 |
431 | attn_prob = self.quan_a_softmax_fn(attn_prob)
432 | attn_prob = self.attn_drop(attn_prob)
433 |
434 | x = (attn_prob @ v).transpose(1, 2).reshape(B, N, C)
435 | x = self.proj(x)
436 | x = self.proj_drop(x)
437 |
438 | return x, None
439 |
440 |
--------------------------------------------------------------------------------
/src/quantization/modules/qbias.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LearnableBias(nn.Module):
6 | def __init__(self, out_chn):
7 | super(LearnableBias, self).__init__()
8 | self.bias = nn.Parameter(torch.zeros(out_chn), requires_grad=True)
9 |
10 | def forward(self, x):
11 | out = x + self.bias.expand_as(x)
12 |
13 | return out
14 |
15 | class LearnableBias4img(nn.Module):
16 | def __init__(self, out_chn):
17 | super(LearnableBias4img, self).__init__()
18 | self.bias = nn.Parameter(torch.zeros(out_chn), requires_grad=True)
19 |
20 | def forward(self, x):
21 | out = x + self.bias.reshape(x.shape[-1],x.shape[-2]).expand_as(x)
22 |
23 | return out
24 |
25 |
--------------------------------------------------------------------------------
/src/quantization/modules/qlinear.py:
--------------------------------------------------------------------------------
1 | import imp
2 | from numpy import clip
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .qbias import LearnableBias, LearnableBias4img
7 | from ..quantizer.statsq import StatsQuantizer
8 | from ...deit_vision_transformer import Mlp
9 | from timm.models.layers import to_2tuple
10 | from ..quantizer.lsq import LsqQuantizer, LsqQuantizerWeight, LsqQuantizer4img, LsqQuantizer4Conv2d, LsqQuantizer4head_input
11 |
12 | class LSQ_input(nn.Module):
13 | def __init__(self, bit=2,all_positive=False, learnable = True, learanbaleBiasdim = 192):
14 | super().__init__()
15 | self.input_bits = bit
16 | self.all_positive = all_positive
17 | self.learnable = learnable
18 | self.input_quant_fn = LsqQuantizer(bit=bit,all_positive=all_positive, learnable = learnable)
19 | self.move_b4 = LearnableBias(learanbaleBiasdim)
20 | self.move_aft = LearnableBias(learanbaleBiasdim)
21 | def forward(self, input):
22 |
23 | input = self.move_b4(input)
24 | input = self.input_quant_fn(input)
25 | input = self.move_aft(input)
26 | return input
27 |
28 | class QLinear(nn.Linear):
29 |
30 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
31 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
32 | pretrained_initialized = False,
33 | **kwargs):
34 | super(QLinear, self).__init__(m.in_features, m.out_features,bias=True)
35 | self.weight_bits = weight_bits
36 | self.input_bits = input_bits
37 | self.aq_learnable = aq_learnable
38 | self.wq_learnable = wq_learnable
39 | self.symmetric = symmetric
40 | self.weight_channelwise = weight_channelwise # not gonna used atm
41 | self.input_channelwise = input_channelwise
42 | self.weight_quant_method = weight_quant_method
43 | self.input_quant_method = input_quant_method
44 | self.input_quant_fn = LsqQuantizer(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable)
45 | self.pretrained_initialized = pretrained_initialized
46 | if pretrained_initialized != False:
47 | self.weight = torch.nn.Parameter(m.weight.detach())
48 | if m.bias is not None:
49 | self.bias = torch.nn.Parameter(m.bias.detach())
50 | if weight_quant_method == 'statsq':
51 | self.statsq_fn = StatsQuantizer(num_bits=self.weight_bits, clip_learnable=wq_learnable).to(m.weight.device)
52 | else:
53 | raise ValueError("Unknown quant_method")
54 |
55 | self.move_b4 = LearnableBias(self.weight.shape[1])
56 | self.move_aft = LearnableBias(self.weight.shape[1])
57 |
58 | def forward(self, input):
59 |
60 | # quantize weight
61 | if self.weight_quant_method == 'statsq':
62 | weight = self.statsq_fn(self.weight)
63 | else:
64 | raise ValueError("Unknown quant_method")
65 | # quantize input
66 | input = self.move_b4(input)
67 | input = self.input_quant_fn(input)
68 | input = self.move_aft(input)
69 | out = nn.functional.linear(input, weight)
70 | if not self.bias is None:
71 | out += self.bias.view(1, -1).expand_as(out)
72 |
73 | return out
74 |
75 | def extra_repr(self):
76 | return (
77 | f"act_bit={self.input_bits}, "
78 | f"weight_bit={self.weight_bits}, "
79 | f"act_all_positive={not self.symmetric}, "
80 | f"wq_learnable={self.wq_learnable}, "
81 | f"aq_learnable={self.aq_learnable}, "
82 | f"weight_channelwise ={self.weight_channelwise}, "
83 | f"input_channelwise ={self.input_channelwise}, "
84 | f"weight_quant_method={self.weight_quant_method}, "
85 | f"activation_quant_method={self.input_quant_method}, "
86 | f"pretrained_initialized = {self.pretrained_initialized}"
87 | )
88 |
89 | class QMLP(Mlp):
90 | def __init__(self, *kargs, m: Mlp, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", act_layer=nn.GELU,
91 | pretrained_initialized = False,
92 | **kwargs):
93 | super().__init__(
94 | in_features = m.in_features,
95 | hidden_features = m.hidden_features,
96 | out_features = m.out_features,
97 | drop = m.drop
98 | )
99 |
100 | out_features = m.out_features or m.in_features
101 | hidden_features = m.hidden_features or m.in_features
102 | drop_probs = to_2tuple(self.drop)
103 |
104 | self.fc1 = QLinear(m = m.fc1,weight_bits=weight_bits,input_bits=input_bits,
105 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=True,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise,
106 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized)
107 |
108 | self.act_layer = act_layer
109 |
110 | if act_layer != 'rprelu':
111 | if act_layer != 'None':
112 | self.act = act_layer()
113 | else:
114 | self.act = nn.Identity()
115 |
116 |
117 | self.drop1 = nn.Dropout(drop_probs[0])
118 | self.fc2 = QLinear(m = m.fc2,weight_bits=weight_bits,input_bits=input_bits,
119 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=False,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise,
120 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized)
121 | self.drop2 = nn.Dropout(drop_probs[1])
122 |
123 | def forward(self, x):
124 |
125 | x = self.fc1(x)
126 | if self.act_layer != 'rprelu':
127 | x = self.act(x)
128 | else:
129 | x = self.move1(x)
130 | x = self.act(x)
131 | x = self.move2(x)
132 |
133 | x = self.drop1(x)
134 | x = self.fc2(x)
135 | x = self.drop2(x)
136 | return x
137 |
138 | class LSQ_QConv2d(nn.Conv2d):
139 | def __init__(self, *kargs, m: torch.nn.Conv2d, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
140 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="lsq", input_quant_method="lsq",
141 | pretrained_initialized = False,
142 | **kwargs):
143 | super(LSQ_QConv2d, self).__init__(in_channels = m.in_channels, out_channels = m.out_channels, kernel_size = m.kernel_size
144 | , stride=m.stride, padding=m.padding, dilation=m.dilation, groups=m.groups, bias=True)
145 | self.weight_bits = weight_bits
146 | self.input_bits = input_bits
147 | self.aq_learnable = aq_learnable
148 | self.wq_learnable = wq_learnable
149 | self.symmetric = symmetric
150 | self.weight_channelwise = weight_channelwise # not gonna used atm
151 | self.input_channelwise = input_channelwise
152 | self.weight_quant_method = weight_quant_method
153 | self.input_quant_method = input_quant_method
154 | self.input_quant_fn = LsqQuantizer4img(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable)
155 | self.pretrained_initialized = pretrained_initialized
156 | if pretrained_initialized != False:
157 | self.weight = torch.nn.Parameter(m.weight.detach())
158 | if m.bias is not None:
159 | self.bias = torch.nn.Parameter(m.bias.detach())
160 |
161 | self.lsqw_fn = LsqQuantizer4Conv2d(bit=self.weight_bits, learnable=aq_learnable).to(m.weight.device)
162 |
163 | self.move_b4 = LearnableBias4img(224*224) # 3x224x224
164 | self.move_aft = LearnableBias4img(224*224) # 3x224x224
165 |
166 | def forward(self, input):
167 | # quantize weight
168 | weight = self.lsqw_fn(self.weight)
169 |
170 | # quantize input
171 | input = self.move_b4(input)
172 | input = self.input_quant_fn(input)
173 | input = self.move_aft(input)
174 | out = nn.functional.conv2d(input, weight, self.bias, self.stride,
175 | self.padding, self.dilation, self.groups)
176 |
177 | return out
178 |
179 | def extra_repr(self):
180 | return (
181 | f"act_bit={self.input_bits}, "
182 | f"weight_bit={self.weight_bits}, "
183 | f"act_all_positive={not self.symmetric}, "
184 | f"wq_learnable={self.wq_learnable}, "
185 | f"aq_learnable={self.aq_learnable}, "
186 | f"weight_channelwise ={self.weight_channelwise}, "
187 | f"input_channelwise ={self.input_channelwise}, "
188 | f"weight_quant_method={self.weight_quant_method}, "
189 | f"activation_quant_method={self.input_quant_method}, "
190 | f"pretrained_initialized = {self.pretrained_initialized}"
191 | )
192 |
193 | class LSQ_QLinear4head(nn.Linear):
194 |
195 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
196 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
197 | pretrained_initialized = False,
198 | **kwargs):
199 | super(LSQ_QLinear4head, self).__init__(m.in_features, m.out_features,bias=True)
200 | self.weight_bits = weight_bits
201 | self.input_bits = input_bits
202 | self.aq_learnable = aq_learnable
203 | self.wq_learnable = wq_learnable
204 | self.symmetric = symmetric
205 | self.weight_channelwise = weight_channelwise # not gonna used atm
206 | self.input_channelwise = input_channelwise
207 | self.weight_quant_method = weight_quant_method
208 | self.input_quant_method = input_quant_method
209 | self.input_quant_fn = LsqQuantizer4head_input(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable)
210 | self.pretrained_initialized = pretrained_initialized
211 | if pretrained_initialized != False:
212 | self.weight = torch.nn.Parameter(m.weight.detach())
213 | if m.bias is not None:
214 | self.bias = torch.nn.Parameter(m.bias.detach())
215 | if weight_quant_method == 'lsq':
216 | self.lsqw_fn = LsqQuantizerWeight(bit=self.weight_bits, per_channel=weight_channelwise ,learnable=wq_learnable).to(m.weight.device)
217 | else:
218 | raise ValueError("Unknown quant_method")
219 |
220 | self.move_b4 = LearnableBias(self.weight.shape[1])
221 | self.move_aft = LearnableBias(self.weight.shape[1])
222 |
223 | def forward(self, input):
224 |
225 | # quantize weight
226 | if self.weight_quant_method == 'lsq':
227 | weight = self.lsqw_fn(self.weight)
228 | else:
229 | raise ValueError("Unknown quant_method")
230 | # quantize input
231 | input = self.move_b4(input)
232 | input = self.input_quant_fn(input)
233 | input = self.move_aft(input)
234 | out = nn.functional.linear(input, weight)
235 | if not self.bias is None:
236 | out += self.bias.view(1, -1).expand_as(out)
237 |
238 | return out
239 |
240 | def extra_repr(self):
241 | return (
242 | f"act_bit={self.input_bits}, "
243 | f"weight_bit={self.weight_bits}, "
244 | f"act_all_positive={not self.symmetric}, "
245 | f"wq_learnable={self.wq_learnable}, "
246 | f"aq_learnable={self.aq_learnable}, "
247 | f"weight_channelwise ={self.weight_channelwise}, "
248 | f"input_channelwise ={self.input_channelwise}, "
249 | f"weight_quant_method={self.weight_quant_method}, "
250 | f"activation_quant_method={self.input_quant_method}, "
251 | f"pretrained_initialized = {self.pretrained_initialized}"
252 | )
253 |
254 | class LSQ_w_and_act_QLinear(nn.Linear):
255 |
256 | def __init__(self, *kargs, m: torch.nn.Linear, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
257 | symmetric=True, weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq",
258 | pretrained_initialized = False,
259 | **kwargs):
260 | super(LSQ_w_and_act_QLinear, self).__init__(m.in_features, m.out_features,bias=True)
261 | self.weight_bits = weight_bits
262 | self.input_bits = input_bits
263 | self.aq_learnable = aq_learnable
264 | self.wq_learnable = wq_learnable
265 | self.symmetric = symmetric
266 | self.weight_channelwise = weight_channelwise # not gonna used atm
267 | self.input_channelwise = input_channelwise
268 | self.weight_quant_method = weight_quant_method
269 | self.input_quant_method = input_quant_method
270 | self.input_quant_fn = LsqQuantizer(bit=input_bits,all_positive=(symmetric==False), learnable = aq_learnable)
271 | self.pretrained_initialized = pretrained_initialized
272 | if pretrained_initialized != False:
273 | self.weight = torch.nn.Parameter(m.weight.detach())
274 | if m.bias is not None:
275 | self.bias = torch.nn.Parameter(m.bias.detach())
276 | if weight_quant_method == 'lsq':
277 | self.lsqw_fn = LsqQuantizerWeight(bit=self.weight_bits, per_channel=weight_channelwise ,learnable=wq_learnable).to(m.weight.device)
278 | else:
279 | raise ValueError("Unknown quant_method")
280 |
281 | self.move_b4 = LearnableBias(self.weight.shape[1])
282 | self.move_aft = LearnableBias(self.weight.shape[1])
283 |
284 | def forward(self, input):
285 |
286 | # quantize weight
287 | if self.weight_quant_method == 'lsq':
288 | weight = self.lsqw_fn(self.weight)
289 | else:
290 | raise ValueError("Unknown quant_method")
291 | # quantize input
292 | input = self.move_b4(input)
293 | input = self.input_quant_fn(input)
294 | input = self.move_aft(input)
295 | out = nn.functional.linear(input, weight)
296 | if not self.bias is None:
297 | out += self.bias.view(1, -1).expand_as(out)
298 |
299 | return out
300 |
301 | def extra_repr(self):
302 | return (
303 | f"act_bit={self.input_bits}, "
304 | f"weight_bit={self.weight_bits}, "
305 | f"act_all_positive={not self.symmetric}, "
306 | f"wq_learnable={self.wq_learnable}, "
307 | f"aq_learnable={self.aq_learnable}, "
308 | f"weight_channelwise ={self.weight_channelwise}, "
309 | f"input_channelwise ={self.input_channelwise}, "
310 | f"weight_quant_method={self.weight_quant_method}, "
311 | f"activation_quant_method={self.input_quant_method}, "
312 | f"pretrained_initialized = {self.pretrained_initialized}"
313 | )
314 |
315 | class LSQ_w_and_act_QMLP(Mlp):
316 | def __init__(self, *kargs, m: Mlp, weight_bits=8, input_bits=8, aq_learnable=True, wq_learnable = True,
317 | weight_channelwise=True, input_channelwise=True, weight_quant_method="statsq", input_quant_method="lsq", act_layer=nn.GELU,
318 | pretrained_initialized = False,
319 | **kwargs):
320 | super().__init__(
321 | in_features = m.in_features,
322 | hidden_features = m.hidden_features,
323 | out_features = m.out_features,
324 | drop = m.drop
325 | )
326 |
327 | out_features = m.out_features or m.in_features
328 | hidden_features = m.hidden_features or m.in_features
329 | drop_probs = to_2tuple(self.drop)
330 |
331 | self.fc1 = LSQ_w_and_act_QLinear(m = m.fc1,weight_bits=weight_bits,input_bits=input_bits,
332 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=True,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise,
333 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized)
334 |
335 | self.act_layer = act_layer
336 |
337 | if act_layer != 'rprelu':
338 | if act_layer != 'None':
339 | self.act = act_layer()
340 | else:
341 | self.act = nn.Identity()
342 |
343 |
344 | self.drop1 = nn.Dropout(drop_probs[0])
345 | self.fc2 = LSQ_w_and_act_QLinear(m = m.fc2,weight_bits=weight_bits,input_bits=input_bits,
346 | aq_learnable=aq_learnable,wq_learnable=wq_learnable,symmetric=False,weight_channelwise=weight_channelwise,input_channelwise=input_channelwise,
347 | weight_quant_method=weight_quant_method,input_quant_method=input_quant_method, pretrained_initialized = pretrained_initialized)
348 | self.drop2 = nn.Dropout(drop_probs[1])
349 |
350 | def forward(self, x):
351 |
352 | x = self.fc1(x)
353 | if self.act_layer != 'rprelu':
354 | x = self.act(x)
355 | else:
356 | x = self.move1(x)
357 | x = self.act(x)
358 | x = self.move2(x)
359 |
360 | x = self.drop1(x)
361 | x = self.fc2(x)
362 | x = self.drop2(x)
363 | return x
364 |
365 |
366 |
367 |
368 |
369 |
--------------------------------------------------------------------------------
/src/quantization/modules/utils.py:
--------------------------------------------------------------------------------
1 | from email.policy import default
2 | import os
3 | import torch
4 | # from .conv import QConv2d, QConvBn2d
5 | # from .linear import QLinear
6 | from .qlinear import LSQ_QConv2d
7 | from .qlinear import QLinear, QMLP, LSQ_w_and_act_QMLP, LSQ_w_and_act_QLinear, LSQ_QLinear4head
8 | from .attention import QAttention, QAttention_lsq
9 | from .attention import QAttention_qkreparam_4_cga
10 | from .attention import QAttention_qkreparam
11 |
12 | # from src.utils import Attention
13 | from src.deit_vision_transformer import Attention as deit_attention
14 | from src.deit_vision_transformer import Mlp
15 | from src.swin import ShiftedWindowAttention
16 | from torchvision.ops.misc import MLP as swin_MLP
17 |
18 | from .swin_attention_and_mlp import QAttention_swin, QMLP_swin, QAttention_swin_qkreparam, QAttention_swin_qkreparam_4_cga
19 |
20 |
21 | QMODULE_MAPPINGS = {
22 | torch.nn.Linear: QLinear,
23 | deit_attention: QAttention,
24 | Mlp: QMLP
25 | }
26 |
27 | ## 0: QAttention_qkreparam, 1: QAttention_qkreparam_4_cga
28 | QMODULE_MAPPINGS_QK_REPARAM = [
29 | {
30 | torch.nn.Linear: QLinear,
31 | deit_attention: QAttention_qkreparam,
32 | Mlp: QMLP
33 | },
34 | {
35 | torch.nn.Linear: QLinear,
36 | deit_attention: QAttention_qkreparam_4_cga,
37 | Mlp: QMLP
38 | }
39 | ]
40 | QMODULE_MAPPINGS_W_AND_ACT = {
41 | torch.nn.Linear: LSQ_w_and_act_QLinear,
42 | deit_attention: QAttention_lsq,
43 | Mlp: LSQ_w_and_act_QMLP
44 | }
45 | def get_module_by_name(model, module_name):
46 | names = module_name.split(".")
47 | module = model
48 | for name in names:
49 | module = getattr(module, name)
50 | return module
51 |
52 |
53 | def set_module_by_name(model, module_name, module):
54 | if module_name == 'head' or module_name == 'head_dist':
55 | setattr(model, module_name, module)
56 | else:
57 | names = module_name.split(".")
58 | parent = get_module_by_name(model, ".".join(names[:-1]))
59 | setattr(parent, names[-1], module)
60 |
61 |
62 | def replace_module_by_qmodule_deit(model, qconfigs, pretrained_initialized = False,
63 | qk_reparam = False, qk_reparam_type = 0, boundaryRange = 0.005):
64 |
65 | if qconfigs[list(qconfigs.keys())[0]]["weight"]["mode"] == 'lsq' and qconfigs[list(qconfigs.keys())[0]]["act"]["mode"] == 'lsq':
66 |
67 | for name, cfg in qconfigs.items():
68 | if name == "patch_embed.proj":
69 | module = get_module_by_name(model, name)
70 |
71 | qmodule = LSQ_QConv2d(
72 | m = module,
73 | weight_bits = 8,
74 | input_bits = 8,
75 | weight_channelwise = True,
76 | input_channelwise = True,
77 | weight_quant_method = 'lsq',
78 | input_quant_method = 'lsq',
79 | aq_learnable = True, ## act
80 | wq_learnable = True,## weight
81 | act_layer = cfg["act_layer"],
82 | pretrained_initialized = pretrained_initialized
83 | )
84 | set_module_by_name(model, name, qmodule)
85 | elif name == "head" or name == "head_dist":
86 | module = get_module_by_name(model, name)
87 | qmodule = LSQ_QLinear4head(
88 | m = module,
89 | weight_bits = 8,
90 | input_bits = 8,
91 | weight_channelwise = True,
92 | input_channelwise = True,
93 | weight_quant_method = 'lsq',
94 | input_quant_method = 'lsq',
95 | aq_learnable = True, ## act
96 | wq_learnable = True,## weight
97 | symmetric = True, ## ac
98 | act_layer = cfg["act_layer"],
99 | pretrained_initialized = pretrained_initialized
100 | )
101 | set_module_by_name(model, name, qmodule)
102 | else:
103 | module = get_module_by_name(model, name)
104 |
105 | qmodule = QMODULE_MAPPINGS_W_AND_ACT[type(module)](
106 | m = module,
107 | weight_bits = cfg["weight"]['bit'],
108 | input_bits = cfg["act"]['bit'],
109 | weight_channelwise = cfg["weight"]["per_channel"],
110 | input_channelwise = cfg["act"]["per_channel"],
111 | weight_quant_method = cfg["weight"]["mode"],
112 | input_quant_method = cfg["act"]["mode"],
113 | aq_learnable = cfg["act"]["learnable"], ## act
114 | wq_learnable = cfg["weight"]["learnable"],## weight
115 | symmetric = not cfg["act"]['all_positive'], ## ac
116 | act_layer = cfg["act_layer"],
117 | pretrained_initialized = pretrained_initialized
118 | )
119 | set_module_by_name(model, name, qmodule)
120 |
121 | elif qk_reparam:
122 | if qk_reparam_type == 0:
123 | for name, cfg in qconfigs.items():
124 | if name == "patch_embed.proj":
125 | module = get_module_by_name(model, name)
126 | qmodule = LSQ_QConv2d(
127 | m = module,
128 | weight_bits = 8,
129 | input_bits = 8,
130 | weight_channelwise = True,
131 | input_channelwise = True,
132 | weight_quant_method = 'lsq',
133 | input_quant_method = 'lsq',
134 | aq_learnable = True, ## act
135 | wq_learnable = True,## weight
136 | act_layer = cfg["act_layer"],
137 | pretrained_initialized = pretrained_initialized
138 | )
139 | set_module_by_name(model, name, qmodule)
140 | elif name == "head" or name == "head_dist":
141 | module = get_module_by_name(model, name)
142 | qmodule = LSQ_QLinear4head(
143 | m = module,
144 | weight_bits = 8,
145 | input_bits = 8,
146 | weight_channelwise = True,
147 | input_channelwise = True,
148 | weight_quant_method = 'lsq',
149 | input_quant_method = 'lsq',
150 | aq_learnable = True, ## act
151 | wq_learnable = True,## weight
152 | symmetric = True, ## ac
153 | act_layer = cfg["act_layer"],
154 | pretrained_initialized = pretrained_initialized
155 | )
156 | set_module_by_name(model, name, qmodule)
157 | else:
158 | module = get_module_by_name(model, name)
159 | qmodule = QMODULE_MAPPINGS_QK_REPARAM[qk_reparam_type][type(module)](
160 | m = module,
161 | weight_bits = cfg["weight"]['bit'],
162 | input_bits = cfg["act"]['bit'],
163 | weight_channelwise = cfg["weight"]["per_channel"],
164 | input_channelwise = cfg["act"]["per_channel"],
165 | weight_quant_method = cfg["weight"]["mode"],
166 | input_quant_method = cfg["act"]["mode"],
167 | aq_learnable = cfg["act"]["learnable"], ## act
168 | wq_learnable = cfg["weight"]["learnable"],## weight
169 | # symmetric = not cfg["act"]['all_positive'], ## ac
170 | act_layer = cfg["act_layer"],
171 | pretrained_initialized = pretrained_initialized
172 | )
173 | set_module_by_name(model, name, qmodule)
174 | elif qk_reparam_type == 1:
175 | for name, cfg in qconfigs.items():
176 | if name == "patch_embed.proj":
177 | module = get_module_by_name(model, name)
178 | qmodule = LSQ_QConv2d(
179 | m = module,
180 | weight_bits = 8,
181 | input_bits = 8,
182 | weight_channelwise = True,
183 | input_channelwise = True,
184 | weight_quant_method = 'lsq',
185 | input_quant_method = 'lsq',
186 | aq_learnable = True, ## act
187 | wq_learnable = True,## weight
188 | act_layer = cfg["act_layer"],
189 | pretrained_initialized = pretrained_initialized
190 | )
191 | set_module_by_name(model, name, qmodule)
192 | elif name == "head" or name == "head_dist":
193 | module = get_module_by_name(model, name)
194 | qmodule = LSQ_QLinear4head(
195 | m = module,
196 | weight_bits = 8,
197 | input_bits = 8,
198 | weight_channelwise = True,
199 | input_channelwise = True,
200 | weight_quant_method = 'lsq',
201 | input_quant_method = 'lsq',
202 | aq_learnable = True, ## act
203 | wq_learnable = True,## weight
204 | symmetric = True, ## ac
205 | act_layer = cfg["act_layer"],
206 | pretrained_initialized = pretrained_initialized
207 | )
208 | set_module_by_name(model, name, qmodule)
209 | else:
210 | module = get_module_by_name(model, name)
211 | qmodule = QMODULE_MAPPINGS_QK_REPARAM[qk_reparam_type][type(module)](
212 | m = module,
213 | weight_bits = cfg["weight"]['bit'],
214 | input_bits = cfg["act"]['bit'],
215 | weight_channelwise = cfg["weight"]["per_channel"],
216 | input_channelwise = cfg["act"]["per_channel"],
217 | weight_quant_method = cfg["weight"]["mode"],
218 | input_quant_method = cfg["act"]["mode"],
219 | aq_learnable = cfg["act"]["learnable"], ## act
220 | wq_learnable = cfg["weight"]["learnable"],## weight
221 | act_layer = cfg["act_layer"],
222 | pretrained_initialized = pretrained_initialized,
223 | boundaryRange = boundaryRange
224 | )
225 | set_module_by_name(model, name, qmodule)
226 |
227 | else: ## statsq w quant
228 | for name, cfg in qconfigs.items():
229 | if name == "patch_embed.proj":
230 | module = get_module_by_name(model, name)
231 |
232 | qmodule = LSQ_QConv2d(
233 | m = module,
234 | weight_bits = 8,
235 | input_bits = 8,
236 | weight_channelwise = True,
237 | input_channelwise = True,
238 | weight_quant_method = 'lsq',
239 | input_quant_method = 'lsq',
240 | aq_learnable = True, ## act
241 | wq_learnable = True,## weight
242 | act_layer = cfg["act_layer"],
243 | pretrained_initialized = pretrained_initialized
244 | )
245 | set_module_by_name(model, name, qmodule)
246 | elif name == "head" or name == "head_dist":
247 | module = get_module_by_name(model, name)
248 | qmodule = LSQ_QLinear4head(
249 | m = module,
250 | weight_bits = 8,
251 | input_bits = 8,
252 | weight_channelwise = True,
253 | input_channelwise = True,
254 | weight_quant_method = 'lsq',
255 | input_quant_method = 'lsq',
256 | aq_learnable = True, ## act
257 | wq_learnable = True,## weight
258 | symmetric = True, ## ac
259 | act_layer = cfg["act_layer"],
260 | pretrained_initialized = pretrained_initialized
261 | )
262 | set_module_by_name(model, name, qmodule)
263 | else:
264 | module = get_module_by_name(model, name)
265 |
266 | qmodule = QMODULE_MAPPINGS[type(module)](
267 | m = module,
268 | weight_bits = cfg["weight"]['bit'],
269 | input_bits = cfg["act"]['bit'],
270 | weight_channelwise = cfg["weight"]["per_channel"],
271 | input_channelwise = cfg["act"]["per_channel"],
272 | weight_quant_method = cfg["weight"]["mode"],
273 | input_quant_method = cfg["act"]["mode"],
274 | aq_learnable = cfg["act"]["learnable"], ## act
275 | wq_learnable = cfg["weight"]["learnable"],## weight
276 | # symmetric = not cfg["act"]['all_positive'], ## ac
277 | act_layer = cfg["act_layer"],
278 | pretrained_initialized = pretrained_initialized
279 | )
280 | set_module_by_name(model, name, qmodule)
281 |
282 | return model
283 |
284 |
285 |
286 | QMODULE_MAPPINGS_SWIN = {
287 | torch.nn.Linear: QLinear,
288 | ShiftedWindowAttention: QAttention_swin,
289 | swin_MLP: QMLP_swin
290 | }
291 |
292 | QMODULE_MAPPINGS_QK_REPARAM_SWIN = [
293 | {
294 | torch.nn.Linear: QLinear,
295 | ShiftedWindowAttention: QAttention_swin_qkreparam,
296 | swin_MLP: QMLP_swin
297 | },
298 | {
299 | torch.nn.Linear: QLinear,
300 | ShiftedWindowAttention: QAttention_swin_qkreparam_4_cga,
301 | swin_MLP: QMLP_swin
302 | }
303 | ]
304 |
305 | def replace_module_by_qmodule_swin(model, qconfigs, pretrained_initialized = False, qk_reparam = False, qk_reparam_type = 0, boundaryRange = 0.005):
306 | if qk_reparam:
307 | for name, cfg in qconfigs.items():
308 | if name == "features.0.0":
309 | module = get_module_by_name(model, name)
310 | qmodule = LSQ_QConv2d(
311 | m = module,
312 | weight_bits = 8,
313 | input_bits = 8,
314 | weight_channelwise = True,
315 | input_channelwise = True,
316 | weight_quant_method = 'lsq',
317 | input_quant_method = 'lsq',
318 | aq_learnable = True, ## act
319 | wq_learnable = True,## weight
320 | act_layer = cfg["act_layer"],
321 | pretrained_initialized = pretrained_initialized
322 | )
323 | set_module_by_name(model, name, qmodule)
324 | elif name == "head":
325 | module = get_module_by_name(model, name)
326 | qmodule = LSQ_QLinear4head(
327 | m = module,
328 | weight_bits = 8,
329 | input_bits = 8,
330 | weight_channelwise = True,
331 | input_channelwise = True,
332 | weight_quant_method = 'lsq',
333 | input_quant_method = 'lsq',
334 | aq_learnable = True, ## act
335 | wq_learnable = True,## weight
336 | symmetric = True, ## ac
337 | act_layer = cfg["act_layer"],
338 | pretrained_initialized = pretrained_initialized
339 | )
340 | set_module_by_name(model, name, qmodule)
341 | else:
342 | module = get_module_by_name(model, name)
343 | qmodule = QMODULE_MAPPINGS_QK_REPARAM_SWIN[qk_reparam_type][type(module)](
344 | m = module,
345 | weight_bits = cfg["weight"]['bit'],
346 | input_bits = cfg["act"]['bit'],
347 | weight_channelwise = cfg["weight"]["per_channel"],
348 | input_channelwise = cfg["act"]["per_channel"],
349 | weight_quant_method = cfg["weight"]["mode"],
350 | input_quant_method = cfg["act"]["mode"],
351 | aq_learnable = cfg["act"]["learnable"], ## act
352 | wq_learnable = cfg["weight"]["learnable"],## weight
353 | # symmetric = not cfg["act"]['all_positive'], ## ac
354 | act_layer = cfg["act_layer"],
355 | pretrained_initialized = pretrained_initialized
356 | )
357 | set_module_by_name(model, name, qmodule)
358 |
359 | else: ## statsq w quant
360 | for name, cfg in qconfigs.items():
361 | if name == "features.0.0":
362 | module = get_module_by_name(model, name)
363 | qmodule = LSQ_QConv2d(
364 | m = module,
365 | weight_bits = 8,
366 | input_bits = 8,
367 | weight_channelwise = True,
368 | input_channelwise = True,
369 | weight_quant_method = 'lsq',
370 | input_quant_method = 'lsq',
371 | aq_learnable = True, ## act
372 | wq_learnable = True,## weight
373 | act_layer = cfg["act_layer"],
374 | pretrained_initialized = pretrained_initialized
375 | )
376 | set_module_by_name(model, name, qmodule)
377 | elif name == "head":
378 | module = get_module_by_name(model, name)
379 | qmodule = LSQ_QLinear4head(
380 | m = module,
381 | weight_bits = 8,
382 | input_bits = 8,
383 | weight_channelwise = True,
384 | input_channelwise = True,
385 | weight_quant_method = 'lsq',
386 | input_quant_method = 'lsq',
387 | aq_learnable = True, ## act
388 | wq_learnable = True,## weight
389 | symmetric = True, ## ac
390 | act_layer = cfg["act_layer"],
391 | pretrained_initialized = pretrained_initialized
392 | )
393 | set_module_by_name(model, name, qmodule)
394 | else:
395 | module = get_module_by_name(model, name)
396 |
397 | qmodule = QMODULE_MAPPINGS_SWIN[type(module)](
398 | m = module,
399 | weight_bits = cfg["weight"]['bit'],
400 | input_bits = cfg["act"]['bit'],
401 | weight_channelwise = cfg["weight"]["per_channel"],
402 | input_channelwise = cfg["act"]["per_channel"],
403 | weight_quant_method = cfg["weight"]["mode"],
404 | input_quant_method = cfg["act"]["mode"],
405 | aq_learnable = cfg["act"]["learnable"], ## act
406 | wq_learnable = cfg["weight"]["learnable"],## weight
407 | # symmetric = not cfg["act"]['all_positive'], ## ac
408 | act_layer = cfg["act_layer"],
409 | pretrained_initialized = pretrained_initialized
410 | )
411 | set_module_by_name(model, name, qmodule)
412 |
413 | return model
414 |
--------------------------------------------------------------------------------
/src/quantization/quantizer/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | from .lsq import LsqQuantizer, LsqQuantizerWeight
4 | from .statsq import StatsQuantizer
5 |
6 |
--------------------------------------------------------------------------------
/src/quantization/quantizer/statsq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import numpy as np
5 |
6 | ## create 1D mask
7 | def create_mask(s2, prob):
8 | raw = torch.zeros((s2,))
9 | raw[:int((1-prob) * s2)] = 1.0/(1.0-prob) # set EXACTLY 30% of the pixels in the mask
10 | ridx = torch.randperm(s2) # a random permutation of the entries
11 | return raw[ridx]
12 |
13 | def round_pass(x):
14 | y = x.round()
15 | y_grad = x
16 | return (y - y_grad).detach() + y_grad
17 |
18 | def grad_scale(x, scale):
19 | y = x
20 | y_grad = x * scale
21 | return (y - y_grad).detach() + y_grad
22 |
23 | def clip(x, eps):
24 | x_clip = torch.where(x > eps, x, eps)
25 | return x - x.detach() + x_clip.detach()
26 |
27 | def modify_grad(x, freeze_inds):
28 | x = x * freeze_inds * 0 + x * (1-freeze_inds)
29 | return x
30 |
31 |
32 | class TrackOscillation(nn.Module):
33 | """
34 | This is a wrapper of the int_forward function of a quantizer.
35 | It tracks the oscillations in integer domain.
36 | """
37 |
38 | def __init__(self, momentum=0.01, freeze_threshold=0, use_ema_x_int=True):
39 | super(TrackOscillation, self).__init__()
40 | self.momentum = momentum
41 |
42 | self.prev_x_int = None
43 | self.prev_switch_dir = None
44 |
45 | # Statistics to log
46 | self.ema_oscillation = None
47 | self.oscillated_sum = None
48 | self.total_oscillation = None
49 | self.iters_since_reset = 0
50 |
51 | # Extra variables for weight freezing
52 | self.freeze_threshold = freeze_threshold # This should be at least 2-3x the momentum value.
53 | self.use_ema_x_int = use_ema_x_int
54 | self.frozen = None
55 | self.frozen_x_int = None
56 | self.ema_x_int = None
57 |
58 | def __call__(self, x_int, skip_tracking=False, *args, **kwargs):
59 |
60 | # Apply weight freezing
61 | if self.frozen is not None:
62 | x_int = ~self.frozen * x_int + self.frozen * self.frozen_x_int
63 |
64 | if skip_tracking:
65 | return x_int
66 |
67 | with torch.no_grad():
68 | # Check if everything is correctly initialized, otherwise do so
69 | self.check_init(x_int)
70 |
71 | # detect difference in x_int NB we round to avoid int inaccuracies
72 | delta_x_int = torch.round(self.prev_x_int - x_int).detach() # should be {-1, 0, 1}
73 | switch_dir = torch.sign(delta_x_int) # This is {-1, 0, 1} as sign(0) is mapped to 0
74 | # binary mask for switching
75 | switched = delta_x_int != 0
76 |
77 | oscillated = (self.prev_switch_dir * switch_dir) == -1
78 | self.ema_oscillation = (
79 | self.momentum * oscillated + (1 - self.momentum) * self.ema_oscillation
80 | )
81 |
82 | # Update prev_switch_dir for the switch variables
83 | self.prev_switch_dir[switched] = switch_dir[switched]
84 | self.prev_x_int = x_int
85 | self.oscillated_sum = oscillated.sum()
86 | self.total_oscillation += oscillated
87 | self.iters_since_reset += 1
88 |
89 | # Freeze some weights
90 | if self.freeze_threshold > 0:
91 | freeze_weights = self.ema_oscillation > self.freeze_threshold
92 | self.frozen[freeze_weights] = True # Set them to frozen
93 | if self.use_ema_x_int:
94 | self.frozen_x_int[freeze_weights] = torch.round(self.ema_x_int[freeze_weights])
95 | # Update x_int EMA which can be used for freezing
96 | self.ema_x_int = self.momentum * x_int + (1 - self.momentum) * self.ema_x_int
97 | else:
98 | self.frozen_x_int[freeze_weights] = x_int[freeze_weights]
99 |
100 | return x_int
101 |
102 | def check_init(self, x_int):
103 | if self.prev_x_int is None:
104 | # Init prev switch dir to 0
105 | self.prev_switch_dir = torch.zeros_like(x_int)
106 | self.prev_x_int = x_int.detach() # Not sure if needed, don't think so
107 | self.ema_oscillation = torch.zeros_like(x_int)
108 | self.oscillated_sum = 0
109 | self.total_oscillation = torch.zeros_like(x_int)
110 | else:
111 | assert (
112 | self.prev_x_int.shape == x_int.shape
113 | ), "Tracking shape does not match current tensor shape."
114 |
115 | # For weight freezing
116 | if self.frozen is None and self.freeze_threshold > 0:
117 | self.frozen = torch.zeros_like(x_int, dtype=torch.bool)
118 | self.frozen_x_int = torch.zeros_like(x_int)
119 | if self.use_ema_x_int:
120 | self.ema_x_int = x_int.detach().clone()
121 |
122 | class StatsQuantizer(nn.Module):
123 | def __init__(self, num_bits, clip_learnable):
124 | super(StatsQuantizer, self).__init__()
125 | self.num_bits = num_bits
126 | init_act_clip_val = 2.0
127 |
128 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False)
129 |
130 | self.s = None
131 |
132 |
133 | def forward(self, weight):
134 |
135 | real_weights = weight
136 |
137 | if len(weight.shape) == 2:
138 | scaling_factor = 2 * torch.mean(abs(real_weights),dim=1,keepdim=True) # dim, 1
139 | elif len(weight.shape) == 3:
140 | scaling_factor = 2 * torch.mean(torch.mean(abs(real_weights),dim=-1,keepdim=True),dim=0,keepdim=True) # 1, dim, 1
141 |
142 | scaling_factor = scaling_factor.detach()
143 | self.s = scaling_factor.squeeze().cpu()
144 | scaled_weights = real_weights/scaling_factor
145 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6)
146 | n = float(2 ** (self.num_bits - 1))
147 | quan_weights_no_grad = scaling_factor * ((torch.round((cliped_weights) * n - 0.5 ) + 0.5) / n)
148 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights
149 |
150 | return quan_weights
151 |
152 |
153 |
154 | class StatsQuantizer_specific_4_qkreparam_cga(nn.Module):
155 | def __init__(self, num_bits, clip_learnable, boundaryRange = 0.005):
156 | super(StatsQuantizer_specific_4_qkreparam_cga, self).__init__()
157 |
158 | self.num_bits = num_bits
159 | init_act_clip_val = 2.0
160 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False)
161 | self.s = None
162 | self.boundaryRange = boundaryRange
163 |
164 | def forward(self, weight):
165 |
166 | real_weights = weight
167 |
168 | if len(weight.shape) == 2:
169 | scaling_factor = 2 * torch.mean(abs(real_weights),dim=1,keepdim=True) # dim, 1
170 | elif len(weight.shape) == 3:
171 | scaling_factor = 2 * torch.mean(torch.mean(abs(real_weights),dim=-1,keepdim=True),dim=0,keepdim=True) # 1, dim, 1
172 |
173 |
174 | scaling_factor = scaling_factor.detach()
175 | self.s = scaling_factor.squeeze().cpu()
176 | scaled_weights = real_weights/scaling_factor
177 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6)
178 | n = float(2 ** (self.num_bits - 1))
179 | b4_round = (cliped_weights) * n - 0.5
180 |
181 | if self.training:
182 | not_freeze_idx = torch.zeros_like(real_weights).cuda()
183 | for i in np.arange(start=-(2**(self.num_bits - 1)),stop=(2**(self.num_bits - 1) - 1)): # 0.5 - boundaryRange < x < 0.5 + boundaryRange
184 | within_boundary = ((b4_round - i) <= (0.5 + self.boundaryRange)) * ((b4_round - i) >= (0.5 - self.boundaryRange)) #idx of # 0.5 - boundaryRange < x < 0.5 + boundaryRange
185 | not_freeze_idx = not_freeze_idx + within_boundary.float()
186 |
187 | freeze_idx = 1.0-not_freeze_idx
188 | b4_round = b4_round.detach() * freeze_idx + b4_round* (1-freeze_idx)
189 |
190 | quan_weights_no_grad = scaling_factor * ((torch.round( b4_round ) + 0.5) / n)
191 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights
192 |
193 | return quan_weights
194 |
195 |
196 | class StatsQuantizer_4d(nn.Module): # B, num_heads, N, in_features
197 | def __init__(self, num_bits, clip_learnable):
198 | super(StatsQuantizer_4d, self).__init__()
199 |
200 | self.num_bits = num_bits
201 | init_act_clip_val = 2.0
202 | self.clip_val = nn.Parameter(torch.Tensor([init_act_clip_val]), requires_grad=False)
203 | self.s = None
204 |
205 | def forward(self, weight):
206 |
207 | real_weights = weight
208 |
209 | scaling_factor = 2 * torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=1,keepdim=True),dim=0,keepdim=True)
210 |
211 | scaling_factor = scaling_factor.detach()
212 | self.s = scaling_factor.squeeze().cpu()
213 | scaled_weights = real_weights/scaling_factor
214 | cliped_weights = torch.clamp(scaled_weights, min=(-self.clip_val/2), max=(self.clip_val/2)-1e-6)
215 | n = float(2 ** (self.num_bits - 1))
216 | quan_weights_no_grad = scaling_factor * ((torch.round((cliped_weights) * n - 0.5 ) + 0.5) / n)
217 | quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights
218 |
219 | return quan_weights
220 |
221 |
222 |
223 |
224 |
225 |
--------------------------------------------------------------------------------
/src/quantization/utils.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .modules.qlinear import QLinear, LSQ_w_and_act_QLinear
6 | from timm.loss import SoftTargetCrossEntropy
7 |
8 | def unitwise_norm(x, norm_type=2.0):
9 | if x.ndim <= 1:
10 | return x.norm(norm_type)
11 | else:
12 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
13 |
14 |
15 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
16 | if isinstance(parameters, torch.Tensor):
17 | parameters = [parameters]
18 | for p in parameters:
19 | if p.grad is None:
20 | continue
21 | p_data = p.detach()
22 | g_data = p.grad.detach()
23 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
24 | grad_norm = unitwise_norm(g_data, norm_type=norm_type)
25 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
26 | new_grad = torch.where(grad_norm < max_norm, g_data, clipped_grad)
27 | p.grad.detach().copy_(new_grad)
28 |
29 | class Multi_KLLossSoft(torch.nn.modules.loss._Loss):
30 | def forward(self, output, target, T=1.0):
31 | output = output[0] if isinstance(output, tuple) else output
32 | target = target[0] if isinstance(target, tuple) else target
33 | output, target = output / T, target / T
34 | target_prob = F.softmax(target, dim=1)
35 | output_log_prob = F.log_softmax(output, dim=1)
36 | loss = - torch.sum(target_prob * output_log_prob, dim=1)
37 | if self.reduction == "mean":
38 | return loss.mean()
39 | elif self.reduction == "sum":
40 | return loss.sum()
41 | else:
42 | return loss
43 |
44 | class KLLossSoft(torch.nn.modules.loss._Loss):
45 | def forward(self, output, target, T=1.0):
46 | output = output[0] if isinstance(output, tuple) else output
47 | target = target[0] if isinstance(target, tuple) else target
48 | output, target = output / T, target / T
49 | target_prob = F.softmax(target, dim=1)
50 | output_log_prob = F.log_softmax(output, dim=1)
51 | loss = - torch.sum(target_prob * output_log_prob, dim=1)
52 | if self.reduction == "mean":
53 | return loss.mean()
54 | elif self.reduction == "sum":
55 | return loss.sum()
56 | else:
57 | return loss
58 |
59 | class KDLossSoftandHard(torch.nn.Module):
60 | def __init__(self) -> None:
61 | super().__init__()
62 | self.KLSoft = KLLossSoft()
63 | self.Hard = nn.CrossEntropyLoss()
64 |
65 | def forward(self, output, hard_target, soft_target):
66 |
67 | if isinstance(output, tuple):
68 | cls_output = output[0]
69 | dist_output = output[1]
70 | soft_loss = self.KLSoft(dist_output,soft_target)
71 | hard_loss = self.Hard(cls_output, hard_target)
72 |
73 | else:
74 | soft_loss = self.KLSoft(output,soft_target)
75 | hard_loss = self.Hard(output, hard_target)
76 |
77 | return soft_loss + hard_loss
78 |
79 | def dampening_loss(w_fp, w_q, x_min, x_max):
80 | # L &= (s*w_{int} - w)^2
81 | # We also need to add clipping for both cases, we can do so by using the forward
82 | # this is also clipped and our target
83 | # clamp w in FP32 domain to not change range learning (min(max) is needed for per-channel)
84 | w_fp_clip = torch.min(torch.max(w_fp, x_min), x_max)
85 | loss = (w_q - w_fp_clip) ** 2
86 |
87 | return loss.sum()
88 |
89 | class DampeningLoss(torch.nn.Module):
90 | def __init__(self, weighting=1.0, weight_quant_method = 'nu2u') -> None:
91 | super().__init__()
92 | """
93 | Calculates the dampening loss for all weights in a given quantized model. It is
94 | expected that all quantized weights are in a Hijacker module.
95 |
96 | """
97 | self.weighting = weighting
98 | self.weight_quant_method = weight_quant_method
99 |
100 | def forward(self, model):
101 | total_bin_loss = 0
102 | for name, module in model.named_modules():
103 | if isinstance(module, QLinear) or isinstance(module, LSQ_w_and_act_QLinear) :
104 | # print(name,"calculate dampening loss")
105 | # FP32 weight tensor, potential folded but before quantization
106 | weight = module.weight
107 | # The matching weight quantizer (not manager, direct quantizer class)
108 | if self.weight_quant_method == 'lsq':
109 | weight_q = module.lsqw_fn(weight).detach()
110 | weight_q_min = (module.lsqw_fn.thd_neg * module.lsqw_fn.s).unsqueeze(dim=-1)
111 | weight_q_max = (module.lsqw_fn.thd_pos * module.lsqw_fn.s).unsqueeze(dim=-1)
112 |
113 | elif self.weight_quant_method == 'nu2u':
114 | weight_q = module.nu2u_fn(weight).detach()
115 | weight_q_min, _ = torch.min(weight_q, 1)
116 | weight_q_min = weight_q_min.unsqueeze(dim=-1)
117 | weight_q_max, _ = torch.max(weight_q, 1)
118 | weight_q_max = weight_q_max.unsqueeze(dim=-1)
119 |
120 | total_bin_loss += dampening_loss(weight, weight_q, weight_q_min, weight_q_max)
121 | return total_bin_loss * self.weighting
122 |
123 | class KDLossSoftandHard_dampening(torch.nn.Module):
124 | def __init__(self, weight_quant_method) -> None:
125 | super().__init__()
126 | self.KLSoft = KLLossSoft()
127 | self.Hard = nn.CrossEntropyLoss()
128 | self.dampening_loss = DampeningLoss(weighting=0, weight_quant_method=weight_quant_method)
129 |
130 | def forward(self, output, hard_target, soft_target, model):
131 |
132 | if isinstance(output, tuple):
133 | cls_output = output[0]
134 | dist_output = output[1]
135 | soft_loss = self.KLSoft(dist_output,soft_target)
136 | hard_loss = self.Hard(cls_output, hard_target)
137 |
138 | else:
139 | soft_loss = self.KLSoft(output,soft_target)
140 | hard_loss = self.Hard(output, hard_target)
141 |
142 | dampening_loss = self.dampening_loss(model) ## as for now only works for LSQ_QLinea AND LSQ_w_and_act_QLinear
143 |
144 | return soft_loss + hard_loss + dampening_loss
145 |
146 | class KDLossSoftandSoftTargetCE(torch.nn.Module):
147 | def __init__(self) -> None:
148 | super().__init__()
149 | self.KLSoft = KLLossSoft()
150 | self.Hard = SoftTargetCrossEntropy()
151 |
152 | def forward(self, output, hard_target, soft_target):
153 |
154 | if isinstance(output, tuple):
155 | cls_output = output[0]
156 | dist_output = output[1]
157 | soft_loss = self.KLSoft(dist_output,soft_target)
158 | hard_loss = self.Hard(cls_output, hard_target)
159 |
160 | else:
161 | soft_loss = self.KLSoft(output,soft_target)
162 | hard_loss = self.Hard(output, hard_target)
163 |
164 | return soft_loss + hard_loss
165 |
166 | def att_loss_r2b(Q_s, Q_t):
167 | Q_s_norm = Q_s / torch.norm(Q_s, p=2)
168 | Q_t_norm = Q_t / torch.norm(Q_t, p=2)
169 | tmp = Q_s_norm - Q_t_norm
170 | loss = torch.norm(tmp, p=2)
171 | return loss
172 |
173 | def direction_matching_distillation(student_scores, teacher_scores):
174 | tmp_loss = 0.
175 | # new_teacher_scores = [teacher_scores[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)]
176 | for student_score, teacher_score in zip(student_scores, teacher_scores):
177 | student_score = torch.where(student_score <= -1e2,
178 | torch.zeros_like(student_score).to(student_scores[0].device),
179 | student_score)
180 | teacher_score = torch.where(teacher_score <= -1e2,
181 | torch.zeros_like(teacher_score).to(student_scores[0].device),
182 | teacher_score)
183 | tmp_loss += att_loss_r2b(student_score, teacher_score)
184 | return tmp_loss
185 |
186 | class KDLossSoftandHard_qk(torch.nn.Module):
187 | def __init__(self) -> None:
188 | super().__init__()
189 | self.KLSoft = KLLossSoft()
190 | self.Hard = nn.CrossEntropyLoss()
191 |
192 | def forward(self, student_logit, student_attn_info ,target, teacher_logit, teacher_attn_info):
193 |
194 | student_q = []
195 | teacher_q = []
196 | student_k = []
197 | teacher_k = []
198 | for layer in range(len(student_attn_info)):
199 | student_q.append(student_attn_info[layer][1])
200 | teacher_q.append(teacher_attn_info[layer][1])
201 | student_k.append(student_attn_info[layer][2])
202 | teacher_k.append(teacher_attn_info[layer][2])
203 |
204 | if isinstance(student_logit, tuple):
205 | cls_output = student_logit[0]
206 | dist_output = student_logit[1]
207 | soft_loss = self.KLSoft(dist_output, teacher_logit)
208 | hard_loss = self.Hard(cls_output, target)
209 |
210 | else:
211 | soft_loss = self.KLSoft(student_logit,teacher_logit)
212 | hard_loss = self.Hard(student_logit, target)
213 |
214 |
215 | q_loss = direction_matching_distillation(student_q, teacher_q)
216 | k_loss = direction_matching_distillation(student_k, teacher_k)
217 |
218 |
219 | return soft_loss + hard_loss + q_loss + k_loss
220 |
221 | class KDLossSoftandHard_qkv(torch.nn.Module):
222 | def __init__(self) -> None:
223 | super().__init__()
224 | self.KLSoft = KLLossSoft()
225 | self.Hard = nn.CrossEntropyLoss()
226 |
227 | def forward(self, student_logit, student_attn_info ,target, teacher_logit, teacher_attn_info):
228 |
229 | student_q = []
230 | teacher_q = []
231 | student_k = []
232 | teacher_k = []
233 | student_v = []
234 | teacher_v = []
235 | for layer in range(len(student_attn_info)):
236 | student_q.append(student_attn_info[layer][1])
237 | teacher_q.append(teacher_attn_info[layer][1])
238 | student_k.append(student_attn_info[layer][2])
239 | teacher_k.append(teacher_attn_info[layer][2])
240 | student_v.append(student_attn_info[layer][3])
241 | teacher_v.append(teacher_attn_info[layer][3])
242 |
243 | if isinstance(student_logit, tuple):
244 | cls_output = student_logit[0]
245 | dist_output = student_logit[1]
246 | soft_loss = self.KLSoft(dist_output, teacher_logit)
247 | hard_loss = self.Hard(cls_output, target)
248 |
249 | else:
250 | soft_loss = self.KLSoft(student_logit,teacher_logit)
251 | hard_loss = self.Hard(student_logit, target)
252 |
253 |
254 | q_loss = direction_matching_distillation(student_q, teacher_q)
255 | k_loss = direction_matching_distillation(student_k, teacher_k)
256 | v_loss = direction_matching_distillation(student_v, teacher_v)
257 |
258 | return soft_loss + hard_loss + q_loss + k_loss + v_loss
259 |
260 | class KLTokenMSELoss(torch.nn.Module):
261 | def __init__(
262 | self,
263 | alpha: float = 0.5,
264 | kd_type: str = "last",
265 | reduction: str = "mean",
266 | ):
267 | super().__init__()
268 | self.reduction = reduction
269 | self.alpha = alpha
270 | self.kl_loss = KLLossSoft(reduction=reduction)
271 | self.mse_loss = nn.MSELoss(reduction=reduction)
272 | self.kd_type = kd_type
273 |
274 | def _kl_loss(self, output, target):
275 | return self.kl_loss(output, target)
276 |
277 | def _mse_loss(self, output, target):
278 | mse_loss = 0
279 | if self.kd_type == "last":
280 | if isinstance(output, torch.Tensor):
281 | _, N, _ = target.size()
282 | mse_loss = self.mse_loss(output[:, -N:], target)
283 | else:
284 | _, N, _ = target[-1].size()
285 | mse_loss = self.mse_loss(output[-1][:, -N:], target[-1])
286 | elif self.kd_type == "all":
287 | if isinstance(output, torch.Tensor):
288 | _, N, _ = target.size()
289 | mse_loss = self.mse_loss(output[:, -N:], target)
290 | else:
291 | assert len(output) == len(target)
292 | for i in range(len(output)):
293 | _, N, _ = target[i].size()
294 | mse_loss += self.mse_loss(output[i][:, -N:], target[i])
295 | mse_loss = mse_loss / len(output)
296 | else:
297 | raise NotImplementedError
298 | return mse_loss
299 |
300 | def forward(self, output, target):
301 | assert len(output) == len(target)
302 |
303 | kl_loss = self.kl_loss(output[0], target[0])
304 | mse_loss = self._mse_loss(output[1], target[1])
305 | loss = kl_loss + self.alpha * mse_loss
306 | # print(f"KL loss {kl_loss}, MSE loss {mse_loss}, total loss {loss}")
307 |
308 | return loss
309 |
310 |
--------------------------------------------------------------------------------
/src/registry.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def register_model(func):
4 | """
5 | Fallback wrapper in case timm isn't installed
6 | """
7 | return func
8 |
--------------------------------------------------------------------------------
/src/swin.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, Callable, List, Any
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn, Tensor
7 |
8 | from torchvision.ops.misc import MLP, Permute
9 | from torchvision.ops.stochastic_depth import StochasticDepth
10 | from torchvision.transforms._presets import ImageClassification, InterpolationMode
11 | from torchvision.utils import _log_api_usage_once
12 | from torchvision.models._api import WeightsEnum, Weights
13 | from torchvision.models._meta import _IMAGENET_CATEGORIES
14 | from torchvision.models._utils import _ovewrite_named_param
15 | import math
16 |
17 | from timm.models.registry import register_model
18 |
19 | __all__ = [
20 | "SwinTransformer",
21 | "Swin_T_Weights",
22 | "swin_t"
23 | ]
24 |
25 |
26 | class PatchMerging(nn.Module):
27 | """Patch Merging Layer.
28 | Args:
29 | dim (int): Number of input channels.
30 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
31 | """
32 |
33 | def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
34 | super().__init__()
35 | _log_api_usage_once(self)
36 | self.dim = dim
37 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
38 | self.norm = norm_layer(4 * dim)
39 |
40 | def forward(self, x):
41 | """
42 | Args:
43 | x (Tensor): input tensor with expected layout of [..., H, W, C]
44 | Returns:
45 | Tensor with layout of [..., H/2, W/2, 2*C]
46 | """
47 | features_x, attn_x = x
48 | H, W, _ = features_x.shape[-3:]
49 | features_x = F.pad(features_x, (0, 0, 0, W % 2, 0, H % 2))
50 |
51 | features_x0 = features_x[..., 0::2, 0::2, :] # ... H/2 W/2 C
52 | features_x1 = features_x[..., 1::2, 0::2, :] # ... H/2 W/2 C
53 | features_x2 = features_x[..., 0::2, 1::2, :] # ... H/2 W/2 C
54 | features_x3 = features_x[..., 1::2, 1::2, :] # ... H/2 W/2 C
55 | features_x = torch.cat([features_x0, features_x1, features_x2, features_x3], -1) # ... H/2 W/2 4*C
56 |
57 | features_x = self.norm(features_x)
58 | features_x = self.reduction(features_x) # ... H/2 W/2 2*C
59 | return features_x, attn_x
60 |
61 |
62 | def shifted_window_attention(
63 | input: Tensor,
64 | qkv_weight: Tensor,
65 | proj_weight: Tensor,
66 | relative_position_bias: Tensor,
67 | window_size: List[int],
68 | num_heads: int,
69 | shift_size: List[int],
70 | attention_dropout: float = 0.0,
71 | dropout: float = 0.0,
72 | qkv_bias: Optional[Tensor] = None,
73 | proj_bias: Optional[Tensor] = None,
74 | qqkkvv: bool = False
75 | ):
76 | """
77 | Window based multi-head self attention (W-MSA) module with relative position bias.
78 | It supports both of shifted and non-shifted window.
79 | Args:
80 | input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
81 | qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
82 | proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
83 | relative_position_bias (Tensor): The learned relative position bias added to attention.
84 | window_size (List[int]): Window size.
85 | num_heads (int): Number of attention heads.
86 | shift_size (List[int]): Shift size for shifted window attention.
87 | attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
88 | dropout (float): Dropout ratio of output. Default: 0.0.
89 | qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
90 | proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
91 | Returns:
92 | Tensor[N, H, W, C]: The output tensor after shifted window attention.
93 | """
94 | B, H, W, C = input.shape
95 | # pad feature maps to multiples of window size
96 | pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
97 | pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
98 | x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
99 | _, pad_H, pad_W, _ = x.shape
100 |
101 | # If window size is larger than feature size, there is no need to shift window
102 | if window_size[0] >= pad_H:
103 | shift_size[0] = 0
104 | if window_size[1] >= pad_W:
105 | shift_size[1] = 0
106 |
107 | # cyclic shift
108 | if sum(shift_size) > 0:
109 | x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
110 |
111 | # partition windows
112 | num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
113 | x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
114 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
115 |
116 | # multi-head attention
117 | qkv = F.linear(x, qkv_weight, qkv_bias)
118 | qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
119 | q, k, v = qkv[0], qkv[1], qkv[2]
120 |
121 | attn = q.matmul(k.transpose(-2, -1)) * (C // num_heads) ** -0.5
122 | # add relative position bias
123 | attn = attn + relative_position_bias
124 |
125 | if sum(shift_size) > 0:
126 | # generate attention mask
127 | attn_mask = x.new_zeros((pad_H, pad_W))
128 | h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
129 | w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
130 | count = 0
131 | for h in h_slices:
132 | for w in w_slices:
133 | attn_mask[h[0] : h[1], w[0] : w[1]] = count
134 | count += 1
135 | attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
136 | attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
137 | attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
138 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
139 | attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
140 | attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
141 | attn = attn.view(-1, num_heads, x.size(1), x.size(1))
142 |
143 | attn = F.softmax(attn, dim=-1)
144 | attn = F.dropout(attn, p=attention_dropout)
145 |
146 | x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
147 | x = F.linear(x, proj_weight, proj_bias)
148 | x = F.dropout(x, p=dropout)
149 |
150 | # reverse windows
151 | x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
152 | x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
153 |
154 | # reverse cyclic shift
155 | if sum(shift_size) > 0:
156 | x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
157 |
158 | # unpad features
159 | x = x[:, :H, :W, :].contiguous()
160 | if qqkkvv:
161 | q_score = torch.matmul(q, q.transpose(-1, -2))
162 | q_score = q_score / math.sqrt(C // num_heads)
163 | k_score = torch.matmul(k, k.transpose(-1, -2))
164 | k_score = k_score / math.sqrt(C // num_heads)
165 | v_score = torch.matmul(v, v.transpose(-1, -2))
166 | v_score = v_score / math.sqrt(C // num_heads)
167 |
168 | return x, (attn, q_score, k_score, v_score)
169 | else:
170 | return x, None
171 |
172 |
173 | torch.fx.wrap("shifted_window_attention")
174 |
175 |
176 | class ShiftedWindowAttention(nn.Module):
177 | """
178 | See :func:`shifted_window_attention`.
179 | """
180 |
181 | def __init__(
182 | self,
183 | dim: int,
184 | window_size: List[int],
185 | shift_size: List[int],
186 | num_heads: int,
187 | qkv_bias: bool = True,
188 | proj_bias: bool = True,
189 | attention_dropout: float = 0.0,
190 | dropout: float = 0.0,
191 | qqkkvv: bool = False,
192 | ):
193 | super().__init__()
194 | if len(window_size) != 2 or len(shift_size) != 2:
195 | raise ValueError("window_size and shift_size must be of length 2")
196 | self.dim = dim
197 | self.window_size = window_size
198 | self.shift_size = shift_size
199 | self.num_heads = num_heads
200 | self.attention_dropout = attention_dropout
201 | self.dropout = dropout
202 | self.qqkkvv = qqkkvv
203 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
204 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
205 |
206 | # define a parameter table of relative position bias
207 | self.relative_position_bias_table = nn.Parameter(
208 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
209 | ) # 2*Wh-1 * 2*Ww-1, nH
210 |
211 | # get pair-wise relative position index for each token inside the window
212 | coords_h = torch.arange(self.window_size[0])
213 | coords_w = torch.arange(self.window_size[1])
214 | coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
215 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
216 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
217 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
218 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
219 | relative_coords[:, :, 1] += self.window_size[1] - 1
220 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
221 | relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww
222 | self.register_buffer("relative_position_index", relative_position_index)
223 |
224 | nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
225 |
226 | def forward(self, x: Tensor):
227 | """
228 | Args:
229 | x (Tensor): Tensor with layout of [B, H, W, C]
230 | Returns:
231 | Tensor with same layout as input, i.e. [B, H, W, C]
232 | """
233 |
234 | N = self.window_size[0] * self.window_size[1]
235 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
236 | relative_position_bias = relative_position_bias.view(N, N, -1)
237 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
238 |
239 | return shifted_window_attention(
240 | x,
241 | self.qkv.weight,
242 | self.proj.weight,
243 | relative_position_bias,
244 | self.window_size,
245 | self.num_heads,
246 | shift_size=self.shift_size,
247 | attention_dropout=self.attention_dropout,
248 | dropout=self.dropout,
249 | qkv_bias=self.qkv.bias,
250 | proj_bias=self.proj.bias,
251 | qqkkvv = self.qqkkvv
252 | )
253 |
254 |
255 | class SwinTransformerBlock(nn.Module):
256 | """
257 | Swin Transformer Block.
258 | Args:
259 | dim (int): Number of input channels.
260 | num_heads (int): Number of attention heads.
261 | window_size (List[int]): Window size.
262 | shift_size (List[int]): Shift size for shifted window attention.
263 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
264 | dropout (float): Dropout rate. Default: 0.0.
265 | attention_dropout (float): Attention dropout rate. Default: 0.0.
266 | stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
267 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
268 | attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
269 | """
270 |
271 | def __init__(
272 | self,
273 | dim: int,
274 | num_heads: int,
275 | window_size: List[int],
276 | shift_size: List[int],
277 | mlp_ratio: float = 4.0,
278 | dropout: float = 0.0,
279 | qqkkvv: bool = False,
280 | attention_dropout: float = 0.0,
281 | stochastic_depth_prob: float = 0.0,
282 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
283 | attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
284 | ):
285 | super().__init__()
286 | _log_api_usage_once(self)
287 |
288 | self.norm1 = norm_layer(dim)
289 | self.attn = attn_layer(
290 | dim,
291 | window_size,
292 | shift_size,
293 | num_heads,
294 | attention_dropout=attention_dropout,
295 | dropout=dropout,
296 | qqkkvv=qqkkvv
297 | )
298 | self.qqkkvv = qqkkvv
299 | self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
300 | self.norm2 = norm_layer(dim)
301 | self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
302 |
303 | for m in self.mlp.modules():
304 | if isinstance(m, nn.Linear):
305 | nn.init.xavier_uniform_(m.weight)
306 | if m.bias is not None:
307 | nn.init.normal_(m.bias, std=1e-6)
308 |
309 | def forward(self, x):
310 |
311 | x = x[0]
312 |
313 | if self.qqkkvv:
314 | temp_x, attn_mtrx = self.attn(self.norm1(x))
315 | x = x + self.stochastic_depth(temp_x)
316 | x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
317 | return x, attn_mtrx
318 | else:
319 | temp_x, _ = self.attn(self.norm1(x))
320 | x = x + self.stochastic_depth(temp_x)
321 | x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
322 | return x, None
323 |
324 |
325 | class SwinTransformer(nn.Module):
326 | """
327 | Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
328 | Shifted Windows" `_ paper.
329 | Args:
330 | patch_size (List[int]): Patch size.
331 | embed_dim (int): Patch embedding dimension.
332 | depths (List(int)): Depth of each Swin Transformer layer.
333 | num_heads (List(int)): Number of attention heads in different layers.
334 | window_size (List[int]): Window size.
335 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
336 | dropout (float): Dropout rate. Default: 0.0.
337 | attention_dropout (float): Attention dropout rate. Default: 0.0.
338 | stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0.
339 | num_classes (int): Number of classes for classification head. Default: 1000.
340 | block (nn.Module, optional): SwinTransformer Block. Default: None.
341 | norm_layer (nn.Module, optional): Normalization layer. Default: None.
342 | """
343 |
344 | def __init__(
345 | self,
346 | patch_size: List[int],
347 | embed_dim: int,
348 | depths: List[int],
349 | num_heads: List[int],
350 | window_size: List[int],
351 | mlp_ratio: float = 4.0,
352 | dropout: float = 0.0,
353 | qqkkvv: bool = False,
354 | attention_dropout: float = 0.0,
355 | stochastic_depth_prob: float = 0.0,
356 | num_classes: int = 1000,
357 | norm_layer: Optional[Callable[..., nn.Module]] = None,
358 | block: Optional[Callable[..., nn.Module]] = None,
359 | ):
360 | super().__init__()
361 | _log_api_usage_once(self)
362 | self.num_classes = num_classes
363 |
364 | if block is None:
365 | block = SwinTransformerBlock
366 |
367 | if norm_layer is None:
368 | norm_layer = partial(nn.LayerNorm, eps=1e-5)
369 |
370 | layers: List[nn.Module] = []
371 | # split image into non-overlapping patches
372 | layers.append(
373 | nn.Sequential(
374 | nn.Conv2d(
375 | 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
376 | ),
377 | Permute([0, 2, 3, 1]),
378 | norm_layer(embed_dim),
379 | )
380 | )
381 |
382 |
383 | total_stage_blocks = sum(depths)
384 | stage_block_id = 0
385 | # build SwinTransformer blocks
386 | for i_stage in range(len(depths)):
387 | stage: List[nn.Module] = []
388 | dim = embed_dim * 2 ** i_stage
389 | for i_layer in range(depths[i_stage]):
390 | # adjust stochastic depth probability based on the depth of the stage block
391 | sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
392 | stage.append(
393 | block(
394 | dim,
395 | num_heads[i_stage],
396 | window_size=window_size,
397 | shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
398 | mlp_ratio=mlp_ratio,
399 | dropout=dropout,
400 | qqkkvv=qqkkvv,
401 | attention_dropout=attention_dropout,
402 | stochastic_depth_prob=sd_prob,
403 | norm_layer=norm_layer,
404 | )
405 | )
406 | stage_block_id += 1
407 | layers.append(nn.Sequential(*stage))
408 | # add patch merging layer
409 | if i_stage < (len(depths) - 1):
410 | layers.append(PatchMerging(dim, norm_layer))
411 | self.features = nn.Sequential(*layers)
412 | self.qqkkvv = qqkkvv
413 | num_features = embed_dim * 2 ** (len(depths) - 1)
414 | self.norm = norm_layer(num_features)
415 | self.avgpool = nn.AdaptiveAvgPool2d(1)
416 | self.head = nn.Linear(num_features, num_classes)
417 |
418 | for m in self.modules():
419 | if isinstance(m, nn.Linear):
420 | nn.init.trunc_normal_(m.weight, std=0.02)
421 | if m.bias is not None:
422 | nn.init.zeros_(m.bias)
423 |
424 |
425 | def forward_features(self, x):
426 |
427 | ## patch embeddings
428 | x = self.features[0](x)
429 | x = (x, None)
430 | ## blocks
431 | attn_matrixs = []
432 | for block in self.features[1:]:
433 | x, attn_matrix = block(x)
434 | x = (x, None)
435 | attn_matrixs.append(attn_matrix)
436 |
437 | return x[0], attn_matrixs
438 |
439 |
440 |
441 | def forward(self, x):
442 | x, attn_matrixs = self.forward_features(x)
443 | x = self.norm(x)
444 | x = x.permute(0, 3, 1, 2)
445 | x = self.avgpool(x)
446 | x = torch.flatten(x, 1)
447 | x = self.head(x)
448 | return x, attn_matrixs
449 |
450 |
451 | def _swin_transformer(
452 | patch_size: List[int],
453 | embed_dim: int,
454 | depths: List[int],
455 | num_heads: List[int],
456 | window_size: List[int],
457 | stochastic_depth_prob: float,
458 | weights: Optional[WeightsEnum],
459 | progress: bool,
460 | **kwargs: Any,
461 | ) -> SwinTransformer:
462 | if weights is not None:
463 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
464 |
465 | model = SwinTransformer(
466 | patch_size=patch_size,
467 | embed_dim=embed_dim,
468 | depths=depths,
469 | num_heads=num_heads,
470 | window_size=window_size,
471 | stochastic_depth_prob=stochastic_depth_prob,
472 | **kwargs,
473 | )
474 |
475 | if weights is not None:
476 | model.load_state_dict(weights.get_state_dict(progress=progress))
477 |
478 | return model
479 |
480 |
481 |
482 |
483 | _COMMON_META = {
484 | "categories": _IMAGENET_CATEGORIES,
485 | }
486 |
487 |
488 | class Swin_T_Weights(WeightsEnum):
489 | IMAGENET1K_V1 = Weights(
490 | url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
491 | transforms=partial(
492 | ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
493 | ),
494 | meta={
495 | **_COMMON_META,
496 | "num_params": 28288354,
497 | "min_size": (224, 224),
498 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
499 | "_metrics": {
500 | "ImageNet-1K": {
501 | "acc@1": 81.474,
502 | "acc@5": 95.776,
503 | }
504 | },
505 | "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
506 | },
507 | )
508 | DEFAULT = IMAGENET1K_V1
509 |
510 |
511 | @register_model
512 | def swin_t(*, drop_path = 0.2, pretrained = False, progress: bool = True, **kwargs: Any) -> SwinTransformer:
513 | """
514 | Constructs a swin_tiny architecture from
515 | `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_.
516 |
517 | Args:
518 | weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
519 | pretrained weights to use. See
520 | :class:`~torchvision.models.Swin_T_Weights` below for
521 | more details, and possible values. By default, no pre-trained
522 | weights are used.
523 | progress (bool, optional): If True, displays a progress bar of the
524 | download to stderr. Default is True.
525 | **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
526 | base class. Please refer to the `source code
527 | `_
528 | for more details about this class.
529 |
530 | .. autoclass:: torchvision.models.Swin_T_Weights
531 | :members:
532 | """
533 |
534 | model = _swin_transformer(
535 | patch_size=[4, 4],
536 | embed_dim=96,
537 | depths=[2, 2, 6, 2],
538 | num_heads=[3, 6, 12, 24],
539 | window_size=[7, 7],
540 | stochastic_depth_prob=drop_path,
541 | weights=None,
542 | progress=progress,
543 | **kwargs,
544 | )
545 | if pretrained and kwargs['num_classes'] == 1000:
546 | print("load pretrained")
547 | checkpoint = torch.hub.load_state_dict_from_url(
548 | url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
549 | map_location="cpu", check_hash=True
550 | )
551 | model.load_state_dict(checkpoint)
552 |
553 | return model
554 |
555 |
556 |
557 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .transformers import *
3 |
--------------------------------------------------------------------------------
/src/utils/embedder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Embedder(nn.Module):
5 | def __init__(self,
6 | word_embedding_dim=300,
7 | vocab_size=100000,
8 | padding_idx=1,
9 | pretrained_weight=None,
10 | embed_freeze=False,
11 | *args, **kwargs):
12 | super(Embedder, self).__init__()
13 | self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \
14 | if pretrained_weight is not None else \
15 | nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
16 | self.embeddings.weight.requires_grad = not embed_freeze
17 |
18 | def forward_mask(self, mask):
19 | bsz, seq_len = mask.shape
20 | new_mask = mask.view(bsz, seq_len, 1)
21 | new_mask = new_mask.sum(-1)
22 | new_mask = (new_mask > 0)
23 | return new_mask
24 |
25 | def forward(self, x, mask=None):
26 | embed = self.embeddings(x)
27 | embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float()
28 | return embed, mask
29 |
30 | @staticmethod
31 | def init_weight(m):
32 | if isinstance(m, nn.Linear):
33 | nn.init.trunc_normal_(m.weight, std=.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 | else:
37 | nn.init.normal_(m.weight)
38 |
--------------------------------------------------------------------------------
/src/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import logging
5 |
6 | _logger = logging.getLogger('train')
7 |
8 |
9 | def resize_pos_embed(posemb, posemb_new, num_tokens=1):
10 | # Copied from `timm` by Ross Wightman:
11 | # github.com/rwightman/pytorch-image-models
12 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
13 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
14 | ntok_new = posemb_new.shape[1]
15 | if num_tokens:
16 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
17 | ntok_new -= num_tokens
18 | else:
19 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
20 | gs_old = int(math.sqrt(len(posemb_grid)))
21 | gs_new = int(math.sqrt(ntok_new))
22 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
23 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
24 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
25 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
26 | return posemb
27 |
28 |
29 | def pe_check(model, state_dict, pe_key='classifier.positional_emb'):
30 | if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys():
31 | if model.state_dict()[pe_key].shape != state_dict[pe_key].shape:
32 | state_dict[pe_key] = resize_pos_embed(state_dict[pe_key],
33 | model.state_dict()[pe_key],
34 | num_tokens=model.classifier.num_tokens)
35 | return state_dict
36 |
37 |
38 | def fc_check(model, state_dict, fc_key='classifier.fc'):
39 | for key in [f'{fc_key}.weight', f'{fc_key}.bias']:
40 | if key is not None and key in state_dict.keys() and key in model.state_dict().keys():
41 | if model.state_dict()[key].shape != state_dict[key].shape:
42 | _logger.warning(f'Removing {key}, number of classes has changed.')
43 | state_dict[key] = model.state_dict()[key]
44 | return state_dict
45 |
--------------------------------------------------------------------------------
/src/utils/stochastic_depth.py:
--------------------------------------------------------------------------------
1 | # Thanks to rwightman's timm package
2 | # github.com:rwightman/pytorch-image-models
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def drop_path(x, drop_prob: float = 0., training: bool = False):
9 | """
10 | Obtained from: github.com:rwightman/pytorch-image-models
11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0. or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
22 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
23 | random_tensor.floor_() # binarize
24 | output = x.div(keep_prob) * random_tensor
25 | return output
26 |
27 |
28 | class DropPath(nn.Module):
29 | """
30 | Obtained from: github.com:rwightman/pytorch-image-models
31 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32 | """
33 |
34 | def __init__(self, drop_prob=None):
35 | super(DropPath, self).__init__()
36 | self.drop_prob = drop_prob
37 |
38 | def forward(self, x):
39 | return drop_path(x, self.drop_prob, self.training)
40 |
--------------------------------------------------------------------------------
/src/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Tokenizer(nn.Module):
7 | def __init__(self,
8 | kernel_size, stride, padding,
9 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
10 | n_conv_layers=1,
11 | n_input_channels=3,
12 | n_output_channels=64,
13 | in_planes=64,
14 | activation=None,
15 | max_pool=True,
16 | conv_bias=False):
17 | super(Tokenizer, self).__init__()
18 |
19 | n_filter_list = [n_input_channels] + \
20 | [in_planes for _ in range(n_conv_layers - 1)] + \
21 | [n_output_channels]
22 |
23 | self.conv_layers = nn.Sequential(
24 | *[nn.Sequential(
25 | nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
26 | kernel_size=(kernel_size, kernel_size),
27 | stride=(stride, stride),
28 | padding=(padding, padding), bias=conv_bias),
29 | nn.Identity() if activation is None else activation(),
30 | nn.MaxPool2d(kernel_size=pooling_kernel_size,
31 | stride=pooling_stride,
32 | padding=pooling_padding) if max_pool else nn.Identity()
33 | )
34 | for i in range(n_conv_layers)
35 | ])
36 | # print(self.conv_layers)
37 | self.flattener = nn.Flatten(2, 3)
38 | self.apply(self.init_weight)
39 |
40 | def sequence_length(self, n_channels=3, height=224, width=224):
41 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
42 |
43 | def forward(self, x):
44 | return self.flattener(self.conv_layers(x)).transpose(-2, -1)
45 |
46 | @staticmethod
47 | def init_weight(m):
48 | if isinstance(m, nn.Conv2d):
49 | nn.init.kaiming_normal_(m.weight)
50 |
51 |
52 | class TextTokenizer(nn.Module):
53 | def __init__(self,
54 | kernel_size, stride, padding,
55 | pooling_kernel_size=3, pooling_stride=2, pooling_padding=1,
56 | embedding_dim=300,
57 | n_output_channels=128,
58 | activation=None,
59 | max_pool=True,
60 | *args, **kwargs):
61 | super(TextTokenizer, self).__init__()
62 |
63 | self.max_pool = max_pool
64 | self.conv_layers = nn.Sequential(
65 | nn.Conv2d(1, n_output_channels,
66 | kernel_size=(kernel_size, embedding_dim),
67 | stride=(stride, 1),
68 | padding=(padding, 0), bias=False),
69 | nn.Identity() if activation is None else activation(),
70 | nn.MaxPool2d(
71 | kernel_size=(pooling_kernel_size, 1),
72 | stride=(pooling_stride, 1),
73 | padding=(pooling_padding, 0)
74 | ) if max_pool else nn.Identity()
75 | )
76 |
77 | self.apply(self.init_weight)
78 |
79 | def seq_len(self, seq_len=32, embed_dim=300):
80 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
81 |
82 | def forward_mask(self, mask):
83 | new_mask = mask.unsqueeze(1).float()
84 | cnn_weight = torch.ones(
85 | (1, 1, self.conv_layers[0].kernel_size[0]),
86 | device=mask.device,
87 | dtype=torch.float)
88 | new_mask = F.conv1d(
89 | new_mask, cnn_weight, None,
90 | self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], 1, 1)
91 | if self.max_pool:
92 | new_mask = F.max_pool1d(
93 | new_mask, self.conv_layers[2].kernel_size[0],
94 | self.conv_layers[2].stride[0], self.conv_layers[2].padding[0], 1, False, False)
95 | new_mask = new_mask.squeeze(1)
96 | new_mask = (new_mask > 0)
97 | return new_mask
98 |
99 | def forward(self, x, mask=None):
100 | x = x.unsqueeze(1)
101 | x = self.conv_layers(x)
102 | x = x.transpose(1, 3).squeeze(1)
103 | if mask is not None:
104 | mask = self.forward_mask(mask).unsqueeze(-1).float()
105 | x = x * mask
106 | return x, mask
107 |
108 | @staticmethod
109 | def init_weight(m):
110 | if isinstance(m, nn.Conv2d):
111 | nn.init.kaiming_normal_(m.weight)
112 |
--------------------------------------------------------------------------------
/src/utils/transformers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch.nn import Module, ModuleList, Linear, Dropout, LayerNorm, Identity, Parameter, init
5 | import torch.nn.functional as F
6 | from .stochastic_depth import DropPath
7 |
8 |
9 | class Attention(Module):
10 | """
11 | Obtained from timm: github.com:rwightman/pytorch-image-models
12 | """
13 |
14 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1, use_skip=False):
15 | super().__init__()
16 | self.dim = dim
17 | self.num_heads = num_heads
18 | self.use_skip = use_skip
19 | self.attention_dropout = attention_dropout
20 | self.projection_dropout = projection_dropout
21 |
22 | head_dim = dim // self.num_heads
23 | self.scale = head_dim ** -0.5
24 |
25 | self.qkv = Linear(dim, dim * 3, bias=False)
26 | self.attn_drop = Dropout(attention_dropout)
27 | self.proj = Linear(dim, dim)
28 | self.proj_drop = Dropout(projection_dropout)
29 |
30 | def forward(self, x):
31 | B, N, C = x.shape
32 | # qkv dimension: 3, B, num_heads, N, C // num_heads
33 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
34 | q, k, v = qkv[0], qkv[1], qkv[2]
35 |
36 | if self.use_skip:
37 | res = x.reshape(
38 | B, N, self.num_heads, C // self.num_heads
39 | ).permute(0, 2, 1, 3)
40 | q += res
41 | k += res
42 | v += res
43 |
44 | attn = (q @ k.transpose(-2, -1)) * self.scale
45 | attn = attn.softmax(dim=-1)
46 | attn = self.attn_drop(attn)
47 |
48 |
49 |
50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
51 | if not self.use_skip:
52 | x = self.proj(x)
53 | else:
54 | x = x + self.proj(x)
55 | x = self.proj_drop(x)
56 | return x
57 |
58 |
59 | class MaskedAttention(Module):
60 | def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
61 | super().__init__()
62 | self.num_heads = num_heads
63 | head_dim = dim // self.num_heads
64 | self.scale = head_dim ** -0.5
65 |
66 | self.qkv = Linear(dim, dim * 3, bias=False)
67 | self.attn_drop = Dropout(attention_dropout)
68 | self.proj = Linear(dim, dim)
69 | self.proj_drop = Dropout(projection_dropout)
70 |
71 | def forward(self, x, mask=None):
72 | B, N, C = x.shape
73 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74 | q, k, v = qkv[0], qkv[1], qkv[2]
75 |
76 | attn = (q @ k.transpose(-2, -1)) * self.scale
77 |
78 | if mask is not None:
79 | mask_value = -torch.finfo(attn.dtype).max
80 | assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions'
81 | mask = mask[:, None, :] * mask[:, :, None]
82 | mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
83 | attn.masked_fill_(~mask, mask_value)
84 |
85 | attn = attn.softmax(dim=-1)
86 | attn = self.attn_drop(attn)
87 |
88 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 | return x
92 |
93 |
94 | class TransformerEncoderLayer(Module):
95 | """
96 | Inspired by torch.nn.TransformerEncoderLayer and timm.
97 | """
98 |
99 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
100 | attention_dropout=0.1, drop_path_rate=0.1,
101 | use_layer_scale=False, use_skip=False, use_relu=False):
102 | super(TransformerEncoderLayer, self).__init__()
103 | self.pre_norm = LayerNorm(d_model)
104 | self.self_attn = Attention(dim=d_model, num_heads=nhead,
105 | attention_dropout=attention_dropout, projection_dropout=dropout, use_skip=use_skip)
106 |
107 | self.linear1 = Linear(d_model, dim_feedforward)
108 | self.dropout1 = Dropout(dropout)
109 | self.norm1 = LayerNorm(d_model)
110 | self.linear2 = Linear(dim_feedforward, d_model)
111 | self.dropout2 = Dropout(dropout)
112 |
113 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
114 |
115 | # self.activation = F.gelu if not use_relu else F.relu
116 | self.activation = nn.GELU() if not use_relu else nn.ReLU(inplace=True)
117 |
118 | init_values = 1e-4
119 | self.use_layer_scale = use_layer_scale
120 |
121 | if self.use_layer_scale:
122 | self.gamma_1 = torch.nn.Parameter(init_values * torch.ones(d_model), requires_grad=True)
123 | self.gamma_2 = torch.nn.Parameter(init_values * torch.ones(d_model), requires_grad=True)
124 |
125 | def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
126 | src2 = self.self_attn(self.pre_norm(src))
127 | src2 = self.gamma_1 * src2 if self.use_layer_scale else src2
128 | src = src + self.drop_path(src2)
129 | src = self.norm1(src)
130 | gelu_output = self.activation(self.linear1(src))
131 | src2 = self.linear2(self.dropout1(gelu_output))
132 | src2 = self.gamma_2 * src2 if self.use_layer_scale else src2
133 | src = src + self.drop_path(self.dropout2(src2))
134 | return src
135 |
136 |
137 | class MaskedTransformerEncoderLayer(Module):
138 | """
139 | Inspired by torch.nn.TransformerEncoderLayer and timm.
140 | """
141 |
142 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
143 | attention_dropout=0.1, drop_path_rate=0.1):
144 | super(MaskedTransformerEncoderLayer, self).__init__()
145 | self.pre_norm = LayerNorm(d_model)
146 | self.self_attn = MaskedAttention(dim=d_model, num_heads=nhead,
147 | attention_dropout=attention_dropout, projection_dropout=dropout)
148 |
149 | self.linear1 = Linear(d_model, dim_feedforward)
150 | self.dropout1 = Dropout(dropout)
151 | self.norm1 = LayerNorm(d_model)
152 | self.linear2 = Linear(dim_feedforward, d_model)
153 | self.dropout2 = Dropout(dropout)
154 |
155 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity()
156 |
157 | self.activation = F.gelu
158 |
159 | def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:
160 | src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))
161 | src = self.norm1(src)
162 | src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
163 | src = src + self.drop_path(self.dropout2(src2))
164 | return src
165 |
166 |
167 | class TransformerClassifier(Module):
168 | def __init__(self,
169 | seq_pool=True,
170 | embedding_dim=768,
171 | num_layers=12,
172 | num_heads=12,
173 | mlp_ratio=4.0,
174 | num_classes=1000,
175 | dropout=0.1,
176 | attention_dropout=0.1,
177 | stochastic_depth=0.1,
178 | positional_embedding='learnable',
179 | sequence_length=None,
180 | use_layer_scale=False,
181 | use_skip=False,
182 | use_relu=False):
183 | super().__init__()
184 |
185 | positional_embedding = positional_embedding if \
186 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
187 | dim_feedforward = int(embedding_dim * mlp_ratio)
188 | self.embedding_dim = embedding_dim
189 | self.sequence_length = sequence_length
190 | self.seq_pool = seq_pool
191 | self.num_tokens = 0
192 |
193 | assert sequence_length is not None or positional_embedding == 'none', \
194 | f"Positional embedding is set to {positional_embedding} and" \
195 | f" the sequence length was not specified."
196 |
197 | if not seq_pool:
198 | sequence_length += 1
199 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
200 | requires_grad=True)
201 | self.num_tokens = 1
202 | else:
203 | self.attention_pool = Linear(self.embedding_dim, 1)
204 |
205 | if positional_embedding != 'none':
206 | if positional_embedding == 'learnable':
207 | self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim),
208 | requires_grad=True)
209 | init.trunc_normal_(self.positional_emb, std=0.2)
210 | else:
211 | self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
212 | requires_grad=False)
213 | else:
214 | self.positional_emb = None
215 |
216 | self.dropout = Dropout(p=dropout)
217 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]
218 | self.blocks = ModuleList([
219 | TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
220 | dim_feedforward=dim_feedforward, dropout=dropout,
221 | attention_dropout=attention_dropout,
222 | drop_path_rate=dpr[i], use_layer_scale=use_layer_scale,
223 | use_skip=use_skip, use_relu=use_relu)
224 | for i in range(num_layers)])
225 | self.norm = LayerNorm(embedding_dim)
226 |
227 | self.fc = Linear(embedding_dim, num_classes)
228 | self.apply(self.init_weight)
229 |
230 | def forward(self, x):
231 | if self.positional_emb is None and x.size(1) < self.sequence_length:
232 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
233 |
234 | features = []
235 |
236 | if not self.seq_pool:
237 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
238 | x = torch.cat((cls_token, x), dim=1)
239 |
240 | if self.positional_emb is not None:
241 | x += self.positional_emb
242 |
243 | x = self.dropout(x)
244 |
245 | for blk in self.blocks:
246 | features.append(x)
247 | x = blk(x)
248 | x = self.norm(x)
249 |
250 | features.append(x)
251 | if self.seq_pool:
252 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
253 | else:
254 | x = x[:, 0]
255 |
256 | x = self.fc(x)
257 | return x, features
258 |
259 | @staticmethod
260 | def init_weight(m):
261 | if isinstance(m, Linear):
262 | init.trunc_normal_(m.weight, std=.02)
263 | if isinstance(m, Linear) and m.bias is not None:
264 | init.constant_(m.bias, 0)
265 | elif isinstance(m, LayerNorm):
266 | init.constant_(m.bias, 0)
267 | init.constant_(m.weight, 1.0)
268 |
269 | @staticmethod
270 | def sinusoidal_embedding(n_channels, dim):
271 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
272 | for p in range(n_channels)])
273 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
274 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
275 | return pe.unsqueeze(0)
276 |
277 |
278 | class MaskedTransformerClassifier(Module):
279 | def __init__(self,
280 | seq_pool=True,
281 | embedding_dim=768,
282 | num_layers=12,
283 | num_heads=12,
284 | mlp_ratio=4.0,
285 | num_classes=1000,
286 | dropout=0.1,
287 | attention_dropout=0.1,
288 | stochastic_depth=0.1,
289 | positional_embedding='sine',
290 | seq_len=None,
291 | *args, **kwargs):
292 | super().__init__()
293 | positional_embedding = positional_embedding if \
294 | positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
295 | dim_feedforward = int(embedding_dim * mlp_ratio)
296 | self.embedding_dim = embedding_dim
297 | self.seq_len = seq_len
298 | self.seq_pool = seq_pool
299 | self.num_tokens = 0
300 |
301 | assert seq_len is not None or positional_embedding == 'none', \
302 | f"Positional embedding is set to {positional_embedding} and" \
303 | f" the sequence length was not specified."
304 |
305 | if not seq_pool:
306 | seq_len += 1
307 | self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim),
308 | requires_grad=True)
309 | self.num_tokens = 1
310 | else:
311 | self.attention_pool = Linear(self.embedding_dim, 1)
312 |
313 | if positional_embedding != 'none':
314 | if positional_embedding == 'learnable':
315 | seq_len += 1 # padding idx
316 | self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim),
317 | requires_grad=True)
318 | init.trunc_normal_(self.positional_emb, std=0.2)
319 | else:
320 | self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len,
321 | embedding_dim,
322 | padding_idx=True),
323 | requires_grad=False)
324 | else:
325 | self.positional_emb = None
326 |
327 | self.dropout = Dropout(p=dropout)
328 | dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)]
329 | self.blocks = ModuleList([
330 | MaskedTransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
331 | dim_feedforward=dim_feedforward, dropout=dropout,
332 | attention_dropout=attention_dropout, drop_path_rate=dpr[i])
333 | for i in range(num_layers)])
334 | self.norm = LayerNorm(embedding_dim)
335 |
336 | self.fc = Linear(embedding_dim, num_classes)
337 | self.apply(self.init_weight)
338 |
339 | def forward(self, x, mask=None):
340 | if self.positional_emb is None and x.size(1) < self.seq_len:
341 | x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
342 |
343 | if not self.seq_pool:
344 | cls_token = self.class_emb.expand(x.shape[0], -1, -1)
345 | x = torch.cat((cls_token, x), dim=1)
346 | if mask is not None:
347 | mask = torch.cat([torch.ones(size=(mask.shape[0], 1), device=mask.device), mask.float()], dim=1)
348 | mask = (mask > 0)
349 |
350 | if self.positional_emb is not None:
351 | x += self.positional_emb
352 |
353 | x = self.dropout(x)
354 |
355 | for blk in self.blocks:
356 | x = blk(x, mask=mask)
357 | x = self.norm(x)
358 |
359 | if self.seq_pool:
360 | x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
361 | else:
362 | x = x[:, 0]
363 |
364 | x = self.fc(x)
365 | return x
366 |
367 | @staticmethod
368 | def init_weight(m):
369 | if isinstance(m, Linear):
370 | init.trunc_normal_(m.weight, std=.02)
371 | if isinstance(m, Linear) and m.bias is not None:
372 | init.constant_(m.bias, 0)
373 | elif isinstance(m, LayerNorm):
374 | init.constant_(m.bias, 0)
375 | init.constant_(m.weight, 1.0)
376 |
377 | @staticmethod
378 | def sinusoidal_embedding(n_channels, dim, padding_idx=False):
379 | pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
380 | for p in range(n_channels)])
381 | pe[:, 0::2] = torch.sin(pe[:, 0::2])
382 | pe[:, 1::2] = torch.cos(pe[:, 1::2])
383 | pe = pe.unsqueeze(0)
384 | if padding_idx:
385 | return torch.cat([torch.zeros((1, 1, dim)), pe], dim=1)
386 | return pe
387 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BatchNorm(nn.modules.batchnorm._BatchNorm):
6 | def __init__(
7 | self,
8 | num_features,
9 | eps=1e-5,
10 | momentum=0.1,
11 | affine=True,
12 | track_running_stats=True,
13 | device=None,
14 | dtype=None,
15 | transpose=False,
16 | ):
17 | factory_kwargs = {'device': device, 'dtype': dtype}
18 | super().__init__(
19 | num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
20 | )
21 | self.transpose = transpose
22 |
23 | def _check_input_dim(self, input):
24 | if input.dim() < 2:
25 | raise ValueError(
26 | "expected at least 2D input (got {}D input)".format(input.dim())
27 | )
28 |
29 | def forward(self, input: torch.Tensor) -> torch.Tensor:
30 | if self.transpose:
31 | dim = input.ndim
32 | input = input.transpose(dim-1, dim-2)
33 | output = super().forward(input)
34 | if self.transpose:
35 | return output.transpose(dim-1, dim-2)
36 | else:
37 | return output
38 |
39 |
40 | def build_bn_from_ln(ln, transpose=False):
41 | assert isinstance(ln.normalized_shape, int) or len(ln.normalized_shape) == 1
42 | num_features = (
43 | ln.normalized_shape
44 | if isinstance(ln.normalized_shape, int)
45 | else ln.normalized_shape[0]
46 | )
47 | return BatchNorm(
48 | num_features=num_features,
49 | affine=ln.elementwise_affine,
50 | device=ln.weight.device,
51 | transpose=transpose,
52 | )
53 |
54 |
55 | def replace_module_by_module(module, m_to_replace, build_fn):
56 | module_output = module
57 | if isinstance(module, m_to_replace):
58 | module_output = build_fn(module)
59 | for name, child in module.named_children():
60 | module_output.add_module(
61 | name, replace_module_by_module(child, m_to_replace, build_fn)
62 | )
63 | del module
64 | return module_output
65 |
66 |
67 | def replace_ln_by_bn2d(module):
68 | return replace_module_by_module(
69 | module,
70 | nn.LayerNorm,
71 | lambda x: build_bn_from_ln(ln=x, tranpose=False),
72 | )
73 |
74 |
75 | def replace_ln_by_bn1d(module):
76 | return replace_module_by_module(
77 | module,
78 | nn.LayerNorm,
79 | lambda x: build_bn_from_ln(ln=x, transpose=True),
80 | )
81 |
82 |
--------------------------------------------------------------------------------
/timm_fix_imagenet_loading_bugs/dataset_factory.py:
--------------------------------------------------------------------------------
1 | """ Dataset Factory
2 |
3 | Hacked together by / Copyright 2021, Ross Wightman
4 | """
5 | import os
6 |
7 | from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
8 | try:
9 | from torchvision.datasets import Places365
10 | has_places365 = True
11 | except ImportError:
12 | has_places365 = False
13 | try:
14 | from torchvision.datasets import INaturalist
15 | has_inaturalist = True
16 | except ImportError:
17 | has_inaturalist = False
18 |
19 | from .dataset import IterableImageDataset, ImageDataset
20 |
21 | _TORCH_BASIC_DS = dict(
22 | cifar10=CIFAR10,
23 | cifar100=CIFAR100,
24 | mnist=MNIST,
25 | qmist=QMNIST,
26 | kmnist=KMNIST,
27 | fashion_mnist=FashionMNIST,
28 | )
29 | _TRAIN_SYNONYM = {'train', 'training'}
30 | _EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
31 |
32 |
33 | def _search_split(root, split):
34 | # look for sub-folder with name of split in root and use that if it exists
35 | split_name = split.split('[')[0]
36 | try_root = os.path.join(root, split_name)
37 | if os.path.exists(try_root):
38 | return try_root
39 |
40 | def _try(syn):
41 | for s in syn:
42 | try_root = os.path.join(root, s)
43 | if os.path.exists(try_root):
44 | return try_root
45 | return root
46 | if split_name in _TRAIN_SYNONYM:
47 | root = _try(_TRAIN_SYNONYM)
48 | elif split_name in _EVAL_SYNONYM:
49 | root = _try(_EVAL_SYNONYM)
50 | return root
51 |
52 |
53 | def create_dataset(
54 | name,
55 | root,
56 | split='validation',
57 | search_split=True,
58 | class_map=None,
59 | load_bytes=False,
60 | is_training=False,
61 | download=False,
62 | batch_size=None,
63 | repeats=0,
64 | **kwargs
65 | ):
66 | """ Dataset factory method
67 |
68 | In parenthesis after each arg are the type of dataset supported for each arg, one of:
69 | * folder - default, timm folder (or tar) based ImageDataset
70 | * torch - torchvision based datasets
71 | * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
72 | * all - any of the above
73 |
74 | Args:
75 | name: dataset name, empty is okay for folder based datasets
76 | root: root folder of dataset (all)
77 | split: dataset split (all)
78 | search_split: search for split specific child fold from root so one can specify
79 | `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
80 | class_map: specify class -> index mapping via text file or dict (folder)
81 | load_bytes: load data, return images as undecoded bytes (folder)
82 | download: download dataset if not present and supported (TFDS, torch)
83 | is_training: create dataset in train mode, this is different from the split.
84 | For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
85 | batch_size: batch size hint for (TFDS)
86 | repeats: dataset repeats per iteration i.e. epoch (TFDS)
87 | **kwargs: other args to pass to dataset
88 |
89 | Returns:
90 | Dataset object
91 | """
92 | name = name.lower()
93 | if name.startswith('torch/'):
94 | name = name.split('/', 2)[-1]
95 | torch_kwargs = dict(root=root, download=download, **kwargs)
96 | if name in _TORCH_BASIC_DS:
97 | ds_class = _TORCH_BASIC_DS[name]
98 | use_train = split in _TRAIN_SYNONYM
99 | ds = ds_class(train=use_train, **torch_kwargs)
100 | elif name == 'inaturalist' or name == 'inat':
101 | assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
102 | target_type = 'full'
103 | split_split = split.split('/')
104 | if len(split_split) > 1:
105 | target_type = split_split[0].split('_')
106 | if len(target_type) == 1:
107 | target_type = target_type[0]
108 | split = split_split[-1]
109 | if split in _TRAIN_SYNONYM:
110 | split = '2021_train'
111 | elif split in _EVAL_SYNONYM:
112 | split = '2021_valid'
113 | ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
114 | elif name == 'places365':
115 | assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
116 | if split in _TRAIN_SYNONYM:
117 | split = 'train-standard'
118 | elif split in _EVAL_SYNONYM:
119 | split = 'val'
120 | ds = Places365(split=split, **torch_kwargs)
121 | elif name == 'imagenet':
122 | if split in _EVAL_SYNONYM:
123 | split = 'val'
124 | print('imagenet torch')
125 | torch_kwargs = dict(root=root, **kwargs)
126 | ds = ImageNet(split=split, **torch_kwargs)
127 | elif name == 'image_folder' or name == 'folder':
128 | # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
129 | if search_split and os.path.isdir(root):
130 | # look for split specific sub-folder in root
131 | root = _search_split(root, split)
132 | ds = ImageFolder(root, **kwargs)
133 | else:
134 | assert False, f"Unknown torchvision dataset {name}"
135 | elif name.startswith('tfds/'):
136 | ds = IterableImageDataset(
137 | root, parser=name, split=split, is_training=is_training,
138 | download=download, batch_size=batch_size, repeats=repeats, **kwargs)
139 | else:
140 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
141 | if search_split and os.path.isdir(root):
142 | # look for split specific sub-folder in root
143 | root = _search_split(root, split)
144 | ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
145 | return ds
146 |
--------------------------------------------------------------------------------
/train_scripts/deit_s/w2a2_deit_s.sh:
--------------------------------------------------------------------------------
1 | ## traing of 2-bit DeiT-S
2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 140 \
7 | --weight-decay 0.05 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 5.47e-4 \
10 | --warmup-epochs 5 \
11 | --mixup 0.0 --cutmix 0.0 \
12 | --aq-enable \
13 | --aq-mode lsq \
14 | --aq-per-channel \
15 | --aq_clip_learnable \
16 | --aq-bitw 2 \
17 | --wq-enable \
18 | --wq-per-channel \
19 | --wq-bitw 2 \
20 | --wq-mode statsq \
21 | --model_type deit \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher deit_small_distilled_patch16_224 \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w2a2_deit_s_qkreparam/ \
31 | --visible_gpu '4,5,6,7' \
32 | --world_size '4' \
33 | --tcp_port '36969'
34 |
35 | ## Finetune 2-bit DeiT-S with CGA
36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_small_distilled_patch16_224 \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 140 \
41 | --weight-decay 0.05 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 5.47e-4 \
44 | --warmup-epochs 5 \
45 | --mixup 0.0 --cutmix 0.0 \
46 | --aq-enable \
47 | --aq-mode lsq \
48 | --aq-per-channel \
49 | --aq_clip_learnable \
50 | --aq-bitw 2 \
51 | --wq-enable \
52 | --wq-per-channel \
53 | --wq-bitw 2 \
54 | --wq-mode statsq \
55 | --model_type deit \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher deit_small_distilled_patch16_224 \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w2a2_deit_s_qkreparam_cga_0005/ \
68 | --visible_gpu '4,5,6,7' \
69 | --world_size '4' \
70 | --tcp_port '36969'
71 |
--------------------------------------------------------------------------------
/train_scripts/deit_s/w3a3_deit_s.sh:
--------------------------------------------------------------------------------
1 | ## traing of 3-bit DeiT-S
2 | python3 train.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 100 \
7 | --weight-decay 0.0 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 3.2e-4 \
10 | --warmup-epochs 0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 3 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 3 \
19 | --wq-mode statsq \
20 | --model_type deit \
21 | --quantized \
22 | --pretrained \
23 | --pretrained_initialized \
24 | --use-kd --teacher deit_small_distilled_patch16_224 \
25 | --kd_hard_and_soft 1 \
26 | --qk_reparam \
27 | --qk_reparam_type 0 \
28 | --teacher_pretrained \
29 | --output ./outputs/w3a3_deit_s_qkreparam/ \
30 | --visible_gpu '0,1,2,3,4,5,6,7' \
31 | --world_size '8' \
32 | --tcp_port '36969'
33 |
34 | ## Finetune 3-bit DeiT-S with CGA
35 | python3 cga.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
36 | your_path/dataset/imagenet-1k/imagenet \
37 | --dataset 'torch/imagenet' \
38 | --epochs 300 \
39 | --batch-size 100 \
40 | --weight-decay 0.0 \
41 | --warmup-lr 1.0e-6 \
42 | --lr 3.2e-4 \
43 | --warmup-epochs 0 \
44 | --aq-enable \
45 | --aq-mode lsq \
46 | --aq-per-channel \
47 | --aq_clip_learnable \
48 | --aq-bitw 3 \
49 | --wq-enable \
50 | --wq-per-channel \
51 | --wq-bitw 3 \
52 | --wq-mode statsq \
53 | --model_type deit \
54 | --quantized \
55 | --pretrained \
56 | --pretrained_initialized \
57 | --use-kd --teacher deit_small_distilled_patch16_224 \
58 | --kd_hard_and_soft 1 \
59 | --qk_reparam \
60 | --qk_reparam_type 1 \
61 | --boundaryRange 0.005 \
62 | --freeze_for_n_epochs 30 \
63 | --teacher_pretrained \
64 | --resume put the model you wish to finetune here \
65 | --output ./outputs/w3a3_deit_s_qkreparam_cga_0005/ \
66 | --visible_gpu '0,1,2,3,4,5,6,7' \
67 | --world_size '8' \
68 | --tcp_port '36969'
69 |
70 |
71 |
--------------------------------------------------------------------------------
/train_scripts/deit_s/w4a4_deit_s.sh:
--------------------------------------------------------------------------------
1 | ## traing of 4-bit DeiT-S
2 | python3 train.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 100 \
7 | --weight-decay 0.0 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 3.2e-4 \
10 | --warmup-epochs 0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 4 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 4 \
19 | --wq-mode statsq \
20 | --model_type deit \
21 | --quantized \
22 | --pretrained \
23 | --pretrained_initialized \
24 | --use-kd --teacher deit_small_distilled_patch16_224 \
25 | --kd_hard_and_soft 1 \
26 | --qk_reparam \
27 | --qk_reparam_type 0 \
28 | --teacher_pretrained \
29 | --output ./outputs/w4a4_deit_s_qkreparam/ \
30 | --visible_gpu '0,1,2,3,4,5,6,7' \
31 | --world_size '8' \
32 | --tcp_port '36969'
33 |
34 | ## Finetune 4-bit DeiT-S with CGA
35 | python3 cga.py -c./configs/deit_default_imagent.attn_q.yml --model deit_small_distilled_patch16_224 \
36 | your_path/dataset/imagenet-1k/imagenet \
37 | --dataset 'torch/imagenet' \
38 | --epochs 300 \
39 | --batch-size 100 \
40 | --weight-decay 0.0 \
41 | --warmup-lr 1.0e-6 \
42 | --lr 3.2e-4 \
43 | --warmup-epochs 0 \
44 | --aq-enable \
45 | --aq-mode lsq \
46 | --aq-per-channel \
47 | --aq_clip_learnable \
48 | --aq-bitw 4 \
49 | --wq-enable \
50 | --wq-per-channel \
51 | --wq-bitw 4 \
52 | --wq-mode statsq \
53 | --model_type deit \
54 | --quantized \
55 | --pretrained \
56 | --pretrained_initialized \
57 | --use-kd --teacher deit_small_distilled_patch16_224 \
58 | --kd_hard_and_soft 1 \
59 | --qk_reparam \
60 | --qk_reparam_type 1 \
61 | --boundaryRange 0.005 \
62 | --freeze_for_n_epochs 30 \
63 | --teacher_pretrained \
64 | --resume put the model you wish to finetune here \
65 | --output ./outputs/w4a4_deit_s_qkreparam_cga_0005/ \
66 | --visible_gpu '0,1,2,3,4,5,6,7' \
67 | --world_size '8' \
68 | --tcp_port '36969'
69 |
70 |
71 |
--------------------------------------------------------------------------------
/train_scripts/deit_t/w2a2_deit_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 2-bit DeiT-T
2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 140 \
7 | --weight-decay 0.05 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 5.47e-4 \
10 | --warmup-epochs 5 \
11 | --mixup 0.0 --cutmix 0.0 \
12 | --aq-enable \
13 | --aq-mode lsq \
14 | --aq-per-channel \
15 | --aq_clip_learnable \
16 | --aq-bitw 2 \
17 | --wq-enable \
18 | --wq-per-channel \
19 | --wq-bitw 2 \
20 | --wq-mode statsq \
21 | --model_type deit \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w2a2_deit_t_qkreparam/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '36969'
34 |
35 | ## Finetune 2-bit DeiT-T with CGA your_path/model_saved/w2a2_deit_t_lsq_auto_8_bits_fl_qkreparam_correct/model_best_modified.pth.tar
36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 140 \
41 | --weight-decay 0.05 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 5.47e-4 \
44 | --warmup-epochs 5 \
45 | --mixup 0.0 --cutmix 0.0 \
46 | --aq-enable \
47 | --aq-mode lsq \
48 | --aq-per-channel \
49 | --aq_clip_learnable \
50 | --aq-bitw 2 \
51 | --wq-enable \
52 | --wq-per-channel \
53 | --wq-bitw 2 \
54 | --wq-mode statsq \
55 | --model_type deit \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w2a2_deit_t_qkreparam_cga_0005/ \
68 | --visible_gpu '4,5,6,7' \
69 | --world_size '4' \
70 | --tcp_port '36969'
71 |
72 |
--------------------------------------------------------------------------------
/train_scripts/deit_t/w3a3_deit_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 3-bit DeiT-T
2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 140 \
7 | --weight-decay 0.05 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 5.47e-4 \
10 | --warmup-epochs 5 \
11 | --mixup 0.0 --cutmix 0.0 \
12 | --aq-enable \
13 | --aq-mode lsq \
14 | --aq-per-channel \
15 | --aq_clip_learnable \
16 | --aq-bitw 3 \
17 | --wq-enable \
18 | --wq-per-channel \
19 | --wq-bitw 3 \
20 | --wq-mode statsq \
21 | --model_type deit \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w3a3_deit_t_qkreparam/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '36969'
34 |
35 | ## Finetune 3-bit DeiT-T with CGA
36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 140 \
41 | --weight-decay 0.05 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 5.47e-4 \
44 | --warmup-epochs 5 \
45 | --mixup 0.0 --cutmix 0.0 \
46 | --aq-enable \
47 | --aq-mode lsq \
48 | --aq-per-channel \
49 | --aq_clip_learnable \
50 | --aq-bitw 3 \
51 | --wq-enable \
52 | --wq-per-channel \
53 | --wq-bitw 3 \
54 | --wq-mode statsq \
55 | --model_type deit \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w3a3_deit_t_qkreparam_cga_0005/ \
68 | --visible_gpu '4,5,6,7' \
69 | --world_size '4' \
70 | --tcp_port '36969'
71 |
72 |
--------------------------------------------------------------------------------
/train_scripts/deit_t/w4a4_deit_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 4-bit DeiT-T
2 | python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 140 \
7 | --weight-decay 0.05 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 5.47e-4 \
10 | --warmup-epochs 5 \
11 | --mixup 0.0 --cutmix 0.0 \
12 | --aq-enable \
13 | --aq-mode lsq \
14 | --aq-per-channel \
15 | --aq_clip_learnable \
16 | --aq-bitw 4 \
17 | --wq-enable \
18 | --wq-per-channel \
19 | --wq-bitw 4 \
20 | --wq-mode statsq \
21 | --model_type deit \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w4a4_deit_t_qkreparam/ \
31 | --visible_gpu '0,1,2,4' \
32 | --world_size '4' \
33 | --tcp_port '36969'
34 |
35 | ## Finetune 4-bit DeiT-T with CGA
36 | python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 140 \
41 | --weight-decay 0.05 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 5.47e-4 \
44 | --warmup-epochs 5 \
45 | --mixup 0.0 --cutmix 0.0 \
46 | --aq-enable \
47 | --aq-mode lsq \
48 | --aq-per-channel \
49 | --aq_clip_learnable \
50 | --aq-bitw 4 \
51 | --wq-enable \
52 | --wq-per-channel \
53 | --wq-bitw 4 \
54 | --wq-mode statsq \
55 | --model_type deit \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher deit_tiny_distilled_patch16_224 \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w4a4_deit_t_qkreparam_cga_0005/ \
68 | --visible_gpu '4,5,6,7' \
69 | --world_size '4' \
70 | --tcp_port '36969'
71 |
72 |
--------------------------------------------------------------------------------
/train_scripts/swin_t/w2a2_swin_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 2-bit Swin-T
2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 64 \
7 | --weight-decay 0.05 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 5.0e-4 \
10 | --warmup-epochs 5 \
11 | --mixup 0.0 --cutmix 0.0 \
12 | --aq-enable \
13 | --aq-mode lsq \
14 | --aq-per-channel \
15 | --aq_clip_learnable \
16 | --aq-bitw 2 \
17 | --wq-enable \
18 | --wq-per-channel \
19 | --wq-bitw 2 \
20 | --wq-mode statsq \
21 | --model_type swin \
22 | --teacher_type swin \
23 | --quantized \
24 | --pretrained \
25 | --pretrained_initialized \
26 | --use-kd --teacher swin_t \
27 | --kd_hard_and_soft 1 \
28 | --qk_reparam \
29 | --qk_reparam_type 0 \
30 | --teacher_pretrained \
31 | --output ./outputs/w2a2_swin_t_qkreparam/ \
32 | --visible_gpu '0,1,2,3,4,5,6,7' \
33 | --world_size '8' \
34 | --tcp_port '12345'
35 |
36 | ## Fine tune with CGA
37 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
38 | your_path/dataset/imagenet-1k/imagenet \
39 | --dataset 'torch/imagenet' \
40 | --epochs 300 \
41 | --batch-size 64 \
42 | --weight-decay 0.05 \
43 | --warmup-lr 1.0e-6 \
44 | --lr 5.0e-4 \
45 | --warmup-epochs 5 \
46 | --mixup 0.0 --cutmix 0.0 \
47 | --aq-enable \
48 | --aq-mode lsq \
49 | --aq-per-channel \
50 | --aq_clip_learnable \
51 | --aq-bitw 2 \
52 | --wq-enable \
53 | --wq-per-channel \
54 | --wq-bitw 2 \
55 | --wq-mode statsq \
56 | --model_type swin \
57 | --teacher_type swin \
58 | --quantized \
59 | --pretrained \
60 | --pretrained_initialized \
61 | --use-kd --teacher swin_t \
62 | --kd_hard_and_soft 1 \
63 | --qk_reparam \
64 | --qk_reparam_type 1 \
65 | --boundaryRange 0.005 \
66 | --freeze_for_n_epochs 30 \
67 | --teacher_pretrained \
68 | --resume put the model you wish to finetune here \
69 | --output ./outputs/w2a2_swin_t_qkreparam/ \
70 | --visible_gpu '0,1,2,3,4,5,6,7' \
71 | --world_size '8' \
72 | --tcp_port '12345'
--------------------------------------------------------------------------------
/train_scripts/swin_t/w3a3_swin_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 3-bit Swin-T
2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 64 \
7 | --weight-decay 0.0 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 2.0e-4 \
10 | --warmup-epochs 0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 3 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 3 \
19 | --wq-mode statsq \
20 | --model_type swin \
21 | --teacher_type swin \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher swin_t \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w3a3_swin_t_qkreparam/ \
31 | --visible_gpu '0,1,2,3,4,5,6,7' \
32 | --world_size '8' \
33 | --tcp_port '12345'
34 |
35 | ## Fine tune with CGA
36 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 64 \
41 | --weight-decay 0.0 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 2.0e-4 \
44 | --warmup-epochs 0 \
45 | --aq-enable \
46 | --aq-mode lsq \
47 | --aq-per-channel \
48 | --aq_clip_learnable \
49 | --aq-bitw 3 \
50 | --wq-enable \
51 | --wq-per-channel \
52 | --wq-bitw 3 \
53 | --wq-mode statsq \
54 | --model_type swin \
55 | --teacher_type swin \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher swin_t \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w3a3_swin_t_qkreparam/ \
68 | --visible_gpu '0,1,2,3,4,5,6,7' \
69 | --world_size '8' \
70 | --tcp_port '12345'
--------------------------------------------------------------------------------
/train_scripts/swin_t/w4a4_swin_t.sh:
--------------------------------------------------------------------------------
1 | ## traing of 4-bit Swin-T
2 | python3 train.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
3 | your_path/dataset/imagenet-1k/imagenet \
4 | --dataset 'torch/imagenet' \
5 | --epochs 300 \
6 | --batch-size 64 \
7 | --weight-decay 0.0 \
8 | --warmup-lr 1.0e-6 \
9 | --lr 2.0e-4 \
10 | --warmup-epochs 0 \
11 | --aq-enable \
12 | --aq-mode lsq \
13 | --aq-per-channel \
14 | --aq_clip_learnable \
15 | --aq-bitw 4 \
16 | --wq-enable \
17 | --wq-per-channel \
18 | --wq-bitw 4 \
19 | --wq-mode statsq \
20 | --model_type swin \
21 | --teacher_type swin \
22 | --quantized \
23 | --pretrained \
24 | --pretrained_initialized \
25 | --use-kd --teacher swin_t \
26 | --kd_hard_and_soft 1 \
27 | --qk_reparam \
28 | --qk_reparam_type 0 \
29 | --teacher_pretrained \
30 | --output ./outputs/w4a4_swin_t_qkreparam/ \
31 | --visible_gpu '0,1,2,3,4,5,6,7' \
32 | --world_size '8' \
33 | --tcp_port '12345'
34 |
35 | ## Fine tune with CGA
36 | python3 cga.py -c./configs/swin_t_imagenet.attn_q.yml --model swin_t \
37 | your_path/dataset/imagenet-1k/imagenet \
38 | --dataset 'torch/imagenet' \
39 | --epochs 300 \
40 | --batch-size 64 \
41 | --weight-decay 0.0 \
42 | --warmup-lr 1.0e-6 \
43 | --lr 2.0e-4 \
44 | --warmup-epochs 0 \
45 | --aq-enable \
46 | --aq-mode lsq \
47 | --aq-per-channel \
48 | --aq_clip_learnable \
49 | --aq-bitw 4 \
50 | --wq-enable \
51 | --wq-per-channel \
52 | --wq-bitw 4 \
53 | --wq-mode statsq \
54 | --model_type swin \
55 | --teacher_type swin \
56 | --quantized \
57 | --pretrained \
58 | --pretrained_initialized \
59 | --use-kd --teacher swin_t \
60 | --kd_hard_and_soft 1 \
61 | --qk_reparam \
62 | --qk_reparam_type 1 \
63 | --boundaryRange 0.005 \
64 | --freeze_for_n_epochs 30 \
65 | --teacher_pretrained \
66 | --resume put the model you wish to finetune here \
67 | --output ./outputs/w4a4_swin_t_qkreparam/ \
68 | --visible_gpu '0,1,2,3,4,5,6,7' \
69 | --world_size '8' \
70 | --tcp_port '12345'
--------------------------------------------------------------------------------