├── lib ├── __init__.py ├── utils.py ├── gumbel_module.py ├── gates.py └── skip_conv.py ├── resources └── skipconv.png ├── LICENSE ├── main.py └── README.md /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /resources/skipconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qualcomm-AI-research/Skip-Conv/HEAD/resources/skipconv.png -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | # All Rights Reserved. 4 | 5 | import torch 6 | 7 | 8 | def roll_time(x: torch.Tensor) -> torch.Tensor: 9 | return x.view((-1,) + x.shape[2:]) 10 | 11 | 12 | def unroll_time(x: torch.Tensor, t: int) -> torch.Tensor: 13 | return x.view( 14 | ( 15 | -1, 16 | t, 17 | ) 18 | + x.shape[1:] 19 | ) 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer: 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | * Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | # All Rights Reserved. 4 | 5 | import torch 6 | 7 | from lib.gates import GateType 8 | from lib.gates import NormGateType 9 | from lib.skip_conv import SkipConv2d 10 | 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | 13 | b, t, c, h, w = 1, 8, 32, 224, 224 14 | 15 | conv_ops = { 16 | "gate_type": GateType.GUMBEL_GATE, 17 | "in_channels": c, 18 | "out_channels": 64, 19 | "kernel_size": 3, 20 | "stride": 1, 21 | "padding": 1, 22 | "norm_gate_type": NormGateType.OUTPUT, 23 | "norm_gate_eps": 1e-1, 24 | "gumbel_gate_structure": 2, 25 | } 26 | 27 | 28 | def forward_train(model: SkipConv2d, x: torch.Tensor) -> None: 29 | """ 30 | During training, the Skip-Convolution is fed with clips of t frames. 31 | As such, the input tensor has shape (batchsize, n_frames, channels, height, width). 32 | The model is stateless in training mode. 33 | 34 | :param model: the skip-convolution module. 35 | :param x: input tensor having shape (batchsize, n_frames, channels, height, width). 36 | """ 37 | model = model.train() 38 | 39 | y = model(x) 40 | print(y.shape) 41 | 42 | 43 | def forward_test(model: SkipConv2d, x: torch.Tensor, reset_every: int = 4) -> None: 44 | """ 45 | During test, a sequence of t frames is fed iteratively in a for loop. 46 | As such, the input tensor has shape (batchsize, channels, height, width). 47 | The model is stateful in eval mode, and it stores the previous input and output tensors. 48 | Every `reset_every` frames, the state is reset, and a new reference frame is instantiated. 49 | 50 | :param model: the skip-convolution module. 51 | :param x: input tensor having shape (batchsize, n_frames, channels, height, width). 52 | :param reset_every: interval between reference frames. 53 | """ 54 | model = model.eval() 55 | 56 | y = [] 57 | for frame_idx in range(x.shape[1]): 58 | if frame_idx % reset_every == 0: 59 | model.reset() 60 | 61 | y.append(model(x[:, frame_idx])) 62 | 63 | y = torch.stack(y, dim=1) 64 | print(y.shape) 65 | 66 | 67 | def main(): 68 | """ 69 | Main function. 70 | The script will call two functions showcasing how the operator should be used 71 | in training (stateless, cumsum operation) and testing (stateful) within a backbone network. 72 | The reported example feeds random tensors and prints out the resulting shapes. 73 | """ 74 | model = SkipConv2d(**conv_ops).to(device) 75 | x = torch.rand(b, t, c, h, w).to(device) 76 | 77 | forward_train(model, x) 78 | forward_test(model, x) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skip-Conv 2 | 3 | This repository provides the Skip-Convolution module presented in. 4 | 5 | [Amirhossein Habibian](https://habibian.github.io/), 6 | [Davide Abati](https://davideabati.info/), 7 | [Taco S. Cohen](https://tacocohen.wordpress.com/), 8 | [Babak Ehteshami Bejnordi](http://babakint.com/), 9 | "Skip-Convolutions for Efficient Video Processing", CVPR 2021.[[arxiv]](https://arxiv.org/abs/2104.11487) 10 | 11 | Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc) 12 | 13 | ## Reference 14 | If you find our work useful for your research, please cite: 15 | ```latex 16 | @inproceedings{skipconv, 17 | title={Skip-Convolutions for Efficient Video Processing}, 18 | author={Habibian, Amirhossein and Abati, Davide and Cohen, Taco and Bejnordi, Babak Ehteshami}, 19 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 20 | year={2021} 21 | } 22 | ``` 23 | 24 | ## Method 25 | Skip-Convolutions allow to save computation whenever applying a 2D CNN to video frames. 26 | By decomposing the convolution on the current frame as a sum of the convolved past frame and the residual frame, we take advantage of the sparsity of the latter to reduce the amount of operations. 27 | ![! an image](resources/skipconv.png) 28 | We increase the savings by further sparsifying the residual. Given a residual frame or feature map, a gating module is queried for binary maps highlighting the locations where updated representations are needed. 29 | In the image, only blue locations are updated. In all other orange locations, no update is provided and the representations are copied from the previous timestep. 30 | 31 | This repository contains the implementation of both gating modules discussed in the paper. 32 | * **Norm Gates**: we apply a threshold epsilon to the norm of the residual. This strategy requires no training, and can be easily applied to any CNN to save computation. 33 | * **Gumbel Gate**: we feed the residual to a parametric function (a single convolution in our implementation). Gumbel gates require finetuning the network, but achieve better performances with respect to Norm Gates. 34 | 35 | ## Install 36 | Code has been tested with `python3.6` and `pytorch==1.6.0`. 37 | ```bash 38 | conda create -n skipconv python=3.6 39 | conda activate skipconv 40 | conda install pytorch=1.6.0 41 | ``` 42 | 43 | 44 | ## Example use 45 | The file [`lib/skip_conv.py`](lib/skip_conv.py) contains the definition of the Skip Convolution class. 46 | 47 | The file [`main.py`](main.py) highlights how to feed a SkipConv model during training (stateless, feed temporal clips) and inference (stateful, recursively feed single frames). 48 | ```` 49 | python3.6 main.py 50 | ```` 51 | The script will call two functions showcasing how the operator should be used in training and testing within a backbone network. As such, the reported example feeds random tensors and prints out the resulting shapes. 52 | 53 | 54 | To instantiate the different versions of the operator described in the paper, you can change the arguments in `main.conv_ops`. 55 | Specifically, you can experiment with 56 | * `gate_type` (`GateType.NORM_GATE` or `GateType.GUMBEL_GATE`): selects among the two types of gates available, namely Norm-gates and Gumbel-gates. 57 | * `norm_gate_type` (`NormGateType.INPUT` or `NormGateType.OUTPUT`): when Norm-gates are selected, selects among Input-norm or Output-norm gates. 58 | * `norm_gate_eps`: the threshold applied to Norm-gates. Only relevant if Norm-gates are selected. 59 | * `gumbel_gate_structure`: The dimension of the structuring element of Gumbel-gates. Can be 1, 2, 4 or 8. -------------------------------------------------------------------------------- /lib/gumbel_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | # All Rights Reserved. 4 | 5 | # ============================================================================ 6 | # @@-COPYRIGHT-START-@@ 7 | # 8 | # Adapted and modified from the code by Andreas Veit: 9 | # https://github.com/andreasveit/convnet-aig/blob/master/gumbelmodule.py 10 | # Gumbel Softmax Sampler 11 | # Works for categorical and binary input 12 | # 13 | # BSD 3-Clause License 14 | # 15 | # Copyright (c) 2018, Andreas Veit 16 | # All rights reserved. 17 | # 18 | # Redistribution and use in source and binary forms, with or without 19 | # modification, are permitted provided that the following conditions are met: 20 | # 21 | # * Redistributions of source code must retain the above copyright notice, this 22 | # list of conditions and the following disclaimer. 23 | # 24 | # * Redistributions in binary form must reproduce the above copyright notice, 25 | # this list of conditions and the following disclaimer in the documentation 26 | # and/or other materials provided with the distribution. 27 | # 28 | # * Neither the name of the copyright holder nor the names of its 29 | # contributors may be used to endorse or promote products derived from 30 | # this software without specific prior written permission. 31 | # 32 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 33 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 34 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 36 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 37 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 38 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 39 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 40 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 41 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | # 43 | # @@-COPYRIGHT-END-@@ 44 | # ============================================================================ 45 | 46 | 47 | import torch 48 | import torch.nn as nn 49 | 50 | 51 | class HardSoftmax(torch.autograd.Function): 52 | @staticmethod 53 | def forward(ctx, input): 54 | y_hard = input.clone() 55 | y_hard = y_hard.zero_() 56 | y_hard[input >= 0.5] = 1 57 | 58 | return y_hard 59 | 60 | @staticmethod 61 | def backward(ctx, grad_output): 62 | return grad_output, None 63 | 64 | 65 | class GumbelSigmoid(torch.nn.Module): 66 | def __init__(self): 67 | """ 68 | Implementation of gumbel softmax for a binary case using gumbel sigmoid. 69 | """ 70 | super(GumbelSigmoid, self).__init__() 71 | self.sigmoid = nn.Sigmoid() 72 | 73 | def sample_gumbel_like(self, template_tensor, eps=1e-10): 74 | uniform_samples_tensor = template_tensor.clone().uniform_() 75 | gumbel_samples_tensor = -torch.log( 76 | eps - torch.log(uniform_samples_tensor + eps) 77 | ) 78 | 79 | return gumbel_samples_tensor 80 | 81 | def gumbel_sigmoid_sample(self, logits, temperature, inference=False): 82 | """Adds noise to the logits and takes the sigmoid. No Gumbel noise during inference.""" 83 | if not inference: 84 | gumbel_samples_tensor = self.sample_gumbel_like(logits.data) 85 | gumbel_trick_log_prob_samples = logits + gumbel_samples_tensor.data 86 | else: 87 | gumbel_trick_log_prob_samples = logits 88 | soft_samples = self.sigmoid(gumbel_trick_log_prob_samples / temperature) 89 | 90 | return soft_samples 91 | 92 | def gumbel_sigmoid(self, logits, temperature=2 / 3, hard=False, inference=False): 93 | out = self.gumbel_sigmoid_sample(logits, temperature, inference) 94 | if hard: 95 | out = HardSoftmax.apply(out) 96 | 97 | return out 98 | 99 | def forward(self, logits, force_hard=False, temperature=2 / 3): 100 | inference = not self.training 101 | 102 | if self.training and not force_hard: 103 | return self.gumbel_sigmoid( 104 | logits, temperature=temperature, hard=False, inference=inference 105 | ) 106 | else: 107 | return self.gumbel_sigmoid( 108 | logits, temperature=temperature, hard=True, inference=inference 109 | ) 110 | -------------------------------------------------------------------------------- /lib/gates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | # All Rights Reserved. 4 | 5 | from enum import Enum 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn.common_types import _size_2_t 11 | 12 | from lib.gumbel_module import GumbelSigmoid 13 | 14 | 15 | class GateType(Enum): 16 | NORM_GATE = "norm-gate" 17 | GUMBEL_GATE = "gumbel-gate" 18 | 19 | 20 | class NormGateType(Enum): 21 | INPUT = "input" 22 | OUTPUT = "output" 23 | 24 | 25 | class NormGate(nn.Module): 26 | def __init__( 27 | self, 28 | kernel_size: _size_2_t, 29 | stride: _size_2_t, 30 | padding: _size_2_t, 31 | type: NormGateType = NormGateType.OUTPUT, 32 | eps: float = 1e-1, 33 | norm: int = 1, 34 | ): 35 | super().__init__() 36 | self.kernel_size = kernel_size 37 | self.stride = stride 38 | self.padding = padding 39 | self.type = type 40 | self.norm = float(norm) 41 | self.eps = eps 42 | self.qbin = 1e5 43 | 44 | def forward(self, r: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 45 | if self.training: 46 | return self._forward_train(r, w) 47 | else: 48 | return self._forward_test(r, w) 49 | 50 | def _forward_train(self, r, w): 51 | n, c, h, w = r.shape 52 | return torch.ones(n, 1, h, w).to(r.device) # no gating during training 53 | 54 | def _forward_test(self, r: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 55 | if self.type is NormGateType.INPUT: 56 | return self._forward_input(r) 57 | elif self.type is NormGateType.OUTPUT: 58 | return self._forward_output(r, w) 59 | raise ValueError 60 | 61 | def _forward_input(self, r: torch.Tensor) -> torch.Tensor: 62 | """Input norm gates, Eq (5)""" 63 | ri = F.avg_pool2d(r.abs(), self.kernel_size, self.stride, self.padding) 64 | ri_norm = torch.norm(ri, p=self.norm, dim=1, keepdim=True) / ri.size(2) 65 | ri_out_discrete = (ri_norm * self.qbin).floor() / self.qbin 66 | 67 | return (torch.sign(ri_out_discrete - self.eps) + 1) / 2 68 | 69 | def _forward_output(self, r: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 70 | """Output norm gates, Eq (7)""" 71 | ri = F.avg_pool2d(r.abs(), self.kernel_size, self.stride, self.padding) 72 | ri_norm = torch.norm(ri, p=self.norm, dim=1, keepdim=True) / ri.size(2) 73 | w_norm = torch.norm(w, p=self.norm) / w.numel() 74 | ri_out = ri_norm * w_norm 75 | ri_out_discrete = (ri_out * self.qbin).floor() / self.qbin 76 | 77 | return (torch.sign(ri_out_discrete - self.eps) + 1) / 2 78 | 79 | def get_mac(self, r: torch.Tensor, g: torch.Tensor) -> int: 80 | return 0 81 | 82 | 83 | class GumbelGate(nn.Module): 84 | def __init__( 85 | self, 86 | in_channels: int, 87 | kernel_size: _size_2_t, 88 | stride: _size_2_t = 1, 89 | padding: _size_2_t = 0, 90 | structure: int = 1, 91 | ): 92 | super(GumbelGate, self).__init__() 93 | self.gs = GumbelSigmoid() 94 | self.structure = structure 95 | self.in_channels = in_channels 96 | self.kernel_size = kernel_size 97 | 98 | gating_layers = [ 99 | nn.Conv2d( 100 | in_channels, 1, kernel_size=kernel_size, stride=stride, padding=padding 101 | ), 102 | ] 103 | 104 | assert self.structure in [1, 2, 4, 8] 105 | if self.structure == 1: 106 | structure_layers = [] 107 | else: 108 | structure_layers = [ 109 | nn.MaxPool2d(kernel_size=self.structure, stride=self.structure), 110 | nn.UpsamplingNearest2d(scale_factor=self.structure), 111 | ] 112 | 113 | self.gate_network = nn.Sequential(*gating_layers, *structure_layers) 114 | 115 | self.init_weights() 116 | 117 | def init_weights(self, gate_bias_init: float = 0.6) -> None: 118 | conv = self.gate_network[0] 119 | torch.nn.init.xavier_uniform_(conv.weight) 120 | conv.bias.data.fill_(gate_bias_init) 121 | 122 | def forward(self, gate_inp: torch.Tensor, _: torch.Tensor) -> torch.Tensor: 123 | """Gumbel gates, Eq (8)""" 124 | pi_log = self.gate_network(gate_inp) 125 | return self.gs(pi_log, force_hard=True) 126 | 127 | def get_mac(self, r: torch.Tensor, g: torch.Tensor) -> int: 128 | n, c_in, h_in, w_in = r.shape 129 | n, _, h_out, w_out = g.shape 130 | if isinstance(self.kernel_size, tuple): 131 | k_h, k_w = self.kernel_size 132 | else: 133 | k_h = k_w = self.kernel_size 134 | 135 | mac_gates = n * h_out * w_out * c_in * 1 * k_h * k_w 136 | 137 | return mac_gates 138 | -------------------------------------------------------------------------------- /lib/skip_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Qualcomm Technologies, Inc. 2 | 3 | # All Rights Reserved. 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.common_types import _size_2_t 8 | 9 | from lib import gates 10 | from lib.gates import GateType 11 | from lib.gates import NormGateType 12 | from lib.utils import roll_time 13 | from lib.utils import unroll_time 14 | 15 | 16 | class SkipConv2d(nn.Conv2d): 17 | def __init__( 18 | self, 19 | gate_type: GateType, 20 | in_channels: int, 21 | out_channels: int, 22 | kernel_size: _size_2_t, 23 | stride: _size_2_t = 1, 24 | padding: _size_2_t = 0, 25 | dilation: _size_2_t = 1, 26 | groups: int = 1, 27 | bias: bool = True, 28 | padding_mode: str = "zeros", 29 | norm_gate_type: NormGateType = NormGateType.OUTPUT, 30 | norm_gate_eps: float = 1e-1, 31 | gumbel_gate_structure: int = 1, 32 | ): 33 | super(SkipConv2d, self).__init__( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride, 38 | padding, 39 | dilation, 40 | groups, 41 | bias, 42 | padding_mode, 43 | ) 44 | 45 | self.gate_type = gate_type 46 | if self.gate_type is GateType.NORM_GATE: 47 | self.norm_gate_type = norm_gate_type 48 | self.norm_gate_eps = norm_gate_eps 49 | self.gate = gates.NormGate( 50 | kernel_size, stride, padding, self.norm_gate_type, self.norm_gate_eps 51 | ) 52 | elif self.gate_type is GateType.GUMBEL_GATE: 53 | self.gumbel_gate_structure = gumbel_gate_structure 54 | self.gate = gates.GumbelGate( 55 | in_channels, kernel_size, stride, padding, self.gumbel_gate_structure 56 | ) 57 | else: 58 | raise ValueError 59 | 60 | self.z0 = None 61 | self.x0 = None 62 | self.mac = None 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | if self.training: 66 | return self._forward_train(x) 67 | else: 68 | return self._forward_test(x) 69 | 70 | def _forward_train(self, x: torch.Tensor) -> torch.Tensor: 71 | assert x.dim() == 5 72 | t = x.shape[1] 73 | 74 | x0 = x[:, 0] 75 | z0 = super(SkipConv2d, self).forward(x0) 76 | 77 | r = roll_time(x[:, 1:] - x[:, :-1]) 78 | zr = super(SkipConv2d, self).forward(r) 79 | g = self.gate(r.abs(), self.weight) 80 | zr = zr * g 81 | 82 | z0 = unroll_time(z0, t=1) 83 | zr = unroll_time(zr, t=t - 1) 84 | z = torch.cat((z0, zr), dim=1) 85 | z = z.cumsum(dim=1) 86 | 87 | self.mac = self._get_mac_train(r, z, g) 88 | return z 89 | 90 | def _forward_test(self, x: torch.Tensor) -> torch.Tensor: 91 | if self.x0 is None: 92 | z = super(SkipConv2d, self).forward(x) 93 | mac = self._get_mac_test_reference(z) 94 | else: 95 | x0, z0 = self.x0, self.z0 96 | r = x - x0 97 | g = self.gate(r.abs(), self.weight) 98 | zr = super(SkipConv2d, self).forward(r) 99 | zr = zr * g 100 | z = z0 + zr 101 | mac = self._get_mac_test_residual(r, z, g) 102 | 103 | self.x0 = x 104 | self.z0 = z 105 | self.mac = mac 106 | return z 107 | 108 | def reset(self) -> None: 109 | assert ( 110 | not self.training 111 | ), "reset() method should not be called in training mode." 112 | # Resets state, used for test. 113 | self.z0 = None 114 | self.x0 = None 115 | self.mac = None 116 | 117 | def eval(self): 118 | # Sets the model in evaluation mode, and also resets the state. 119 | ret = super(SkipConv2d, self).eval() 120 | self.reset() 121 | return ret 122 | 123 | # ----------------------- 124 | # MAC computing functions 125 | # ----------------------- 126 | 127 | def _get_mac_train(self, r: torch.Tensor, z: torch.Tensor, g: torch.Tensor) -> int: 128 | n, t, c_out, h_out, w_out = z.shape 129 | _, c_in, k_h, k_w = self.weight.shape 130 | 131 | mac_ref = 1 * h_out * w_out * c_in * c_out * k_h * k_w 132 | mac_res = g.sum().item() * c_in * c_out * k_h * k_w 133 | mac_gat = self.gate.get_mac(r, g) 134 | return mac_ref + mac_res + mac_gat 135 | 136 | def _get_mac_test_reference(self, z: torch.Tensor) -> int: 137 | n, c_out, h_out, w_out = z.shape 138 | _, c_in, k_h, k_w = self.weight.shape 139 | 140 | assert n == 1 141 | mac_ref = n * h_out * w_out * c_in * c_out * k_h * k_w 142 | return mac_ref 143 | 144 | def _get_mac_test_residual( 145 | self, r: torch.Tensor, z: torch.Tensor, g: torch.Tensor 146 | ) -> int: 147 | _, C_out, Hout, Wout = z.shape 148 | _, C_in, Kh, Kw = self.weight.shape 149 | 150 | mac_res = g.sum().item() * C_in * C_out * Kh * Kw 151 | mac_gat = self.gate.get_mac(r, g) 152 | return mac_res + mac_gat 153 | --------------------------------------------------------------------------------