├── models
├── __init__.py
├── quantization_utils
│ ├── __init__.py
│ ├── quant_utils.py
│ └── quant_modules.py
├── utils.py
├── layers_quant.py
├── vit_quant.py
└── swin_quant.py
├── overview.png
├── utils
├── __init__.py
├── data_utils.py
├── build_model.py
└── kde.py
├── README.md
├── generate_data.py
├── test_quant.py
└── LICENSE
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vit_quant import *
2 | from .swin_quant import *
3 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zkkli/PSAQ-ViT/HEAD/overview.png
--------------------------------------------------------------------------------
/models/quantization_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .quant_modules import QuantLinear, QuantAct, QuantConv2d
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_model import *
2 | from .kde import KernelDensityEstimator
3 | from .data_utils import build_dataset
4 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import math
4 | from PIL import Image
5 | import torchvision.transforms as transforms
6 | import torchvision.datasets as datasets
7 |
8 |
9 | def build_dataset(args):
10 | model_type = args.model.split("_")[0]
11 | if model_type == "deit":
12 | mean = (0.485, 0.456, 0.406)
13 | std = (0.229, 0.224, 0.225)
14 | crop_pct = 0.875
15 | elif model_type == 'vit':
16 | mean = (0.5, 0.5, 0.5)
17 | std = (0.5, 0.5, 0.5)
18 | crop_pct = 0.9
19 | elif model_type == 'swin':
20 | mean = (0.485, 0.456, 0.406)
21 | std = (0.229, 0.224, 0.225)
22 | crop_pct = 0.9
23 | else:
24 | raise NotImplementedError
25 |
26 | train_transform = build_transform(mean=mean, std=std, crop_pct=crop_pct)
27 | val_transform = build_transform(mean=mean, std=std, crop_pct=crop_pct)
28 |
29 | # Data
30 | traindir = os.path.join(args.dataset, 'train')
31 | valdir = os.path.join(args.dataset, 'val')
32 |
33 | val_dataset = datasets.ImageFolder(valdir, val_transform)
34 | val_loader = torch.utils.data.DataLoader(
35 | val_dataset,
36 | batch_size=args.val_batchsize,
37 | shuffle=False,
38 | num_workers=args.num_workers,
39 | pin_memory=True,
40 | )
41 |
42 | train_dataset = datasets.ImageFolder(traindir, train_transform)
43 | train_loader = torch.utils.data.DataLoader(
44 | train_dataset,
45 | batch_size=args.calib_batchsize,
46 | shuffle=True,
47 | num_workers=args.num_workers,
48 | pin_memory=True,
49 | drop_last=True,
50 | )
51 |
52 | return train_loader, val_loader
53 |
54 |
55 | def build_transform(input_size=224, interpolation="bicubic",
56 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
57 | crop_pct=0.875):
58 | def _pil_interp(method):
59 | if method == "bicubic":
60 | return Image.BICUBIC
61 | elif method == "lanczos":
62 | return Image.LANCZOS
63 | elif method == "hamming":
64 | return Image.HAMMING
65 | else:
66 | return Image.BILINEAR
67 | resize_im = input_size > 32
68 | t = []
69 | if resize_im:
70 | size = int(math.floor(input_size / crop_pct))
71 | ip = _pil_interp(interpolation)
72 | t.append(
73 | transforms.Resize(
74 | size, interpolation=ip
75 | ), # to maintain same ratio w.r.t. 224 images
76 | )
77 | t.append(transforms.CenterCrop(input_size))
78 |
79 | t.append(transforms.ToTensor())
80 | t.append(transforms.Normalize(mean, std))
81 | return transforms.Compose(t)
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # Patch Similarity Aware Data-Free Quantization for Vision Transformers
6 |
7 | This repository contains the official PyTorch implementation for the ECCV 2022 paper
8 | *["Patch Similarity Aware Data-Free Quantization for Vision Transformers"](https://arxiv.org/abs/2203.02250).* To the best of our knowledge, this is the first work on data-free quantization for vision transformers. Below are instructions for reproducing the results.
9 |
10 | ## Installation
11 |
12 | - **To install PSAQ-ViT** and develop locally:
13 |
14 | ```bash
15 | git clone https://github.com/zkkli/PSAQ-ViT.git
16 | cd PSAQ-ViT
17 | ```
18 |
19 | ## Quantization
20 |
21 | - You can quantize and evaluate a single model using the following command:
22 |
23 | ```bash
24 | python test_quant.py [--model] [--dataset] [--w_bit] [--a_bit] [--mode]
25 |
26 | optional arguments:
27 | --model: Model architecture, the choises can be:
28 | deit_tiny, deit_small, deit_base, swin_tiny, and swin_small.
29 | --dataset: Path to ImageNet dataset.
30 | --w_bit: Bit-precision of weights, default=8.
31 | --a_bit: Bit-precision of activation, default=8.
32 | --mode: Mode of calibration data,
33 | 0: Generated fake data (PSAQ-ViT)
34 | 1: Gaussian noise
35 | 2: Real data
36 | ```
37 |
38 | - Example: Quantize DeiT-B with generated fake data **(PSAQ-ViT)**.
39 |
40 | ```bash
41 | python test_quant.py --model deit_base --dataset --mode 0
42 | ```
43 |
44 | - Example: Quantize DeiT-B with Gaussian noise.
45 |
46 | ```bash
47 | python test_quant.py --model deit_base --dataset --mode 1
48 | ```
49 |
50 | - Example: Quantize DeiT-B with Real data.
51 |
52 | ```bash
53 | python test_quant.py --model deit_base --dataset --mode 2
54 | ```
55 |
56 | ## Results
57 |
58 | Below are the experimental results of our proposed PSAQ-ViT that you should get on ImageNet dataset using an RTX 3090 GPU.
59 |
60 | | Model | Prec. | Top-1(%) | Prec. | Top-1(%) |
61 | |:--------------:|:-----:|:--------:|:-----:|:--------:|
62 | | DeiT-T (72.21) | W4/A8 | 65.57 | W8/A8 | 71.56 |
63 | | DeiT-S (79.85) | W4/A8 | 73.23 | W8/A8 | 76.92 |
64 | | DeiT-B (81.85) | W4/A8 | 77.05 | W8/A8 | 79.10 |
65 | | Swin-T (81.35) | W4/A8 | 71.79 | W8/A8 | 75.35 |
66 | | Swin-S (83.20) | W4/A8 | 75.14 | W8/A8 | 76.64 |
67 |
68 | ## Citation
69 |
70 | We appreciate it if you would please cite the following paper if you found the implementation useful for your work:
71 |
72 | ```bash
73 | @inproceedings{li2022psaqvit,
74 | title={Patch Similarity Aware Data-Free Quantization for Vision Transformers},
75 | author={Li, Zhikai and Ma, Liping and Chen, Mengjuan and Xiao, Junrui and Gu, Qingyi},
76 | booktitle={European Conference on Computer Vision},
77 | pages={154--170},
78 | year={2022}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/utils/build_model.py:
--------------------------------------------------------------------------------
1 | from types import MethodType
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import timm
6 | from timm.models import vision_transformer
7 | from timm.models.vision_transformer import Attention
8 | from timm.models.swin_transformer import WindowAttention
9 |
10 | def attention_forward(self, x):
11 | B, N, C = x.shape
12 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
13 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
14 |
15 | # attn = (q @ k.transpose(-2, -1)) * self.scale
16 | attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale
17 | attn = attn.softmax(dim=-1)
18 | attn = self.attn_drop(attn)
19 | del q, k
20 |
21 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
22 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C)
23 | del attn, v
24 | x = self.proj(x)
25 | x = self.proj_drop(x)
26 | return x
27 |
28 | def window_attention_forward(self, x, mask = None):
29 | B_, N, C = x.shape
30 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
31 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
32 |
33 | q = q * self.scale
34 | # attn = (q @ k.transpose(-2, -1))
35 | attn = self.matmul1(q, k.transpose(-2,-1))
36 |
37 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
38 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
39 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
40 | attn = attn + relative_position_bias.unsqueeze(0)
41 |
42 | if mask is not None:
43 | nW = mask.shape[0]
44 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
45 | attn = attn.view(-1, self.num_heads, N, N)
46 | attn = self.softmax(attn)
47 | else:
48 | attn = self.softmax(attn)
49 |
50 | attn = self.attn_drop(attn)
51 |
52 | # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
53 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C)
54 | x = self.proj(x)
55 | x = self.proj_drop(x)
56 | return x
57 |
58 |
59 | class MatMul(nn.Module):
60 | def forward(self, A, B):
61 | return A @ B
62 |
63 |
64 | def build_model(name, Pretrained=True):
65 | """
66 | Get a vision transformer model.
67 | This will replace matrix multiplication operations with matmul modules in the model.
68 |
69 | Currently support almost all models in timm.models.transformers, including:
70 | - vit_tiny/small/base/large_patch16/patch32_224/384,
71 | - deit_tiny/small/base(_distilled)_patch16_224,
72 | - deit_base(_distilled)_patch16_384,
73 | - swin_tiny/small/base/large_patch4_window7_224,
74 | - swin_base/large_patch4_window12_384
75 |
76 | These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator
77 | for calibration and testing.
78 | """
79 | net = timm.create_model(name, pretrained=Pretrained)
80 |
81 | for name, module in net.named_modules():
82 | if isinstance(module, Attention):
83 | setattr(module, "matmul1", MatMul())
84 | setattr(module, "matmul2", MatMul())
85 | module.forward = MethodType(attention_forward, module)
86 | if isinstance(module, WindowAttention):
87 | setattr(module, "matmul1", MatMul())
88 | setattr(module, "matmul2", MatMul())
89 | module.forward = MethodType(window_attention_forward, module)
90 |
91 | net = net.cuda()
92 | net.eval()
93 | return net
94 |
--------------------------------------------------------------------------------
/models/quantization_utils/quant_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | from torch.autograd import Function, Variable
4 | import torch
5 |
6 |
7 | def reshape_tensor(input, scale, zero_point, is_weight=True):
8 | if is_weight:
9 | if len(input.shape) == 4:
10 | range_shape = (-1, 1, 1, 1)
11 | elif len(input.shape) == 2:
12 | range_shape = (-1, 1)
13 | else:
14 | raise NotImplementedError
15 | else:
16 | if len(input.shape) == 2:
17 | range_shape = (1, -1)
18 | elif len(input.shape) == 3:
19 | range_shape = (1, 1, -1)
20 | elif len(input.shape) == 4:
21 | range_shape = (1, -1, 1, 1)
22 | else:
23 | raise NotImplementedError
24 |
25 | scale = scale.reshape(range_shape)
26 | zero_point = zero_point.reshape(range_shape)
27 |
28 | return scale, zero_point
29 |
30 |
31 | def symmetric_linear_quantization_params(num_bits,
32 | min_val,
33 | max_val):
34 | """
35 | Compute the scaling factor and zeropoint with the given quantization range for symmetric quantization.
36 | Parameters:
37 | ----------
38 | saturation_min: lower bound for quantization range
39 | saturation_max: upper bound for quantization range
40 | per_channel: if True, calculate the scaling factor per channel.
41 | """
42 | qmax = 2 ** (num_bits - 1) - 1
43 | qmin = -(2 ** (num_bits - 1))
44 | eps = torch.finfo(torch.float32).eps
45 |
46 | max_val = torch.max(-min_val, max_val)
47 | scale = max_val / (float(qmax - qmin) / 2)
48 | scale.clamp_(eps)
49 | zero_point = torch.zeros_like(max_val, dtype=torch.int64)
50 |
51 | return scale, zero_point, qmin, qmax
52 |
53 |
54 | def asymmetric_linear_quantization_params(num_bits,
55 | min_val,
56 | max_val):
57 | """
58 | Compute the scaling factor and zeropoint with the given quantization range.
59 | saturation_min: lower bound for quantization range
60 | saturation_max: upper bound for quantization range
61 | """
62 | qmax = 2 ** num_bits - 1
63 | qmin = 0
64 | eps = torch.finfo(torch.float32).eps
65 |
66 | scale = (max_val - min_val) / float(qmax - qmin)
67 | scale.clamp_(eps)
68 | zero_point = qmin - torch.round(min_val / scale)
69 | zero_point.clamp_(qmin, qmax)
70 |
71 | return scale, zero_point, qmin, qmax
72 |
73 |
74 | class SymmetricQuantFunction(Function):
75 | """
76 | Class to quantize the given floating-point values with given range and bit-setting.
77 | Currently only support inference, but not support back-propagation.
78 | """
79 | @staticmethod
80 | def forward(ctx, x, k, x_min=None, x_max=None):
81 | """
82 | x: single-precision value to be quantized
83 | k: bit-setting for x
84 | x_min: lower bound for quantization range
85 | x_max=None
86 | """
87 | scale, zero_point, qmin, qmax = symmetric_linear_quantization_params(k, x_min, x_max)
88 | scale, zero_point = reshape_tensor(x, scale, zero_point, is_weight=True)
89 |
90 | # quantize
91 | quant_x = x / scale + zero_point
92 | quant_x = quant_x.round().clamp(qmin, qmax)
93 |
94 | # dequantize
95 | quant_x = (quant_x - zero_point) * scale
96 |
97 | return torch.autograd.Variable(quant_x)
98 |
99 | @staticmethod
100 | def backward(ctx, grad_output):
101 | raise NotImplementedError
102 |
103 |
104 | class AsymmetricQuantFunction(Function):
105 | """
106 | Class to quantize the given floating-point values with given range and bit-setting.
107 | Currently only support inference, but not support back-propagation.
108 | """
109 | @staticmethod
110 | def forward(ctx, x, k, x_min=None, x_max=None):
111 | """
112 | x: single-precision value to be quantized
113 | k: bit-setting for x
114 | x_min: lower bound for quantization range
115 | x_max=None
116 | """
117 | scale, zero_point, qmin, qmax = asymmetric_linear_quantization_params(k, x_min, x_max)
118 | scale, zero_point = reshape_tensor(x, scale, zero_point, is_weight=False)
119 |
120 | # quantize
121 | quant_x = x / scale + zero_point
122 | quant_x = quant_x.round().clamp(qmin, qmax)
123 |
124 | # dequantize
125 | quant_x = (quant_x - zero_point) * scale
126 |
127 | return torch.autograd.Variable(quant_x)
128 |
129 | @staticmethod
130 | def backward(ctx, grad_output):
131 | raise NotImplementedError
132 |
--------------------------------------------------------------------------------
/utils/kde.py:
--------------------------------------------------------------------------------
1 | """Implementation of Kernel Density Estimation (KDE) [1].
2 | Kernel density estimation is a nonparameteric density estimation method. It works by
3 | placing kernels K on each point in a "training" dataset D. Then, for a test point x,
4 | p(x) is estimated as p(x) = 1 / |D| \sum_{x_i \in D} K(u(x, x_i)), where u is some
5 | function of x, x_i. In order for p(x) to be a valid probability distribution, the kernel
6 | K must also be a valid probability distribution.
7 | References (used throughout the file):
8 | [1]: https://en.wikipedia.org/wiki/Kernel_density_estimation
9 | """
10 |
11 | import abc
12 |
13 | import numpy as np
14 | import torch
15 | from torch import nn
16 |
17 |
18 | class GenerativeModel(abc.ABC, nn.Module):
19 | """Base class inherited by all generative models in pytorch-generative.
20 | Provides:
21 | * An abstract `sample()` method which is implemented by subclasses that support
22 | generating samples.
23 | * Variables `self._c, self._h, self._w` which store the shape of the (first)
24 | image Tensor the model was trained with. Note that `forward()` must have been
25 | called at least once and the input must be an image for these variables to be
26 | available.
27 | * A `device` property which returns the device of the model's parameters.
28 | """
29 |
30 | def __call__(self, *args, **kwargs):
31 | if getattr(self, "_c", None) is None and len(args[0].shape) == 4:
32 | _, self._c, self._h, self._w = args[0].shape
33 | return super().__call__(*args, **kwargs)
34 |
35 | @property
36 | def device(self):
37 | return next(self.parameters()).device
38 |
39 | @abc.abstractmethod
40 | def sample(self, n_samples):
41 | ...
42 |
43 |
44 | class Kernel(abc.ABC, nn.Module):
45 | """Base class which defines the interface for all kernels."""
46 |
47 | def __init__(self, bandwidth=0.01):
48 | """Initializes a new Kernel.
49 | Args:
50 | bandwidth: The kernel's (band)width.
51 | """
52 | super().__init__()
53 | self.bandwidth = bandwidth
54 |
55 | def _diffs(self, test_Xs, train_Xs):
56 | """Computes difference between each x in test_Xs with all train_Xs."""
57 | test_Xs = test_Xs.view(test_Xs.shape[0], test_Xs.shape[1], 1)
58 | train_Xs = train_Xs.view(train_Xs.shape[0], 1, *train_Xs.shape[1:])
59 | return test_Xs - train_Xs
60 |
61 | @abc.abstractmethod
62 | def forward(self, test_Xs, train_Xs):
63 | """Computes p(x) for each x in test_Xs given train_Xs."""
64 |
65 | @abc.abstractmethod
66 | def sample(self, train_Xs):
67 | """Generates samples from the kernel distribution."""
68 |
69 |
70 | class ParzenWindowKernel(Kernel):
71 | """Implementation of the Parzen window kernel."""
72 |
73 | def forward(self, test_Xs, train_Xs):
74 | abs_diffs = torch.abs(self._diffs(test_Xs, train_Xs))
75 | dims = tuple(range(len(abs_diffs.shape))[1:])
76 | dim = np.prod(abs_diffs.shape[1:])
77 | inside = torch.sum(abs_diffs / self.bandwidth <= 0.5, dim=dims) == dim
78 | coef = 1 / self.bandwidth ** dim
79 | return (coef * inside) #.mean() #dim=1
80 |
81 | def sample(self, train_Xs):
82 | device = train_Xs.device
83 | noise = (torch.rand(train_Xs.shape, device=device) - 0.5) * self.bandwidth
84 | return train_Xs + noise
85 |
86 |
87 | class GaussianKernel(Kernel):
88 | """Implementation of the Gaussian kernel."""
89 |
90 | def forward(self, test_Xs, train_Xs):
91 | diffs = self._diffs(test_Xs, train_Xs)
92 | dims = tuple(range(len(diffs.shape))[2:])
93 | var = self.bandwidth ** 2
94 | exp = torch.exp(- torch.pow(diffs,2) / (2 * var))
95 | coef = 1 / torch.sqrt(torch.tensor(2 * np.pi * var))
96 | return (coef * exp).mean(dim=-1)
97 |
98 | def sample(self, train_Xs):
99 | device = train_Xs.device
100 | noise = torch.randn(train_Xs.shape) * self.bandwidth
101 | return train_Xs + noise
102 |
103 |
104 | class KernelDensityEstimator(GenerativeModel):
105 | """The KernelDensityEstimator model."""
106 |
107 | def __init__(self, train_Xs, kernel=None):
108 | """Initializes a new KernelDensityEstimator.
109 | Args:
110 | train_Xs: The "training" data to use when estimating probabilities.
111 | kernel: The kernel to place on each of the train_Xs.
112 | """
113 | super().__init__()
114 | self.kernel = kernel or GaussianKernel()
115 | self.train_Xs = train_Xs
116 |
117 | @property
118 | def device(self):
119 | return self.train_Xs.device
120 |
121 | # TODO(eugenhotaj): This method consumes O(train_Xs * x) memory. Implement an
122 | # iterative version instead.
123 | def forward(self, x):
124 | return self.kernel(x, self.train_Xs)
125 |
126 | def sample(self, n_samples):
127 | idxs = np.random.choice(range(len(self.train_Xs)), size=n_samples)
128 | return self.kernel.sample(self.train_Xs[idxs])
129 |
--------------------------------------------------------------------------------
/models/quantization_utils/quant_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.multiprocessing as mp
7 | from torch.nn import Parameter
8 |
9 | from .quant_utils import *
10 |
11 |
12 | class QuantConv2d(nn.Conv2d):
13 | """
14 | Class to quantize weights of given convolutional layer
15 | """
16 | def __init__(self,
17 | weight_bit,
18 | in_channels,
19 | out_channels,
20 | kernel_size,
21 | stride=1,
22 | padding=0,
23 | dilation=1,
24 | groups=1,
25 | bias=True):
26 | super(QuantConv2d, self).__init__(in_channels=in_channels,
27 | out_channels=out_channels,
28 | kernel_size=kernel_size,
29 | stride=stride,
30 | padding=padding,
31 | dilation=dilation,
32 | groups=groups,
33 | bias=bias)
34 | self.weight_bit = weight_bit
35 | self.quant = False
36 | self.weight_function = SymmetricQuantFunction.apply
37 |
38 | def __repr__(self):
39 | s = super(QuantConv2d, self).__repr__()
40 | s = "(" + s + " weight_bit={})".format(self.weight_bit)
41 | return s
42 |
43 | def forward(self, x):
44 | """
45 | using quantized weights to forward activation x
46 | """
47 | if not self.quant:
48 | return F.conv2d(
49 | x,
50 | self.weight,
51 | self.bias,
52 | self.stride,
53 | self.padding,
54 | self.dilation,
55 | self.groups,
56 | )
57 |
58 | v = self.weight
59 | v = v.reshape(v.shape[0], -1)
60 | v_max = v.max(axis=1).values
61 | v_min = v.min(axis=1).values
62 | w = self.weight_function(self.weight, self.weight_bit, v_min, v_max)
63 |
64 | return F.conv2d(
65 | x,
66 | w,
67 | self.bias,
68 | self.stride,
69 | self.padding,
70 | self.dilation,
71 | self.groups
72 | )
73 |
74 |
75 | class QuantLinear(nn.Linear):
76 | """
77 | Class to quantize weights of given Linear layer
78 | """
79 | def __init__(self,
80 | weight_bit,
81 | in_features,
82 | out_features,
83 | bias=True):
84 | super(QuantLinear, self).__init__(in_features, out_features, bias)
85 | self.weight_bit = weight_bit
86 | self.quant = False
87 | self.weight_function = SymmetricQuantFunction.apply
88 |
89 | def __repr__(self):
90 | s = super(QuantLinear, self).__repr__()
91 | s = "(" + s + " weight_bit={})".format(self.weight_bit)
92 | return s
93 |
94 | def forward(self, x):
95 | """
96 | using quantized weights to forward activation x
97 | """
98 | if not self.quant:
99 | return F.linear(
100 | x,
101 | self.weight,
102 | self.bias
103 | )
104 |
105 | v = self.weight
106 | v = v.reshape(v.shape[0], -1)
107 | v_max = v.max(axis=1).values
108 | v_min = v.min(axis=1).values
109 | w = self.weight_function(self.weight, self.weight_bit, v_min, v_max)
110 |
111 | return F.linear(
112 | x,
113 | weight=w,
114 | bias=self.bias
115 | )
116 |
117 |
118 | class QuantAct(nn.Module):
119 | """
120 | Class to quantize given activations
121 | """
122 | def __init__(self,
123 | activation_bit,
124 | running_stat=True):
125 | super(QuantAct, self).__init__()
126 | self.activation_bit = activation_bit
127 | self.running_stat = running_stat
128 | self.quant = False
129 | self.act_function = AsymmetricQuantFunction.apply
130 |
131 | self.register_buffer('x_min', torch.zeros(1))
132 | self.register_buffer('x_max', torch.zeros(1))
133 |
134 | def __repr__(self):
135 | return "{0}(activation_bit={1}, running_stat={2}, Act_min: {3:.2f}, Act_max: {4:.2f})".format(
136 | self.__class__.__name__, self.activation_bit, self.running_stat,
137 | self.x_min.item(), self.x_max.item())
138 |
139 | def fix(self):
140 | """
141 | fix the activation range by setting running stat
142 | """
143 | self.running_stat = False
144 |
145 | def unfix(self):
146 | """
147 | unfix the activation range by setting running stat
148 | """
149 | self.running_stat = True
150 |
151 | def forward(self, x):
152 | """
153 | quantize given activation x
154 | """
155 | if self.running_stat:
156 | cur_max = x.data.max()
157 | cur_min = x.data.min()
158 | if self.x_max == 0:
159 | self.x_max = cur_max
160 | self.x_min = cur_min
161 | else:
162 | self.x_max = torch.max(cur_max, self.x_max)
163 | self.x_min = torch.min(cur_min, self.x_min)
164 |
165 | if not self.quant:
166 | return x
167 |
168 | quant_act = self.act_function(x, self.activation_bit, self.x_min, self.x_max)
169 |
170 | return quant_act
171 |
--------------------------------------------------------------------------------
/generate_data.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from tqdm import tqdm
8 |
9 | from utils import *
10 |
11 |
12 | model_zoo = {'deit_tiny': 'deit_tiny_patch16_224',
13 | 'deit_small': 'deit_small_patch16_224',
14 | 'deit_base': 'deit_base_patch16_224',
15 | 'swin_tiny': 'swin_tiny_patch4_window7_224',
16 | 'swin_small': 'swin_small_patch4_window7_224',
17 | }
18 |
19 |
20 | class AttentionMap:
21 | def __init__(self, module):
22 | self.hook = module.register_forward_hook(self.hook_fn)
23 |
24 | def hook_fn(self, module, input, output):
25 | self.feature = output
26 |
27 | def remove(self):
28 | self.hook.remove()
29 |
30 |
31 | def generate_data(args):
32 | args.batch_size = args.calib_batchsize
33 |
34 | # Load pretrained model
35 | p_model = build_model(model_zoo[args.model], Pretrained=True)
36 |
37 | # Hook the attention
38 | hooks = []
39 | if 'swin' in args.model:
40 | for m in p_model.layers:
41 | for n in range(len(m.blocks)):
42 | hooks.append(AttentionMap(m.blocks[n].attn.matmul2))
43 | else:
44 | for m in p_model.blocks:
45 | hooks.append(AttentionMap(m.attn.matmul2))
46 |
47 | # Init Gaussian noise
48 | img = torch.randn((args.batch_size, 3, 224, 224)).cuda()
49 | img.requires_grad = True
50 |
51 | # Init optimizer
52 | args.lr = 0.25 if 'swin' in args.model else 0.20
53 | optimizer = optim.Adam([img], lr=args.lr, betas=[0.5, 0.9], eps=1e-8)
54 |
55 | # Set pseudo labels
56 | pred = torch.LongTensor([random.randint(0, 999) for _ in range(args.batch_size)]).to('cuda')
57 | var_pred = random.uniform(2500, 3000) # for batch_size 32
58 |
59 | criterion = nn.CrossEntropyLoss()
60 |
61 | # Train for two epochs
62 | for lr_it in range(2):
63 | if lr_it == 0:
64 | iterations_per_layer = 500
65 | lim = 15
66 | else:
67 | iterations_per_layer = 500
68 | lim = 30
69 |
70 | lr_scheduler = lr_cosine_policy(args.lr, 100, iterations_per_layer)
71 |
72 | with tqdm(range(iterations_per_layer)) as pbar:
73 | for itr in pbar:
74 | pbar.set_description(f"Epochs {lr_it+1}/{2}")
75 |
76 | # Learning rate scheduling
77 | lr_scheduler(optimizer, itr, itr)
78 |
79 | # Apply random jitter offsets (from DeepInversion[1])
80 | # [1] Yin, Hongxu, et al. "Dreaming to distill: Data-free knowledge transfer via deepinversion.", CVPR2020.
81 | off = random.randint(-lim, lim)
82 | img_jit = torch.roll(img, shifts=(off, off), dims=(2, 3))
83 | # Flipping
84 | flip = random.random() > 0.5
85 | if flip:
86 | img_jit = torch.flip(img_jit, dims=(3,))
87 |
88 | # Forward pass
89 | optimizer.zero_grad()
90 | p_model.zero_grad()
91 |
92 | output = p_model(img_jit)
93 |
94 | loss_oh = criterion(output, pred)
95 | loss_tv = torch.norm(get_image_prior_losses(img_jit) - var_pred)
96 |
97 | loss_entropy = 0
98 | for itr_hook in range(len(hooks)):
99 | # Hook attention
100 | attention = hooks[itr_hook].feature
101 | attention_p = attention.mean(dim=1)[:, 1:, :]
102 | sims = torch.cosine_similarity(attention_p.unsqueeze(1), attention_p.unsqueeze(2), dim=3)
103 |
104 | # Compute differential entropy
105 | kde = KernelDensityEstimator(sims.view(args.batch_size, -1))
106 | start_p = sims.min().item()
107 | end_p = sims.max().item()
108 | x_plot = torch.linspace(start_p, end_p, steps=10).repeat(args.batch_size, 1).cuda()
109 | kde_estimate = kde(x_plot)
110 | dif_entropy_estimated = differential_entropy(kde_estimate, x_plot)
111 | loss_entropy -= dif_entropy_estimated
112 |
113 | # Combine loss
114 | total_loss = loss_entropy + 1.0 * loss_oh + 0.05 * loss_tv
115 |
116 | # Do image update
117 | total_loss.backward()
118 | optimizer.step()
119 |
120 | # Clip color outliers
121 | img.data = clip(img.data)
122 |
123 | return img.detach()
124 |
125 |
126 | def differential_entropy(pdf, x_pdf):
127 | # pdf is a vector because we want to perform a numerical integration
128 | pdf = pdf + 1e-4
129 | f = -1 * pdf * torch.log(pdf)
130 | # Integrate using the composite trapezoidal rule
131 | ans = torch.trapz(f, x_pdf, dim=-1).mean()
132 | return ans
133 |
134 |
135 | def get_image_prior_losses(inputs_jit):
136 | # Compute total variation regularization loss
137 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:]
138 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :]
139 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:]
140 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:]
141 |
142 | loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4)
143 | return loss_var_l2
144 |
145 |
146 | def clip(image_tensor, use_fp16=False):
147 | # Adjust the input based on mean and variance
148 | if use_fp16:
149 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16)
150 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16)
151 | else:
152 | mean = np.array([0.485, 0.456, 0.406])
153 | std = np.array([0.229, 0.224, 0.225])
154 | for c in range(3):
155 | m, s = mean[c], std[c]
156 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
157 | #image_tensor[:, c] = torch.clamp(image_tensor[:, c], 0, 1)
158 | return image_tensor
159 |
160 |
161 | def lr_policy(lr_fn):
162 | def _alr(optimizer, iteration, epoch):
163 | lr = lr_fn(iteration, epoch)
164 | for param_group in optimizer.param_groups:
165 | param_group['lr'] = lr
166 |
167 | return _alr
168 |
169 |
170 | def lr_cosine_policy(base_lr, warmup_length, epochs):
171 | def _lr_fn(iteration, epoch):
172 | if epoch < warmup_length:
173 | lr = base_lr * (epoch + 1) / warmup_length
174 | else:
175 | e = epoch - warmup_length
176 | es = epochs - warmup_length
177 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
178 | return lr
179 |
180 | return lr_policy(_lr_fn)
181 |
--------------------------------------------------------------------------------
/test_quant.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import sys
5 | import random
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 |
10 | from models import *
11 | from utils import *
12 |
13 | from generate_data import generate_data
14 |
15 |
16 | def get_args_parser():
17 | parser = argparse.ArgumentParser(description="PSAQ-ViT", add_help=False)
18 | parser.add_argument("--model", default="deit_tiny",
19 | choices=['deit_tiny', 'deit_small', 'deit_base', 'swin_tiny', 'swin_small'],
20 | help="model")
21 | parser.add_argument('--dataset', default="/Path/to/Dataset/",
22 | help='path to dataset')
23 | parser.add_argument("--calib-batchsize", default=32,
24 | type=int, help="batchsize of calibration set")
25 | parser.add_argument("--val-batchsize", default=200,
26 | type=int, help="batchsize of validation set")
27 | parser.add_argument("--num-workers", default=16, type=int,
28 | help="number of data loading workers (default: 16)")
29 | parser.add_argument("--device", default="cuda", type=str, help="device")
30 | parser.add_argument("--print-freq", default=100,
31 | type=int, help="print frequency")
32 | parser.add_argument("--seed", default=0, type=int, help="seed")
33 |
34 | parser.add_argument("--mode", default=0,
35 | type=int, help="mode of calibration data, 0: PSAQ-ViT, 1: Gaussian noise, 2: Real data")
36 | parser.add_argument('--w_bit', default=8,
37 | type=int, help='bit-precision of weights')
38 | parser.add_argument('--a_bit', default=8,
39 | type=int, help='bit-precision of activation')
40 |
41 | return parser
42 |
43 |
44 | class Config:
45 | def __init__(self, w_bit, a_bit):
46 | self.weight_bit = w_bit
47 | self.activation_bit = a_bit
48 |
49 |
50 | def str2model(name):
51 | model_zoo = {'deit_tiny': deit_tiny_patch16_224,
52 | 'deit_small': deit_small_patch16_224,
53 | 'deit_base': deit_base_patch16_224,
54 | 'swin_tiny': swin_tiny_patch4_window7_224,
55 | 'swin_small': swin_small_patch4_window7_224
56 | }
57 | print('Model: %s' % model_zoo[name].__name__)
58 | return model_zoo[name]
59 |
60 |
61 | def seed(seed=0):
62 | sys.setrecursionlimit(100000)
63 | os.environ["PYTHONHASHSEED"] = str(seed)
64 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
65 | torch.manual_seed(seed)
66 | torch.cuda.manual_seed_all(seed)
67 | torch.backends.cudnn.benchmark = False
68 | torch.backends.cudnn.deterministic = True
69 | np.random.seed(seed)
70 | random.seed(seed)
71 |
72 |
73 | def main():
74 | print(args)
75 | seed(args.seed)
76 |
77 | device = torch.device(args.device)
78 | # Load bit-config
79 | cfg = Config(args.w_bit, args.a_bit)
80 |
81 | # Build model
82 | model = str2model(args.model)(pretrained=True, cfg=cfg)
83 | model = model.to(device)
84 | model.eval()
85 |
86 | # Build dataloader
87 | train_loader, val_loader = build_dataset(args)
88 |
89 | # Define loss function (criterion)
90 | criterion = nn.CrossEntropyLoss().to(device)
91 |
92 | # Get calibration set
93 | # Case 0: PASQ-ViT
94 | if args.mode == 0:
95 | print("Generating data...")
96 | calibrate_data = generate_data(args)
97 | print("Calibrating with generated data...")
98 | with torch.no_grad():
99 | output = model(calibrate_data)
100 | # Case 1: Gaussian noise
101 | elif args.mode == 1:
102 | calibrate_data = torch.randn((args.calib_batchsize, 3, 224, 224)).to(device)
103 | print("Calibrating with Gaussian noise...")
104 | with torch.no_grad():
105 | output = model(calibrate_data)
106 | # Case 2: Real data (Standard)
107 | elif args.mode == 2:
108 | for data, target in train_loader:
109 | calibrate_data = data.to(device)
110 | break
111 | print("Calibrating with real data...")
112 | with torch.no_grad():
113 | output = model(calibrate_data)
114 | # Not implemented
115 | else:
116 | raise NotImplementedError
117 |
118 | # Freeze model
119 | model.model_quant()
120 | model.model_freeze()
121 |
122 | # Validate the quantized model
123 | print("Validating...")
124 | val_loss, val_prec1, val_prec5 = validate(
125 | args, val_loader, model, criterion, device
126 | )
127 |
128 |
129 | def validate(args, val_loader, model, criterion, device):
130 | batch_time = AverageMeter()
131 | losses = AverageMeter()
132 | top1 = AverageMeter()
133 | top5 = AverageMeter()
134 |
135 | # Switch to evaluate mode
136 | model.eval()
137 |
138 | val_start_time = end = time.time()
139 | for i, (data, target) in enumerate(val_loader):
140 | target = target.to(device)
141 | data = data.to(device)
142 | target = target.to(device)
143 |
144 | with torch.no_grad():
145 | output = model(data)
146 | loss = criterion(output, target)
147 |
148 | # Measure accuracy and record loss
149 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
150 | losses.update(loss.data.item(), data.size(0))
151 | top1.update(prec1.data.item(), data.size(0))
152 | top5.update(prec5.data.item(), data.size(0))
153 |
154 | # Measure elapsed time
155 | batch_time.update(time.time() - end)
156 | end = time.time()
157 |
158 | if i % args.print_freq == 0:
159 | print(
160 | "Test: [{0}/{1}]\t"
161 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
162 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
163 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
164 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format(
165 | i,
166 | len(val_loader),
167 | batch_time=batch_time,
168 | loss=losses,
169 | top1=top1,
170 | top5=top5,
171 | )
172 | )
173 | val_end_time = time.time()
174 | print(" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Time {time:.3f}".format(
175 | top1=top1, top5=top5, time=val_end_time - val_start_time))
176 |
177 | return losses.avg, top1.avg, top5.avg
178 |
179 |
180 | class AverageMeter(object):
181 | """Computes and stores the average and current value"""
182 |
183 | def __init__(self):
184 | self.reset()
185 |
186 | def reset(self):
187 | self.val = 0
188 | self.avg = 0
189 | self.sum = 0
190 | self.count = 0
191 |
192 | def update(self, val, n=1):
193 | self.val = val
194 | self.sum += val * n
195 | self.count += n
196 | self.avg = self.sum / self.count
197 |
198 |
199 | def accuracy(output, target, topk=(1,)):
200 | """Computes the precision@k for the specified values of k"""
201 | maxk = max(topk)
202 | batch_size = target.size(0)
203 |
204 | _, pred = output.topk(maxk, 1, True, True)
205 | pred = pred.t()
206 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
207 |
208 | res = []
209 | for k in topk:
210 | correct_k = correct[:k].reshape(-1).float().sum(0)
211 | res.append(correct_k.mul_(100.0 / batch_size))
212 | return res
213 |
214 |
215 | if __name__ == "__main__":
216 | parser = argparse.ArgumentParser('PSAQ', parents=[get_args_parser()])
217 | args = parser.parse_args()
218 | main()
219 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 |
10 | @torch.no_grad()
11 | def load_weights_from_npz(model, url, check_hash=False, progress=False, prefix=''):
12 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
13 | """
14 |
15 | def _n2p(w, t=True):
16 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
17 | w = w.flatten()
18 | if t:
19 | if w.ndim == 4:
20 | w = w.transpose([3, 2, 0, 1])
21 | elif w.ndim == 3:
22 | w = w.transpose([2, 0, 1])
23 | elif w.ndim == 2:
24 | w = w.transpose([1, 0])
25 | return torch.from_numpy(w)
26 |
27 | def _get_cache_dir(child_dir=''):
28 | """
29 | Returns the location of the directory where models are cached (and creates it if necessary).
30 | """
31 | hub_dir = torch.hub.get_dir()
32 | child_dir = () if not child_dir else (child_dir,)
33 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
34 | os.makedirs(model_dir, exist_ok=True)
35 | return model_dir
36 |
37 | def _download_cached_file(url, check_hash=True, progress=False):
38 | parts = torch.hub.urlparse(url)
39 | filename = os.path.basename(parts.path)
40 | cached_file = os.path.join(_get_cache_dir(), filename)
41 | if not os.path.exists(cached_file):
42 | hash_prefix = None
43 | if check_hash:
44 | r = torch.hub.HASH_REGEX.search(
45 | filename) # r is Optional[Match[str]]
46 | hash_prefix = r.group(1) if r else None
47 | torch.hub.download_url_to_file(
48 | url, cached_file, hash_prefix, progress=progress)
49 | return cached_file
50 |
51 | def adapt_input_conv(in_chans, conv_weight):
52 | conv_type = conv_weight.dtype
53 | # Some weights are in torch.half, ensure it's float for sum on CPU
54 | conv_weight = conv_weight.float()
55 | O, I, J, K = conv_weight.shape
56 | if in_chans == 1:
57 | if I > 3:
58 | assert conv_weight.shape[1] % 3 == 0
59 | # For models with space2depth stems
60 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
61 | conv_weight = conv_weight.sum(dim=2, keepdim=False)
62 | else:
63 | conv_weight = conv_weight.sum(dim=1, keepdim=True)
64 | elif in_chans != 3:
65 | if I != 3:
66 | raise NotImplementedError(
67 | 'Weight format not supported by conversion.')
68 | else:
69 | # NOTE this strategy should be better than random init, but there could be other combinations of
70 | # the original RGB input layer weights that'd work better for specific cases.
71 | repeat = int(math.ceil(in_chans / 3))
72 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[
73 | :, :in_chans, :, :]
74 | conv_weight *= (3 / float(in_chans))
75 | conv_weight = conv_weight.to(conv_type)
76 | return conv_weight
77 |
78 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
79 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
80 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
81 | ntok_new = posemb_new.shape[1]
82 | if num_tokens:
83 | posemb_tok, posemb_grid = posemb[:,
84 | :num_tokens], posemb[0, num_tokens:]
85 | ntok_new -= num_tokens
86 | else:
87 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
88 | gs_old = int(math.sqrt(len(posemb_grid)))
89 | if not len(gs_new): # backwards compatibility
90 | gs_new = [int(math.sqrt(ntok_new))] * 2
91 | assert len(gs_new) >= 2
92 | posemb_grid = posemb_grid.reshape(
93 | 1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
94 | posemb_grid = F.interpolate(
95 | posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
96 | posemb_grid = posemb_grid.permute(
97 | 0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
98 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
99 | return posemb
100 |
101 | cached_file = _download_cached_file(
102 | url, check_hash=check_hash, progress=progress)
103 |
104 | w = np.load(cached_file)
105 | if not prefix and 'opt/target/embedding/kernel' in w:
106 | prefix = 'opt/target/'
107 |
108 | if hasattr(model.patch_embed, 'backbone'):
109 | # hybrid
110 | backbone = model.patch_embed.backbone
111 | stem_only = not hasattr(backbone, 'stem')
112 | stem = backbone if stem_only else backbone.stem
113 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
114 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
115 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
116 | if not stem_only:
117 | for i, stage in enumerate(backbone.stages):
118 | for j, block in enumerate(stage.blocks):
119 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
120 | for r in range(3):
121 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
122 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
123 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
124 | if block.downsample is not None:
125 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
126 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
127 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
128 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
129 | else:
130 | embed_conv_w = adapt_input_conv(
131 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
132 | model.patch_embed.proj.weight.copy_(embed_conv_w)
133 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
134 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
135 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
136 | if pos_embed_w.shape != model.pos_embed.shape:
137 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
138 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
139 | model.pos_embed.copy_(pos_embed_w)
140 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
141 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
142 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
143 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
144 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
145 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
146 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
147 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
148 | for i, block in enumerate(model.blocks.children()):
149 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
150 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
151 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
152 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
153 | block.attn.qkv.weight.copy_(torch.cat([
154 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
155 | block.attn.qkv.bias.copy_(torch.cat([
156 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
157 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
158 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
159 | for r in range(2):
160 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
161 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
162 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
163 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
164 |
--------------------------------------------------------------------------------
/models/layers_quant.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from itertools import repeat
4 | import collections.abc
5 |
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct
11 |
12 |
13 | def _ntuple(n):
14 | def parse(x):
15 | if isinstance(x, collections.abc.Iterable):
16 | return x
17 | return tuple(repeat(x, n))
18 |
19 | return parse
20 |
21 |
22 | to_2tuple = _ntuple(2)
23 |
24 |
25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
26 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
27 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
28 | def norm_cdf(x):
29 | # Computes standard normal cumulative distribution function
30 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
31 |
32 | if (mean < a - 2 * std) or (mean > b + 2 * std):
33 | warnings.warn(
34 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
35 | "The distribution of values may be incorrect.",
36 | stacklevel=2,
37 | )
38 |
39 | with torch.no_grad():
40 | # Values are generated by using a truncated uniform distribution and
41 | # then using the inverse CDF for the normal distribution.
42 | # Get upper and lower cdf values
43 | l = norm_cdf((a - mean) / std)
44 | u = norm_cdf((b - mean) / std)
45 |
46 | # Uniformly fill tensor with values from [l, u], then translate to
47 | # [2l-1, 2u-1].
48 | tensor.uniform_(2 * l - 1, 2 * u - 1)
49 |
50 | # Use inverse cdf transform for normal distribution to get truncated
51 | # standard normal
52 | tensor.erfinv_()
53 |
54 | # Transform to proper mean, std
55 | tensor.mul_(std * math.sqrt(2.0))
56 | tensor.add_(mean)
57 |
58 | # Clamp to ensure it's in the proper range
59 | tensor.clamp_(min=a, max=b)
60 | return tensor
61 |
62 |
63 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
64 | # type: (Tensor, float, float, float, float) -> Tensor
65 | r"""Fills the input Tensor with values drawn from a truncated
66 | normal distribution. The values are effectively drawn from the
67 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
68 | with values outside :math:`[a, b]` redrawn until they are within
69 | the bounds. The method used for generating the random values works
70 | best when :math:`a \leq \text{mean} \leq b`.
71 | Args:
72 | tensor: an n-dimensional `torch.Tensor`
73 | mean: the mean of the normal distribution
74 | std: the standard deviation of the normal distribution
75 | a: the minimum cutoff value
76 | b: the maximum cutoff value
77 | Examples:
78 | >>> w = torch.empty(3, 5)
79 | >>> nn.init.trunc_normal_(w)
80 | """
81 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
82 |
83 |
84 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
85 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
86 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
87 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
88 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
89 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
90 | 'survival rate' as the argument.
91 | """
92 | if drop_prob == 0.0 or not training:
93 | return x
94 | keep_prob = 1 - drop_prob
95 | shape = (x.shape[0],) + (1,) * (
96 | x.ndim - 1
97 | ) # work with diff dim tensors, not just 2D ConvNets
98 | random_tensor = keep_prob + \
99 | torch.rand(shape, dtype=x.dtype, device=x.device)
100 | random_tensor.floor_() # binarize
101 | output = x.div(keep_prob) * random_tensor
102 | return output
103 |
104 |
105 | class DropPath(nn.Module):
106 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
107 |
108 | def __init__(self, drop_prob=None):
109 | super(DropPath, self).__init__()
110 | self.drop_prob = drop_prob
111 |
112 | def forward(self, x):
113 | return drop_path(x, self.drop_prob, self.training)
114 |
115 |
116 | class Mlp(nn.Module):
117 | def __init__(self,
118 | in_features,
119 | hidden_features,
120 | out_features=None,
121 | act_layer=nn.GELU,
122 | drop=0.0,
123 | cfg=None):
124 | super().__init__()
125 | out_features = out_features or in_features
126 | hidden_features = hidden_features or in_features
127 | self.fc1 = QuantLinear(cfg.weight_bit,
128 | in_features,
129 | hidden_features)
130 | self.act = act_layer()
131 | self.QuantAct1 = QuantAct(cfg.activation_bit)
132 | self.fc2 = QuantLinear(cfg.weight_bit,
133 | hidden_features,
134 | out_features)
135 | self.QuantAct2 = QuantAct(cfg.activation_bit)
136 | self.drop = nn.Dropout(drop)
137 |
138 | def forward(self, x):
139 | x = self.fc1(x)
140 | x = self.act(x)
141 | x = self.QuantAct1(x)
142 | x = self.drop(x)
143 | x = self.fc2(x)
144 | x = self.QuantAct2(x)
145 | x = self.drop(x)
146 | return x
147 |
148 |
149 | class PatchEmbed(nn.Module):
150 | """Image to Patch Embedding"""
151 |
152 | def __init__(self,
153 | img_size=224,
154 | patch_size=16,
155 | in_chans=3,
156 | embed_dim=768,
157 | norm_layer=None,
158 | cfg=None):
159 | super().__init__()
160 | img_size = to_2tuple(img_size)
161 | patch_size = to_2tuple(patch_size)
162 | self.img_size = img_size
163 | self.patch_size = patch_size
164 |
165 | self.grid_size = (img_size[0] // patch_size[0],
166 | img_size[1] // patch_size[1])
167 | self.num_patches = self.grid_size[0] * self.grid_size[1]
168 |
169 | self.proj = QuantConv2d(cfg.weight_bit,
170 | in_chans,
171 | embed_dim,
172 | kernel_size=patch_size,
173 | stride=patch_size)
174 | if norm_layer:
175 | self.QuantAct_before_norm = QuantAct(cfg.activation_bit)
176 | self.norm = norm_layer(embed_dim)
177 | self.QuantAct = QuantAct(cfg.activation_bit)
178 | else:
179 | self.QuantAct_before_norm = nn.Identity()
180 | self.norm = nn.Identity()
181 | self.QuantAct = QuantAct(cfg.activation_bit)
182 |
183 | def forward(self, x):
184 | B, C, H, W = x.shape
185 | # FIXME look at relaxing size constraints
186 | assert (
187 | H == self.img_size[0] and W == self.img_size[1]
188 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
189 | x = self.proj(x).flatten(2).transpose(1, 2)
190 | x = self.QuantAct_before_norm(x)
191 | x = self.norm(x)
192 | x = self.QuantAct(x)
193 | return x
194 |
195 |
196 | class HybridEmbed(nn.Module):
197 | """CNN Feature Map Embedding
198 | Extract feature map from CNN, flatten, project to embedding dim.
199 | """
200 |
201 | def __init__(
202 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
203 | super().__init__()
204 | assert isinstance(backbone, nn.Module)
205 | img_size = to_2tuple(img_size)
206 | self.img_size = img_size
207 | self.backbone = backbone
208 | if feature_size is None:
209 | with torch.no_grad():
210 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
211 | # map for all networks, the feature metadata has reliable channel and stride info, but using
212 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
213 | training = backbone.training
214 | if training:
215 | backbone.eval()
216 | o = self.backbone(torch.zeros(
217 | 1, in_chans, img_size[0], img_size[1]))
218 | if isinstance(o, (list, tuple)):
219 | # last feature if backbone outputs list/tuple of features
220 | o = o[-1]
221 | feature_size = o.shape[-2:]
222 | feature_dim = o.shape[1]
223 | backbone.train(training)
224 | else:
225 | feature_size = to_2tuple(feature_size)
226 | if hasattr(self.backbone, "feature_info"):
227 | feature_dim = self.backbone.feature_info.channels()[-1]
228 | else:
229 | feature_dim = self.backbone.num_features
230 | self.num_patches = feature_size[0] * feature_size[1]
231 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
232 |
233 | def forward(self, x):
234 | x = self.backbone(x)
235 | if isinstance(x, (list, tuple)):
236 | x = x[-1] # last feature if backbone outputs list/tuple of features
237 | x = self.proj(x).flatten(2).transpose(1, 2)
238 | return x
239 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/models/vit_quant.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import re
4 | import warnings
5 | from itertools import repeat
6 | import collections.abc
7 | from collections import OrderedDict
8 | from functools import partial
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import nn
13 |
14 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct
15 | from .layers_quant import PatchEmbed, HybridEmbed, Mlp, DropPath, trunc_normal_
16 |
17 |
18 | __all__ = ['deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224']
19 |
20 |
21 | class Attention(nn.Module):
22 | def __init__(
23 | self,
24 | dim,
25 | num_heads=8,
26 | qkv_bias=False,
27 | qk_scale=None,
28 | attn_drop=0.0,
29 | proj_drop=0.0,
30 | cfg=None):
31 | super().__init__()
32 | self.num_heads = num_heads
33 | head_dim = dim // num_heads
34 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
35 | self.scale = qk_scale or head_dim ** -0.5
36 |
37 | self.qkv = QuantLinear(
38 | cfg.weight_bit,
39 | dim,
40 | dim*3
41 | )
42 | self.qact1 = QuantAct(cfg.activation_bit)
43 | self.qact2 = QuantAct(cfg.activation_bit)
44 | self.proj = QuantLinear(
45 | cfg.weight_bit,
46 | dim,
47 | dim
48 | )
49 | self.qact3 = QuantAct(cfg.activation_bit)
50 | self.qact_attn1 = QuantAct(cfg.activation_bit)
51 | self.attn_drop = nn.Dropout(attn_drop)
52 | self.proj_drop = nn.Dropout(proj_drop)
53 |
54 | def forward(self, x):
55 | B, N, C = x.shape
56 | x = self.qkv(x)
57 | x = self.qact1(x)
58 | qkv = x.reshape(B, N, 3, self.num_heads, C //
59 | self.num_heads).permute(2, 0, 3, 1, 4) # (BN33)
60 | q, k, v = (
61 | qkv[0],
62 | qkv[1],
63 | qkv[2]
64 | ) # make torchscript happy (cannot use tensor as tuple)
65 | attn = (q @ k.transpose(-2, -1)) * self.scale
66 | attn = self.qact_attn1(attn)
67 | attn = attn.softmax(dim=-1)
68 | attn = self.attn_drop(attn)
69 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
70 | x = self.qact2(x)
71 | x = self.proj(x)
72 | x = self.qact3(x)
73 | x = self.proj_drop(x)
74 | return x
75 |
76 |
77 | class Block(nn.Module):
78 | def __init__(
79 | self,
80 | dim,
81 | num_heads,
82 | mlp_ratio=4.0,
83 | qkv_bias=False,
84 | qk_scale=None,
85 | drop=0.0,
86 | attn_drop=0.0,
87 | drop_path=0.0,
88 | act_layer=nn.GELU,
89 | norm_layer=nn.LayerNorm,
90 | cfg=None):
91 | super().__init__()
92 | self.norm1 = norm_layer(dim)
93 | self.qact1 = QuantAct(cfg.activation_bit)
94 | self.attn = Attention(
95 | dim,
96 | num_heads=num_heads,
97 | qkv_bias=qkv_bias,
98 | qk_scale=qk_scale,
99 | attn_drop=attn_drop,
100 | proj_drop=drop,
101 | cfg=cfg
102 | )
103 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
104 | self.drop_path = DropPath(
105 | drop_path) if drop_path > 0.0 else nn.Identity()
106 | self.qact2 = QuantAct(cfg.activation_bit)
107 | self.norm2 = norm_layer(dim)
108 | self.qact3 = QuantAct(cfg.activation_bit)
109 | mlp_hidden_dim = int(dim * mlp_ratio)
110 | self.mlp = Mlp(
111 | in_features=dim,
112 | hidden_features=mlp_hidden_dim,
113 | act_layer=act_layer,
114 | drop=drop,
115 | cfg=cfg
116 | )
117 | self.qact4 = QuantAct(cfg.activation_bit)
118 |
119 | def forward(self, x):
120 | x = self.qact2(
121 | x + self.drop_path(self.attn(self.qact1(self.norm1(x)))))
122 | x = self.qact4(x + self.drop_path(self.mlp(self.qact3(self.norm2(x)))))
123 | return x
124 |
125 |
126 | class VisionTransformer(nn.Module):
127 | """Vision Transformer
128 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
129 | https://arxiv.org/abs/2010.11929
130 | """
131 |
132 | def __init__(
133 | self,
134 | img_size=224,
135 | patch_size=16,
136 | in_chans=3,
137 | num_classes=1000,
138 | embed_dim=768,
139 | depth=12,
140 | num_heads=12,
141 | mlp_ratio=4.0,
142 | qkv_bias=True,
143 | qk_scale=None,
144 | representation_size=None,
145 | drop_rate=0.0,
146 | attn_drop_rate=0.0,
147 | drop_path_rate=0.0,
148 | hybrid_backbone=None,
149 | norm_layer=None,
150 | input_quant=False,
151 | cfg=None):
152 | super().__init__()
153 | self.num_classes = num_classes
154 | self.num_features = (
155 | self.embed_dim
156 | ) = embed_dim # num_features for consistency with other models
157 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
158 |
159 | self.cfg = cfg
160 | self.input_quant = input_quant
161 | if input_quant:
162 | self.qact_input = QuantAct(cfg.activation_bit)
163 |
164 | if hybrid_backbone is not None:
165 | self.patch_embed = HybridEmbed(
166 | hybrid_backbone,
167 | img_size=img_size,
168 | in_chans=in_chans,
169 | embed_dim=embed_dim,
170 | )
171 | else:
172 | self.patch_embed = PatchEmbed(
173 | img_size=img_size,
174 | patch_size=patch_size,
175 | in_chans=in_chans,
176 | embed_dim=embed_dim,
177 | cfg=cfg)
178 | num_patches = self.patch_embed.num_patches
179 |
180 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
181 | self.pos_embed = nn.Parameter(
182 | torch.zeros(1, num_patches + 1, embed_dim))
183 | self.pos_drop = nn.Dropout(p=drop_rate)
184 |
185 | self.qact_embed = QuantAct(cfg.activation_bit)
186 | self.qact_pos = QuantAct(cfg.activation_bit)
187 | self.qact1 = QuantAct(cfg.activation_bit)
188 |
189 | dpr = [
190 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
191 | ] # stochastic depth decay rule
192 | self.blocks = nn.ModuleList(
193 | [
194 | Block(
195 | dim=embed_dim,
196 | num_heads=num_heads,
197 | mlp_ratio=mlp_ratio,
198 | qkv_bias=qkv_bias,
199 | qk_scale=qk_scale,
200 | drop=drop_rate,
201 | attn_drop=attn_drop_rate,
202 | drop_path=dpr[i],
203 | norm_layer=norm_layer,
204 | cfg=cfg
205 | )
206 | for i in range(depth)
207 | ]
208 | )
209 | self.norm = norm_layer(embed_dim)
210 | self.qact2 = QuantAct(cfg.activation_bit)
211 |
212 | # Representation layer
213 | if representation_size:
214 | self.num_features = representation_size
215 | self.pre_logits = nn.Sequential(
216 | OrderedDict(
217 | [
218 | ("fc", nn.Linear(embed_dim, representation_size)),
219 | ("act", nn.Tanh()),
220 | ]
221 | )
222 | )
223 | else:
224 | self.pre_logits = nn.Identity()
225 |
226 | # Classifier head
227 | self.head = (
228 | QuantLinear(
229 | cfg.weight_bit,
230 | self.num_features,
231 | num_classes
232 | )
233 | if num_classes > 0
234 | else nn.Identity()
235 | )
236 | self.act_out = QuantAct(cfg.activation_bit)
237 | trunc_normal_(self.pos_embed, std=0.02)
238 | trunc_normal_(self.cls_token, std=0.02)
239 | self.apply(self._init_weights)
240 |
241 | def _init_weights(self, m):
242 | if isinstance(m, nn.Linear):
243 | trunc_normal_(m.weight, std=0.02)
244 | if isinstance(m, nn.Linear) and m.bias is not None:
245 | nn.init.constant_(m.bias, 0)
246 | elif isinstance(m, nn.LayerNorm):
247 | nn.init.constant_(m.bias, 0)
248 | nn.init.constant_(m.weight, 1.0)
249 |
250 | @torch.jit.ignore
251 | def no_weight_decay(self):
252 | return {"pos_embed", "cls_token"}
253 |
254 | def model_quant(self):
255 | for m in self.modules():
256 | if type(m) in [QuantLinear, QuantConv2d, QuantAct]:
257 | m.quant = True
258 |
259 | def model_freeze(self):
260 | for m in self.modules():
261 | if type(m) in [QuantAct]:
262 | m.running_stat = False
263 |
264 | def model_unfreeze(self):
265 | for m in self.modules():
266 | if type(m) in [QuantAct]:
267 | m.running_stat = True
268 |
269 | def forward_features(self, x):
270 | B = x.shape[0]
271 |
272 | if self.input_quant:
273 | x = self.qact_input(x)
274 |
275 | x = self.patch_embed(x)
276 |
277 | cls_tokens = self.cls_token.expand(
278 | B, -1, -1
279 | ) # stole cls_tokens impl from Phil Wang, thanks
280 | x = torch.cat((cls_tokens, x), dim=1)
281 | x = self.qact_embed(x)
282 | x = x + self.qact_pos(self.pos_embed)
283 | x = self.qact1(x)
284 |
285 | x = self.pos_drop(x)
286 |
287 | for blk in self.blocks:
288 | x = blk(x)
289 |
290 | x = self.norm(x)[:, 0]
291 | x = self.qact2(x)
292 | x = self.pre_logits(x)
293 | return x
294 |
295 | def forward(self, x):
296 | x = self.forward_features(x)
297 | x = self.head(x)
298 | x = self.act_out(x)
299 | return x
300 |
301 |
302 | def deit_tiny_patch16_224(pretrained=False, cfg=None, **kwargs):
303 | model = VisionTransformer(
304 | patch_size=16,
305 | embed_dim=192,
306 | depth=12,
307 | num_heads=3,
308 | mlp_ratio=4,
309 | qkv_bias=True,
310 | norm_layer=None,
311 | input_quant=True,
312 | cfg=cfg,
313 | **kwargs,
314 | )
315 | if pretrained:
316 | checkpoint = torch.hub.load_state_dict_from_url(
317 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
318 | map_location="cpu",
319 | check_hash=True,
320 | )
321 | model.load_state_dict(checkpoint["model"], strict=False)
322 | return model
323 |
324 |
325 | def deit_small_patch16_224(pretrained=False, cfg=None, **kwargs):
326 | model = VisionTransformer(
327 | patch_size=16,
328 | embed_dim=384,
329 | depth=12,
330 | num_heads=6,
331 | mlp_ratio=4,
332 | qkv_bias=True,
333 | norm_layer=None,
334 | input_quant=True,
335 | cfg=cfg,
336 | **kwargs
337 | )
338 | if pretrained:
339 | checkpoint = torch.hub.load_state_dict_from_url(
340 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
341 | map_location="cpu", check_hash=True
342 | )
343 | model.load_state_dict(checkpoint["model"], strict=False)
344 | return model
345 |
346 |
347 | def deit_base_patch16_224(pretrained=False, cfg=None, **kwargs):
348 | model = VisionTransformer(
349 | patch_size=16,
350 | embed_dim=768,
351 | depth=12,
352 | num_heads=12,
353 | mlp_ratio=4,
354 | qkv_bias=True,
355 | norm_layer=None,
356 | input_quant=True,
357 | cfg=cfg,
358 | **kwargs
359 | )
360 | if pretrained:
361 | checkpoint = torch.hub.load_state_dict_from_url(
362 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
363 | map_location="cpu", check_hash=True
364 | )
365 | model.load_state_dict(checkpoint["model"], strict=False)
366 | return model
367 |
--------------------------------------------------------------------------------
/models/swin_quant.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.checkpoint as checkpoint
7 |
8 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct
9 | from .layers_quant import PatchEmbed, HybridEmbed, Mlp, DropPath, trunc_normal_, to_2tuple
10 |
11 |
12 | __all__ = ['swin_tiny_patch4_window7_224', 'swin_small_patch4_window7_224']
13 |
14 |
15 | def window_partition(x, window_size: int):
16 | """
17 | Args:
18 | x: (B, H, W, C)
19 | window_size (int): window size
20 |
21 | Returns:
22 | windows: (num_windows*B, window_size, window_size, C)
23 | """
24 | B, H, W, C = x.shape
25 | x = x.view(B, H // window_size, window_size,
26 | W // window_size, window_size, C)
27 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
28 | ).view(-1, window_size, window_size, C)
29 | return windows
30 |
31 |
32 | def window_reverse(windows, window_size: int, H: int, W: int):
33 | """
34 | Args:
35 | windows: (num_windows*B, window_size, window_size, C)
36 | window_size (int): Window size
37 | H (int): Height of image
38 | W (int): Width of image
39 |
40 | Returns:
41 | x: (B, H, W, C)
42 | """
43 | B = int(windows.shape[0] / (H * W / window_size / window_size))
44 | x = windows.view(B, H // window_size, W // window_size,
45 | window_size, window_size, -1)
46 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
47 | return x
48 |
49 |
50 | class WindowAttention(nn.Module):
51 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
52 | It supports both of shifted and non-shifted window.
53 |
54 | Args:
55 | dim (int): Number of input channels.
56 | window_size (tuple[int]): The height and width of the window.
57 | num_heads (int): Number of attention heads.
58 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
59 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
60 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
61 | """
62 |
63 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., cfg=None):
64 |
65 | super().__init__()
66 | self.dim = dim
67 | self.window_size = window_size # Wh, Ww
68 | self.num_heads = num_heads
69 | head_dim = dim // num_heads
70 | self.scale = head_dim ** -0.5
71 |
72 | # define a parameter table of relative position bias
73 | self.relative_position_bias_table = nn.Parameter(
74 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
75 |
76 | # get pair-wise relative position index for each token inside the window
77 | coords_h = torch.arange(self.window_size[0])
78 | coords_w = torch.arange(self.window_size[1])
79 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
80 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
81 | relative_coords = coords_flatten[:, :, None] - \
82 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
83 | relative_coords = relative_coords.permute(
84 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
85 | relative_coords[:, :, 0] += self.window_size[0] - \
86 | 1 # shift to start from 0
87 | relative_coords[:, :, 1] += self.window_size[1] - 1
88 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
89 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
90 | self.register_buffer("relative_position_index",
91 | relative_position_index)
92 |
93 | self.qkv = QuantLinear(
94 | cfg.weight_bit,
95 | dim,
96 | dim*3
97 | )
98 | self.qact1 = QuantAct(cfg.activation_bit)
99 | self.qact_attn1 = QuantAct(cfg.activation_bit)
100 | self.qact_table = QuantAct(cfg.activation_bit)
101 | self.qact2 = QuantAct(cfg.activation_bit)
102 |
103 | self.attn_drop = nn.Dropout(attn_drop)
104 | self.qact3 = QuantAct(cfg.activation_bit)
105 | self.qact4 = QuantAct(cfg.activation_bit)
106 | # self.proj = nn.Linear(dim, dim)
107 | self.proj = QuantLinear(
108 | cfg.weight_bit,
109 | dim,
110 | dim
111 | )
112 | self.proj_drop = nn.Dropout(proj_drop)
113 |
114 | trunc_normal_(self.relative_position_bias_table, std=.02)
115 | self.softmax = nn.Softmax(dim=-1)
116 |
117 | def forward(self, x, mask: Optional[torch.Tensor] = None):
118 | """
119 | Args:
120 | x: input features with shape of (num_windows*B, N, C)
121 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
122 | """
123 | B_, N, C = x.shape
124 | x = self.qkv(x)
125 | x = self.qact1(x)
126 | qkv = x.reshape(B_, N, 3, self.num_heads, C //
127 | self.num_heads).permute(2, 0, 3, 1, 4)
128 | # make torchscript happy (cannot use tensor as tuple)
129 | q, k, v = qkv[0], qkv[1], qkv[2]
130 |
131 | q = q * self.scale
132 | attn = (q @ k.transpose(-2, -1))
133 | attn = self.qact_attn1(attn)
134 | relative_position_bias_table_q = self.qact_table(
135 | self.relative_position_bias_table)
136 | relative_position_bias = relative_position_bias_table_q[self.relative_position_index.view(-1)].view(
137 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
138 | relative_position_bias = relative_position_bias.permute(
139 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
140 | attn = attn + relative_position_bias.unsqueeze(0)
141 | attn = self.qact2(attn)
142 |
143 | if mask is not None:
144 | nW = mask.shape[0]
145 | attn = attn.view(B_ // nW, nW, self.num_heads, N,
146 | N) + mask.unsqueeze(1).unsqueeze(0)
147 | attn = attn.view(-1, self.num_heads, N, N)
148 | attn = self.softmax(attn)
149 | else:
150 | attn = self.softmax(attn)
151 |
152 | attn = self.attn_drop(attn)
153 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
154 | x = self.qact3(x)
155 | x = self.proj(x)
156 | x = self.qact4(x)
157 | x = self.proj_drop(x)
158 | return x
159 |
160 |
161 | class SwinTransformerBlock(nn.Module):
162 | r""" Swin Transformer Block.
163 |
164 | Args:
165 | dim (int): Number of input channels.
166 | input_resolution (tuple[int]): Input resulotion.
167 | num_heads (int): Number of attention heads.
168 | window_size (int): Window size.
169 | shift_size (int): Shift size for SW-MSA.
170 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
171 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
172 | drop (float, optional): Dropout rate. Default: 0.0
173 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
174 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
175 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
177 | """
178 |
179 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
180 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
181 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, cfg=None):
182 | super().__init__()
183 | self.dim = dim
184 | self.input_resolution = input_resolution
185 | self.num_heads = num_heads
186 | self.window_size = window_size
187 | self.shift_size = shift_size
188 | self.mlp_ratio = mlp_ratio
189 | if min(self.input_resolution) <= self.window_size:
190 | # if window size is larger than input resolution, we don't partition windows
191 | self.shift_size = 0
192 | self.window_size = min(self.input_resolution)
193 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
194 |
195 | self.norm1 = norm_layer(dim)
196 | self.qact1 = QuantAct(cfg.activation_bit)
197 | self.attn = WindowAttention(
198 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
199 | attn_drop=attn_drop, proj_drop=drop, cfg=cfg)
200 |
201 | self.drop_path = DropPath(
202 | drop_path) if drop_path > 0. else nn.Identity()
203 | self.qact2 = QuantAct(cfg.activation_bit)
204 | self.norm2 = norm_layer(dim)
205 | self.qact3 = QuantAct(cfg.activation_bit)
206 | mlp_hidden_dim = int(dim * mlp_ratio)
207 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
208 | drop=drop, cfg=cfg)
209 | self.qact4 = QuantAct(cfg.activation_bit)
210 | if self.shift_size > 0:
211 | # calculate attention mask for SW-MSA
212 | H, W = self.input_resolution
213 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
214 | h_slices = (slice(0, -self.window_size),
215 | slice(-self.window_size, -self.shift_size),
216 | slice(-self.shift_size, None))
217 | w_slices = (slice(0, -self.window_size),
218 | slice(-self.window_size, -self.shift_size),
219 | slice(-self.shift_size, None))
220 | cnt = 0
221 | for h in h_slices:
222 | for w in w_slices:
223 | img_mask[:, h, w, :] = cnt
224 | cnt += 1
225 |
226 | # nW, window_size, window_size, 1
227 | mask_windows = window_partition(img_mask, self.window_size)
228 | mask_windows = mask_windows.view(-1,
229 | self.window_size * self.window_size)
230 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
231 | attn_mask = attn_mask.masked_fill(
232 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
233 | else:
234 | attn_mask = None
235 |
236 | self.register_buffer("attn_mask", attn_mask)
237 |
238 | def forward(self, x):
239 | H, W = self.input_resolution
240 | B, L, C = x.shape
241 | assert L == H * W, "input feature has wrong size"
242 |
243 | shortcut = x
244 | x = self.norm1(x)
245 | x = self.qact1(x)
246 | x = x.view(B, H, W, C)
247 |
248 | # cyclic shift
249 | if self.shift_size > 0:
250 | shifted_x = torch.roll(
251 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
252 | else:
253 | shifted_x = x
254 |
255 | # partition windows
256 | # nW*B, window_size, window_size, C
257 | x_windows = window_partition(shifted_x, self.window_size)
258 | # nW*B, window_size*window_size, C
259 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
260 |
261 | # W-MSA/SW-MSA
262 | # nW*B, window_size*window_size, C
263 | attn_windows = self.attn(x_windows, mask=self.attn_mask)
264 |
265 | # merge windows
266 | attn_windows = attn_windows.view(-1,
267 | self.window_size, self.window_size, C)
268 | shifted_x = window_reverse(
269 | attn_windows, self.window_size, H, W) # B H' W' C
270 |
271 | # reverse cyclic shift
272 | if self.shift_size > 0:
273 | x = torch.roll(shifted_x, shifts=(
274 | self.shift_size, self.shift_size), dims=(1, 2))
275 | else:
276 | x = shifted_x
277 | x = x.view(B, H * W, C)
278 |
279 | # FFN
280 | x = shortcut + self.drop_path(x)
281 | x = self.qact2(x)
282 | x = x + self.drop_path(self.mlp(self.qact3(self.norm2(x))))
283 | x = self.qact4(x)
284 |
285 | return x
286 |
287 |
288 | class PatchMerging(nn.Module):
289 | r""" Patch Merging Layer.
290 |
291 | Args:
292 | input_resolution (tuple[int]): Resolution of input feature.
293 | dim (int): Number of input channels.
294 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
295 | """
296 |
297 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, cfg=None):
298 | super().__init__()
299 | self.input_resolution = input_resolution
300 | self.dim = dim
301 |
302 | self.norm = norm_layer(4 * dim)
303 | self.qact1 = QuantAct(cfg.activation_bit)
304 | # self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
305 | self.reduction = QuantLinear(
306 | cfg.weight_bit,
307 | 4 * dim,
308 | 2 * dim,
309 | bias=False
310 | )
311 | self.qact2 = QuantAct(cfg.activation_bit)
312 |
313 | def forward(self, x):
314 | """
315 | x: B, H*W, C
316 | """
317 | H, W = self.input_resolution
318 | B, L, C = x.shape
319 | assert L == H * W, "input feature has wrong size"
320 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
321 |
322 | x = x.view(B, H, W, C)
323 |
324 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
325 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
326 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
327 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
328 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
329 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
330 | x = self.norm(x)
331 | x = self.qact1(x)
332 | x = self.reduction(x)
333 | x = self.qact2(x)
334 | return x
335 |
336 | def extra_repr(self) -> str:
337 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
338 |
339 | def flops(self):
340 | H, W = self.input_resolution
341 | flops = H * W * self.dim
342 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
343 | return flops
344 |
345 |
346 | class BasicLayer(nn.Module):
347 | """ A basic Swin Transformer layer for one stage.
348 |
349 | Args:
350 | dim (int): Number of input channels.
351 | input_resolution (tuple[int]): Input resolution.
352 | depth (int): Number of blocks.
353 | num_heads (int): Number of attention heads.
354 | window_size (int): Local window size.
355 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
356 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
357 | drop (float, optional): Dropout rate. Default: 0.0
358 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
359 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
360 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
361 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
362 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
363 | """
364 |
365 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
366 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
367 | norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
368 | cfg=None):
369 |
370 | super().__init__()
371 | self.dim = dim
372 | self.input_resolution = input_resolution
373 | self.depth = depth
374 | self.use_checkpoint = use_checkpoint
375 |
376 | # build blocks
377 | self.blocks = nn.ModuleList([
378 | SwinTransformerBlock(
379 | dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
380 | shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
381 | qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
382 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer,
383 | cfg=cfg)
384 | for i in range(depth)])
385 |
386 | # patch merging layer
387 | if downsample is not None:
388 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, cfg=cfg)
389 | else:
390 | self.downsample = None
391 |
392 | def forward(self, x):
393 | for i, blk in enumerate(self.blocks):
394 | if not torch.jit.is_scripting() and self.use_checkpoint:
395 | x = checkpoint.checkpoint(blk, x)
396 | else:
397 | x = blk(x)
398 | if self.downsample is not None:
399 | x = self.downsample(x)
400 | return x
401 |
402 | def extra_repr(self) -> str:
403 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
404 |
405 |
406 | class SwinTransformer(nn.Module):
407 | r""" Swin Transformer
408 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
409 | https://arxiv.org/pdf/2103.14030
410 |
411 | Args:
412 | img_size (int | tuple(int)): Input image size. Default 224
413 | patch_size (int | tuple(int)): Patch size. Default: 4
414 | in_chans (int): Number of input image channels. Default: 3
415 | num_classes (int): Number of classes for classification head. Default: 1000
416 | embed_dim (int): Patch embedding dimension. Default: 96
417 | depths (tuple(int)): Depth of each Swin Transformer layer.
418 | num_heads (tuple(int)): Number of attention heads in different layers.
419 | window_size (int): Window size. Default: 7
420 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
421 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
422 | drop_rate (float): Dropout rate. Default: 0
423 | attn_drop_rate (float): Attention dropout rate. Default: 0
424 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
425 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
426 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
427 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
428 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
429 | """
430 |
431 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
432 | embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
433 | window_size=7, mlp_ratio=4., qkv_bias=True,
434 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
435 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
436 | use_checkpoint=False, input_quant=False, cfg=None, **kwargs):
437 | super().__init__()
438 |
439 | self.num_classes = num_classes
440 | self.num_layers = len(depths)
441 | self.embed_dim = embed_dim
442 | self.ape = ape
443 | self.patch_norm = patch_norm
444 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
445 | self.mlp_ratio = mlp_ratio
446 | self.input_quant = input_quant
447 | self.cfg = cfg
448 | if input_quant:
449 | self.qact_input = QuantAct(cfg.activation_bit)
450 | # split image into non-overlapping patches
451 | self.patch_embed = PatchEmbed(
452 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
453 | norm_layer=norm_layer if self.patch_norm else None, cfg=cfg)
454 | num_patches = self.patch_embed.num_patches
455 | self.patch_grid = self.patch_embed.grid_size
456 |
457 | # absolute position embedding
458 | if self.ape:
459 | self.absolute_pos_embed = nn.Parameter(
460 | torch.zeros(1, num_patches, embed_dim))
461 | trunc_normal_(self.absolute_pos_embed, std=.02)
462 | self.qact1 = QuantAct(cfg.activation_bit)
463 | else:
464 | self.absolute_pos_embed = None
465 |
466 | self.pos_drop = nn.Dropout(p=drop_rate)
467 |
468 | # stochastic depth
469 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
470 | sum(depths))] # stochastic depth decay rule
471 |
472 | # build layers
473 | layers = []
474 | for i_layer in range(self.num_layers):
475 | layers += [BasicLayer(
476 | dim=int(embed_dim * 2 ** i_layer),
477 | input_resolution=(
478 | self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),
479 | depth=depths[i_layer],
480 | num_heads=num_heads[i_layer],
481 | window_size=window_size,
482 | mlp_ratio=self.mlp_ratio,
483 | qkv_bias=qkv_bias,
484 | drop=drop_rate,
485 | attn_drop=attn_drop_rate,
486 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
487 | norm_layer=norm_layer,
488 | downsample=PatchMerging if (
489 | i_layer < self.num_layers - 1) else None,
490 | use_checkpoint=use_checkpoint,
491 | cfg=cfg)
492 | ]
493 | self.layers = nn.Sequential(*layers)
494 |
495 | self.norm = norm_layer(self.num_features)
496 |
497 | self.qact2 = QuantAct(cfg.activation_bit)
498 | self.avgpool = nn.AdaptiveAvgPool1d(1)
499 | self.qact3 = QuantAct(cfg.activation_bit)
500 | self.head = QuantLinear(
501 | cfg.weight_bit,
502 | self.num_features,
503 | num_classes
504 | )
505 |
506 | self.act_out = QuantAct(cfg.activation_bit)
507 | self.apply(self._init_weights)
508 |
509 | def _init_weights(self, m):
510 | if isinstance(m, nn.Linear):
511 | trunc_normal_(m.weight, std=.02)
512 | if isinstance(m, nn.Linear) and m.bias is not None:
513 | nn.init.constant_(m.bias, 0)
514 | elif isinstance(m, nn.LayerNorm):
515 | nn.init.constant_(m.bias, 0)
516 | nn.init.constant_(m.weight, 1.0)
517 |
518 | @torch.jit.ignore
519 | def no_weight_decay(self):
520 | return {'absolute_pos_embed'}
521 |
522 | @torch.jit.ignore
523 | def no_weight_decay_keywords(self):
524 | return {'relative_position_bias_table'}
525 |
526 | def model_quant(self):
527 | for m in self.modules():
528 | if type(m) in [QuantLinear, QuantConv2d, QuantAct]:
529 | m.quant = True
530 |
531 | def model_freeze(self):
532 | for m in self.modules():
533 | if type(m) in [QuantAct]:
534 | m.running_stat = False
535 |
536 | def model_unfreeze(self):
537 | for m in self.modules():
538 | if type(m) in [QuantAct]:
539 | m.running_stat = True
540 |
541 | def forward_features(self, x):
542 | if self.input_quant:
543 | x = self.qact_input(x)
544 | x = self.patch_embed(x)
545 | if self.absolute_pos_embed is not None:
546 | x = x + self.absolute_pos_embed
547 | x = self.qact1(x)
548 | x = self.pos_drop(x)
549 | for i, layer in enumerate(self.layers):
550 | x = layer(x)
551 |
552 | x = self.norm(x) # B L C
553 | x = self.qact2(x)
554 |
555 | x = self.avgpool(x.transpose(1, 2)) # B C 1
556 | x = self.qact3(x)
557 |
558 | x = torch.flatten(x, 1)
559 | return x
560 |
561 | def forward(self, x):
562 | x = self.forward_features(x)
563 | x = self.head(x)
564 | x = self.act_out(x)
565 | return x
566 |
567 |
568 | def swin_tiny_patch4_window7_224(pretrained=False, cfg=None, **kwargs):
569 | """ Swin-T @ 224x224, trained ImageNet-1k
570 | """
571 | model = SwinTransformer(
572 | patch_size=4,
573 | window_size=7,
574 | embed_dim=96,
575 | depths=(2, 2, 6, 2),
576 | num_heads=(3, 6, 12, 24),
577 | input_quant=True,
578 | cfg=cfg,
579 | **kwargs
580 | )
581 | if pretrained:
582 | checkpoint = torch.hub.load_state_dict_from_url(
583 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth",
584 | map_location="cpu", check_hash=True
585 | )
586 | model.load_state_dict(checkpoint["model"], strict=False)
587 | return model
588 |
589 |
590 | def swin_small_patch4_window7_224(pretrained=False, cfg=None, **kwargs):
591 | """ Swin-S @ 224x224, trained ImageNet-1k
592 | """
593 | model = SwinTransformer(
594 | patch_size=4,
595 | window_size=7,
596 | embed_dim=96,
597 | depths=(2, 2, 18, 2),
598 | num_heads=(3, 6, 12, 24),
599 | input_quant=True,
600 | cfg=cfg,
601 | **kwargs
602 | )
603 | if pretrained:
604 | checkpoint = torch.hub.load_state_dict_from_url(
605 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth",
606 | map_location="cpu", check_hash=True
607 | )
608 | model.load_state_dict(checkpoint["model"], strict=False)
609 | return model
610 |
--------------------------------------------------------------------------------