├── .gitignore ├── LICENSE ├── README.md ├── WCC ├── QuantConv2d.py ├── WCC.py ├── __init__.py ├── transform_model.py └── util │ ├── __init__.py │ ├── quantization.py │ └── wavelet.py └── images └── semseg.png /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shahaf E. Finder, Yair Zohav, Maor Ashkenazi, and Eran Treister 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaveletCompressedConvolution 2 | Official implementation for [Wavelet Feature Maps Compression for Image-to-Image CNNs](https://arxiv.org/abs/2205.12268), NeurIPS 2022. 3 | 4 |

5 | 6 |

7 | 8 | # Use as a Drop-in Replacement 9 | 10 | Code example is available at `WCC/transform_model.py`. 11 | Note that it is a common practice to avoid quantizing/compressing the first and last layers of the network. 12 | 13 | # Recreating Paper Experiments 14 | 15 | For best results it's recommended to reduce the number of bits and the compression rate gradually. 16 | E.g., load quantized 8/8 checkpoint to train quantized 8/6, load 4/8 checkpoint to train wcc 4/8 50%, and load wcc 25% to train wcc 12.5%. 17 | 18 | In all the experiments we used a popular implementation, and changed only the main file to include the model transform after creation. For example, from the Deeplab implementation: 19 | ``` 20 | WCC.wavelet_deeplabmobilev2(model, opts.wt_levels, opts.wt_compression, opts.bit_w, opts.bit_a) 21 | ``` 22 | 23 | ## 1. Object Detection 24 | We used the following implementation: 25 | https://github.com/rwightman/efficientdet-pytorch 26 | 27 | | Precision | Wavelet Shrinkage | BOPs(B) | mAP ↑ | 28 | | :--------: | :----------------: | :-----: | :--------: | 29 | | FP32 | --- | 6,144 | 40.08 | 30 | | 4/8 | --- | 280.4 | 31.44 | 31 | | 4/8 | 50% | 198.5 | 31.15 | 32 | | 4/8 | 25% | 155.4 | 27.49 | 33 | 34 | 35 | ## 2. Semantic Segmentation 36 | We used the following implementation: 37 | https://github.com/VainF/DeepLabV3Plus-Pytorch 38 | 39 | We trained the model with the optional flag `--separable_conv`. 40 | 41 | Cityscapes results: (see paper for Pascal VOC as well as more configurations) 42 | | Precision | Wavelet Shrinkage | BOPs(B) | mIoU ↑ | 43 | | :--------: | :----------------: | :-----: | :---------: | 44 | | FP32 | --- | 36,377 | 0.717 | 45 | | 8/8 | --- | 2,273 | 0.701 | 46 | | 8/6 | --- | 1,705 | 0.683 | 47 | | 8/4 | --- | 1,136 | 0.173 | 48 | | 8/8 | 50% | 1,213 | 0.681 | 49 | | 8/8 | 25% | 673 | 0.620 | 50 | | 8/8 | 12.5% | 403 | 0.552 | 51 | 52 | ## 3. Depth Prediction 53 | We used the following implementation: 54 | https://github.com/nianticlabs/monodepth2 55 | 56 | | Precision | Wavelet Shrinkage | BOPs(B) | AbsRel ↓ | RMSE ↓ | 57 | | :--------: | :----------------: | :-----: | :-----------: | :---------: | 58 | | FP32 | --- | 1,163.6 | 0.093 | 4.022 | 59 | | 8/8 | --- | 133.6 | 0.092 | 4.018 | 60 | | 8/4 | --- | 99.26 | 0.097 | 4.166 | 61 | | 8/2 | --- | 82.1 | 0.268 | 8.223 | 62 | | 8/8 | 50% | 103.9 | 0.098 | 4.217 | 63 | | 8/8 | 25% | 88.5 | 0.112 | 4.663 | 64 | | 8/8 | 12.5% | 80.8 | 0.131 | 5.046 | 65 | 66 | ## 4. Super-resolution 67 | We used the following implementation: 68 | https://github.com/sanghyun-son/EDSR-PyTorch 69 | 70 | 71 | # Citation 72 | 73 | ``` 74 | @inproceedings{finder2022wavelet, 75 | title={Wavelet Feature Maps Compression for Image-to-Image CNNs}, 76 | author={Finder, Shahaf E and Zohav, Yair and Ashkenazi, Maor and Treister, Eran}, 77 | booktitle={Advances in Neural Information Processing Systems}, 78 | year={2022} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /WCC/QuantConv2d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .util.quantization import weight_quantize_fn, act_quantize_fn 7 | 8 | 9 | class QuantConv2d(nn.Conv2d): 10 | def __init__(self, in_channels, out_channels, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1, 11 | padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias=False, bit_w=8, 12 | bit_a=8): 13 | super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, 14 | bias) 15 | self.layer_type = 'QuantConv2d' 16 | self.bit_w = bit_w 17 | self.bit_a = bit_a 18 | self.weight_quant = weight_quantize_fn(self.bit_w) 19 | self.act_quant = act_quantize_fn(self.bit_a) 20 | 21 | def forward(self, x): 22 | weight_q = self.weight_quant(self.weight) 23 | x = self.act_quant(x) 24 | return F.conv2d(x, weight_q, self.bias, self.stride, 25 | self.padding, self.dilation, self.groups) 26 | 27 | def change_bit(self, bit_w, bit_a): 28 | self.bit_w = bit_w 29 | self.bit_a = bit_a 30 | self.weight_quant.change_bit(bit_w) 31 | self.act_quant.change_bit(bit_a) 32 | -------------------------------------------------------------------------------- /WCC/WCC.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .util.quantization import weight_quantize_fn, act_quantize_fn 8 | from .util import wavelet 9 | 10 | 11 | class WCC(nn.Conv1d): 12 | def __init__(self, in_channels: int, 13 | out_channels: int, 14 | stride: Union[int, Tuple] = 1, 15 | padding: Union[int, Tuple] = 0, 16 | dilation: Union[int, Tuple] = 1, 17 | groups: int = 1, 18 | bias: bool = False, 19 | levels: int = 3, 20 | compress_rate: float = 0.25, 21 | bit_w: int = 8, 22 | bit_a: int = 8, 23 | wt_type: str = "db1"): 24 | super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias) 25 | self.layer_type = 'WCC' 26 | self.bit_w = bit_w 27 | self.bit_a = bit_a 28 | 29 | self.weight_quant = weight_quantize_fn(self.bit_w) 30 | self.act_quant = act_quantize_fn(self.bit_a, signed=True) 31 | 32 | self.levels = levels 33 | self.wt_type = wt_type 34 | self.compress_rate = compress_rate 35 | 36 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, 37 | in_size=in_channels, 38 | out_size=out_channels) 39 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) 40 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) 41 | self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels) 42 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels) 43 | 44 | self.get_pad = lambda n: ((2**levels) - n) % (2**levels) 45 | 46 | def forward(self, x): 47 | in_shape = x.shape 48 | pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3])) 49 | x = F.pad(x, pads) # pad to match 2^(levels) 50 | 51 | weight_q = self.weight_quant(self.weight) # quantize weights 52 | x = self.wt(x) # H 53 | topk, ids = self.compress(x) # T 54 | topk_q = self.act_quant(topk) # quantize activations 55 | topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1 56 | x = self.decompress(topk_q, ids, x.shape) # T^T 57 | x = self.iwt(x) # H^T 58 | 59 | x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads 60 | return x 61 | 62 | def compress(self, x): 63 | b, c, h, w = x.shape 64 | acc = x.norm(dim=1).pow(2) 65 | acc = acc.view(b, h * w) 66 | k = int(h * w * self.compress_rate) 67 | ids = acc.topk(k, dim=1, sorted=False)[1] 68 | ids.unsqueeze_(dim=1) 69 | topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1)) 70 | return topk, ids 71 | 72 | def decompress(self, topk, ids, shape): 73 | b, _, h, w = shape 74 | ids = ids.repeat(1, self.out_channels, 1) 75 | x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device) 76 | x = x.scatter(dim=2, index=ids, src=topk) 77 | x = x.reshape((b, self.out_channels, h, w)) 78 | return x 79 | 80 | def change_wt_params(self, compress_rate, levels, wt_type="db1"): 81 | self.compress_rate = compress_rate 82 | self.levels = levels 83 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, 84 | in_size=self.in_channels, 85 | out_size=self.out_channels) 86 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) 87 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) 88 | self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels) 89 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels) 90 | 91 | def change_bit(self, bit_w, bit_a): 92 | self.bit_w = bit_w 93 | self.bit_a = bit_a 94 | self.weight_quant.change_bit(bit_w) 95 | self.act_quant.change_bit(bit_a) 96 | -------------------------------------------------------------------------------- /WCC/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/WCC/__init__.py -------------------------------------------------------------------------------- /WCC/transform_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .WCC import WCC 3 | from .QuantConv2d import QuantConv2d 4 | 5 | 6 | def quantize_deeplabmobilev2(model, bit_w, bit_a): 7 | first_conv = model.backbone.low_level_features[0][0] 8 | last_layer = model.classifier.classifier[3] 9 | quantize_module(model, bit_w, bit_a) 10 | model.backbone.low_level_features[0][0] = first_conv 11 | model.classifier.classifier[3] = last_layer 12 | 13 | 14 | def wavelet_deeplabmobilev2(model, levels, compress_rate, bit_w, bit_a): 15 | first_conv = model.backbone.low_level_features[0][0] 16 | last_layer = model.classifier.classifier[3] 17 | wavelet_module(model, levels, compress_rate, bit_w, bit_a) 18 | model.backbone.low_level_features[0][0] = first_conv 19 | model.classifier.classifier[3] = last_layer 20 | 21 | 22 | def quantize_module(module, bit_w, bit_a): 23 | new_module = module 24 | if isinstance(module, nn.Conv2d): 25 | new_module = QuantConv2d(module.in_channels, 26 | module.out_channels, 27 | module.kernel_size, 28 | module.stride, 29 | module.padding, 30 | module.dilation, 31 | module.groups, 32 | module.bias is not None, 33 | bit_w, 34 | bit_a) 35 | new_module.weight = module.weight 36 | new_module.bias = module.bias 37 | for name, child in module.named_children(): 38 | new_module.add_module(name, quantize_module(child, bit_w, bit_a)) 39 | return new_module 40 | 41 | 42 | def change_module_bits(module, bit_w, bit_a): 43 | if isinstance(module, QuantConv2d) or isinstance(module, WCC): 44 | module.change_bit(bit_w, bit_a) 45 | else: 46 | for name, child in module.named_children(): 47 | change_module_bits(child, bit_w, bit_a) 48 | 49 | 50 | def wavelet_module(module, levels, compress_rate, bit_w, bit_a): 51 | new_module = module 52 | if isinstance(module, nn.Conv2d): 53 | if module.kernel_size[0] > 1: 54 | new_module = QuantConv2d(module.in_channels, 55 | module.out_channels, 56 | module.kernel_size, 57 | module.stride, 58 | module.padding, 59 | module.dilation, 60 | module.groups, 61 | module.bias is not None, 62 | bit_w, 63 | bit_a) 64 | new_module.weight = module.weight 65 | new_module.bias = module.bias 66 | else: 67 | new_module = WCC(module.in_channels, 68 | module.out_channels, 69 | module.stride[0], 70 | module.padding[0], 71 | module.dilation[0], 72 | module.groups, 73 | module.bias is not None, 74 | levels, 75 | compress_rate, 76 | bit_w, 77 | bit_a) 78 | new_module.weight = nn.Parameter(module.weight.squeeze(-1)) 79 | new_module.bias = module.bias 80 | if isinstance(module, QuantConv2d): 81 | new_module.act_quant.a_alpha = module.act_quant.a_alpha 82 | new_module.weight_quant.w_alpha = module.weight_quant.w_alpha 83 | else: 84 | for name, child in module.named_children(): 85 | new_module.add_module(name, wavelet_module(child, levels, compress_rate, bit_w, bit_a)) 86 | return new_module 87 | 88 | 89 | def change_module_wt_params(module, compress_rate, levels): 90 | if isinstance(module, WCC): 91 | module.change_wt_params(compress_rate, levels) 92 | else: 93 | for name, child in module.named_children(): 94 | change_module_wt_params(child, compress_rate, levels) 95 | -------------------------------------------------------------------------------- /WCC/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/WCC/util/__init__.py -------------------------------------------------------------------------------- /WCC/util/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import Parameter 4 | 5 | 6 | def weight_quantization(b): 7 | def uniform_quant(x, b): 8 | xdiv = x.mul((2 ** b - 1)) 9 | xhard = xdiv.round().div(2 ** b - 1) 10 | return xhard 11 | 12 | class _pq(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, input, alpha): 15 | input.div_(alpha) # weights are first divided by alpha 16 | input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1] 17 | sign = input_c.sign() 18 | input_abs = input_c.abs() 19 | input_q = uniform_quant(input_abs, b).mul(sign) 20 | ctx.save_for_backward(input, input_q) 21 | input_q = input_q.mul(alpha) # rescale to the original range 22 | return input_q 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | grad_input = grad_output.clone() # grad for weights will not be clipped 27 | input, input_q = ctx.saved_tensors 28 | i = (input.abs() > 1.).float() 29 | sign = input.sign() 30 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() 31 | return grad_input, grad_alpha 32 | 33 | return _pq().apply 34 | 35 | 36 | class weight_quantize_fn(nn.Module): 37 | def __init__(self, bit_w): 38 | super(weight_quantize_fn, self).__init__() 39 | assert bit_w > 0 40 | 41 | self.bit_w = bit_w - 1 42 | self.weight_q = weight_quantization(b=self.bit_w) 43 | self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True)) 44 | 45 | def forward(self, weight): 46 | mean = weight.data.mean() 47 | std = weight.data.std() 48 | weight = weight.add(-mean).div(std) # weights normalization 49 | weight_q = self.weight_q(weight, self.w_alpha) 50 | return weight_q 51 | 52 | def change_bit(self, bit_w): 53 | self.bit_w = bit_w - 1 54 | self.weight_q = weight_quantization(b=self.bit_w) 55 | 56 | def act_quantization(b, signed=False): 57 | def uniform_quant(x, b=3): 58 | xdiv = x.mul(2 ** b - 1) 59 | xhard = xdiv.round().div(2 ** b - 1) 60 | return xhard 61 | 62 | class _uq(torch.autograd.Function): 63 | @staticmethod 64 | def forward(ctx, input, alpha): 65 | input = input.div(alpha) 66 | input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1) 67 | input_q = uniform_quant(input_c, b) 68 | ctx.save_for_backward(input, input_q) 69 | input_q = input_q.mul(alpha) 70 | return input_q 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | grad_input = grad_output.clone() 75 | input, input_q = ctx.saved_tensors 76 | i = (input.abs() > 1.).float() 77 | sign = input.sign() 78 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() 79 | grad_input = grad_input * (1 - i) 80 | return grad_input, grad_alpha 81 | 82 | return _uq().apply 83 | 84 | class act_quantize_fn(nn.Module): 85 | def __init__(self, bit_a, signed=False): 86 | super(act_quantize_fn, self).__init__() 87 | self.bit_a = bit_a 88 | self.signed = signed 89 | if signed: 90 | self.bit_a -= 1 91 | assert bit_a > 0 92 | 93 | self.act_q = act_quantization(b=self.bit_a, signed=signed) 94 | self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True)) 95 | 96 | def forward(self, x): 97 | return self.act_q(x, self.a_alpha) 98 | 99 | def change_bit(self, bit_a): 100 | self.bit_a = bit_a 101 | if self.signed: 102 | self.bit_a -= 1 103 | self.act_q = act_quantization(b=self.bit_a, signed=self.signed) 104 | -------------------------------------------------------------------------------- /WCC/util/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import pywt.data 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | import torch.nn.functional as F 7 | 8 | 9 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float): 10 | w = pywt.Wavelet(wave) 11 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type) 12 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type) 13 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), 14 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), 15 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), 16 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0) 17 | 18 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1) 19 | 20 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0]) 21 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0]) 22 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1), 23 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1), 24 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1), 25 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0) 26 | 27 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1) 28 | 29 | return dec_filters, rec_filters 30 | 31 | 32 | def wt(x, filters, in_size, level): 33 | _, _, h, w = x.shape 34 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) 35 | res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad) 36 | if level > 1: 37 | res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1) 38 | res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w) 39 | return res 40 | 41 | 42 | def iwt(x, inv_filters, in_size, level): 43 | _, _, h, w = x.shape 44 | pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1) 45 | res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2) 46 | if level > 1: 47 | res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1) 48 | res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad) 49 | return res 50 | 51 | 52 | def get_inverse_transform(weights, in_size, level): 53 | class InverseWaveletTransform(Function): 54 | 55 | @staticmethod 56 | def forward(ctx, input): 57 | with torch.no_grad(): 58 | x = iwt(input, weights, in_size, level) 59 | return x 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | grad = wt(grad_output, weights, in_size, level) 64 | return grad, None 65 | 66 | return InverseWaveletTransform().apply 67 | 68 | 69 | def get_transform(weights, in_size, level): 70 | class WaveletTransform(Function): 71 | 72 | @staticmethod 73 | def forward(ctx, input): 74 | with torch.no_grad(): 75 | x = wt(input, weights, in_size, level) 76 | return x 77 | 78 | @staticmethod 79 | def backward(ctx, grad_output): 80 | grad = iwt(grad_output, weights, in_size, level) 81 | return grad, None 82 | 83 | return WaveletTransform().apply 84 | -------------------------------------------------------------------------------- /images/semseg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/images/semseg.png --------------------------------------------------------------------------------