├── models ├── __init__.py ├── quantization_utils │ ├── __init__.py │ ├── quant_utils.py │ └── quant_modules.py ├── utils.py ├── layers_quant.py ├── vit_quant.py └── swin_quant.py ├── overview.png ├── utils ├── __init__.py ├── data_utils.py ├── build_model.py └── kde.py ├── README.md ├── generate_data.py ├── test_quant.py └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit_quant import * 2 | from .swin_quant import * 3 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zkkli/PSAQ-ViT/HEAD/overview.png -------------------------------------------------------------------------------- /models/quantization_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .quant_modules import QuantLinear, QuantAct, QuantConv2d -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_model import * 2 | from .kde import KernelDensityEstimator 3 | from .data_utils import build_dataset 4 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | 8 | 9 | def build_dataset(args): 10 | model_type = args.model.split("_")[0] 11 | if model_type == "deit": 12 | mean = (0.485, 0.456, 0.406) 13 | std = (0.229, 0.224, 0.225) 14 | crop_pct = 0.875 15 | elif model_type == 'vit': 16 | mean = (0.5, 0.5, 0.5) 17 | std = (0.5, 0.5, 0.5) 18 | crop_pct = 0.9 19 | elif model_type == 'swin': 20 | mean = (0.485, 0.456, 0.406) 21 | std = (0.229, 0.224, 0.225) 22 | crop_pct = 0.9 23 | else: 24 | raise NotImplementedError 25 | 26 | train_transform = build_transform(mean=mean, std=std, crop_pct=crop_pct) 27 | val_transform = build_transform(mean=mean, std=std, crop_pct=crop_pct) 28 | 29 | # Data 30 | traindir = os.path.join(args.dataset, 'train') 31 | valdir = os.path.join(args.dataset, 'val') 32 | 33 | val_dataset = datasets.ImageFolder(valdir, val_transform) 34 | val_loader = torch.utils.data.DataLoader( 35 | val_dataset, 36 | batch_size=args.val_batchsize, 37 | shuffle=False, 38 | num_workers=args.num_workers, 39 | pin_memory=True, 40 | ) 41 | 42 | train_dataset = datasets.ImageFolder(traindir, train_transform) 43 | train_loader = torch.utils.data.DataLoader( 44 | train_dataset, 45 | batch_size=args.calib_batchsize, 46 | shuffle=True, 47 | num_workers=args.num_workers, 48 | pin_memory=True, 49 | drop_last=True, 50 | ) 51 | 52 | return train_loader, val_loader 53 | 54 | 55 | def build_transform(input_size=224, interpolation="bicubic", 56 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), 57 | crop_pct=0.875): 58 | def _pil_interp(method): 59 | if method == "bicubic": 60 | return Image.BICUBIC 61 | elif method == "lanczos": 62 | return Image.LANCZOS 63 | elif method == "hamming": 64 | return Image.HAMMING 65 | else: 66 | return Image.BILINEAR 67 | resize_im = input_size > 32 68 | t = [] 69 | if resize_im: 70 | size = int(math.floor(input_size / crop_pct)) 71 | ip = _pil_interp(interpolation) 72 | t.append( 73 | transforms.Resize( 74 | size, interpolation=ip 75 | ), # to maintain same ratio w.r.t. 224 images 76 | ) 77 | t.append(transforms.CenterCrop(input_size)) 78 | 79 | t.append(transforms.ToTensor()) 80 | t.append(transforms.Normalize(mean, std)) 81 | return transforms.Compose(t) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # Patch Similarity Aware Data-Free Quantization for Vision Transformers 6 | 7 | This repository contains the official PyTorch implementation for the ECCV 2022 paper 8 | *["Patch Similarity Aware Data-Free Quantization for Vision Transformers"](https://arxiv.org/abs/2203.02250).* To the best of our knowledge, this is the first work on data-free quantization for vision transformers. Below are instructions for reproducing the results. 9 | 10 | ## Installation 11 | 12 | - **To install PSAQ-ViT** and develop locally: 13 | 14 | ```bash 15 | git clone https://github.com/zkkli/PSAQ-ViT.git 16 | cd PSAQ-ViT 17 | ``` 18 | 19 | ## Quantization 20 | 21 | - You can quantize and evaluate a single model using the following command: 22 | 23 | ```bash 24 | python test_quant.py [--model] [--dataset] [--w_bit] [--a_bit] [--mode] 25 | 26 | optional arguments: 27 | --model: Model architecture, the choises can be: 28 | deit_tiny, deit_small, deit_base, swin_tiny, and swin_small. 29 | --dataset: Path to ImageNet dataset. 30 | --w_bit: Bit-precision of weights, default=8. 31 | --a_bit: Bit-precision of activation, default=8. 32 | --mode: Mode of calibration data, 33 | 0: Generated fake data (PSAQ-ViT) 34 | 1: Gaussian noise 35 | 2: Real data 36 | ``` 37 | 38 | - Example: Quantize DeiT-B with generated fake data **(PSAQ-ViT)**. 39 | 40 | ```bash 41 | python test_quant.py --model deit_base --dataset --mode 0 42 | ``` 43 | 44 | - Example: Quantize DeiT-B with Gaussian noise. 45 | 46 | ```bash 47 | python test_quant.py --model deit_base --dataset --mode 1 48 | ``` 49 | 50 | - Example: Quantize DeiT-B with Real data. 51 | 52 | ```bash 53 | python test_quant.py --model deit_base --dataset --mode 2 54 | ``` 55 | 56 | ## Results 57 | 58 | Below are the experimental results of our proposed PSAQ-ViT that you should get on ImageNet dataset using an RTX 3090 GPU. 59 | 60 | | Model | Prec. | Top-1(%) | Prec. | Top-1(%) | 61 | |:--------------:|:-----:|:--------:|:-----:|:--------:| 62 | | DeiT-T (72.21) | W4/A8 | 65.57 | W8/A8 | 71.56 | 63 | | DeiT-S (79.85) | W4/A8 | 73.23 | W8/A8 | 76.92 | 64 | | DeiT-B (81.85) | W4/A8 | 77.05 | W8/A8 | 79.10 | 65 | | Swin-T (81.35) | W4/A8 | 71.79 | W8/A8 | 75.35 | 66 | | Swin-S (83.20) | W4/A8 | 75.14 | W8/A8 | 76.64 | 67 | 68 | ## Citation 69 | 70 | We appreciate it if you would please cite the following paper if you found the implementation useful for your work: 71 | 72 | ```bash 73 | @inproceedings{li2022psaqvit, 74 | title={Patch Similarity Aware Data-Free Quantization for Vision Transformers}, 75 | author={Li, Zhikai and Ma, Liping and Chen, Mengjuan and Xiao, Junrui and Gu, Qingyi}, 76 | booktitle={European Conference on Computer Vision}, 77 | pages={154--170}, 78 | year={2022} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /utils/build_model.py: -------------------------------------------------------------------------------- 1 | from types import MethodType 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import timm 6 | from timm.models import vision_transformer 7 | from timm.models.vision_transformer import Attention 8 | from timm.models.swin_transformer import WindowAttention 9 | 10 | def attention_forward(self, x): 11 | B, N, C = x.shape 12 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 13 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 14 | 15 | # attn = (q @ k.transpose(-2, -1)) * self.scale 16 | attn = self.matmul1(q, k.transpose(-2, -1)) * self.scale 17 | attn = attn.softmax(dim=-1) 18 | attn = self.attn_drop(attn) 19 | del q, k 20 | 21 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C) 22 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B, N, C) 23 | del attn, v 24 | x = self.proj(x) 25 | x = self.proj_drop(x) 26 | return x 27 | 28 | def window_attention_forward(self, x, mask = None): 29 | B_, N, C = x.shape 30 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 31 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 32 | 33 | q = q * self.scale 34 | # attn = (q @ k.transpose(-2, -1)) 35 | attn = self.matmul1(q, k.transpose(-2,-1)) 36 | 37 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 38 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 39 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 40 | attn = attn + relative_position_bias.unsqueeze(0) 41 | 42 | if mask is not None: 43 | nW = mask.shape[0] 44 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 45 | attn = attn.view(-1, self.num_heads, N, N) 46 | attn = self.softmax(attn) 47 | else: 48 | attn = self.softmax(attn) 49 | 50 | attn = self.attn_drop(attn) 51 | 52 | # x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 53 | x = self.matmul2(attn, v).transpose(1, 2).reshape(B_, N, C) 54 | x = self.proj(x) 55 | x = self.proj_drop(x) 56 | return x 57 | 58 | 59 | class MatMul(nn.Module): 60 | def forward(self, A, B): 61 | return A @ B 62 | 63 | 64 | def build_model(name, Pretrained=True): 65 | """ 66 | Get a vision transformer model. 67 | This will replace matrix multiplication operations with matmul modules in the model. 68 | 69 | Currently support almost all models in timm.models.transformers, including: 70 | - vit_tiny/small/base/large_patch16/patch32_224/384, 71 | - deit_tiny/small/base(_distilled)_patch16_224, 72 | - deit_base(_distilled)_patch16_384, 73 | - swin_tiny/small/base/large_patch4_window7_224, 74 | - swin_base/large_patch4_window12_384 75 | 76 | These models are finetuned on imagenet-1k and should use ViTImageNetLoaderGenerator 77 | for calibration and testing. 78 | """ 79 | net = timm.create_model(name, pretrained=Pretrained) 80 | 81 | for name, module in net.named_modules(): 82 | if isinstance(module, Attention): 83 | setattr(module, "matmul1", MatMul()) 84 | setattr(module, "matmul2", MatMul()) 85 | module.forward = MethodType(attention_forward, module) 86 | if isinstance(module, WindowAttention): 87 | setattr(module, "matmul1", MatMul()) 88 | setattr(module, "matmul2", MatMul()) 89 | module.forward = MethodType(window_attention_forward, module) 90 | 91 | net = net.cuda() 92 | net.eval() 93 | return net 94 | -------------------------------------------------------------------------------- /models/quantization_utils/quant_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from torch.autograd import Function, Variable 4 | import torch 5 | 6 | 7 | def reshape_tensor(input, scale, zero_point, is_weight=True): 8 | if is_weight: 9 | if len(input.shape) == 4: 10 | range_shape = (-1, 1, 1, 1) 11 | elif len(input.shape) == 2: 12 | range_shape = (-1, 1) 13 | else: 14 | raise NotImplementedError 15 | else: 16 | if len(input.shape) == 2: 17 | range_shape = (1, -1) 18 | elif len(input.shape) == 3: 19 | range_shape = (1, 1, -1) 20 | elif len(input.shape) == 4: 21 | range_shape = (1, -1, 1, 1) 22 | else: 23 | raise NotImplementedError 24 | 25 | scale = scale.reshape(range_shape) 26 | zero_point = zero_point.reshape(range_shape) 27 | 28 | return scale, zero_point 29 | 30 | 31 | def symmetric_linear_quantization_params(num_bits, 32 | min_val, 33 | max_val): 34 | """ 35 | Compute the scaling factor and zeropoint with the given quantization range for symmetric quantization. 36 | Parameters: 37 | ---------- 38 | saturation_min: lower bound for quantization range 39 | saturation_max: upper bound for quantization range 40 | per_channel: if True, calculate the scaling factor per channel. 41 | """ 42 | qmax = 2 ** (num_bits - 1) - 1 43 | qmin = -(2 ** (num_bits - 1)) 44 | eps = torch.finfo(torch.float32).eps 45 | 46 | max_val = torch.max(-min_val, max_val) 47 | scale = max_val / (float(qmax - qmin) / 2) 48 | scale.clamp_(eps) 49 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 50 | 51 | return scale, zero_point, qmin, qmax 52 | 53 | 54 | def asymmetric_linear_quantization_params(num_bits, 55 | min_val, 56 | max_val): 57 | """ 58 | Compute the scaling factor and zeropoint with the given quantization range. 59 | saturation_min: lower bound for quantization range 60 | saturation_max: upper bound for quantization range 61 | """ 62 | qmax = 2 ** num_bits - 1 63 | qmin = 0 64 | eps = torch.finfo(torch.float32).eps 65 | 66 | scale = (max_val - min_val) / float(qmax - qmin) 67 | scale.clamp_(eps) 68 | zero_point = qmin - torch.round(min_val / scale) 69 | zero_point.clamp_(qmin, qmax) 70 | 71 | return scale, zero_point, qmin, qmax 72 | 73 | 74 | class SymmetricQuantFunction(Function): 75 | """ 76 | Class to quantize the given floating-point values with given range and bit-setting. 77 | Currently only support inference, but not support back-propagation. 78 | """ 79 | @staticmethod 80 | def forward(ctx, x, k, x_min=None, x_max=None): 81 | """ 82 | x: single-precision value to be quantized 83 | k: bit-setting for x 84 | x_min: lower bound for quantization range 85 | x_max=None 86 | """ 87 | scale, zero_point, qmin, qmax = symmetric_linear_quantization_params(k, x_min, x_max) 88 | scale, zero_point = reshape_tensor(x, scale, zero_point, is_weight=True) 89 | 90 | # quantize 91 | quant_x = x / scale + zero_point 92 | quant_x = quant_x.round().clamp(qmin, qmax) 93 | 94 | # dequantize 95 | quant_x = (quant_x - zero_point) * scale 96 | 97 | return torch.autograd.Variable(quant_x) 98 | 99 | @staticmethod 100 | def backward(ctx, grad_output): 101 | raise NotImplementedError 102 | 103 | 104 | class AsymmetricQuantFunction(Function): 105 | """ 106 | Class to quantize the given floating-point values with given range and bit-setting. 107 | Currently only support inference, but not support back-propagation. 108 | """ 109 | @staticmethod 110 | def forward(ctx, x, k, x_min=None, x_max=None): 111 | """ 112 | x: single-precision value to be quantized 113 | k: bit-setting for x 114 | x_min: lower bound for quantization range 115 | x_max=None 116 | """ 117 | scale, zero_point, qmin, qmax = asymmetric_linear_quantization_params(k, x_min, x_max) 118 | scale, zero_point = reshape_tensor(x, scale, zero_point, is_weight=False) 119 | 120 | # quantize 121 | quant_x = x / scale + zero_point 122 | quant_x = quant_x.round().clamp(qmin, qmax) 123 | 124 | # dequantize 125 | quant_x = (quant_x - zero_point) * scale 126 | 127 | return torch.autograd.Variable(quant_x) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /utils/kde.py: -------------------------------------------------------------------------------- 1 | """Implementation of Kernel Density Estimation (KDE) [1]. 2 | Kernel density estimation is a nonparameteric density estimation method. It works by 3 | placing kernels K on each point in a "training" dataset D. Then, for a test point x, 4 | p(x) is estimated as p(x) = 1 / |D| \sum_{x_i \in D} K(u(x, x_i)), where u is some 5 | function of x, x_i. In order for p(x) to be a valid probability distribution, the kernel 6 | K must also be a valid probability distribution. 7 | References (used throughout the file): 8 | [1]: https://en.wikipedia.org/wiki/Kernel_density_estimation 9 | """ 10 | 11 | import abc 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | 17 | 18 | class GenerativeModel(abc.ABC, nn.Module): 19 | """Base class inherited by all generative models in pytorch-generative. 20 | Provides: 21 | * An abstract `sample()` method which is implemented by subclasses that support 22 | generating samples. 23 | * Variables `self._c, self._h, self._w` which store the shape of the (first) 24 | image Tensor the model was trained with. Note that `forward()` must have been 25 | called at least once and the input must be an image for these variables to be 26 | available. 27 | * A `device` property which returns the device of the model's parameters. 28 | """ 29 | 30 | def __call__(self, *args, **kwargs): 31 | if getattr(self, "_c", None) is None and len(args[0].shape) == 4: 32 | _, self._c, self._h, self._w = args[0].shape 33 | return super().__call__(*args, **kwargs) 34 | 35 | @property 36 | def device(self): 37 | return next(self.parameters()).device 38 | 39 | @abc.abstractmethod 40 | def sample(self, n_samples): 41 | ... 42 | 43 | 44 | class Kernel(abc.ABC, nn.Module): 45 | """Base class which defines the interface for all kernels.""" 46 | 47 | def __init__(self, bandwidth=0.01): 48 | """Initializes a new Kernel. 49 | Args: 50 | bandwidth: The kernel's (band)width. 51 | """ 52 | super().__init__() 53 | self.bandwidth = bandwidth 54 | 55 | def _diffs(self, test_Xs, train_Xs): 56 | """Computes difference between each x in test_Xs with all train_Xs.""" 57 | test_Xs = test_Xs.view(test_Xs.shape[0], test_Xs.shape[1], 1) 58 | train_Xs = train_Xs.view(train_Xs.shape[0], 1, *train_Xs.shape[1:]) 59 | return test_Xs - train_Xs 60 | 61 | @abc.abstractmethod 62 | def forward(self, test_Xs, train_Xs): 63 | """Computes p(x) for each x in test_Xs given train_Xs.""" 64 | 65 | @abc.abstractmethod 66 | def sample(self, train_Xs): 67 | """Generates samples from the kernel distribution.""" 68 | 69 | 70 | class ParzenWindowKernel(Kernel): 71 | """Implementation of the Parzen window kernel.""" 72 | 73 | def forward(self, test_Xs, train_Xs): 74 | abs_diffs = torch.abs(self._diffs(test_Xs, train_Xs)) 75 | dims = tuple(range(len(abs_diffs.shape))[1:]) 76 | dim = np.prod(abs_diffs.shape[1:]) 77 | inside = torch.sum(abs_diffs / self.bandwidth <= 0.5, dim=dims) == dim 78 | coef = 1 / self.bandwidth ** dim 79 | return (coef * inside) #.mean() #dim=1 80 | 81 | def sample(self, train_Xs): 82 | device = train_Xs.device 83 | noise = (torch.rand(train_Xs.shape, device=device) - 0.5) * self.bandwidth 84 | return train_Xs + noise 85 | 86 | 87 | class GaussianKernel(Kernel): 88 | """Implementation of the Gaussian kernel.""" 89 | 90 | def forward(self, test_Xs, train_Xs): 91 | diffs = self._diffs(test_Xs, train_Xs) 92 | dims = tuple(range(len(diffs.shape))[2:]) 93 | var = self.bandwidth ** 2 94 | exp = torch.exp(- torch.pow(diffs,2) / (2 * var)) 95 | coef = 1 / torch.sqrt(torch.tensor(2 * np.pi * var)) 96 | return (coef * exp).mean(dim=-1) 97 | 98 | def sample(self, train_Xs): 99 | device = train_Xs.device 100 | noise = torch.randn(train_Xs.shape) * self.bandwidth 101 | return train_Xs + noise 102 | 103 | 104 | class KernelDensityEstimator(GenerativeModel): 105 | """The KernelDensityEstimator model.""" 106 | 107 | def __init__(self, train_Xs, kernel=None): 108 | """Initializes a new KernelDensityEstimator. 109 | Args: 110 | train_Xs: The "training" data to use when estimating probabilities. 111 | kernel: The kernel to place on each of the train_Xs. 112 | """ 113 | super().__init__() 114 | self.kernel = kernel or GaussianKernel() 115 | self.train_Xs = train_Xs 116 | 117 | @property 118 | def device(self): 119 | return self.train_Xs.device 120 | 121 | # TODO(eugenhotaj): This method consumes O(train_Xs * x) memory. Implement an 122 | # iterative version instead. 123 | def forward(self, x): 124 | return self.kernel(x, self.train_Xs) 125 | 126 | def sample(self, n_samples): 127 | idxs = np.random.choice(range(len(self.train_Xs)), size=n_samples) 128 | return self.kernel.sample(self.train_Xs[idxs]) 129 | -------------------------------------------------------------------------------- /models/quantization_utils/quant_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.multiprocessing as mp 7 | from torch.nn import Parameter 8 | 9 | from .quant_utils import * 10 | 11 | 12 | class QuantConv2d(nn.Conv2d): 13 | """ 14 | Class to quantize weights of given convolutional layer 15 | """ 16 | def __init__(self, 17 | weight_bit, 18 | in_channels, 19 | out_channels, 20 | kernel_size, 21 | stride=1, 22 | padding=0, 23 | dilation=1, 24 | groups=1, 25 | bias=True): 26 | super(QuantConv2d, self).__init__(in_channels=in_channels, 27 | out_channels=out_channels, 28 | kernel_size=kernel_size, 29 | stride=stride, 30 | padding=padding, 31 | dilation=dilation, 32 | groups=groups, 33 | bias=bias) 34 | self.weight_bit = weight_bit 35 | self.quant = False 36 | self.weight_function = SymmetricQuantFunction.apply 37 | 38 | def __repr__(self): 39 | s = super(QuantConv2d, self).__repr__() 40 | s = "(" + s + " weight_bit={})".format(self.weight_bit) 41 | return s 42 | 43 | def forward(self, x): 44 | """ 45 | using quantized weights to forward activation x 46 | """ 47 | if not self.quant: 48 | return F.conv2d( 49 | x, 50 | self.weight, 51 | self.bias, 52 | self.stride, 53 | self.padding, 54 | self.dilation, 55 | self.groups, 56 | ) 57 | 58 | v = self.weight 59 | v = v.reshape(v.shape[0], -1) 60 | v_max = v.max(axis=1).values 61 | v_min = v.min(axis=1).values 62 | w = self.weight_function(self.weight, self.weight_bit, v_min, v_max) 63 | 64 | return F.conv2d( 65 | x, 66 | w, 67 | self.bias, 68 | self.stride, 69 | self.padding, 70 | self.dilation, 71 | self.groups 72 | ) 73 | 74 | 75 | class QuantLinear(nn.Linear): 76 | """ 77 | Class to quantize weights of given Linear layer 78 | """ 79 | def __init__(self, 80 | weight_bit, 81 | in_features, 82 | out_features, 83 | bias=True): 84 | super(QuantLinear, self).__init__(in_features, out_features, bias) 85 | self.weight_bit = weight_bit 86 | self.quant = False 87 | self.weight_function = SymmetricQuantFunction.apply 88 | 89 | def __repr__(self): 90 | s = super(QuantLinear, self).__repr__() 91 | s = "(" + s + " weight_bit={})".format(self.weight_bit) 92 | return s 93 | 94 | def forward(self, x): 95 | """ 96 | using quantized weights to forward activation x 97 | """ 98 | if not self.quant: 99 | return F.linear( 100 | x, 101 | self.weight, 102 | self.bias 103 | ) 104 | 105 | v = self.weight 106 | v = v.reshape(v.shape[0], -1) 107 | v_max = v.max(axis=1).values 108 | v_min = v.min(axis=1).values 109 | w = self.weight_function(self.weight, self.weight_bit, v_min, v_max) 110 | 111 | return F.linear( 112 | x, 113 | weight=w, 114 | bias=self.bias 115 | ) 116 | 117 | 118 | class QuantAct(nn.Module): 119 | """ 120 | Class to quantize given activations 121 | """ 122 | def __init__(self, 123 | activation_bit, 124 | running_stat=True): 125 | super(QuantAct, self).__init__() 126 | self.activation_bit = activation_bit 127 | self.running_stat = running_stat 128 | self.quant = False 129 | self.act_function = AsymmetricQuantFunction.apply 130 | 131 | self.register_buffer('x_min', torch.zeros(1)) 132 | self.register_buffer('x_max', torch.zeros(1)) 133 | 134 | def __repr__(self): 135 | return "{0}(activation_bit={1}, running_stat={2}, Act_min: {3:.2f}, Act_max: {4:.2f})".format( 136 | self.__class__.__name__, self.activation_bit, self.running_stat, 137 | self.x_min.item(), self.x_max.item()) 138 | 139 | def fix(self): 140 | """ 141 | fix the activation range by setting running stat 142 | """ 143 | self.running_stat = False 144 | 145 | def unfix(self): 146 | """ 147 | unfix the activation range by setting running stat 148 | """ 149 | self.running_stat = True 150 | 151 | def forward(self, x): 152 | """ 153 | quantize given activation x 154 | """ 155 | if self.running_stat: 156 | cur_max = x.data.max() 157 | cur_min = x.data.min() 158 | if self.x_max == 0: 159 | self.x_max = cur_max 160 | self.x_min = cur_min 161 | else: 162 | self.x_max = torch.max(cur_max, self.x_max) 163 | self.x_min = torch.min(cur_min, self.x_min) 164 | 165 | if not self.quant: 166 | return x 167 | 168 | quant_act = self.act_function(x, self.activation_bit, self.x_min, self.x_max) 169 | 170 | return quant_act 171 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from tqdm import tqdm 8 | 9 | from utils import * 10 | 11 | 12 | model_zoo = {'deit_tiny': 'deit_tiny_patch16_224', 13 | 'deit_small': 'deit_small_patch16_224', 14 | 'deit_base': 'deit_base_patch16_224', 15 | 'swin_tiny': 'swin_tiny_patch4_window7_224', 16 | 'swin_small': 'swin_small_patch4_window7_224', 17 | } 18 | 19 | 20 | class AttentionMap: 21 | def __init__(self, module): 22 | self.hook = module.register_forward_hook(self.hook_fn) 23 | 24 | def hook_fn(self, module, input, output): 25 | self.feature = output 26 | 27 | def remove(self): 28 | self.hook.remove() 29 | 30 | 31 | def generate_data(args): 32 | args.batch_size = args.calib_batchsize 33 | 34 | # Load pretrained model 35 | p_model = build_model(model_zoo[args.model], Pretrained=True) 36 | 37 | # Hook the attention 38 | hooks = [] 39 | if 'swin' in args.model: 40 | for m in p_model.layers: 41 | for n in range(len(m.blocks)): 42 | hooks.append(AttentionMap(m.blocks[n].attn.matmul2)) 43 | else: 44 | for m in p_model.blocks: 45 | hooks.append(AttentionMap(m.attn.matmul2)) 46 | 47 | # Init Gaussian noise 48 | img = torch.randn((args.batch_size, 3, 224, 224)).cuda() 49 | img.requires_grad = True 50 | 51 | # Init optimizer 52 | args.lr = 0.25 if 'swin' in args.model else 0.20 53 | optimizer = optim.Adam([img], lr=args.lr, betas=[0.5, 0.9], eps=1e-8) 54 | 55 | # Set pseudo labels 56 | pred = torch.LongTensor([random.randint(0, 999) for _ in range(args.batch_size)]).to('cuda') 57 | var_pred = random.uniform(2500, 3000) # for batch_size 32 58 | 59 | criterion = nn.CrossEntropyLoss() 60 | 61 | # Train for two epochs 62 | for lr_it in range(2): 63 | if lr_it == 0: 64 | iterations_per_layer = 500 65 | lim = 15 66 | else: 67 | iterations_per_layer = 500 68 | lim = 30 69 | 70 | lr_scheduler = lr_cosine_policy(args.lr, 100, iterations_per_layer) 71 | 72 | with tqdm(range(iterations_per_layer)) as pbar: 73 | for itr in pbar: 74 | pbar.set_description(f"Epochs {lr_it+1}/{2}") 75 | 76 | # Learning rate scheduling 77 | lr_scheduler(optimizer, itr, itr) 78 | 79 | # Apply random jitter offsets (from DeepInversion[1]) 80 | # [1] Yin, Hongxu, et al. "Dreaming to distill: Data-free knowledge transfer via deepinversion.", CVPR2020. 81 | off = random.randint(-lim, lim) 82 | img_jit = torch.roll(img, shifts=(off, off), dims=(2, 3)) 83 | # Flipping 84 | flip = random.random() > 0.5 85 | if flip: 86 | img_jit = torch.flip(img_jit, dims=(3,)) 87 | 88 | # Forward pass 89 | optimizer.zero_grad() 90 | p_model.zero_grad() 91 | 92 | output = p_model(img_jit) 93 | 94 | loss_oh = criterion(output, pred) 95 | loss_tv = torch.norm(get_image_prior_losses(img_jit) - var_pred) 96 | 97 | loss_entropy = 0 98 | for itr_hook in range(len(hooks)): 99 | # Hook attention 100 | attention = hooks[itr_hook].feature 101 | attention_p = attention.mean(dim=1)[:, 1:, :] 102 | sims = torch.cosine_similarity(attention_p.unsqueeze(1), attention_p.unsqueeze(2), dim=3) 103 | 104 | # Compute differential entropy 105 | kde = KernelDensityEstimator(sims.view(args.batch_size, -1)) 106 | start_p = sims.min().item() 107 | end_p = sims.max().item() 108 | x_plot = torch.linspace(start_p, end_p, steps=10).repeat(args.batch_size, 1).cuda() 109 | kde_estimate = kde(x_plot) 110 | dif_entropy_estimated = differential_entropy(kde_estimate, x_plot) 111 | loss_entropy -= dif_entropy_estimated 112 | 113 | # Combine loss 114 | total_loss = loss_entropy + 1.0 * loss_oh + 0.05 * loss_tv 115 | 116 | # Do image update 117 | total_loss.backward() 118 | optimizer.step() 119 | 120 | # Clip color outliers 121 | img.data = clip(img.data) 122 | 123 | return img.detach() 124 | 125 | 126 | def differential_entropy(pdf, x_pdf): 127 | # pdf is a vector because we want to perform a numerical integration 128 | pdf = pdf + 1e-4 129 | f = -1 * pdf * torch.log(pdf) 130 | # Integrate using the composite trapezoidal rule 131 | ans = torch.trapz(f, x_pdf, dim=-1).mean() 132 | return ans 133 | 134 | 135 | def get_image_prior_losses(inputs_jit): 136 | # Compute total variation regularization loss 137 | diff1 = inputs_jit[:, :, :, :-1] - inputs_jit[:, :, :, 1:] 138 | diff2 = inputs_jit[:, :, :-1, :] - inputs_jit[:, :, 1:, :] 139 | diff3 = inputs_jit[:, :, 1:, :-1] - inputs_jit[:, :, :-1, 1:] 140 | diff4 = inputs_jit[:, :, :-1, :-1] - inputs_jit[:, :, 1:, 1:] 141 | 142 | loss_var_l2 = torch.norm(diff1) + torch.norm(diff2) + torch.norm(diff3) + torch.norm(diff4) 143 | return loss_var_l2 144 | 145 | 146 | def clip(image_tensor, use_fp16=False): 147 | # Adjust the input based on mean and variance 148 | if use_fp16: 149 | mean = np.array([0.485, 0.456, 0.406], dtype=np.float16) 150 | std = np.array([0.229, 0.224, 0.225], dtype=np.float16) 151 | else: 152 | mean = np.array([0.485, 0.456, 0.406]) 153 | std = np.array([0.229, 0.224, 0.225]) 154 | for c in range(3): 155 | m, s = mean[c], std[c] 156 | image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s) 157 | #image_tensor[:, c] = torch.clamp(image_tensor[:, c], 0, 1) 158 | return image_tensor 159 | 160 | 161 | def lr_policy(lr_fn): 162 | def _alr(optimizer, iteration, epoch): 163 | lr = lr_fn(iteration, epoch) 164 | for param_group in optimizer.param_groups: 165 | param_group['lr'] = lr 166 | 167 | return _alr 168 | 169 | 170 | def lr_cosine_policy(base_lr, warmup_length, epochs): 171 | def _lr_fn(iteration, epoch): 172 | if epoch < warmup_length: 173 | lr = base_lr * (epoch + 1) / warmup_length 174 | else: 175 | e = epoch - warmup_length 176 | es = epochs - warmup_length 177 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 178 | return lr 179 | 180 | return lr_policy(_lr_fn) 181 | -------------------------------------------------------------------------------- /test_quant.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import sys 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from models import * 11 | from utils import * 12 | 13 | from generate_data import generate_data 14 | 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser(description="PSAQ-ViT", add_help=False) 18 | parser.add_argument("--model", default="deit_tiny", 19 | choices=['deit_tiny', 'deit_small', 'deit_base', 'swin_tiny', 'swin_small'], 20 | help="model") 21 | parser.add_argument('--dataset', default="/Path/to/Dataset/", 22 | help='path to dataset') 23 | parser.add_argument("--calib-batchsize", default=32, 24 | type=int, help="batchsize of calibration set") 25 | parser.add_argument("--val-batchsize", default=200, 26 | type=int, help="batchsize of validation set") 27 | parser.add_argument("--num-workers", default=16, type=int, 28 | help="number of data loading workers (default: 16)") 29 | parser.add_argument("--device", default="cuda", type=str, help="device") 30 | parser.add_argument("--print-freq", default=100, 31 | type=int, help="print frequency") 32 | parser.add_argument("--seed", default=0, type=int, help="seed") 33 | 34 | parser.add_argument("--mode", default=0, 35 | type=int, help="mode of calibration data, 0: PSAQ-ViT, 1: Gaussian noise, 2: Real data") 36 | parser.add_argument('--w_bit', default=8, 37 | type=int, help='bit-precision of weights') 38 | parser.add_argument('--a_bit', default=8, 39 | type=int, help='bit-precision of activation') 40 | 41 | return parser 42 | 43 | 44 | class Config: 45 | def __init__(self, w_bit, a_bit): 46 | self.weight_bit = w_bit 47 | self.activation_bit = a_bit 48 | 49 | 50 | def str2model(name): 51 | model_zoo = {'deit_tiny': deit_tiny_patch16_224, 52 | 'deit_small': deit_small_patch16_224, 53 | 'deit_base': deit_base_patch16_224, 54 | 'swin_tiny': swin_tiny_patch4_window7_224, 55 | 'swin_small': swin_small_patch4_window7_224 56 | } 57 | print('Model: %s' % model_zoo[name].__name__) 58 | return model_zoo[name] 59 | 60 | 61 | def seed(seed=0): 62 | sys.setrecursionlimit(100000) 63 | os.environ["PYTHONHASHSEED"] = str(seed) 64 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed_all(seed) 67 | torch.backends.cudnn.benchmark = False 68 | torch.backends.cudnn.deterministic = True 69 | np.random.seed(seed) 70 | random.seed(seed) 71 | 72 | 73 | def main(): 74 | print(args) 75 | seed(args.seed) 76 | 77 | device = torch.device(args.device) 78 | # Load bit-config 79 | cfg = Config(args.w_bit, args.a_bit) 80 | 81 | # Build model 82 | model = str2model(args.model)(pretrained=True, cfg=cfg) 83 | model = model.to(device) 84 | model.eval() 85 | 86 | # Build dataloader 87 | train_loader, val_loader = build_dataset(args) 88 | 89 | # Define loss function (criterion) 90 | criterion = nn.CrossEntropyLoss().to(device) 91 | 92 | # Get calibration set 93 | # Case 0: PASQ-ViT 94 | if args.mode == 0: 95 | print("Generating data...") 96 | calibrate_data = generate_data(args) 97 | print("Calibrating with generated data...") 98 | with torch.no_grad(): 99 | output = model(calibrate_data) 100 | # Case 1: Gaussian noise 101 | elif args.mode == 1: 102 | calibrate_data = torch.randn((args.calib_batchsize, 3, 224, 224)).to(device) 103 | print("Calibrating with Gaussian noise...") 104 | with torch.no_grad(): 105 | output = model(calibrate_data) 106 | # Case 2: Real data (Standard) 107 | elif args.mode == 2: 108 | for data, target in train_loader: 109 | calibrate_data = data.to(device) 110 | break 111 | print("Calibrating with real data...") 112 | with torch.no_grad(): 113 | output = model(calibrate_data) 114 | # Not implemented 115 | else: 116 | raise NotImplementedError 117 | 118 | # Freeze model 119 | model.model_quant() 120 | model.model_freeze() 121 | 122 | # Validate the quantized model 123 | print("Validating...") 124 | val_loss, val_prec1, val_prec5 = validate( 125 | args, val_loader, model, criterion, device 126 | ) 127 | 128 | 129 | def validate(args, val_loader, model, criterion, device): 130 | batch_time = AverageMeter() 131 | losses = AverageMeter() 132 | top1 = AverageMeter() 133 | top5 = AverageMeter() 134 | 135 | # Switch to evaluate mode 136 | model.eval() 137 | 138 | val_start_time = end = time.time() 139 | for i, (data, target) in enumerate(val_loader): 140 | target = target.to(device) 141 | data = data.to(device) 142 | target = target.to(device) 143 | 144 | with torch.no_grad(): 145 | output = model(data) 146 | loss = criterion(output, target) 147 | 148 | # Measure accuracy and record loss 149 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 150 | losses.update(loss.data.item(), data.size(0)) 151 | top1.update(prec1.data.item(), data.size(0)) 152 | top5.update(prec5.data.item(), data.size(0)) 153 | 154 | # Measure elapsed time 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | 158 | if i % args.print_freq == 0: 159 | print( 160 | "Test: [{0}/{1}]\t" 161 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 162 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 163 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 164 | "Prec@5 {top5.val:.3f} ({top5.avg:.3f})".format( 165 | i, 166 | len(val_loader), 167 | batch_time=batch_time, 168 | loss=losses, 169 | top1=top1, 170 | top5=top5, 171 | ) 172 | ) 173 | val_end_time = time.time() 174 | print(" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Time {time:.3f}".format( 175 | top1=top1, top5=top5, time=val_end_time - val_start_time)) 176 | 177 | return losses.avg, top1.avg, top5.avg 178 | 179 | 180 | class AverageMeter(object): 181 | """Computes and stores the average and current value""" 182 | 183 | def __init__(self): 184 | self.reset() 185 | 186 | def reset(self): 187 | self.val = 0 188 | self.avg = 0 189 | self.sum = 0 190 | self.count = 0 191 | 192 | def update(self, val, n=1): 193 | self.val = val 194 | self.sum += val * n 195 | self.count += n 196 | self.avg = self.sum / self.count 197 | 198 | 199 | def accuracy(output, target, topk=(1,)): 200 | """Computes the precision@k for the specified values of k""" 201 | maxk = max(topk) 202 | batch_size = target.size(0) 203 | 204 | _, pred = output.topk(maxk, 1, True, True) 205 | pred = pred.t() 206 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 207 | 208 | res = [] 209 | for k in topk: 210 | correct_k = correct[:k].reshape(-1).float().sum(0) 211 | res.append(correct_k.mul_(100.0 / batch_size)) 212 | return res 213 | 214 | 215 | if __name__ == "__main__": 216 | parser = argparse.ArgumentParser('PSAQ', parents=[get_args_parser()]) 217 | args = parser.parse_args() 218 | main() 219 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | @torch.no_grad() 11 | def load_weights_from_npz(model, url, check_hash=False, progress=False, prefix=''): 12 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 13 | """ 14 | 15 | def _n2p(w, t=True): 16 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 17 | w = w.flatten() 18 | if t: 19 | if w.ndim == 4: 20 | w = w.transpose([3, 2, 0, 1]) 21 | elif w.ndim == 3: 22 | w = w.transpose([2, 0, 1]) 23 | elif w.ndim == 2: 24 | w = w.transpose([1, 0]) 25 | return torch.from_numpy(w) 26 | 27 | def _get_cache_dir(child_dir=''): 28 | """ 29 | Returns the location of the directory where models are cached (and creates it if necessary). 30 | """ 31 | hub_dir = torch.hub.get_dir() 32 | child_dir = () if not child_dir else (child_dir,) 33 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 34 | os.makedirs(model_dir, exist_ok=True) 35 | return model_dir 36 | 37 | def _download_cached_file(url, check_hash=True, progress=False): 38 | parts = torch.hub.urlparse(url) 39 | filename = os.path.basename(parts.path) 40 | cached_file = os.path.join(_get_cache_dir(), filename) 41 | if not os.path.exists(cached_file): 42 | hash_prefix = None 43 | if check_hash: 44 | r = torch.hub.HASH_REGEX.search( 45 | filename) # r is Optional[Match[str]] 46 | hash_prefix = r.group(1) if r else None 47 | torch.hub.download_url_to_file( 48 | url, cached_file, hash_prefix, progress=progress) 49 | return cached_file 50 | 51 | def adapt_input_conv(in_chans, conv_weight): 52 | conv_type = conv_weight.dtype 53 | # Some weights are in torch.half, ensure it's float for sum on CPU 54 | conv_weight = conv_weight.float() 55 | O, I, J, K = conv_weight.shape 56 | if in_chans == 1: 57 | if I > 3: 58 | assert conv_weight.shape[1] % 3 == 0 59 | # For models with space2depth stems 60 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 61 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 62 | else: 63 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 64 | elif in_chans != 3: 65 | if I != 3: 66 | raise NotImplementedError( 67 | 'Weight format not supported by conversion.') 68 | else: 69 | # NOTE this strategy should be better than random init, but there could be other combinations of 70 | # the original RGB input layer weights that'd work better for specific cases. 71 | repeat = int(math.ceil(in_chans / 3)) 72 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[ 73 | :, :in_chans, :, :] 74 | conv_weight *= (3 / float(in_chans)) 75 | conv_weight = conv_weight.to(conv_type) 76 | return conv_weight 77 | 78 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 79 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 80 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 81 | ntok_new = posemb_new.shape[1] 82 | if num_tokens: 83 | posemb_tok, posemb_grid = posemb[:, 84 | :num_tokens], posemb[0, num_tokens:] 85 | ntok_new -= num_tokens 86 | else: 87 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 88 | gs_old = int(math.sqrt(len(posemb_grid))) 89 | if not len(gs_new): # backwards compatibility 90 | gs_new = [int(math.sqrt(ntok_new))] * 2 91 | assert len(gs_new) >= 2 92 | posemb_grid = posemb_grid.reshape( 93 | 1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 94 | posemb_grid = F.interpolate( 95 | posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 96 | posemb_grid = posemb_grid.permute( 97 | 0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 98 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 99 | return posemb 100 | 101 | cached_file = _download_cached_file( 102 | url, check_hash=check_hash, progress=progress) 103 | 104 | w = np.load(cached_file) 105 | if not prefix and 'opt/target/embedding/kernel' in w: 106 | prefix = 'opt/target/' 107 | 108 | if hasattr(model.patch_embed, 'backbone'): 109 | # hybrid 110 | backbone = model.patch_embed.backbone 111 | stem_only = not hasattr(backbone, 'stem') 112 | stem = backbone if stem_only else backbone.stem 113 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 114 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 115 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 116 | if not stem_only: 117 | for i, stage in enumerate(backbone.stages): 118 | for j, block in enumerate(stage.blocks): 119 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 120 | for r in range(3): 121 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 122 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 123 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 124 | if block.downsample is not None: 125 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 126 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 127 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 128 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 129 | else: 130 | embed_conv_w = adapt_input_conv( 131 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 132 | model.patch_embed.proj.weight.copy_(embed_conv_w) 133 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 134 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 135 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 136 | if pos_embed_w.shape != model.pos_embed.shape: 137 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 138 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 139 | model.pos_embed.copy_(pos_embed_w) 140 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 141 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 142 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 143 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 144 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 145 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 146 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 147 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 148 | for i, block in enumerate(model.blocks.children()): 149 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 150 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 151 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 152 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 153 | block.attn.qkv.weight.copy_(torch.cat([ 154 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 155 | block.attn.qkv.bias.copy_(torch.cat([ 156 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 157 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 158 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 159 | for r in range(2): 160 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 161 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 162 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 163 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 164 | -------------------------------------------------------------------------------- /models/layers_quant.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from itertools import repeat 4 | import collections.abc 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct 11 | 12 | 13 | def _ntuple(n): 14 | def parse(x): 15 | if isinstance(x, collections.abc.Iterable): 16 | return x 17 | return tuple(repeat(x, n)) 18 | 19 | return parse 20 | 21 | 22 | to_2tuple = _ntuple(2) 23 | 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 27 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 28 | def norm_cdf(x): 29 | # Computes standard normal cumulative distribution function 30 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 31 | 32 | if (mean < a - 2 * std) or (mean > b + 2 * std): 33 | warnings.warn( 34 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 35 | "The distribution of values may be incorrect.", 36 | stacklevel=2, 37 | ) 38 | 39 | with torch.no_grad(): 40 | # Values are generated by using a truncated uniform distribution and 41 | # then using the inverse CDF for the normal distribution. 42 | # Get upper and lower cdf values 43 | l = norm_cdf((a - mean) / std) 44 | u = norm_cdf((b - mean) / std) 45 | 46 | # Uniformly fill tensor with values from [l, u], then translate to 47 | # [2l-1, 2u-1]. 48 | tensor.uniform_(2 * l - 1, 2 * u - 1) 49 | 50 | # Use inverse cdf transform for normal distribution to get truncated 51 | # standard normal 52 | tensor.erfinv_() 53 | 54 | # Transform to proper mean, std 55 | tensor.mul_(std * math.sqrt(2.0)) 56 | tensor.add_(mean) 57 | 58 | # Clamp to ensure it's in the proper range 59 | tensor.clamp_(min=a, max=b) 60 | return tensor 61 | 62 | 63 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 64 | # type: (Tensor, float, float, float, float) -> Tensor 65 | r"""Fills the input Tensor with values drawn from a truncated 66 | normal distribution. The values are effectively drawn from the 67 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 68 | with values outside :math:`[a, b]` redrawn until they are within 69 | the bounds. The method used for generating the random values works 70 | best when :math:`a \leq \text{mean} \leq b`. 71 | Args: 72 | tensor: an n-dimensional `torch.Tensor` 73 | mean: the mean of the normal distribution 74 | std: the standard deviation of the normal distribution 75 | a: the minimum cutoff value 76 | b: the maximum cutoff value 77 | Examples: 78 | >>> w = torch.empty(3, 5) 79 | >>> nn.init.trunc_normal_(w) 80 | """ 81 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 82 | 83 | 84 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 85 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 86 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 87 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 88 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 89 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 90 | 'survival rate' as the argument. 91 | """ 92 | if drop_prob == 0.0 or not training: 93 | return x 94 | keep_prob = 1 - drop_prob 95 | shape = (x.shape[0],) + (1,) * ( 96 | x.ndim - 1 97 | ) # work with diff dim tensors, not just 2D ConvNets 98 | random_tensor = keep_prob + \ 99 | torch.rand(shape, dtype=x.dtype, device=x.device) 100 | random_tensor.floor_() # binarize 101 | output = x.div(keep_prob) * random_tensor 102 | return output 103 | 104 | 105 | class DropPath(nn.Module): 106 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 107 | 108 | def __init__(self, drop_prob=None): 109 | super(DropPath, self).__init__() 110 | self.drop_prob = drop_prob 111 | 112 | def forward(self, x): 113 | return drop_path(x, self.drop_prob, self.training) 114 | 115 | 116 | class Mlp(nn.Module): 117 | def __init__(self, 118 | in_features, 119 | hidden_features, 120 | out_features=None, 121 | act_layer=nn.GELU, 122 | drop=0.0, 123 | cfg=None): 124 | super().__init__() 125 | out_features = out_features or in_features 126 | hidden_features = hidden_features or in_features 127 | self.fc1 = QuantLinear(cfg.weight_bit, 128 | in_features, 129 | hidden_features) 130 | self.act = act_layer() 131 | self.QuantAct1 = QuantAct(cfg.activation_bit) 132 | self.fc2 = QuantLinear(cfg.weight_bit, 133 | hidden_features, 134 | out_features) 135 | self.QuantAct2 = QuantAct(cfg.activation_bit) 136 | self.drop = nn.Dropout(drop) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) 140 | x = self.act(x) 141 | x = self.QuantAct1(x) 142 | x = self.drop(x) 143 | x = self.fc2(x) 144 | x = self.QuantAct2(x) 145 | x = self.drop(x) 146 | return x 147 | 148 | 149 | class PatchEmbed(nn.Module): 150 | """Image to Patch Embedding""" 151 | 152 | def __init__(self, 153 | img_size=224, 154 | patch_size=16, 155 | in_chans=3, 156 | embed_dim=768, 157 | norm_layer=None, 158 | cfg=None): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | self.img_size = img_size 163 | self.patch_size = patch_size 164 | 165 | self.grid_size = (img_size[0] // patch_size[0], 166 | img_size[1] // patch_size[1]) 167 | self.num_patches = self.grid_size[0] * self.grid_size[1] 168 | 169 | self.proj = QuantConv2d(cfg.weight_bit, 170 | in_chans, 171 | embed_dim, 172 | kernel_size=patch_size, 173 | stride=patch_size) 174 | if norm_layer: 175 | self.QuantAct_before_norm = QuantAct(cfg.activation_bit) 176 | self.norm = norm_layer(embed_dim) 177 | self.QuantAct = QuantAct(cfg.activation_bit) 178 | else: 179 | self.QuantAct_before_norm = nn.Identity() 180 | self.norm = nn.Identity() 181 | self.QuantAct = QuantAct(cfg.activation_bit) 182 | 183 | def forward(self, x): 184 | B, C, H, W = x.shape 185 | # FIXME look at relaxing size constraints 186 | assert ( 187 | H == self.img_size[0] and W == self.img_size[1] 188 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 189 | x = self.proj(x).flatten(2).transpose(1, 2) 190 | x = self.QuantAct_before_norm(x) 191 | x = self.norm(x) 192 | x = self.QuantAct(x) 193 | return x 194 | 195 | 196 | class HybridEmbed(nn.Module): 197 | """CNN Feature Map Embedding 198 | Extract feature map from CNN, flatten, project to embedding dim. 199 | """ 200 | 201 | def __init__( 202 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 203 | super().__init__() 204 | assert isinstance(backbone, nn.Module) 205 | img_size = to_2tuple(img_size) 206 | self.img_size = img_size 207 | self.backbone = backbone 208 | if feature_size is None: 209 | with torch.no_grad(): 210 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 211 | # map for all networks, the feature metadata has reliable channel and stride info, but using 212 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 213 | training = backbone.training 214 | if training: 215 | backbone.eval() 216 | o = self.backbone(torch.zeros( 217 | 1, in_chans, img_size[0], img_size[1])) 218 | if isinstance(o, (list, tuple)): 219 | # last feature if backbone outputs list/tuple of features 220 | o = o[-1] 221 | feature_size = o.shape[-2:] 222 | feature_dim = o.shape[1] 223 | backbone.train(training) 224 | else: 225 | feature_size = to_2tuple(feature_size) 226 | if hasattr(self.backbone, "feature_info"): 227 | feature_dim = self.backbone.feature_info.channels()[-1] 228 | else: 229 | feature_dim = self.backbone.num_features 230 | self.num_patches = feature_size[0] * feature_size[1] 231 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1) 232 | 233 | def forward(self, x): 234 | x = self.backbone(x) 235 | if isinstance(x, (list, tuple)): 236 | x = x[-1] # last feature if backbone outputs list/tuple of features 237 | x = self.proj(x).flatten(2).transpose(1, 2) 238 | return x 239 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/vit_quant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import re 4 | import warnings 5 | from itertools import repeat 6 | import collections.abc 7 | from collections import OrderedDict 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct 15 | from .layers_quant import PatchEmbed, HybridEmbed, Mlp, DropPath, trunc_normal_ 16 | 17 | 18 | __all__ = ['deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224'] 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__( 23 | self, 24 | dim, 25 | num_heads=8, 26 | qkv_bias=False, 27 | qk_scale=None, 28 | attn_drop=0.0, 29 | proj_drop=0.0, 30 | cfg=None): 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = dim // num_heads 34 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 35 | self.scale = qk_scale or head_dim ** -0.5 36 | 37 | self.qkv = QuantLinear( 38 | cfg.weight_bit, 39 | dim, 40 | dim*3 41 | ) 42 | self.qact1 = QuantAct(cfg.activation_bit) 43 | self.qact2 = QuantAct(cfg.activation_bit) 44 | self.proj = QuantLinear( 45 | cfg.weight_bit, 46 | dim, 47 | dim 48 | ) 49 | self.qact3 = QuantAct(cfg.activation_bit) 50 | self.qact_attn1 = QuantAct(cfg.activation_bit) 51 | self.attn_drop = nn.Dropout(attn_drop) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | def forward(self, x): 55 | B, N, C = x.shape 56 | x = self.qkv(x) 57 | x = self.qact1(x) 58 | qkv = x.reshape(B, N, 3, self.num_heads, C // 59 | self.num_heads).permute(2, 0, 3, 1, 4) # (BN33) 60 | q, k, v = ( 61 | qkv[0], 62 | qkv[1], 63 | qkv[2] 64 | ) # make torchscript happy (cannot use tensor as tuple) 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | attn = self.qact_attn1(attn) 67 | attn = attn.softmax(dim=-1) 68 | attn = self.attn_drop(attn) 69 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 70 | x = self.qact2(x) 71 | x = self.proj(x) 72 | x = self.qact3(x) 73 | x = self.proj_drop(x) 74 | return x 75 | 76 | 77 | class Block(nn.Module): 78 | def __init__( 79 | self, 80 | dim, 81 | num_heads, 82 | mlp_ratio=4.0, 83 | qkv_bias=False, 84 | qk_scale=None, 85 | drop=0.0, 86 | attn_drop=0.0, 87 | drop_path=0.0, 88 | act_layer=nn.GELU, 89 | norm_layer=nn.LayerNorm, 90 | cfg=None): 91 | super().__init__() 92 | self.norm1 = norm_layer(dim) 93 | self.qact1 = QuantAct(cfg.activation_bit) 94 | self.attn = Attention( 95 | dim, 96 | num_heads=num_heads, 97 | qkv_bias=qkv_bias, 98 | qk_scale=qk_scale, 99 | attn_drop=attn_drop, 100 | proj_drop=drop, 101 | cfg=cfg 102 | ) 103 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 104 | self.drop_path = DropPath( 105 | drop_path) if drop_path > 0.0 else nn.Identity() 106 | self.qact2 = QuantAct(cfg.activation_bit) 107 | self.norm2 = norm_layer(dim) 108 | self.qact3 = QuantAct(cfg.activation_bit) 109 | mlp_hidden_dim = int(dim * mlp_ratio) 110 | self.mlp = Mlp( 111 | in_features=dim, 112 | hidden_features=mlp_hidden_dim, 113 | act_layer=act_layer, 114 | drop=drop, 115 | cfg=cfg 116 | ) 117 | self.qact4 = QuantAct(cfg.activation_bit) 118 | 119 | def forward(self, x): 120 | x = self.qact2( 121 | x + self.drop_path(self.attn(self.qact1(self.norm1(x))))) 122 | x = self.qact4(x + self.drop_path(self.mlp(self.qact3(self.norm2(x))))) 123 | return x 124 | 125 | 126 | class VisionTransformer(nn.Module): 127 | """Vision Transformer 128 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 129 | https://arxiv.org/abs/2010.11929 130 | """ 131 | 132 | def __init__( 133 | self, 134 | img_size=224, 135 | patch_size=16, 136 | in_chans=3, 137 | num_classes=1000, 138 | embed_dim=768, 139 | depth=12, 140 | num_heads=12, 141 | mlp_ratio=4.0, 142 | qkv_bias=True, 143 | qk_scale=None, 144 | representation_size=None, 145 | drop_rate=0.0, 146 | attn_drop_rate=0.0, 147 | drop_path_rate=0.0, 148 | hybrid_backbone=None, 149 | norm_layer=None, 150 | input_quant=False, 151 | cfg=None): 152 | super().__init__() 153 | self.num_classes = num_classes 154 | self.num_features = ( 155 | self.embed_dim 156 | ) = embed_dim # num_features for consistency with other models 157 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 158 | 159 | self.cfg = cfg 160 | self.input_quant = input_quant 161 | if input_quant: 162 | self.qact_input = QuantAct(cfg.activation_bit) 163 | 164 | if hybrid_backbone is not None: 165 | self.patch_embed = HybridEmbed( 166 | hybrid_backbone, 167 | img_size=img_size, 168 | in_chans=in_chans, 169 | embed_dim=embed_dim, 170 | ) 171 | else: 172 | self.patch_embed = PatchEmbed( 173 | img_size=img_size, 174 | patch_size=patch_size, 175 | in_chans=in_chans, 176 | embed_dim=embed_dim, 177 | cfg=cfg) 178 | num_patches = self.patch_embed.num_patches 179 | 180 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 181 | self.pos_embed = nn.Parameter( 182 | torch.zeros(1, num_patches + 1, embed_dim)) 183 | self.pos_drop = nn.Dropout(p=drop_rate) 184 | 185 | self.qact_embed = QuantAct(cfg.activation_bit) 186 | self.qact_pos = QuantAct(cfg.activation_bit) 187 | self.qact1 = QuantAct(cfg.activation_bit) 188 | 189 | dpr = [ 190 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 191 | ] # stochastic depth decay rule 192 | self.blocks = nn.ModuleList( 193 | [ 194 | Block( 195 | dim=embed_dim, 196 | num_heads=num_heads, 197 | mlp_ratio=mlp_ratio, 198 | qkv_bias=qkv_bias, 199 | qk_scale=qk_scale, 200 | drop=drop_rate, 201 | attn_drop=attn_drop_rate, 202 | drop_path=dpr[i], 203 | norm_layer=norm_layer, 204 | cfg=cfg 205 | ) 206 | for i in range(depth) 207 | ] 208 | ) 209 | self.norm = norm_layer(embed_dim) 210 | self.qact2 = QuantAct(cfg.activation_bit) 211 | 212 | # Representation layer 213 | if representation_size: 214 | self.num_features = representation_size 215 | self.pre_logits = nn.Sequential( 216 | OrderedDict( 217 | [ 218 | ("fc", nn.Linear(embed_dim, representation_size)), 219 | ("act", nn.Tanh()), 220 | ] 221 | ) 222 | ) 223 | else: 224 | self.pre_logits = nn.Identity() 225 | 226 | # Classifier head 227 | self.head = ( 228 | QuantLinear( 229 | cfg.weight_bit, 230 | self.num_features, 231 | num_classes 232 | ) 233 | if num_classes > 0 234 | else nn.Identity() 235 | ) 236 | self.act_out = QuantAct(cfg.activation_bit) 237 | trunc_normal_(self.pos_embed, std=0.02) 238 | trunc_normal_(self.cls_token, std=0.02) 239 | self.apply(self._init_weights) 240 | 241 | def _init_weights(self, m): 242 | if isinstance(m, nn.Linear): 243 | trunc_normal_(m.weight, std=0.02) 244 | if isinstance(m, nn.Linear) and m.bias is not None: 245 | nn.init.constant_(m.bias, 0) 246 | elif isinstance(m, nn.LayerNorm): 247 | nn.init.constant_(m.bias, 0) 248 | nn.init.constant_(m.weight, 1.0) 249 | 250 | @torch.jit.ignore 251 | def no_weight_decay(self): 252 | return {"pos_embed", "cls_token"} 253 | 254 | def model_quant(self): 255 | for m in self.modules(): 256 | if type(m) in [QuantLinear, QuantConv2d, QuantAct]: 257 | m.quant = True 258 | 259 | def model_freeze(self): 260 | for m in self.modules(): 261 | if type(m) in [QuantAct]: 262 | m.running_stat = False 263 | 264 | def model_unfreeze(self): 265 | for m in self.modules(): 266 | if type(m) in [QuantAct]: 267 | m.running_stat = True 268 | 269 | def forward_features(self, x): 270 | B = x.shape[0] 271 | 272 | if self.input_quant: 273 | x = self.qact_input(x) 274 | 275 | x = self.patch_embed(x) 276 | 277 | cls_tokens = self.cls_token.expand( 278 | B, -1, -1 279 | ) # stole cls_tokens impl from Phil Wang, thanks 280 | x = torch.cat((cls_tokens, x), dim=1) 281 | x = self.qact_embed(x) 282 | x = x + self.qact_pos(self.pos_embed) 283 | x = self.qact1(x) 284 | 285 | x = self.pos_drop(x) 286 | 287 | for blk in self.blocks: 288 | x = blk(x) 289 | 290 | x = self.norm(x)[:, 0] 291 | x = self.qact2(x) 292 | x = self.pre_logits(x) 293 | return x 294 | 295 | def forward(self, x): 296 | x = self.forward_features(x) 297 | x = self.head(x) 298 | x = self.act_out(x) 299 | return x 300 | 301 | 302 | def deit_tiny_patch16_224(pretrained=False, cfg=None, **kwargs): 303 | model = VisionTransformer( 304 | patch_size=16, 305 | embed_dim=192, 306 | depth=12, 307 | num_heads=3, 308 | mlp_ratio=4, 309 | qkv_bias=True, 310 | norm_layer=None, 311 | input_quant=True, 312 | cfg=cfg, 313 | **kwargs, 314 | ) 315 | if pretrained: 316 | checkpoint = torch.hub.load_state_dict_from_url( 317 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 318 | map_location="cpu", 319 | check_hash=True, 320 | ) 321 | model.load_state_dict(checkpoint["model"], strict=False) 322 | return model 323 | 324 | 325 | def deit_small_patch16_224(pretrained=False, cfg=None, **kwargs): 326 | model = VisionTransformer( 327 | patch_size=16, 328 | embed_dim=384, 329 | depth=12, 330 | num_heads=6, 331 | mlp_ratio=4, 332 | qkv_bias=True, 333 | norm_layer=None, 334 | input_quant=True, 335 | cfg=cfg, 336 | **kwargs 337 | ) 338 | if pretrained: 339 | checkpoint = torch.hub.load_state_dict_from_url( 340 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 341 | map_location="cpu", check_hash=True 342 | ) 343 | model.load_state_dict(checkpoint["model"], strict=False) 344 | return model 345 | 346 | 347 | def deit_base_patch16_224(pretrained=False, cfg=None, **kwargs): 348 | model = VisionTransformer( 349 | patch_size=16, 350 | embed_dim=768, 351 | depth=12, 352 | num_heads=12, 353 | mlp_ratio=4, 354 | qkv_bias=True, 355 | norm_layer=None, 356 | input_quant=True, 357 | cfg=cfg, 358 | **kwargs 359 | ) 360 | if pretrained: 361 | checkpoint = torch.hub.load_state_dict_from_url( 362 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 363 | map_location="cpu", check_hash=True 364 | ) 365 | model.load_state_dict(checkpoint["model"], strict=False) 366 | return model 367 | -------------------------------------------------------------------------------- /models/swin_quant.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint as checkpoint 7 | 8 | from .quantization_utils import QuantLinear, QuantConv2d, QuantAct 9 | from .layers_quant import PatchEmbed, HybridEmbed, Mlp, DropPath, trunc_normal_, to_2tuple 10 | 11 | 12 | __all__ = ['swin_tiny_patch4_window7_224', 'swin_small_patch4_window7_224'] 13 | 14 | 15 | def window_partition(x, window_size: int): 16 | """ 17 | Args: 18 | x: (B, H, W, C) 19 | window_size (int): window size 20 | 21 | Returns: 22 | windows: (num_windows*B, window_size, window_size, C) 23 | """ 24 | B, H, W, C = x.shape 25 | x = x.view(B, H // window_size, window_size, 26 | W // window_size, window_size, C) 27 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( 28 | ).view(-1, window_size, window_size, C) 29 | return windows 30 | 31 | 32 | def window_reverse(windows, window_size: int, H: int, W: int): 33 | """ 34 | Args: 35 | windows: (num_windows*B, window_size, window_size, C) 36 | window_size (int): Window size 37 | H (int): Height of image 38 | W (int): Width of image 39 | 40 | Returns: 41 | x: (B, H, W, C) 42 | """ 43 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 44 | x = windows.view(B, H // window_size, W // window_size, 45 | window_size, window_size, -1) 46 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 47 | return x 48 | 49 | 50 | class WindowAttention(nn.Module): 51 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 52 | It supports both of shifted and non-shifted window. 53 | 54 | Args: 55 | dim (int): Number of input channels. 56 | window_size (tuple[int]): The height and width of the window. 57 | num_heads (int): Number of attention heads. 58 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 59 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 60 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 61 | """ 62 | 63 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., cfg=None): 64 | 65 | super().__init__() 66 | self.dim = dim 67 | self.window_size = window_size # Wh, Ww 68 | self.num_heads = num_heads 69 | head_dim = dim // num_heads 70 | self.scale = head_dim ** -0.5 71 | 72 | # define a parameter table of relative position bias 73 | self.relative_position_bias_table = nn.Parameter( 74 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 75 | 76 | # get pair-wise relative position index for each token inside the window 77 | coords_h = torch.arange(self.window_size[0]) 78 | coords_w = torch.arange(self.window_size[1]) 79 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 80 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 81 | relative_coords = coords_flatten[:, :, None] - \ 82 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 83 | relative_coords = relative_coords.permute( 84 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 85 | relative_coords[:, :, 0] += self.window_size[0] - \ 86 | 1 # shift to start from 0 87 | relative_coords[:, :, 1] += self.window_size[1] - 1 88 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 89 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 90 | self.register_buffer("relative_position_index", 91 | relative_position_index) 92 | 93 | self.qkv = QuantLinear( 94 | cfg.weight_bit, 95 | dim, 96 | dim*3 97 | ) 98 | self.qact1 = QuantAct(cfg.activation_bit) 99 | self.qact_attn1 = QuantAct(cfg.activation_bit) 100 | self.qact_table = QuantAct(cfg.activation_bit) 101 | self.qact2 = QuantAct(cfg.activation_bit) 102 | 103 | self.attn_drop = nn.Dropout(attn_drop) 104 | self.qact3 = QuantAct(cfg.activation_bit) 105 | self.qact4 = QuantAct(cfg.activation_bit) 106 | # self.proj = nn.Linear(dim, dim) 107 | self.proj = QuantLinear( 108 | cfg.weight_bit, 109 | dim, 110 | dim 111 | ) 112 | self.proj_drop = nn.Dropout(proj_drop) 113 | 114 | trunc_normal_(self.relative_position_bias_table, std=.02) 115 | self.softmax = nn.Softmax(dim=-1) 116 | 117 | def forward(self, x, mask: Optional[torch.Tensor] = None): 118 | """ 119 | Args: 120 | x: input features with shape of (num_windows*B, N, C) 121 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 122 | """ 123 | B_, N, C = x.shape 124 | x = self.qkv(x) 125 | x = self.qact1(x) 126 | qkv = x.reshape(B_, N, 3, self.num_heads, C // 127 | self.num_heads).permute(2, 0, 3, 1, 4) 128 | # make torchscript happy (cannot use tensor as tuple) 129 | q, k, v = qkv[0], qkv[1], qkv[2] 130 | 131 | q = q * self.scale 132 | attn = (q @ k.transpose(-2, -1)) 133 | attn = self.qact_attn1(attn) 134 | relative_position_bias_table_q = self.qact_table( 135 | self.relative_position_bias_table) 136 | relative_position_bias = relative_position_bias_table_q[self.relative_position_index.view(-1)].view( 137 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 138 | relative_position_bias = relative_position_bias.permute( 139 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 140 | attn = attn + relative_position_bias.unsqueeze(0) 141 | attn = self.qact2(attn) 142 | 143 | if mask is not None: 144 | nW = mask.shape[0] 145 | attn = attn.view(B_ // nW, nW, self.num_heads, N, 146 | N) + mask.unsqueeze(1).unsqueeze(0) 147 | attn = attn.view(-1, self.num_heads, N, N) 148 | attn = self.softmax(attn) 149 | else: 150 | attn = self.softmax(attn) 151 | 152 | attn = self.attn_drop(attn) 153 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 154 | x = self.qact3(x) 155 | x = self.proj(x) 156 | x = self.qact4(x) 157 | x = self.proj_drop(x) 158 | return x 159 | 160 | 161 | class SwinTransformerBlock(nn.Module): 162 | r""" Swin Transformer Block. 163 | 164 | Args: 165 | dim (int): Number of input channels. 166 | input_resolution (tuple[int]): Input resulotion. 167 | num_heads (int): Number of attention heads. 168 | window_size (int): Window size. 169 | shift_size (int): Shift size for SW-MSA. 170 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 171 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 172 | drop (float, optional): Dropout rate. Default: 0.0 173 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 174 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 175 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 176 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 177 | """ 178 | 179 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 180 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 181 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, cfg=None): 182 | super().__init__() 183 | self.dim = dim 184 | self.input_resolution = input_resolution 185 | self.num_heads = num_heads 186 | self.window_size = window_size 187 | self.shift_size = shift_size 188 | self.mlp_ratio = mlp_ratio 189 | if min(self.input_resolution) <= self.window_size: 190 | # if window size is larger than input resolution, we don't partition windows 191 | self.shift_size = 0 192 | self.window_size = min(self.input_resolution) 193 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 194 | 195 | self.norm1 = norm_layer(dim) 196 | self.qact1 = QuantAct(cfg.activation_bit) 197 | self.attn = WindowAttention( 198 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, 199 | attn_drop=attn_drop, proj_drop=drop, cfg=cfg) 200 | 201 | self.drop_path = DropPath( 202 | drop_path) if drop_path > 0. else nn.Identity() 203 | self.qact2 = QuantAct(cfg.activation_bit) 204 | self.norm2 = norm_layer(dim) 205 | self.qact3 = QuantAct(cfg.activation_bit) 206 | mlp_hidden_dim = int(dim * mlp_ratio) 207 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, 208 | drop=drop, cfg=cfg) 209 | self.qact4 = QuantAct(cfg.activation_bit) 210 | if self.shift_size > 0: 211 | # calculate attention mask for SW-MSA 212 | H, W = self.input_resolution 213 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 214 | h_slices = (slice(0, -self.window_size), 215 | slice(-self.window_size, -self.shift_size), 216 | slice(-self.shift_size, None)) 217 | w_slices = (slice(0, -self.window_size), 218 | slice(-self.window_size, -self.shift_size), 219 | slice(-self.shift_size, None)) 220 | cnt = 0 221 | for h in h_slices: 222 | for w in w_slices: 223 | img_mask[:, h, w, :] = cnt 224 | cnt += 1 225 | 226 | # nW, window_size, window_size, 1 227 | mask_windows = window_partition(img_mask, self.window_size) 228 | mask_windows = mask_windows.view(-1, 229 | self.window_size * self.window_size) 230 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 231 | attn_mask = attn_mask.masked_fill( 232 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 233 | else: 234 | attn_mask = None 235 | 236 | self.register_buffer("attn_mask", attn_mask) 237 | 238 | def forward(self, x): 239 | H, W = self.input_resolution 240 | B, L, C = x.shape 241 | assert L == H * W, "input feature has wrong size" 242 | 243 | shortcut = x 244 | x = self.norm1(x) 245 | x = self.qact1(x) 246 | x = x.view(B, H, W, C) 247 | 248 | # cyclic shift 249 | if self.shift_size > 0: 250 | shifted_x = torch.roll( 251 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 252 | else: 253 | shifted_x = x 254 | 255 | # partition windows 256 | # nW*B, window_size, window_size, C 257 | x_windows = window_partition(shifted_x, self.window_size) 258 | # nW*B, window_size*window_size, C 259 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 260 | 261 | # W-MSA/SW-MSA 262 | # nW*B, window_size*window_size, C 263 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 264 | 265 | # merge windows 266 | attn_windows = attn_windows.view(-1, 267 | self.window_size, self.window_size, C) 268 | shifted_x = window_reverse( 269 | attn_windows, self.window_size, H, W) # B H' W' C 270 | 271 | # reverse cyclic shift 272 | if self.shift_size > 0: 273 | x = torch.roll(shifted_x, shifts=( 274 | self.shift_size, self.shift_size), dims=(1, 2)) 275 | else: 276 | x = shifted_x 277 | x = x.view(B, H * W, C) 278 | 279 | # FFN 280 | x = shortcut + self.drop_path(x) 281 | x = self.qact2(x) 282 | x = x + self.drop_path(self.mlp(self.qact3(self.norm2(x)))) 283 | x = self.qact4(x) 284 | 285 | return x 286 | 287 | 288 | class PatchMerging(nn.Module): 289 | r""" Patch Merging Layer. 290 | 291 | Args: 292 | input_resolution (tuple[int]): Resolution of input feature. 293 | dim (int): Number of input channels. 294 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 295 | """ 296 | 297 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, cfg=None): 298 | super().__init__() 299 | self.input_resolution = input_resolution 300 | self.dim = dim 301 | 302 | self.norm = norm_layer(4 * dim) 303 | self.qact1 = QuantAct(cfg.activation_bit) 304 | # self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 305 | self.reduction = QuantLinear( 306 | cfg.weight_bit, 307 | 4 * dim, 308 | 2 * dim, 309 | bias=False 310 | ) 311 | self.qact2 = QuantAct(cfg.activation_bit) 312 | 313 | def forward(self, x): 314 | """ 315 | x: B, H*W, C 316 | """ 317 | H, W = self.input_resolution 318 | B, L, C = x.shape 319 | assert L == H * W, "input feature has wrong size" 320 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 321 | 322 | x = x.view(B, H, W, C) 323 | 324 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 325 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 326 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 327 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 328 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 329 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 330 | x = self.norm(x) 331 | x = self.qact1(x) 332 | x = self.reduction(x) 333 | x = self.qact2(x) 334 | return x 335 | 336 | def extra_repr(self) -> str: 337 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 338 | 339 | def flops(self): 340 | H, W = self.input_resolution 341 | flops = H * W * self.dim 342 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 343 | return flops 344 | 345 | 346 | class BasicLayer(nn.Module): 347 | """ A basic Swin Transformer layer for one stage. 348 | 349 | Args: 350 | dim (int): Number of input channels. 351 | input_resolution (tuple[int]): Input resolution. 352 | depth (int): Number of blocks. 353 | num_heads (int): Number of attention heads. 354 | window_size (int): Local window size. 355 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 356 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 357 | drop (float, optional): Dropout rate. Default: 0.0 358 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 359 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 360 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 361 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 362 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 363 | """ 364 | 365 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 366 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 367 | norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 368 | cfg=None): 369 | 370 | super().__init__() 371 | self.dim = dim 372 | self.input_resolution = input_resolution 373 | self.depth = depth 374 | self.use_checkpoint = use_checkpoint 375 | 376 | # build blocks 377 | self.blocks = nn.ModuleList([ 378 | SwinTransformerBlock( 379 | dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, 380 | shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, 381 | qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, 382 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, 383 | cfg=cfg) 384 | for i in range(depth)]) 385 | 386 | # patch merging layer 387 | if downsample is not None: 388 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, cfg=cfg) 389 | else: 390 | self.downsample = None 391 | 392 | def forward(self, x): 393 | for i, blk in enumerate(self.blocks): 394 | if not torch.jit.is_scripting() and self.use_checkpoint: 395 | x = checkpoint.checkpoint(blk, x) 396 | else: 397 | x = blk(x) 398 | if self.downsample is not None: 399 | x = self.downsample(x) 400 | return x 401 | 402 | def extra_repr(self) -> str: 403 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 404 | 405 | 406 | class SwinTransformer(nn.Module): 407 | r""" Swin Transformer 408 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 409 | https://arxiv.org/pdf/2103.14030 410 | 411 | Args: 412 | img_size (int | tuple(int)): Input image size. Default 224 413 | patch_size (int | tuple(int)): Patch size. Default: 4 414 | in_chans (int): Number of input image channels. Default: 3 415 | num_classes (int): Number of classes for classification head. Default: 1000 416 | embed_dim (int): Patch embedding dimension. Default: 96 417 | depths (tuple(int)): Depth of each Swin Transformer layer. 418 | num_heads (tuple(int)): Number of attention heads in different layers. 419 | window_size (int): Window size. Default: 7 420 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 421 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 422 | drop_rate (float): Dropout rate. Default: 0 423 | attn_drop_rate (float): Attention dropout rate. Default: 0 424 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 425 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 426 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 427 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 428 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 429 | """ 430 | 431 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 432 | embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), 433 | window_size=7, mlp_ratio=4., qkv_bias=True, 434 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 435 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 436 | use_checkpoint=False, input_quant=False, cfg=None, **kwargs): 437 | super().__init__() 438 | 439 | self.num_classes = num_classes 440 | self.num_layers = len(depths) 441 | self.embed_dim = embed_dim 442 | self.ape = ape 443 | self.patch_norm = patch_norm 444 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 445 | self.mlp_ratio = mlp_ratio 446 | self.input_quant = input_quant 447 | self.cfg = cfg 448 | if input_quant: 449 | self.qact_input = QuantAct(cfg.activation_bit) 450 | # split image into non-overlapping patches 451 | self.patch_embed = PatchEmbed( 452 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 453 | norm_layer=norm_layer if self.patch_norm else None, cfg=cfg) 454 | num_patches = self.patch_embed.num_patches 455 | self.patch_grid = self.patch_embed.grid_size 456 | 457 | # absolute position embedding 458 | if self.ape: 459 | self.absolute_pos_embed = nn.Parameter( 460 | torch.zeros(1, num_patches, embed_dim)) 461 | trunc_normal_(self.absolute_pos_embed, std=.02) 462 | self.qact1 = QuantAct(cfg.activation_bit) 463 | else: 464 | self.absolute_pos_embed = None 465 | 466 | self.pos_drop = nn.Dropout(p=drop_rate) 467 | 468 | # stochastic depth 469 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 470 | sum(depths))] # stochastic depth decay rule 471 | 472 | # build layers 473 | layers = [] 474 | for i_layer in range(self.num_layers): 475 | layers += [BasicLayer( 476 | dim=int(embed_dim * 2 ** i_layer), 477 | input_resolution=( 478 | self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), 479 | depth=depths[i_layer], 480 | num_heads=num_heads[i_layer], 481 | window_size=window_size, 482 | mlp_ratio=self.mlp_ratio, 483 | qkv_bias=qkv_bias, 484 | drop=drop_rate, 485 | attn_drop=attn_drop_rate, 486 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 487 | norm_layer=norm_layer, 488 | downsample=PatchMerging if ( 489 | i_layer < self.num_layers - 1) else None, 490 | use_checkpoint=use_checkpoint, 491 | cfg=cfg) 492 | ] 493 | self.layers = nn.Sequential(*layers) 494 | 495 | self.norm = norm_layer(self.num_features) 496 | 497 | self.qact2 = QuantAct(cfg.activation_bit) 498 | self.avgpool = nn.AdaptiveAvgPool1d(1) 499 | self.qact3 = QuantAct(cfg.activation_bit) 500 | self.head = QuantLinear( 501 | cfg.weight_bit, 502 | self.num_features, 503 | num_classes 504 | ) 505 | 506 | self.act_out = QuantAct(cfg.activation_bit) 507 | self.apply(self._init_weights) 508 | 509 | def _init_weights(self, m): 510 | if isinstance(m, nn.Linear): 511 | trunc_normal_(m.weight, std=.02) 512 | if isinstance(m, nn.Linear) and m.bias is not None: 513 | nn.init.constant_(m.bias, 0) 514 | elif isinstance(m, nn.LayerNorm): 515 | nn.init.constant_(m.bias, 0) 516 | nn.init.constant_(m.weight, 1.0) 517 | 518 | @torch.jit.ignore 519 | def no_weight_decay(self): 520 | return {'absolute_pos_embed'} 521 | 522 | @torch.jit.ignore 523 | def no_weight_decay_keywords(self): 524 | return {'relative_position_bias_table'} 525 | 526 | def model_quant(self): 527 | for m in self.modules(): 528 | if type(m) in [QuantLinear, QuantConv2d, QuantAct]: 529 | m.quant = True 530 | 531 | def model_freeze(self): 532 | for m in self.modules(): 533 | if type(m) in [QuantAct]: 534 | m.running_stat = False 535 | 536 | def model_unfreeze(self): 537 | for m in self.modules(): 538 | if type(m) in [QuantAct]: 539 | m.running_stat = True 540 | 541 | def forward_features(self, x): 542 | if self.input_quant: 543 | x = self.qact_input(x) 544 | x = self.patch_embed(x) 545 | if self.absolute_pos_embed is not None: 546 | x = x + self.absolute_pos_embed 547 | x = self.qact1(x) 548 | x = self.pos_drop(x) 549 | for i, layer in enumerate(self.layers): 550 | x = layer(x) 551 | 552 | x = self.norm(x) # B L C 553 | x = self.qact2(x) 554 | 555 | x = self.avgpool(x.transpose(1, 2)) # B C 1 556 | x = self.qact3(x) 557 | 558 | x = torch.flatten(x, 1) 559 | return x 560 | 561 | def forward(self, x): 562 | x = self.forward_features(x) 563 | x = self.head(x) 564 | x = self.act_out(x) 565 | return x 566 | 567 | 568 | def swin_tiny_patch4_window7_224(pretrained=False, cfg=None, **kwargs): 569 | """ Swin-T @ 224x224, trained ImageNet-1k 570 | """ 571 | model = SwinTransformer( 572 | patch_size=4, 573 | window_size=7, 574 | embed_dim=96, 575 | depths=(2, 2, 6, 2), 576 | num_heads=(3, 6, 12, 24), 577 | input_quant=True, 578 | cfg=cfg, 579 | **kwargs 580 | ) 581 | if pretrained: 582 | checkpoint = torch.hub.load_state_dict_from_url( 583 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth", 584 | map_location="cpu", check_hash=True 585 | ) 586 | model.load_state_dict(checkpoint["model"], strict=False) 587 | return model 588 | 589 | 590 | def swin_small_patch4_window7_224(pretrained=False, cfg=None, **kwargs): 591 | """ Swin-S @ 224x224, trained ImageNet-1k 592 | """ 593 | model = SwinTransformer( 594 | patch_size=4, 595 | window_size=7, 596 | embed_dim=96, 597 | depths=(2, 2, 18, 2), 598 | num_heads=(3, 6, 12, 24), 599 | input_quant=True, 600 | cfg=cfg, 601 | **kwargs 602 | ) 603 | if pretrained: 604 | checkpoint = torch.hub.load_state_dict_from_url( 605 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", 606 | map_location="cpu", check_hash=True 607 | ) 608 | model.load_state_dict(checkpoint["model"], strict=False) 609 | return model 610 | --------------------------------------------------------------------------------