├── .gitignore ├── README.md ├── Test.py └── time_frequence.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.npy 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_audio 2 | audio processing module for pytorch 3 | 4 | ### FFT 5 | `time_frequence.fft()` 6 | ### IFFT 7 | `time_frequence.ifft()` 8 | ### STFT 9 | `time_frequence.stft()` 10 | ### ISTFT 11 | `time_frequence.istft()` 12 | -------------------------------------------------------------------------------- /Test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | import time_frequence as tf 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | import librosa 10 | 11 | 12 | def CalSNR(ref, sig): 13 | ref_p = np.mean(np.square(ref)) 14 | noi_p = np.mean(np.square(sig - ref)) 15 | return 10 * (np.log10(ref_p) - np.log10(noi_p)) 16 | 17 | 18 | class TimeFrequencyTestCase(unittest.TestCase): 19 | 20 | 21 | ''' 22 | def test_ifft(self): 23 | print('\n#########TESTING IFFT##########') 24 | N = 1024 25 | signal = np.random.random(N) 26 | input_ = np.fft.fft(signal, n=N) 27 | ac = Variable(torch.from_numpy(np.real(input_[0]) * np.ones((1, N, 1, 1))).float()) 28 | input_ = np.reshape(input_[1:N//2+1], (1, 1, N//2, 1)) 29 | input_real = Variable(torch.from_numpy(np.real(input_)).float()) 30 | input_imag = Variable(torch.from_numpy(np.imag(input_)).float()) 31 | 32 | 33 | 34 | model = tf.ifft(n_fft=N) 35 | output = model.forward(input_real, input_imag, ac).data.numpy().flatten() 36 | snr = CalSNR(signal, output) 37 | print("SNR:{} dB".format(snr)) 38 | self.assertTrue(snr > 60) 39 | 40 | print('#########IFFT TESTED##########\n') 41 | 42 | 43 | ''' 44 | 45 | def test_istft(self): 46 | 47 | print("###########TESTING ISTFT###########") 48 | signal = np.random.random(1016 * 1024) 49 | spec = librosa.stft(signal, n_fft=1024, hop_length=512, center=False) 50 | magn = np.real(spec)[np.newaxis, np.newaxis, :, :] 51 | phase = np.imag(spec)[np.newaxis, np.newaxis, :, :] 52 | 53 | ac = magn[:, :, 0, :] 54 | magn = magn[:, :, 1:, :] 55 | phase = phase[:, :, 1:, :] 56 | 57 | magn = Variable(torch.from_numpy(magn).float()) 58 | phase = Variable(torch.from_numpy(phase).float()) 59 | ac = Variable(torch.from_numpy(ac).float()) 60 | model = tf.istft(1024, 512) 61 | re_signal = model.forward(magn, phase, ac).data.numpy().flatten() 62 | 63 | snr = CalSNR(signal[1024:-1024], re_signal[1024:-1024]) 64 | print("SNR:{} dB".format(snr)) 65 | self.assertTrue(snr > 60) 66 | 67 | print("###########ISTFT TESTED###########\n") 68 | 69 | 70 | def test_stft(self): 71 | print("\n###########TESTING STFT###########") 72 | 73 | N = 1024 74 | signal = np.random.random(1016 * N) 75 | input = Variable(torch.from_numpy(signal[np.newaxis, :]).float()) 76 | stft_model = tf.stft() 77 | istft_model = tf.istft() 78 | 79 | magn, phase, ac = stft_model(input) 80 | 81 | re_signal = istft_model.forward(magn, phase, ac) 82 | re_signal = re_signal.data.numpy().flatten() 83 | 84 | 85 | 86 | snr = CalSNR(signal[N:-N], re_signal[N:-N]) 87 | print("SNR:{} dB".format(snr)) 88 | 89 | 90 | self.assertTrue(snr > 60) 91 | print("###########STFT TESTED###########\n") 92 | 93 | 94 | if __name__ == '__main__': 95 | unittest.main() 96 | 97 | -------------------------------------------------------------------------------- /time_frequence.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import scipy.signal 9 | 10 | 11 | class ifft(nn.Module): 12 | def __init__(self, nfft=1024): 13 | super(ifft, self).__init__() 14 | assert nfft % 2 == 0 15 | self.nfft = int(nfft) 16 | self.n_freq = n_freq = int(nfft / 2) 17 | real_kernels, imag_kernels, self.ac_cof = _get_ifft_kernels(nfft) 18 | self.real_conv = nn.Conv2d(1, nfft, (n_freq, 1), stride=1, padding=0, bias=False) 19 | self.imag_conv = nn.Conv2d(1, nfft, (n_freq, 1), stride=1, padding=0, bias=False) 20 | 21 | self.real_conv.weight.data.copy_(real_kernels) 22 | self.imag_conv.weight.data.copy_(imag_kernels) 23 | self.real_model = nn.Sequential(self.real_conv) 24 | self.imag_model = nn.Sequential(self.imag_conv) 25 | def forward(self, magn, phase, ac=None): 26 | assert magn.size()[2] == phase.size()[2] == self.n_freq 27 | output = self.real_model(magn) - self.imag_model(phase) 28 | if ac is not None: 29 | output = output + ac * self.ac_cof 30 | return output / self.nfft 31 | 32 | 33 | 34 | 35 | 36 | 37 | def _get_ifft_kernels(nfft): 38 | nfft = int(nfft) 39 | assert nfft % 2 == 0 40 | def kernel_fn(time, freq): 41 | return np.exp(1j * (2 * np.pi * time * freq) / 1024.) 42 | 43 | kernels = np.fromfunction(kernel_fn, (int(nfft), int(nfft/2+1)), dtype=np.float64) 44 | 45 | kernels = np.zeros((1024, 513)) * 1j 46 | 47 | ''' 48 | for i in range(1024): 49 | for j in range(513): 50 | kernels[i, j] = kernel_fn(i, j) 51 | ''' 52 | 53 | ac_cof = float(np.real(kernels[0, 0])) 54 | 55 | kernels = 2 * kernels[:, 1:] 56 | kernels[:, -1] = kernels[:, -1] / 2.0 57 | 58 | real_kernels = np.real(kernels) 59 | imag_kernels = np.imag(kernels) 60 | 61 | real_kernels = torch.from_numpy(real_kernels[:, np.newaxis, :, np.newaxis]) 62 | imag_kernels = torch.from_numpy(imag_kernels[:, np.newaxis, :, np.newaxis]) 63 | return real_kernels, imag_kernels, ac_cof 64 | 65 | 66 | 67 | 68 | class istft(nn.Module): 69 | def __init__(self, nfft=1024, hop_length=512): 70 | super(istft, self).__init__() 71 | assert nfft % 2 == 0 72 | assert hop_length <= nfft 73 | self.hop_length = hop_length 74 | 75 | self.nfft = int(nfft) 76 | self.n_freq = n_freq = int(nfft / 2) 77 | self.real_kernels, self.imag_kernels, self.ac_cof = _get_istft_kernels(nfft) 78 | trans_kernels = np.zeros((nfft, nfft), np.float64) 79 | np.fill_diagonal(trans_kernels, np.ones((nfft, ), dtype=np.float64)) 80 | self.trans_kernels = nn.Parameter(torch.from_numpy(trans_kernels[:, np.newaxis, np.newaxis, :]).float()) 81 | 82 | def forward(self, magn, phase, ac): 83 | ''' 84 | batch None frequency frame 85 | ''' 86 | assert magn.size()[2] == phase.size()[2] == self.n_freq 87 | nfft = self.nfft 88 | hop = self.hop_length 89 | 90 | # complex conjugate 91 | phase = -1. * phase 92 | real_part = F.conv2d(magn, self.real_kernels) 93 | imag_part = F.conv2d(phase, self.imag_kernels) 94 | 95 | output = real_part - imag_part 96 | 97 | 98 | ac = ac.unsqueeze(1) 99 | ac = float(self.ac_cof) * ac.expand_as(output) 100 | output = output + ac 101 | output = output / float(self.nfft) 102 | 103 | output = F.conv_transpose2d(output, self.trans_kernels, stride=self.hop_length) 104 | output = output.squeeze(1) 105 | output = output.squeeze(1) 106 | return output 107 | 108 | def _get_istft_kernels(nfft): 109 | nfft = int(nfft) 110 | assert nfft % 2 == 0 111 | def kernel_fn(time, freq): 112 | return np.exp(1j * (2 * np.pi * time * freq) / nfft) 113 | kernels = np.fromfunction(kernel_fn, (int(nfft), int(nfft/2+1)), dtype=np.float64) 114 | 115 | ac_cof = float(np.real(kernels[0, 0])) 116 | 117 | kernels = 2 * kernels[:, 1:] 118 | kernels[:, -1] = kernels[:, -1] / 2.0 119 | 120 | real_kernels = np.real(kernels) 121 | imag_kernels = np.imag(kernels) 122 | 123 | real_kernels = nn.Parameter(torch.from_numpy(real_kernels[:, np.newaxis, :, np.newaxis]).float()) 124 | imag_kernels = nn.Parameter(torch.from_numpy(imag_kernels[:, np.newaxis, :, np.newaxis]).float()) 125 | return real_kernels, imag_kernels, ac_cof 126 | 127 | 128 | 129 | class stft(nn.Module): 130 | def __init__(self, nfft=1024, hop_length=512, window="hanning"): 131 | super(stft, self).__init__() 132 | assert nfft % 2 == 0 133 | 134 | self.hop_length = hop_length 135 | self.n_freq = n_freq = nfft//2 + 1 136 | 137 | self.real_kernels, self.imag_kernels = _get_stft_kernels(nfft, window) 138 | 139 | def forward(self, sample): 140 | sample = sample.unsqueeze(1) 141 | sample = sample.unsqueeze(1) 142 | 143 | magn = F.conv2d(sample, self.real_kernels, stride=self.hop_length) 144 | phase = F.conv2d(sample, self.imag_kernels, stride=self.hop_length) 145 | 146 | magn = magn.permute(0, 2, 1, 3) 147 | phase = phase.permute(0, 2, 1, 3) 148 | 149 | # complex conjugate 150 | phase = -1 * phase[:,:,1:,:] 151 | ac = magn[:,:,0,:] 152 | magn = magn[:,:,1:,:] 153 | return magn, phase, ac 154 | 155 | 156 | def _get_stft_kernels(nfft, window): 157 | nfft = int(nfft) 158 | assert nfft % 2 == 0 159 | 160 | def kernel_fn(freq, time): 161 | return np.exp(-1j * (2 * np.pi * time * freq) / float(nfft)) 162 | 163 | kernels = np.fromfunction(kernel_fn, (nfft//2+1, nfft), dtype=np.float64) 164 | 165 | if window == "hanning": 166 | win_cof = scipy.signal.get_window("hanning", nfft)[np.newaxis, :] 167 | else: 168 | win_cof = np.ones((1, nfft), dtype=np.float64) 169 | 170 | kernels = kernels[:, np.newaxis, np.newaxis, :] * win_cof 171 | 172 | real_kernels = nn.Parameter(torch.from_numpy(np.real(kernels)).float()) 173 | imag_kernels = nn.Parameter(torch.from_numpy(np.imag(kernels)).float()) 174 | 175 | return real_kernels, imag_kernels 176 | 177 | 178 | 179 | if __name__ == "__main__": 180 | signal = np.random.rand(1024 * 10) 181 | signal = signal - np.mean(signal) 182 | signal = signal[np.newaxis, :] 183 | model = stft(window="retangle") 184 | real, imag, ac = model.forward(Variable(torch.from_numpy(signal).float())) 185 | real = real.data.numpy() 186 | imag = imag.data.numpy() 187 | ac = ac.data.numpy() 188 | print(ac) 189 | 190 | 191 | --------------------------------------------------------------------------------