├── README.md └── octconv.py /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch implementation of Octave Convolution 2 | This is an implementation of the paper [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution](https://arxiv.org/abs/1904.05049). Works with version 1.0. 3 | 4 | -------------------------------------------------------------------------------- /octconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | 5 | class OctConv(torch.nn.Module): 6 | """ 7 | This module implements the OctConv paper https://arxiv.org/pdf/1904.05049v1.pdf 8 | """ 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, alpha_in=0.5, alpha_out=0.5): 10 | super(OctConv, self).__init__() 11 | self.alpha_in, self.alpha_out, self.kernel_size = alpha_in, alpha_out, kernel_size 12 | self.H2H, self.L2L, self.H2L, self.L2H = None, None, None, None 13 | if not (alpha_in == 0.0 and alpha_out == 0.0): 14 | self.L2L = torch.nn.Conv2d(int(alpha_in * in_channels), 15 | int(alpha_out * out_channels), 16 | kernel_size, stride, kernel_size//2) 17 | if not (alpha_in == 0.0 and alpha_out == 1.0): 18 | self.L2H = torch.nn.Conv2d(int(alpha_in * in_channels), 19 | out_channels - int(alpha_out * out_channels), 20 | kernel_size, stride, kernel_size//2) 21 | if not (alpha_in == 1.0 and alpha_out == 0.0): 22 | self.H2L = torch.nn.Conv2d(in_channels - int(alpha_in * in_channels), 23 | int(alpha_out * out_channels), 24 | kernel_size, stride, kernel_size//2) 25 | if not (alpha_in == 1.0 and alpha_out == 1.0): 26 | self.H2H = torch.nn.Conv2d(in_channels - int(alpha_in * in_channels), 27 | out_channels - int(alpha_out * out_channels), 28 | kernel_size, stride, kernel_size//2) 29 | self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') 30 | self.avg_pool = partial(torch.nn.functional.avg_pool2d, kernel_size=kernel_size, stride=kernel_size) 31 | 32 | def forward(self, x): 33 | hf, lf = x 34 | h2h, l2l, h2l, l2h = None, None, None, None 35 | if self.H2H is not None: 36 | h2h = self.H2H(hf) 37 | if self.L2L is not None: 38 | l2l = self.L2L(lf) 39 | if self.H2L is not None: 40 | h2l = self.H2L(self.avg_pool(hf)) 41 | if self.L2H is not None: 42 | l2h = self.upsample(self.L2H(lf)) 43 | hf_, lf_ = 0, 0 44 | for i in [h2h, l2h]: 45 | if i is not None: 46 | hf_ = hf_ + i 47 | for i in [l2l, h2l]: 48 | if i is not None: 49 | lf_ = lf_ + i 50 | return hf_, lf_ 51 | 52 | --------------------------------------------------------------------------------