├── README.md └── model ├── __init__.py ├── layer_unit.py ├── module.py ├── module_stft.py ├── network.py └── network_stft.py /README.md: -------------------------------------------------------------------------------- 1 | # DSDPRNN 2 | Implementation of Dual-Stream DPRNN (paper: Nonlinear Residual Echo Suppression Based on Dual-Stream DPRNN) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mo-yun/DSDPRNN/0430d87192a8d5fba3860e8115fc430086bd1652/model/__init__.py -------------------------------------------------------------------------------- /model/layer_unit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | EPS = 1e-8 5 | 6 | 7 | class DPRNN(nn.Module): 8 | def __init__(self, N, B, H, R, K, rnn_type, dropout=0, bidirectional=False): 9 | super().__init__() 10 | self.N, self.B, self.H, self.R, self.K = N, B, H, R, K 11 | self.rnn_type = rnn_type 12 | self.dropout = dropout 13 | self.bidirectional = bidirectional 14 | self.row_rnn = nn.ModuleList([]) 15 | self.col_rnn = nn.ModuleList([]) 16 | self.row_norm_x = nn.ModuleList([]) 17 | self.row_norm_y = nn.ModuleList([]) 18 | self.col_norm_x = nn.ModuleList([]) 19 | self.col_norm_y = nn.ModuleList([]) 20 | for i in range(R): 21 | self.row_rnn.append(ReferRNN(rnn_type, B, K, col=False, dropout=dropout, bidirectional=True)) 22 | self.col_rnn.append(ReferRNN(rnn_type, B, K, col=True, dropout=dropout, bidirectional=bidirectional)) 23 | if i < R - 1: 24 | self.row_norm_x.append(BlockGroupNorm(B, 2)) 25 | self.col_norm_x.append(BlockGroupNorm(B, 2)) 26 | self.row_norm_y.append(BlockGroupNorm(B, 2)) 27 | self.col_norm_y.append(BlockGroupNorm(B, 2)) 28 | self.output_x = nn.Sequential(nn.PReLU(), nn.Conv2d(B, B, 1)) 29 | 30 | def forward(self, xi, yi, inference=False): 31 | batch_size, B, T, K = xi.shape 32 | x = xi.permute(0, 2, 3, 1) 33 | y = yi.permute(0, 2, 3, 1) 34 | for i in range(self.R): 35 | mx, my = self.row_rnn[i](x, y) 36 | if i < self.R - 1: 37 | x = self.row_norm_x[i](mx + x) 38 | y = self.row_norm_y[i](my + y) 39 | else: 40 | x = mx + x 41 | y = my + y 42 | mx, my = self.col_rnn[i](x, y) 43 | if i < self.R - 1: 44 | x = self.col_norm_x[i](mx + x) 45 | y = self.col_norm_y[i](my + y) 46 | else: 47 | x = mx + x 48 | y = my + y 49 | 50 | x = x.permute(0, 3, 1, 2) 51 | x = self.output_x(x) 52 | return x 53 | 54 | 55 | class ReferRNN(nn.Module): 56 | def __init__(self, rnn_type, inputs_size, segment_size, col, dropout=0, bidirectional=False, online=False): 57 | super().__init__() 58 | self.rnn_type = rnn_type.upper() 59 | self.inputs_size = inputs_size 60 | self.segment_size = segment_size 61 | self.col = col 62 | self.dropout = dropout 63 | self.bidirectional = bidirectional 64 | self.online = online 65 | num_direction = 2 if bidirectional else 1 66 | assert inputs_size % num_direction == 0 67 | hidden_size = inputs_size // num_direction 68 | self.dpt = Dropout2d(dropout) 69 | self.rnn0 = getattr(nn, self.rnn_type)(inputs_size, hidden_size, batch_first=True, bidirectional=bidirectional) 70 | self.rnn1 = getattr(nn, self.rnn_type)(inputs_size, hidden_size, batch_first=True, bidirectional=bidirectional) 71 | self.msw0 = nn.Parameter(0.5 * torch.ones([1, 1, inputs_size])) 72 | self.msw1 = nn.Parameter(0.5 * torch.ones([1, 1, inputs_size])) 73 | if not self.col: 74 | self.projx = nn.Linear(segment_size * 2, segment_size) 75 | self.projy = nn.Linear(segment_size * 2, segment_size) 76 | else: 77 | self.projx = nn.Linear(inputs_size * 2, inputs_size) 78 | self.projy = nn.Linear(inputs_size * 2, inputs_size) 79 | self.hidden_states = None 80 | 81 | def reset_buffer(self): 82 | self.hidden_states = None 83 | 84 | def forward(self, xi, yi): 85 | self.rnn0.flatten_parameters() 86 | self.rnn1.flatten_parameters() 87 | batch_size, T, K, B = xi.size() 88 | if not self.col: 89 | mx = torch.reshape(xi, [batch_size * T, K, B]) 90 | my = torch.reshape(yi, [batch_size * T, K, B]) 91 | x, _ = self.rnn0(mx, None) 92 | y, _ = self.rnn1(my, None) 93 | x = torch.reshape(x, [batch_size, T, K, -1]) 94 | y = torch.reshape(y, [batch_size, T, K, -1]) 95 | else: 96 | mx = torch.transpose(xi, 1, 2) 97 | my = torch.transpose(yi, 1, 2) 98 | mx = torch.reshape(mx, [batch_size * K, T, B]) 99 | my = torch.reshape(my, [batch_size * K, T, B]) 100 | if self.online: 101 | if self.hidden_states is not None: 102 | hidden_states_0, hidden_states_1 = self.hidden_states 103 | else: 104 | hidden_states_0 = hidden_states_1 = None 105 | x, hidden_states_0 = self.rnn0(mx, hidden_states_0) 106 | y, hidden_states_1 = self.rnn1(my, hidden_states_1) 107 | self.hidden_states = (hidden_states_0, hidden_states_1) 108 | else: 109 | x, _ = self.rnn0(mx, None) 110 | y, _ = self.rnn1(my, None) 111 | x = torch.reshape(x, [batch_size, K, T, -1]) 112 | y = torch.reshape(y, [batch_size, K, T, -1]) 113 | x = torch.transpose(x, 1, 2) 114 | y = torch.transpose(y, 1, 2) 115 | x = self.dpt(x) 116 | y = self.dpt(y) 117 | mx = x + self.msw0 * y 118 | my = y + self.msw1 * x 119 | if not self.col: 120 | x = torch.cat([xi, mx], dim=2) 121 | y = torch.cat([yi, my], dim=2) 122 | x = torch.transpose(x, 3, 2) 123 | y = torch.transpose(y, 3, 2) 124 | x = self.projx(x) 125 | y = self.projy(y) 126 | x = torch.transpose(x, 3, 2) 127 | y = torch.transpose(y, 3, 2) 128 | else: 129 | x = torch.cat([xi, mx], dim=3) 130 | y = torch.cat([yi, my], dim=3) 131 | x = self.projx(x) 132 | y = self.projy(y) 133 | return x, y 134 | 135 | 136 | class BlockGroupNorm(nn.Module): 137 | ''' 138 | axis -1 should be the time index, axis 1 should be the channel index.\\ 139 | time delay: the block step. 140 | ''' 141 | 142 | def __init__(self, in_channels, num_groups, elementwise_affine=True, pow_para=0.5, online=False): 143 | super().__init__() 144 | assert in_channels % num_groups == 0, 'in_channels: {}, num_groups: {}'.format(in_channels, num_groups) 145 | self.in_channels = in_channels 146 | self.num_groups = num_groups 147 | self.elementwise_affine = elementwise_affine 148 | self.pow_para = pow_para 149 | self.online = online 150 | if self.elementwise_affine: 151 | self.gamma = nn.Parameter(torch.ones([1, 1, 1, in_channels])) # [1, N, 1] 152 | self.beta = nn.Parameter(torch.zeros([1, 1, 1, in_channels])) # [1, N, 1] 153 | else: 154 | self.register_parameter('gamma', None) 155 | self.register_parameter('bias', None) 156 | self.reset_parameters() 157 | 158 | def reset_parameters(self): 159 | if self.elementwise_affine: 160 | self.gamma.data.fill_(1) 161 | self.beta.data.zero_() 162 | 163 | def forward(self, inputs): 164 | batch_size, T, K, C = inputs.shape 165 | x = torch.reshape(inputs, (batch_size, T, K, self.num_groups, C // self.num_groups)) 166 | variances, means = torch.var_mean(x, dim=[2, 4], keepdim=True, unbiased=False) 167 | x = (x - means) / torch.pow(variances + EPS, self.pow_para) 168 | x = torch.reshape(x, [batch_size, T, K, C]) 169 | if self.elementwise_affine: 170 | x = x * self.gamma + self.beta 171 | return x 172 | 173 | 174 | class Dropout2d(nn.Dropout2d): 175 | def forward(self, x): 176 | x = x.contiguous() 177 | x = x.permute(0, 3, 1, 2) 178 | x = super().forward(x) 179 | x = x.permute(0, 2, 3, 1) 180 | return x 181 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | from .layer_unit import DPRNN 6 | alpha = 0.25 7 | EPS = 1e-8 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, N, L, stride=None, activation='relu'): 12 | super(Encoder, self).__init__() 13 | self.L, self.N = L, N 14 | self.activation = activation 15 | if activation is None: 16 | self.act = lambda x: x 17 | else: 18 | self.act = getattr(torch, activation.lower()) 19 | if stride is None: 20 | self.stride = L // 2 21 | else: 22 | self.stride = stride 23 | self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=self.stride, bias=False) 24 | 25 | def pad_segment(self, inputs): 26 | r = self.stride - (inputs.size(-1) - self.L) % self.stride 27 | if r > 0: 28 | x = F.pad(inputs, [0, r]) 29 | return x 30 | 31 | def forward(self, mixture): 32 | mixture = self.pad_segment(mixture) 33 | mixture = torch.unsqueeze(mixture, 1) 34 | mixture_w = self.conv1d_U(mixture) 35 | mixture_w = self.act(mixture_w) 36 | return mixture_w 37 | 38 | 39 | class Decoder(nn.Module): 40 | def __init__(self, N, L, stride=None): 41 | super(Decoder, self).__init__() 42 | self.N, self.L = N, L 43 | if stride is None: 44 | self.stride = L // 2 45 | else: 46 | self.stride = stride 47 | self.ConvTrans = nn.ConvTranspose1d(N, 1, kernel_size=L, stride=self.stride, bias=True) 48 | 49 | def forward(self, outputs_en): 50 | est_source = self.ConvTrans(outputs_en) 51 | est_source = torch.squeeze(est_source, 1) 52 | return est_source 53 | 54 | 55 | class DPRNN_ME(nn.Module): 56 | def __init__(self, N, B, H, R, K, rnn_type='LSTM', dropout=0, bidirectional=False, mask_nonlinear='relu'): 57 | super().__init__() 58 | self.N, self.B, self.H, self.R, self.K = N, B, H, R, K 59 | self.rnn_type = rnn_type 60 | self.dropout = dropout 61 | self.bidirectional = bidirectional 62 | self.stride = self.K // 2 63 | self.conv_noisy = nn.Conv1d(N, B, 1, bias=False) 64 | self.conv_refer = nn.Conv1d(N, B, 1, bias=False) 65 | self.dprnn = DPRNN(N, B, H, R, K, rnn_type, dropout=dropout, bidirectional=bidirectional) 66 | self.output_audio = Output(B, N, mask_nonlinear) 67 | 68 | def pad_segment(self, inputs): 69 | r = (inputs.size(2) - self.K) % self.stride 70 | if r > 0: 71 | r = self.stride - r 72 | inputs = F.pad(inputs, [0, r]) 73 | return inputs 74 | 75 | def split_feature(self, inputs): 76 | x = self.pad_segment(inputs) 77 | x = x.unfold(2, self.K, self.stride) 78 | return x 79 | 80 | def merge_feature(self, inputs): 81 | x = self.overlap_and_add(inputs, self.stride) 82 | return x 83 | 84 | @staticmethod 85 | def overlap_and_add(signal, frame_step): 86 | """Reconstructs a signal from a framed representation. 87 | 88 | Adds potentially overlapping frames of a signal with shape 89 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 90 | The resulting tensor has shape `[..., output_size]` where 91 | 92 | output_size = (frames - 1) * frame_step + frame_length 93 | 94 | Args: 95 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. 96 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. 97 | 98 | Returns: 99 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. 100 | output_size = (frames - 1) * frame_step + frame_length 101 | 102 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py 103 | """ 104 | outer_dimensions = signal.size()[:-2] 105 | frames, frame_length = signal.size()[-2:] 106 | 107 | subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor 108 | subframe_step = frame_step // subframe_length 109 | subframes_per_frame = frame_length // subframe_length 110 | output_size = frame_step * (frames - 1) + frame_length 111 | output_subframes = output_size // subframe_length 112 | 113 | subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) 114 | 115 | frame = torch.arange(0, output_subframes, dtype=torch.int64, device=signal.device).unfold(0, subframes_per_frame, subframe_step) 116 | frame = frame.contiguous().view(-1) 117 | 118 | result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) 119 | result.index_add_(-2, frame, subframe_signal) 120 | result = result.view(*outer_dimensions, -1) 121 | return result 122 | 123 | def forward(self, noisy, refer, inference=False): 124 | length = noisy.size(2) 125 | x = noisy 126 | y = refer 127 | x = self.conv_noisy(x) 128 | y = self.conv_refer(y) 129 | x = self.split_feature(x) 130 | y = self.split_feature(y) 131 | x = self.dprnn(x, y, inference) 132 | # batch_size, B, T, K = x.shape 133 | x = self.merge_feature(x) 134 | x = x[:, :, :length] 135 | x = self.output_audio(x) 136 | return x 137 | 138 | 139 | class Output(nn.Module): 140 | def __init__(self, B, N, mask_nonlinear='relu'): 141 | super().__init__() 142 | self.B = B 143 | self.N = N 144 | mask_nonlinear = mask_nonlinear.lower() 145 | if mask_nonlinear == 'relu': 146 | self.mask = torch.relu 147 | elif mask_nonlinear == 'sigmoid': 148 | self.mask = torch.sigmoid 149 | elif mask_nonlinear == 'linear': 150 | self.mask = lambda x: x 151 | elif mask_nonlinear == 'lrelu': 152 | self.mask = nn.LeakyReLU() 153 | elif mask_nonlinear == 'elu': 154 | self.elu = nn.ELU() 155 | self.mask = lambda x: self.elu(x) + 1 156 | else: 157 | raise ValueError 158 | self.prelu = nn.PReLU() 159 | self.mask_conv = nn.Conv1d(B, N, 1) 160 | 161 | def forward(self, inputs): 162 | x = self.prelu(inputs) 163 | x = self.mask_conv(x) 164 | x = self.mask(x) 165 | return x 166 | -------------------------------------------------------------------------------- /model/module_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .layer_unit import DPRNN 5 | alpha = 0.25 6 | EPS = 1e-8 7 | 8 | 9 | def complexMulti(a, b): 10 | c1 = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] 11 | c2 = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] 12 | c = torch.stack((c1, c2), -1) 13 | return c 14 | 15 | 16 | class Encoder(nn.Module): 17 | def __init__(self, N, L, stride=None): 18 | super(Encoder, self).__init__() 19 | self.L, self.N = L, N 20 | if stride is None: 21 | self.stride = L // 2 22 | else: 23 | self.stride = stride 24 | self.window = nn.Parameter(torch.hamming_window(self.L), requires_grad=False) 25 | 26 | def pad_segment(self, inputs): 27 | r = self.stride - (inputs.size(-1) - self.L) % self.stride 28 | if r > 0: 29 | x = F.pad(inputs, [0, r]) 30 | return x 31 | 32 | def stft(self, inputs): 33 | return torch.stft(inputs, self.L, self.stride, window=self.window, center=False, return_complex=False) 34 | 35 | def forward(self, mixture): 36 | mixture = self.pad_segment(mixture) 37 | x = self.stft(mixture) 38 | x = torch.transpose(x, 1, 3) 39 | return x 40 | 41 | 42 | class Decoder(nn.Module): 43 | def __init__(self, N, L, stride=None): 44 | super(Decoder, self).__init__() 45 | self.N, self.L = N, L 46 | if stride is None: 47 | self.stride = L // 2 48 | else: 49 | self.stride = stride 50 | self.conv_A = nn.ConvTranspose2d(N, 1, kernel_size=[1, 5], stride=[1, 2]) 51 | self.conv_P = nn.ConvTranspose2d(N, 2, kernel_size=[1, 5], stride=[1, 2]) 52 | self.window = nn.Parameter(torch.hamming_window(self.L), requires_grad=False) 53 | self.kernel_len = self.L 54 | 55 | def istft(self, inputs): 56 | length = (inputs.size(2) - 1) * self.stride + self.kernel_len 57 | return torch.istft(inputs, self.L, self.stride, window=self.window, center=False, length=length, return_complex=False) 58 | 59 | def forward(self, inputs_en, xi): 60 | Am = torch.relu(self.conv_A(xi)) 61 | P = self.conv_P(xi) 62 | norm_P = torch.norm(P, dim=1, keepdim=True) + EPS 63 | P = P / norm_P 64 | x = torch.norm(inputs_en, dim=1, keepdim=True) * Am * P 65 | x = torch.transpose(x, 1, 3) 66 | x = self.istft(x) 67 | return x 68 | 69 | 70 | class DPRNN_ME(nn.Module): 71 | def __init__(self, N, B, H, R, K, rnn_type='LSTM', dropout=0, bidirectional=False, mask_nonlinear='relu'): 72 | super().__init__() 73 | self.N, self.B, self.H, self.R, self.K = N, B, H, R, K 74 | self.rnn_type = rnn_type 75 | self.dropout = dropout 76 | self.bidirectional = bidirectional 77 | self.conv_noisy = nn.Conv2d(2, B, [5, 5], [1, 2], bias=False) 78 | self.conv_refer = nn.Conv2d(2, B, [5, 5], [1, 2], bias=False) 79 | self.dprnn = DPRNN(N, B, H, R, K, rnn_type, dropout=dropout, bidirectional=bidirectional) 80 | self.output_audio = Output(B, N, mask_nonlinear) 81 | 82 | def forward(self, noisy, refer, inference=False): 83 | x = noisy 84 | y = refer 85 | x = F.pad(x, (0, 0, 4, 0)) 86 | y = F.pad(y, (0, 0, 4, 0)) 87 | x = self.conv_noisy(x) 88 | y = self.conv_refer(y) 89 | x = self.dprnn(x, y, inference) 90 | x = self.output_audio(x) 91 | return x 92 | 93 | 94 | class Output(nn.Module): 95 | def __init__(self, B, N, mask_nonlinear='relu'): 96 | super().__init__() 97 | self.B = B 98 | self.N = N 99 | mask_nonlinear = mask_nonlinear.lower() 100 | if mask_nonlinear == 'relu': 101 | self.mask = torch.relu 102 | elif mask_nonlinear == 'sigmoid': 103 | self.mask = torch.sigmoid 104 | elif mask_nonlinear == 'linear': 105 | self.mask = lambda x: x 106 | elif mask_nonlinear == 'lrelu': 107 | self.mask = nn.LeakyReLU() 108 | elif mask_nonlinear == 'elu': 109 | self.elu = nn.ELU() 110 | self.mask = lambda x: self.elu(x) + 1 111 | else: 112 | raise ValueError 113 | self.prelu = nn.PReLU() 114 | self.mask_conv = nn.Conv2d(B, N, 1) 115 | 116 | def forward(self, inputs): 117 | x = self.prelu(inputs) 118 | x = self.mask_conv(x) 119 | x = self.mask(x) 120 | return x 121 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .module import Encoder, Decoder, DPRNN_ME 5 | EPS = 1e-8 6 | 7 | 8 | class Time_DSDPRNN(nn.Module): 9 | def __init__(self, N, L, B, H, R, K, rnn_type='LSTM', stride=None, dropout=0, bidirectional=False, encoder_activation='relu', mask=True, mask_nonlinear='relu'): 10 | super().__init__() 11 | self.N, self.L, self.B, self.H, self.R, self.K = N, L, B, H, R, K 12 | self.rnn_type = rnn_type 13 | self.dropout = dropout 14 | self.bidirectional = bidirectional 15 | self.encoder_activation = encoder_activation 16 | self.mask = mask 17 | self.encoder_noisy = Encoder(N, L, stride=stride, activation=encoder_activation) 18 | self.encoder_refer = Encoder(N, L, stride=stride, activation=encoder_activation) 19 | self.mask_estimator = DPRNN_ME(N, B, H, R, K, rnn_type, dropout, bidirectional, mask_nonlinear) 20 | self.decoder_audio = Decoder(N, L, stride=stride) 21 | 22 | def forward(self, noisy, refer, inference=False): 23 | # [batch_size, length] 24 | length = noisy.size(1) 25 | noisy_en = self.encoder_noisy(noisy) 26 | refer_en = self.encoder_refer(refer) 27 | clean = self.mask_estimator(noisy_en, refer_en, inference) 28 | if self.mask: 29 | clean_en = clean * noisy_en 30 | else: 31 | clean_en = clean 32 | clean = self.decoder_audio(clean_en) 33 | clean = clean[:, :length] 34 | return clean 35 | -------------------------------------------------------------------------------- /model/network_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .module_stft import Encoder, Decoder, DPRNN_ME 5 | EPS = 1e-8 6 | 7 | 8 | class TF_DSDPRNN(nn.Module): 9 | def __init__(self, N, L, B, H, R, K, rnn_type='LSTM', stride=None, dropout=0, bidirectional=False, encoder_activation='relu', mask=True, mask_nonlinear='relu'): 10 | super().__init__() 11 | self.N, self.L, self.B, self.H, self.R, self.K = N, L, B, H, R, K 12 | self.rnn_type = rnn_type 13 | self.dropout = dropout 14 | self.bidirectional = bidirectional 15 | self.encoder_activation = encoder_activation 16 | self.mask = mask 17 | self.encoder_noisy = Encoder(N, L, stride=stride) 18 | self.encoder_refer = Encoder(N, L, stride=stride) 19 | self.mask_estimator = DPRNN_ME(N, B, H, R, K, rnn_type, dropout, bidirectional, mask_nonlinear) 20 | self.decoder_audio = Decoder(N, L, stride=stride) 21 | 22 | def forward(self, noisy, refer, inference=False): 23 | # [batch_size, length] 24 | length = noisy.size(1) 25 | noisy_en = self.encoder_noisy(noisy) 26 | refer_en = self.encoder_refer(refer) 27 | clean = self.mask_estimator(noisy_en, refer_en, inference) 28 | clean = self.decoder_audio(noisy_en, clean) 29 | clean = clean[:, :length] 30 | return clean 31 | --------------------------------------------------------------------------------