├── .gitignore ├── Conv-Tasnet-Deep-w-PRelu.py ├── Conv-Tasnet-Deep-w-dilation.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /Conv-Tasnet-Deep-w-PRelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GlobalLayerNorm(nn.Module): 6 | ''' 7 | Calculate Global Layer Normalization 8 | dim: (int or list or torch.Size) – 9 | input shape from an expected input of size 10 | eps: a value added to the denominator for numerical stability. 11 | elementwise_affine: a boolean value that when set to True, 12 | this module has learnable per-element affine parameters 13 | initialized to ones (for weights) and zeros (for biases). 14 | ''' 15 | 16 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 17 | super(GlobalLayerNorm, self).__init__() 18 | self.dim = dim 19 | self.eps = eps 20 | self.elementwise_affine = elementwise_affine 21 | 22 | if self.elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(self.dim, 1)) 24 | self.bias = nn.Parameter(torch.zeros(self.dim, 1)) 25 | else: 26 | self.register_parameter('weight', None) 27 | self.register_parameter('bias', None) 28 | 29 | def forward(self, x): 30 | # x = N x C x L 31 | # N x 1 x 1 32 | # cln: mean,var N x 1 x L 33 | # gln: mean,var N x 1 x 1 34 | if x.dim() != 3: 35 | raise RuntimeError("{} accept 3D tensor as input".format( 36 | self.__name__)) 37 | 38 | mean = torch.mean(x, (1, 2), keepdim=True) 39 | var = torch.mean((x-mean)**2, (1, 2), keepdim=True) 40 | # N x C x L 41 | if self.elementwise_affine: 42 | x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias 43 | else: 44 | x = (x-mean)/torch.sqrt(var+self.eps) 45 | return x 46 | 47 | 48 | class CumulativeLayerNorm(nn.LayerNorm): 49 | ''' 50 | Calculate Cumulative Layer Normalization 51 | dim: you want to norm dim 52 | elementwise_affine: learnable per-element affine parameters 53 | ''' 54 | 55 | def __init__(self, dim, elementwise_affine=True): 56 | super(CumulativeLayerNorm, self).__init__( 57 | dim, elementwise_affine=elementwise_affine) 58 | 59 | def forward(self, x): 60 | # x: N x C x L 61 | # N x L x C 62 | x = torch.transpose(x, 1, 2) 63 | # N x L x C == only channel norm 64 | x = super().forward(x) 65 | # N x C x L 66 | x = torch.transpose(x, 1, 2) 67 | return x 68 | 69 | 70 | def select_norm(norm, dim): 71 | if norm not in ['gln', 'cln', 'bn']: 72 | if x.dim() != 3: 73 | raise RuntimeError("{} accept 3D tensor as input".format( 74 | self.__name__)) 75 | 76 | if norm == 'gln': 77 | return GlobalLayerNorm(dim, elementwise_affine=True) 78 | if norm == 'cln': 79 | return CumulativeLayerNorm(dim, elementwise_affine=True) 80 | else: 81 | return nn.BatchNorm1d(dim) 82 | 83 | 84 | class Encoder(nn.Module): 85 | 86 | def __init__(self, in_channels, out_channels, kernel_size, stride): 87 | super(Encoder, self).__init__() 88 | self.sequential = nn.Sequential( 89 | Conv1D(in_channels, out_channels, kernel_size, stride=stride), 90 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 91 | nn.PReLU(), 92 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 93 | nn.PReLU(), 94 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 95 | nn.PReLU() 96 | ) 97 | def forward(self, x): 98 | ''' 99 | x: [B, T] 100 | out: [B, N, T] 101 | ''' 102 | x = self.sequential(x) 103 | return x 104 | 105 | class Decoder(nn.Module): 106 | ''' 107 | Decoder 108 | This module can be seen as the gradient of Conv1d with respect to its input. 109 | It is also known as a fractionally-strided convolution 110 | or a deconvolution (although it is not an actual deconvolution operation). 111 | ''' 112 | 113 | def __init__(self, N, kernel_size=16, stride=16 // 2): 114 | super(Decoder, self).__init__() 115 | self.sequential = nn.Sequential( 116 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1), 117 | nn.PReLU(), 118 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1), 119 | nn.PReLU(), 120 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1), 121 | nn.PReLU(), 122 | nn.ConvTranspose1d(N, 1, kernel_size=kernel_size, stride=stride, bias=True) 123 | ) 124 | 125 | def forward(self, x): 126 | """ 127 | x: N x L or N x C x L 128 | """ 129 | x = self.sequential(x) 130 | if torch.squeeze(x).dim() == 1: 131 | x = torch.squeeze(x, dim=1) 132 | else: 133 | x = torch.squeeze(x) 134 | 135 | return x 136 | 137 | 138 | class Conv1D(nn.Conv1d): 139 | ''' 140 | Applies a 1D convolution over an input signal composed of several input planes. 141 | ''' 142 | 143 | def __init__(self, *args, **kwargs): 144 | super(Conv1D, self).__init__(*args, **kwargs) 145 | 146 | def forward(self, x, squeeze=False): 147 | # x: N x C x L 148 | if x.dim() not in [2, 3]: 149 | raise RuntimeError("{} accept 2/3D tensor as input".format( 150 | self.__name__)) 151 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 152 | if squeeze: 153 | x = torch.squeeze(x) 154 | return x 155 | 156 | 157 | class ConvTrans1D(nn.ConvTranspose1d): 158 | ''' 159 | This module can be seen as the gradient of Conv1d with respect to its input. 160 | It is also known as a fractionally-strided convolution 161 | or a deconvolution (although it is not an actual deconvolution operation). 162 | ''' 163 | 164 | def __init__(self, *args, **kwargs): 165 | super(ConvTrans1D, self).__init__(*args, **kwargs) 166 | 167 | def forward(self, x, squeeze=False): 168 | """ 169 | x: N x L or N x C x L 170 | """ 171 | if x.dim() not in [2, 3]: 172 | raise RuntimeError("{} accept 2/3D tensor as input".format( 173 | self.__name__)) 174 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 175 | if squeeze: 176 | x = torch.squeeze(x) 177 | return x 178 | 179 | 180 | class Conv1D_Block(nn.Module): 181 | ''' 182 | Consider only residual links 183 | ''' 184 | 185 | def __init__(self, in_channels=256, out_channels=512, 186 | kernel_size=3, dilation=1, norm='gln', causal=False, skip_con='True'): 187 | super(Conv1D_Block, self).__init__() 188 | # conv 1 x 1 189 | self.conv1x1 = Conv1D(in_channels, out_channels, 1) 190 | self.PReLU_1 = nn.PReLU() 191 | self.norm_1 = select_norm(norm, out_channels) 192 | # not causal don't need to padding, causal need to pad+1 = kernel_size 193 | self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 194 | dilation * (kernel_size - 1)) 195 | # depthwise convolution 196 | self.dwconv = Conv1D(out_channels, out_channels, kernel_size, 197 | groups=out_channels, padding=self.pad, dilation=dilation) 198 | self.PReLU_2 = nn.PReLU() 199 | self.norm_2 = select_norm(norm, out_channels) 200 | self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True) 201 | self.Output = nn.Conv1d(out_channels, in_channels, 1, bias=True) 202 | self.causal = causal 203 | self.skip_con = skip_con 204 | 205 | def forward(self, x): 206 | # x: N x C x L 207 | # N x O_C x L 208 | c = self.conv1x1(x) 209 | # N x O_C x L 210 | c = self.PReLU_1(c) 211 | c = self.norm_1(c) 212 | # causal: N x O_C x (L+pad) 213 | # noncausal: N x O_C x L 214 | c = self.dwconv(c) 215 | c = self.PReLU_2(c) 216 | c = self.norm_2(c) 217 | # N x O_C x L 218 | if self.causal: 219 | c = c[:, :, :-self.pad] 220 | if self.skip_con: 221 | Sc = self.Sc_conv(c) 222 | c = self.Output(c) 223 | return Sc, c+x 224 | c = self.Output(c) 225 | return x+c 226 | 227 | 228 | class Separation(nn.Module): 229 | ''' 230 | R Number of repeats 231 | X Number of convolutional blocks in each repeat 232 | B Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks 233 | H Number of channels in convolutional blocks 234 | P Kernel size in convolutional blocks 235 | norm The type of normalization(gln, cl, bn) 236 | causal Two choice(causal or noncausal) 237 | skip_con Whether to use skip connection 238 | ''' 239 | 240 | def __init__(self, R, X, B, H, P, norm='gln', causal=False, skip_con=True): 241 | super(Separation, self).__init__() 242 | self.separation = nn.ModuleList([]) 243 | for r in range(R): 244 | for x in range(X): 245 | self.separation.append(Conv1D_Block( 246 | B, H, P, 2**x, norm, causal, skip_con)) 247 | self.skip_con = skip_con 248 | 249 | def forward(self, x): 250 | ''' 251 | x: [B, N, L] 252 | out: [B, N, L] 253 | ''' 254 | if self.skip_con: 255 | skip_connection = 0 256 | for i in range(len(self.separation)): 257 | skip, out = self.separation[i](x) 258 | skip_connection = skip_connection + skip 259 | x = out 260 | return skip_connection 261 | else: 262 | for i in range(len(self.separation)): 263 | out = self.separation[i](x) 264 | x = out 265 | return x 266 | 267 | 268 | class ConvTasNet(nn.Module): 269 | ''' 270 | ConvTasNet module 271 | N Number of filters in autoencoder 272 | L Length of the filters (in samples) 273 | B Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks 274 | Sc Number of channels in skip-connection paths’ 1 × 1-conv blocks 275 | H Number of channels in convolutional blocks 276 | P Kernel size in convolutional blocks 277 | X Number of convolutional blocks in each repeat 278 | R Number of repeats 279 | ''' 280 | 281 | def __init__(self, 282 | N=512, 283 | L=16, 284 | B=128, 285 | H=512, 286 | P=3, 287 | X=8, 288 | R=3, 289 | norm="gln", 290 | num_spks=2, 291 | activate="relu", 292 | causal=False, 293 | skip_con=False): 294 | super(ConvTasNet, self).__init__() 295 | # n x 1 x T => n x N x T 296 | self.encoder = Encoder(1, N, L, stride=L // 2) 297 | # n x N x T Layer Normalization of Separation 298 | self.LayerN_S = select_norm('cln', N) 299 | # n x B x T Conv 1 x 1 of Separation 300 | self.BottleN_S = Conv1D(N, B, 1) 301 | # Separation block 302 | # n x B x T => n x B x T 303 | self.separation = Separation(R, X, B, H, P ,norm=norm, causal=causal, skip_con=skip_con) 304 | # n x B x T => n x 2*N x T 305 | self.gen_masks = Conv1D(B, num_spks*N, 1) 306 | # n x N x T => n x 1 x L 307 | self.decoder = Decoder(N, L, stride=L//2) 308 | # activation function 309 | active_f = { 310 | 'relu': nn.ReLU(), 311 | 'sigmoid': nn.Sigmoid(), 312 | 'softmax': nn.Softmax(dim=0) 313 | } 314 | self.activation_type = activate 315 | self.activation = active_f[activate] 316 | self.num_spks = num_spks 317 | 318 | def forward(self, x): 319 | if x.dim() >= 3: 320 | raise RuntimeError( 321 | "{} accept 1/2D tensor as input, but got {:d}".format( 322 | self.__name__, x.dim())) 323 | if x.dim() == 1: 324 | x = torch.unsqueeze(x, 0) 325 | # x: n x 1 x L => n x N x T 326 | w = self.encoder(x) 327 | # n x N x L => n x B x L 328 | e = self.LayerN_S(w) 329 | e = self.BottleN_S(e) 330 | # n x B x L => n x B x L 331 | e = self.separation(e) 332 | # n x B x L => n x num_spk*N x L 333 | m = self.gen_masks(e) 334 | # n x N x L x num_spks 335 | m = torch.chunk(m, chunks=self.num_spks, dim=1) 336 | # num_spks x n x N x L 337 | m = self.activation(torch.stack(m, dim=0)) 338 | d = [w*m[i] for i in range(self.num_spks)] 339 | # decoder part num_spks x n x L 340 | s = [self.decoder(d[i]) for i in range(self.num_spks)] 341 | return s 342 | 343 | 344 | def check_parameters(net): 345 | ''' 346 | Returns module parameters. Mb 347 | ''' 348 | parameters = sum(param.numel() for param in net.parameters()) 349 | return parameters / 10**6 350 | 351 | 352 | def test_convtasnet(): 353 | x = torch.randn(4, 32000) 354 | nnet = ConvTasNet() 355 | s = nnet(x) 356 | print(str(check_parameters(nnet))+' Mb') 357 | print(nnet) 358 | 359 | 360 | if __name__ == "__main__": 361 | test_convtasnet() -------------------------------------------------------------------------------- /Conv-Tasnet-Deep-w-dilation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GlobalLayerNorm(nn.Module): 6 | ''' 7 | Calculate Global Layer Normalization 8 | dim: (int or list or torch.Size) – 9 | input shape from an expected input of size 10 | eps: a value added to the denominator for numerical stability. 11 | elementwise_affine: a boolean value that when set to True, 12 | this module has learnable per-element affine parameters 13 | initialized to ones (for weights) and zeros (for biases). 14 | ''' 15 | 16 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 17 | super(GlobalLayerNorm, self).__init__() 18 | self.dim = dim 19 | self.eps = eps 20 | self.elementwise_affine = elementwise_affine 21 | 22 | if self.elementwise_affine: 23 | self.weight = nn.Parameter(torch.ones(self.dim, 1)) 24 | self.bias = nn.Parameter(torch.zeros(self.dim, 1)) 25 | else: 26 | self.register_parameter('weight', None) 27 | self.register_parameter('bias', None) 28 | 29 | def forward(self, x): 30 | # x = N x C x L 31 | # N x 1 x 1 32 | # cln: mean,var N x 1 x L 33 | # gln: mean,var N x 1 x 1 34 | if x.dim() != 3: 35 | raise RuntimeError("{} accept 3D tensor as input".format( 36 | self.__name__)) 37 | 38 | mean = torch.mean(x, (1, 2), keepdim=True) 39 | var = torch.mean((x-mean)**2, (1, 2), keepdim=True) 40 | # N x C x L 41 | if self.elementwise_affine: 42 | x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias 43 | else: 44 | x = (x-mean)/torch.sqrt(var+self.eps) 45 | return x 46 | 47 | 48 | class CumulativeLayerNorm(nn.LayerNorm): 49 | ''' 50 | Calculate Cumulative Layer Normalization 51 | dim: you want to norm dim 52 | elementwise_affine: learnable per-element affine parameters 53 | ''' 54 | 55 | def __init__(self, dim, elementwise_affine=True): 56 | super(CumulativeLayerNorm, self).__init__( 57 | dim, elementwise_affine=elementwise_affine) 58 | 59 | def forward(self, x): 60 | # x: N x C x L 61 | # N x L x C 62 | x = torch.transpose(x, 1, 2) 63 | # N x L x C == only channel norm 64 | x = super().forward(x) 65 | # N x C x L 66 | x = torch.transpose(x, 1, 2) 67 | return x 68 | 69 | 70 | def select_norm(norm, dim): 71 | if norm not in ['gln', 'cln', 'bn']: 72 | if x.dim() != 3: 73 | raise RuntimeError("{} accept 3D tensor as input".format( 74 | self.__name__)) 75 | 76 | if norm == 'gln': 77 | return GlobalLayerNorm(dim, elementwise_affine=True) 78 | if norm == 'cln': 79 | return CumulativeLayerNorm(dim, elementwise_affine=True) 80 | else: 81 | return nn.BatchNorm1d(dim) 82 | 83 | 84 | class Encoder(nn.Module): 85 | 86 | def __init__(self, in_channels, out_channels, kernel_size, stride): 87 | super(Encoder, self).__init__() 88 | self.sequential = nn.Sequential( 89 | Conv1D(in_channels, out_channels, kernel_size, stride=stride), 90 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding=1), 91 | nn.PReLU(), 92 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, dilation=2, padding=2), 93 | nn.PReLU(), 94 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, dilation=4, padding=4), 95 | nn.PReLU(), 96 | Conv1D(out_channels, out_channels, kernel_size=3, stride=1, dilation=8, padding=8), 97 | nn.PReLU() 98 | ) 99 | def forward(self, x): 100 | ''' 101 | x: [B, T] 102 | out: [B, N, T] 103 | ''' 104 | x = self.sequential(x) 105 | return x 106 | 107 | class Decoder(nn.Module): 108 | ''' 109 | Decoder 110 | This module can be seen as the gradient of Conv1d with respect to its input. 111 | It is also known as a fractionally-strided convolution 112 | or a deconvolution (although it is not an actual deconvolution operation). 113 | ''' 114 | 115 | def __init__(self, N, kernel_size=16, stride=16 // 2): 116 | super(Decoder, self).__init__() 117 | self.sequential = nn.Sequential( 118 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=8, padding=8), 119 | nn.PReLU(), 120 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=4, padding=4), 121 | nn.PReLU(), 122 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=2, padding=2), 123 | nn.PReLU(), 124 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=1, padding=1), 125 | nn.PReLU(), 126 | nn.ConvTranspose1d(N, 1, kernel_size=kernel_size, stride=stride, bias=True) 127 | ) 128 | 129 | def forward(self, x): 130 | """ 131 | x: N x L or N x C x L 132 | """ 133 | x = self.sequential(x) 134 | if torch.squeeze(x).dim() == 1: 135 | x = torch.squeeze(x, dim=1) 136 | else: 137 | x = torch.squeeze(x) 138 | 139 | return x 140 | 141 | 142 | class Conv1D(nn.Conv1d): 143 | ''' 144 | Applies a 1D convolution over an input signal composed of several input planes. 145 | ''' 146 | 147 | def __init__(self, *args, **kwargs): 148 | super(Conv1D, self).__init__(*args, **kwargs) 149 | 150 | def forward(self, x, squeeze=False): 151 | # x: N x C x L 152 | if x.dim() not in [2, 3]: 153 | raise RuntimeError("{} accept 2/3D tensor as input".format( 154 | self.__name__)) 155 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 156 | if squeeze: 157 | x = torch.squeeze(x) 158 | return x 159 | 160 | 161 | class ConvTrans1D(nn.ConvTranspose1d): 162 | ''' 163 | This module can be seen as the gradient of Conv1d with respect to its input. 164 | It is also known as a fractionally-strided convolution 165 | or a deconvolution (although it is not an actual deconvolution operation). 166 | ''' 167 | 168 | def __init__(self, *args, **kwargs): 169 | super(ConvTrans1D, self).__init__(*args, **kwargs) 170 | 171 | def forward(self, x, squeeze=False): 172 | """ 173 | x: N x L or N x C x L 174 | """ 175 | if x.dim() not in [2, 3]: 176 | raise RuntimeError("{} accept 2/3D tensor as input".format( 177 | self.__name__)) 178 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 179 | if squeeze: 180 | x = torch.squeeze(x) 181 | return x 182 | 183 | 184 | class Conv1D_Block(nn.Module): 185 | ''' 186 | Consider only residual links 187 | ''' 188 | 189 | def __init__(self, in_channels=256, out_channels=512, 190 | kernel_size=3, dilation=1, norm='gln', causal=False, skip_con='True'): 191 | super(Conv1D_Block, self).__init__() 192 | # conv 1 x 1 193 | self.conv1x1 = Conv1D(in_channels, out_channels, 1) 194 | self.PReLU_1 = nn.PReLU() 195 | self.norm_1 = select_norm(norm, out_channels) 196 | # not causal don't need to padding, causal need to pad+1 = kernel_size 197 | self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 198 | dilation * (kernel_size - 1)) 199 | # depthwise convolution 200 | self.dwconv = Conv1D(out_channels, out_channels, kernel_size, 201 | groups=out_channels, padding=self.pad, dilation=dilation) 202 | self.PReLU_2 = nn.PReLU() 203 | self.norm_2 = select_norm(norm, out_channels) 204 | self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True) 205 | self.Output = nn.Conv1d(out_channels, in_channels, 1, bias=True) 206 | self.causal = causal 207 | self.skip_con = skip_con 208 | 209 | def forward(self, x): 210 | # x: N x C x L 211 | # N x O_C x L 212 | c = self.conv1x1(x) 213 | # N x O_C x L 214 | c = self.PReLU_1(c) 215 | c = self.norm_1(c) 216 | # causal: N x O_C x (L+pad) 217 | # noncausal: N x O_C x L 218 | c = self.dwconv(c) 219 | c = self.PReLU_2(c) 220 | c = self.norm_2(c) 221 | # N x O_C x L 222 | if self.causal: 223 | c = c[:, :, :-self.pad] 224 | if self.skip_con: 225 | Sc = self.Sc_conv(c) 226 | c = self.Output(c) 227 | return Sc, c+x 228 | c = self.Output(c) 229 | return x+c 230 | 231 | 232 | class Separation(nn.Module): 233 | ''' 234 | R Number of repeats 235 | X Number of convolutional blocks in each repeat 236 | B Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks 237 | H Number of channels in convolutional blocks 238 | P Kernel size in convolutional blocks 239 | norm The type of normalization(gln, cl, bn) 240 | causal Two choice(causal or noncausal) 241 | skip_con Whether to use skip connection 242 | ''' 243 | 244 | def __init__(self, R, X, B, H, P, norm='gln', causal=False, skip_con=True): 245 | super(Separation, self).__init__() 246 | self.separation = nn.ModuleList([]) 247 | for r in range(R): 248 | for x in range(X): 249 | self.separation.append(Conv1D_Block( 250 | B, H, P, 2**x, norm, causal, skip_con)) 251 | self.skip_con = skip_con 252 | 253 | def forward(self, x): 254 | ''' 255 | x: [B, N, L] 256 | out: [B, N, L] 257 | ''' 258 | if self.skip_con: 259 | skip_connection = 0 260 | for i in range(len(self.separation)): 261 | skip, out = self.separation[i](x) 262 | skip_connection = skip_connection + skip 263 | x = out 264 | return skip_connection 265 | else: 266 | for i in range(len(self.separation)): 267 | out = self.separation[i](x) 268 | x = out 269 | return x 270 | 271 | 272 | class ConvTasNet(nn.Module): 273 | ''' 274 | ConvTasNet module 275 | N Number of filters in autoencoder 276 | L Length of the filters (in samples) 277 | B Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks 278 | Sc Number of channels in skip-connection paths’ 1 × 1-conv blocks 279 | H Number of channels in convolutional blocks 280 | P Kernel size in convolutional blocks 281 | X Number of convolutional blocks in each repeat 282 | R Number of repeats 283 | ''' 284 | 285 | def __init__(self, 286 | N=512, 287 | L=16, 288 | B=128, 289 | H=512, 290 | P=3, 291 | X=8, 292 | R=3, 293 | norm="gln", 294 | num_spks=2, 295 | activate="relu", 296 | causal=False, 297 | skip_con=False): 298 | super(ConvTasNet, self).__init__() 299 | # n x 1 x T => n x N x T 300 | self.encoder = Encoder(1, N, L, stride=L // 2) 301 | # n x N x T Layer Normalization of Separation 302 | self.LayerN_S = select_norm('cln', N) 303 | # n x B x T Conv 1 x 1 of Separation 304 | self.BottleN_S = Conv1D(N, B, 1) 305 | # Separation block 306 | # n x B x T => n x B x T 307 | self.separation = Separation(R, X, B, H, P ,norm=norm, causal=causal, skip_con=skip_con) 308 | # n x B x T => n x 2*N x T 309 | self.gen_masks = Conv1D(B, num_spks*N, 1) 310 | # n x N x T => n x 1 x L 311 | self.decoder = Decoder(N, L, stride=L//2) 312 | # activation function 313 | active_f = { 314 | 'relu': nn.ReLU(), 315 | 'sigmoid': nn.Sigmoid(), 316 | 'softmax': nn.Softmax(dim=0) 317 | } 318 | self.activation_type = activate 319 | self.activation = active_f[activate] 320 | self.num_spks = num_spks 321 | 322 | def forward(self, x): 323 | if x.dim() >= 3: 324 | raise RuntimeError( 325 | "{} accept 1/2D tensor as input, but got {:d}".format( 326 | self.__name__, x.dim())) 327 | if x.dim() == 1: 328 | x = torch.unsqueeze(x, 0) 329 | # x: n x 1 x L => n x N x T 330 | w = self.encoder(x) 331 | # n x N x L => n x B x L 332 | e = self.LayerN_S(w) 333 | e = self.BottleN_S(e) 334 | # n x B x L => n x B x L 335 | e = self.separation(e) 336 | # n x B x L => n x num_spk*N x L 337 | m = self.gen_masks(e) 338 | # n x N x L x num_spks 339 | m = torch.chunk(m, chunks=self.num_spks, dim=1) 340 | # num_spks x n x N x L 341 | m = self.activation(torch.stack(m, dim=0)) 342 | d = [w*m[i] for i in range(self.num_spks)] 343 | # decoder part num_spks x n x L 344 | s = [self.decoder(d[i]) for i in range(self.num_spks)] 345 | return s 346 | 347 | 348 | def check_parameters(net): 349 | ''' 350 | Returns module parameters. Mb 351 | ''' 352 | parameters = sum(param.numel() for param in net.parameters()) 353 | return parameters / 10**6 354 | 355 | 356 | def test_convtasnet(): 357 | x = torch.randn(4, 32000) 358 | nnet = ConvTasNet() 359 | s = nnet(x) 360 | print(str(check_parameters(nnet))+' Mb') 361 | print(nnet) 362 | 363 | 364 | if __name__ == "__main__": 365 | test_convtasnet() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch implementation of "An Empirical study of Conv-TasNet" 2 | 3 | ## Conv-Tasnet(Deep w / dilation) 4 | ```python 5 | Encoder: 6 | Conv1D(1, N, kernel_size, stride=stride) 7 | Conv1D(N, N, kernel_size=3, stride=1, dilation=1,padding=1) 8 | Conv1D(N, N, kernel_size=3, stride=1, dilation=2, padding=2) 9 | Conv1D(N, N, kernel_size=3, stride=1, dilation=4, padding=4) 10 | Conv1D(N, N, kernel_size=3, stride=1, dilation=8, padding=8) 11 | 12 | Decoder: 13 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=8, padding=8) 14 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=4, padding=4) 15 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=2, padding=2) 16 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, dilation=1, padding=1) 17 | nn.ConvTranspose1d(N, 1, kernel_size=kernel_size, stride=stride) 18 | ``` 19 | 20 | ## Conv-Tasnet(Deep w / PRelu) 21 | ```python 22 | Encoder: 23 | Conv1D(1, N, kernel_size, stride=stride) 24 | Conv1D(N, N, kernel_size=3, stride=1,padding=1) 25 | Conv1D(N, N, kernel_size=3, stride=1, padding=1) 26 | Conv1D(N, N, kernel_size=3, stride=1, padding=1) 27 | Conv1D(N, N, kernel_size=3, stride=1, padding=1) 28 | 29 | Decoder: 30 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1) 31 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1) 32 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1) 33 | nn.ConvTranspose1d(N, N, kernel_size=3, stride=1, padding=1) 34 | nn.ConvTranspose1d(N, 1, kernel_size=kernel_size, stride=stride) 35 | ``` 36 | 37 | ## How to use 38 | 39 | You can replace the Conv_TasNet.py file in the [Conv-TasNet](https://github.com/JusperLee/Conv-TasNet) repository for training. 40 | 41 | ## Result 42 | Conv-Tasnet(Deep w / dilation): SI-SNRi:16.346, SDRi:16.61 43 | 44 | ## Reference 45 | [1]. Kadioglu B, Horgan M, Liu X, et al. An empirical study of Conv-TasNet[J]. arXiv preprint arXiv:2002.08688, 2020. 46 | --------------------------------------------------------------------------------