├── README.md └── freq_aware_conv2d.py /README.md: -------------------------------------------------------------------------------- 1 | # frequency-aware-conv2d-layer-pytorch 2 | 3 | A Pytorch implementation of frequency-aware convolutional 2D layer. 4 | 5 | pytorch 1.1 6 | 7 | ## Usage 8 | 9 | Instead of `nn.Conv2d()`, use `FreqAwareConv2dLinearBiasOffset()`. 10 | The API of `FreqAwareConv2dLinearBiasOffset()` is same as that of `nn.Conv2d` as of pytorch 1.1. 11 | 12 | ## Reference 13 | 14 | [Acoustic Scene Classification and Audio Tagging with Receptive-Field-Regularized CNNs](https://www.researchgate.net/publication/334250606_CP-JKU_Submissions_to_DCASE'19_Acoustic_Scene_Classification_and_Audio_Tagging_with_Receptive-Field-Regularized_CNNs) by Khaled Koutini, Hamid Eghbal-Zadeh, and Gerhard Widmer, 2019, DCASE workshop. -------------------------------------------------------------------------------- /freq_aware_conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def _get_freq_map(min_freq, max_freq, num_freq, dtype=torch.float32): 7 | """Given params, it returns a frequency map. 8 | num_freq should be positive integer. 9 | """ 10 | if num_freq > 1: 11 | step = float(max_freq - min_freq) / (num_freq - 1) 12 | map = torch.arange(start=min_freq, 13 | end=max_freq + step, 14 | step=step, 15 | dtype=dtype) 16 | return torch.reshape(map, (1, 1, -1, 1)) 17 | elif num_freq == 1: 18 | return torch.tensor([float(max_freq + min_freq) / 2]).view([1, 1, -1, 1]) 19 | else: 20 | raise ValueError('num_freq should be positive but we got: {}'.format(num_freq)) 21 | 22 | 23 | class FreqAwareConv2dLinearBiasOffset(nn.Conv2d): 24 | """A modified conv2d layer that concats the frequency map along the channel axis. 25 | """ 26 | 27 | def __init__(self, in_channels, out_channels, kernel_size, 28 | stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros' 29 | ): 30 | super(FreqAwareConv2dLinearBiasOffset, self).__init__( 31 | in_channels=in_channels + 1, 32 | out_channels=out_channels, 33 | kernel_size=kernel_size, 34 | stride=stride, 35 | padding=padding, 36 | dilation=dilation, 37 | groups=groups, 38 | bias=bias, 39 | padding_mode=padding_mode) 40 | self.freq_map = None 41 | self.min_freq = 0.0 42 | self.max_freq = 1.0 43 | self.freq_axis = 2 44 | self.ch_axis = 1 # it follows torch convention, hence not allowing to change it. 45 | 46 | def forward(self, input: torch.tensor): 47 | """ 48 | 49 | Maybe the input shape is (batch, ch, freq, time), 50 | ..or whatever as long as the freq axis index == self.freq_axis 51 | 52 | Also assumes the input size is always the same so that it can reuse the 53 | same freq_map 54 | 55 | """ 56 | if self.freq_map is None: 57 | num_freq = input.shape[self.freq_axis] 58 | self.freq_map = _get_freq_map(self.min_freq, self.max_freq, num_freq, 59 | dtype=input.dtype).to(input.device) 60 | 61 | expand_shape = list(input.shape) 62 | expand_shape[self.ch_axis] = 1 63 | expanded_map = self.freq_map.expand(*expand_shape) 64 | 65 | input = torch.cat((input, expanded_map), 66 | dim=self.ch_axis) 67 | return super(FreqAwareConv2dLinearBiasOffset, self).forward(input) 68 | 69 | --------------------------------------------------------------------------------