├── .gitignore
├── LICENSE
├── README.md
├── WCC
├── QuantConv2d.py
├── WCC.py
├── __init__.py
├── transform_model.py
└── util
│ ├── __init__.py
│ ├── quantization.py
│ └── wavelet.py
└── images
└── semseg.png
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shahaf E. Finder, Yair Zohav, Maor Ashkenazi, and Eran Treister
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WaveletCompressedConvolution
2 | Official implementation for [Wavelet Feature Maps Compression for Image-to-Image CNNs](https://arxiv.org/abs/2205.12268), NeurIPS 2022.
3 |
4 |
5 |
6 |
7 |
8 | # Use as a Drop-in Replacement
9 |
10 | Code example is available at `WCC/transform_model.py`.
11 | Note that it is a common practice to avoid quantizing/compressing the first and last layers of the network.
12 |
13 | # Recreating Paper Experiments
14 |
15 | For best results it's recommended to reduce the number of bits and the compression rate gradually.
16 | E.g., load quantized 8/8 checkpoint to train quantized 8/6, load 4/8 checkpoint to train wcc 4/8 50%, and load wcc 25% to train wcc 12.5%.
17 |
18 | In all the experiments we used a popular implementation, and changed only the main file to include the model transform after creation. For example, from the Deeplab implementation:
19 | ```
20 | WCC.wavelet_deeplabmobilev2(model, opts.wt_levels, opts.wt_compression, opts.bit_w, opts.bit_a)
21 | ```
22 |
23 | ## 1. Object Detection
24 | We used the following implementation:
25 | https://github.com/rwightman/efficientdet-pytorch
26 |
27 | | Precision | Wavelet Shrinkage | BOPs(B) | mAP ↑ |
28 | | :--------: | :----------------: | :-----: | :--------: |
29 | | FP32 | --- | 6,144 | 40.08 |
30 | | 4/8 | --- | 280.4 | 31.44 |
31 | | 4/8 | 50% | 198.5 | 31.15 |
32 | | 4/8 | 25% | 155.4 | 27.49 |
33 |
34 |
35 | ## 2. Semantic Segmentation
36 | We used the following implementation:
37 | https://github.com/VainF/DeepLabV3Plus-Pytorch
38 |
39 | We trained the model with the optional flag `--separable_conv`.
40 |
41 | Cityscapes results: (see paper for Pascal VOC as well as more configurations)
42 | | Precision | Wavelet Shrinkage | BOPs(B) | mIoU ↑ |
43 | | :--------: | :----------------: | :-----: | :---------: |
44 | | FP32 | --- | 36,377 | 0.717 |
45 | | 8/8 | --- | 2,273 | 0.701 |
46 | | 8/6 | --- | 1,705 | 0.683 |
47 | | 8/4 | --- | 1,136 | 0.173 |
48 | | 8/8 | 50% | 1,213 | 0.681 |
49 | | 8/8 | 25% | 673 | 0.620 |
50 | | 8/8 | 12.5% | 403 | 0.552 |
51 |
52 | ## 3. Depth Prediction
53 | We used the following implementation:
54 | https://github.com/nianticlabs/monodepth2
55 |
56 | | Precision | Wavelet Shrinkage | BOPs(B) | AbsRel ↓ | RMSE ↓ |
57 | | :--------: | :----------------: | :-----: | :-----------: | :---------: |
58 | | FP32 | --- | 1,163.6 | 0.093 | 4.022 |
59 | | 8/8 | --- | 133.6 | 0.092 | 4.018 |
60 | | 8/4 | --- | 99.26 | 0.097 | 4.166 |
61 | | 8/2 | --- | 82.1 | 0.268 | 8.223 |
62 | | 8/8 | 50% | 103.9 | 0.098 | 4.217 |
63 | | 8/8 | 25% | 88.5 | 0.112 | 4.663 |
64 | | 8/8 | 12.5% | 80.8 | 0.131 | 5.046 |
65 |
66 | ## 4. Super-resolution
67 | We used the following implementation:
68 | https://github.com/sanghyun-son/EDSR-PyTorch
69 |
70 |
71 | # Citation
72 |
73 | ```
74 | @inproceedings{finder2022wavelet,
75 | title={Wavelet Feature Maps Compression for Image-to-Image CNNs},
76 | author={Finder, Shahaf E and Zohav, Yair and Ashkenazi, Maor and Treister, Eran},
77 | booktitle={Advances in Neural Information Processing Systems},
78 | year={2022}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/WCC/QuantConv2d.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .util.quantization import weight_quantize_fn, act_quantize_fn
7 |
8 |
9 | class QuantConv2d(nn.Conv2d):
10 | def __init__(self, in_channels, out_channels, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1,
11 | padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias=False, bit_w=8,
12 | bit_a=8):
13 | super(QuantConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups,
14 | bias)
15 | self.layer_type = 'QuantConv2d'
16 | self.bit_w = bit_w
17 | self.bit_a = bit_a
18 | self.weight_quant = weight_quantize_fn(self.bit_w)
19 | self.act_quant = act_quantize_fn(self.bit_a)
20 |
21 | def forward(self, x):
22 | weight_q = self.weight_quant(self.weight)
23 | x = self.act_quant(x)
24 | return F.conv2d(x, weight_q, self.bias, self.stride,
25 | self.padding, self.dilation, self.groups)
26 |
27 | def change_bit(self, bit_w, bit_a):
28 | self.bit_w = bit_w
29 | self.bit_a = bit_a
30 | self.weight_quant.change_bit(bit_w)
31 | self.act_quant.change_bit(bit_a)
32 |
--------------------------------------------------------------------------------
/WCC/WCC.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple
2 |
3 | import torch
4 | from torch import nn as nn
5 | from torch.nn import functional as F
6 |
7 | from .util.quantization import weight_quantize_fn, act_quantize_fn
8 | from .util import wavelet
9 |
10 |
11 | class WCC(nn.Conv1d):
12 | def __init__(self, in_channels: int,
13 | out_channels: int,
14 | stride: Union[int, Tuple] = 1,
15 | padding: Union[int, Tuple] = 0,
16 | dilation: Union[int, Tuple] = 1,
17 | groups: int = 1,
18 | bias: bool = False,
19 | levels: int = 3,
20 | compress_rate: float = 0.25,
21 | bit_w: int = 8,
22 | bit_a: int = 8,
23 | wt_type: str = "db1"):
24 | super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias)
25 | self.layer_type = 'WCC'
26 | self.bit_w = bit_w
27 | self.bit_a = bit_a
28 |
29 | self.weight_quant = weight_quantize_fn(self.bit_w)
30 | self.act_quant = act_quantize_fn(self.bit_a, signed=True)
31 |
32 | self.levels = levels
33 | self.wt_type = wt_type
34 | self.compress_rate = compress_rate
35 |
36 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
37 | in_size=in_channels,
38 | out_size=out_channels)
39 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
40 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
41 | self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels)
42 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels)
43 |
44 | self.get_pad = lambda n: ((2**levels) - n) % (2**levels)
45 |
46 | def forward(self, x):
47 | in_shape = x.shape
48 | pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3]))
49 | x = F.pad(x, pads) # pad to match 2^(levels)
50 |
51 | weight_q = self.weight_quant(self.weight) # quantize weights
52 | x = self.wt(x) # H
53 | topk, ids = self.compress(x) # T
54 | topk_q = self.act_quant(topk) # quantize activations
55 | topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1
56 | x = self.decompress(topk_q, ids, x.shape) # T^T
57 | x = self.iwt(x) # H^T
58 |
59 | x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads
60 | return x
61 |
62 | def compress(self, x):
63 | b, c, h, w = x.shape
64 | acc = x.norm(dim=1).pow(2)
65 | acc = acc.view(b, h * w)
66 | k = int(h * w * self.compress_rate)
67 | ids = acc.topk(k, dim=1, sorted=False)[1]
68 | ids.unsqueeze_(dim=1)
69 | topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1))
70 | return topk, ids
71 |
72 | def decompress(self, topk, ids, shape):
73 | b, _, h, w = shape
74 | ids = ids.repeat(1, self.out_channels, 1)
75 | x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device)
76 | x = x.scatter(dim=2, index=ids, src=topk)
77 | x = x.reshape((b, self.out_channels, h, w))
78 | return x
79 |
80 | def change_wt_params(self, compress_rate, levels, wt_type="db1"):
81 | self.compress_rate = compress_rate
82 | self.levels = levels
83 | dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type,
84 | in_size=self.in_channels,
85 | out_size=self.out_channels)
86 | self.wt_filters = nn.Parameter(dec_filters, requires_grad=False)
87 | self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False)
88 | self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels)
89 | self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels)
90 |
91 | def change_bit(self, bit_w, bit_a):
92 | self.bit_w = bit_w
93 | self.bit_a = bit_a
94 | self.weight_quant.change_bit(bit_w)
95 | self.act_quant.change_bit(bit_a)
96 |
--------------------------------------------------------------------------------
/WCC/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/WCC/__init__.py
--------------------------------------------------------------------------------
/WCC/transform_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .WCC import WCC
3 | from .QuantConv2d import QuantConv2d
4 |
5 |
6 | def quantize_deeplabmobilev2(model, bit_w, bit_a):
7 | first_conv = model.backbone.low_level_features[0][0]
8 | last_layer = model.classifier.classifier[3]
9 | quantize_module(model, bit_w, bit_a)
10 | model.backbone.low_level_features[0][0] = first_conv
11 | model.classifier.classifier[3] = last_layer
12 |
13 |
14 | def wavelet_deeplabmobilev2(model, levels, compress_rate, bit_w, bit_a):
15 | first_conv = model.backbone.low_level_features[0][0]
16 | last_layer = model.classifier.classifier[3]
17 | wavelet_module(model, levels, compress_rate, bit_w, bit_a)
18 | model.backbone.low_level_features[0][0] = first_conv
19 | model.classifier.classifier[3] = last_layer
20 |
21 |
22 | def quantize_module(module, bit_w, bit_a):
23 | new_module = module
24 | if isinstance(module, nn.Conv2d):
25 | new_module = QuantConv2d(module.in_channels,
26 | module.out_channels,
27 | module.kernel_size,
28 | module.stride,
29 | module.padding,
30 | module.dilation,
31 | module.groups,
32 | module.bias is not None,
33 | bit_w,
34 | bit_a)
35 | new_module.weight = module.weight
36 | new_module.bias = module.bias
37 | for name, child in module.named_children():
38 | new_module.add_module(name, quantize_module(child, bit_w, bit_a))
39 | return new_module
40 |
41 |
42 | def change_module_bits(module, bit_w, bit_a):
43 | if isinstance(module, QuantConv2d) or isinstance(module, WCC):
44 | module.change_bit(bit_w, bit_a)
45 | else:
46 | for name, child in module.named_children():
47 | change_module_bits(child, bit_w, bit_a)
48 |
49 |
50 | def wavelet_module(module, levels, compress_rate, bit_w, bit_a):
51 | new_module = module
52 | if isinstance(module, nn.Conv2d):
53 | if module.kernel_size[0] > 1:
54 | new_module = QuantConv2d(module.in_channels,
55 | module.out_channels,
56 | module.kernel_size,
57 | module.stride,
58 | module.padding,
59 | module.dilation,
60 | module.groups,
61 | module.bias is not None,
62 | bit_w,
63 | bit_a)
64 | new_module.weight = module.weight
65 | new_module.bias = module.bias
66 | else:
67 | new_module = WCC(module.in_channels,
68 | module.out_channels,
69 | module.stride[0],
70 | module.padding[0],
71 | module.dilation[0],
72 | module.groups,
73 | module.bias is not None,
74 | levels,
75 | compress_rate,
76 | bit_w,
77 | bit_a)
78 | new_module.weight = nn.Parameter(module.weight.squeeze(-1))
79 | new_module.bias = module.bias
80 | if isinstance(module, QuantConv2d):
81 | new_module.act_quant.a_alpha = module.act_quant.a_alpha
82 | new_module.weight_quant.w_alpha = module.weight_quant.w_alpha
83 | else:
84 | for name, child in module.named_children():
85 | new_module.add_module(name, wavelet_module(child, levels, compress_rate, bit_w, bit_a))
86 | return new_module
87 |
88 |
89 | def change_module_wt_params(module, compress_rate, levels):
90 | if isinstance(module, WCC):
91 | module.change_wt_params(compress_rate, levels)
92 | else:
93 | for name, child in module.named_children():
94 | change_module_wt_params(child, compress_rate, levels)
95 |
--------------------------------------------------------------------------------
/WCC/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/WCC/util/__init__.py
--------------------------------------------------------------------------------
/WCC/util/quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import Parameter
4 |
5 |
6 | def weight_quantization(b):
7 | def uniform_quant(x, b):
8 | xdiv = x.mul((2 ** b - 1))
9 | xhard = xdiv.round().div(2 ** b - 1)
10 | return xhard
11 |
12 | class _pq(torch.autograd.Function):
13 | @staticmethod
14 | def forward(ctx, input, alpha):
15 | input.div_(alpha) # weights are first divided by alpha
16 | input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1]
17 | sign = input_c.sign()
18 | input_abs = input_c.abs()
19 | input_q = uniform_quant(input_abs, b).mul(sign)
20 | ctx.save_for_backward(input, input_q)
21 | input_q = input_q.mul(alpha) # rescale to the original range
22 | return input_q
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | grad_input = grad_output.clone() # grad for weights will not be clipped
27 | input, input_q = ctx.saved_tensors
28 | i = (input.abs() > 1.).float()
29 | sign = input.sign()
30 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
31 | return grad_input, grad_alpha
32 |
33 | return _pq().apply
34 |
35 |
36 | class weight_quantize_fn(nn.Module):
37 | def __init__(self, bit_w):
38 | super(weight_quantize_fn, self).__init__()
39 | assert bit_w > 0
40 |
41 | self.bit_w = bit_w - 1
42 | self.weight_q = weight_quantization(b=self.bit_w)
43 | self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True))
44 |
45 | def forward(self, weight):
46 | mean = weight.data.mean()
47 | std = weight.data.std()
48 | weight = weight.add(-mean).div(std) # weights normalization
49 | weight_q = self.weight_q(weight, self.w_alpha)
50 | return weight_q
51 |
52 | def change_bit(self, bit_w):
53 | self.bit_w = bit_w - 1
54 | self.weight_q = weight_quantization(b=self.bit_w)
55 |
56 | def act_quantization(b, signed=False):
57 | def uniform_quant(x, b=3):
58 | xdiv = x.mul(2 ** b - 1)
59 | xhard = xdiv.round().div(2 ** b - 1)
60 | return xhard
61 |
62 | class _uq(torch.autograd.Function):
63 | @staticmethod
64 | def forward(ctx, input, alpha):
65 | input = input.div(alpha)
66 | input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1)
67 | input_q = uniform_quant(input_c, b)
68 | ctx.save_for_backward(input, input_q)
69 | input_q = input_q.mul(alpha)
70 | return input_q
71 |
72 | @staticmethod
73 | def backward(ctx, grad_output):
74 | grad_input = grad_output.clone()
75 | input, input_q = ctx.saved_tensors
76 | i = (input.abs() > 1.).float()
77 | sign = input.sign()
78 | grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum()
79 | grad_input = grad_input * (1 - i)
80 | return grad_input, grad_alpha
81 |
82 | return _uq().apply
83 |
84 | class act_quantize_fn(nn.Module):
85 | def __init__(self, bit_a, signed=False):
86 | super(act_quantize_fn, self).__init__()
87 | self.bit_a = bit_a
88 | self.signed = signed
89 | if signed:
90 | self.bit_a -= 1
91 | assert bit_a > 0
92 |
93 | self.act_q = act_quantization(b=self.bit_a, signed=signed)
94 | self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True))
95 |
96 | def forward(self, x):
97 | return self.act_q(x, self.a_alpha)
98 |
99 | def change_bit(self, bit_a):
100 | self.bit_a = bit_a
101 | if self.signed:
102 | self.bit_a -= 1
103 | self.act_q = act_quantization(b=self.bit_a, signed=self.signed)
104 |
--------------------------------------------------------------------------------
/WCC/util/wavelet.py:
--------------------------------------------------------------------------------
1 | import pywt
2 | import pywt.data
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | import torch.nn.functional as F
7 |
8 |
9 | def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
10 | w = pywt.Wavelet(wave)
11 | dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
12 | dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
13 | dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
14 | dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
15 | dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
16 | dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
17 |
18 | dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
19 |
20 | rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
21 | rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
22 | rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
23 | rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
24 | rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
25 | rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
26 |
27 | rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
28 |
29 | return dec_filters, rec_filters
30 |
31 |
32 | def wt(x, filters, in_size, level):
33 | _, _, h, w = x.shape
34 | pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
35 | res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad)
36 | if level > 1:
37 | res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1)
38 | res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w)
39 | return res
40 |
41 |
42 | def iwt(x, inv_filters, in_size, level):
43 | _, _, h, w = x.shape
44 | pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1)
45 | res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2)
46 | if level > 1:
47 | res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1)
48 | res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad)
49 | return res
50 |
51 |
52 | def get_inverse_transform(weights, in_size, level):
53 | class InverseWaveletTransform(Function):
54 |
55 | @staticmethod
56 | def forward(ctx, input):
57 | with torch.no_grad():
58 | x = iwt(input, weights, in_size, level)
59 | return x
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | grad = wt(grad_output, weights, in_size, level)
64 | return grad, None
65 |
66 | return InverseWaveletTransform().apply
67 |
68 |
69 | def get_transform(weights, in_size, level):
70 | class WaveletTransform(Function):
71 |
72 | @staticmethod
73 | def forward(ctx, input):
74 | with torch.no_grad():
75 | x = wt(input, weights, in_size, level)
76 | return x
77 |
78 | @staticmethod
79 | def backward(ctx, grad_output):
80 | grad = iwt(grad_output, weights, in_size, level)
81 | return grad, None
82 |
83 | return WaveletTransform().apply
84 |
--------------------------------------------------------------------------------
/images/semseg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BGUCompSci/WaveletCompressedConvolution/7744984627338c5d29be8ff6ef29deb648f52d97/images/semseg.png
--------------------------------------------------------------------------------