├── LICENSE ├── README.md ├── libs ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── conv_stft.cpython-36.pyc └── conv_stft.py ├── memonger ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── checkpoint.cpython-36.pyc │ ├── checkpoint.cpython-37.pyc │ ├── memonger.cpython-36.pyc │ └── memonger.cpython-37.pyc ├── checkpoint.py ├── memonger.py └── resnet.py └── nnet_dpccn.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jyhan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPCCN: Densely-Connected Pyramid Complex Convolutional Network 2 | This repository provides an implementation of the DPCCN model for single-channel speech separation. More details will be updated soon. 3 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 8 10:07:52 2021 4 | 5 | @author: Jyhan 6 | """ 7 | 8 | 9 | -------------------------------------------------------------------------------- /libs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/libs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /libs/__pycache__/conv_stft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/libs/__pycache__/conv_stft.cpython-36.pyc -------------------------------------------------------------------------------- /libs/conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | """ 10 | Return window coefficient 11 | """ 12 | def sqrthann(win_len): 13 | return get_window("hann", win_len, fftbins=True)**0.5 14 | 15 | if win_type == 'None' or win_type is None: 16 | window = np.ones(win_len) 17 | elif win_type == "sqrthann": 18 | window = sqrthann(win_len) 19 | else: 20 | window = get_window(win_type, win_len, fftbins=True)#**0.5 21 | 22 | N = fft_len 23 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 24 | real_kernel = np.real(fourier_basis) 25 | imag_kernel = np.imag(fourier_basis) 26 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 27 | 28 | if invers : 29 | kernel = np.linalg.pinv(kernel).T 30 | 31 | kernel = kernel*window 32 | kernel = kernel[:, None, :] 33 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 34 | 35 | 36 | class ConvSTFT(nn.Module): 37 | 38 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'): 39 | super(ConvSTFT, self).__init__() 40 | 41 | if fft_len == None: 42 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 43 | else: 44 | self.fft_len = fft_len 45 | 46 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 47 | self.register_buffer('weight', kernel) 48 | self.feature_type = feature_type 49 | self.stride = win_inc 50 | self.win_len = win_len 51 | self.dim = self.fft_len 52 | 53 | def forward(self, inputs): 54 | if inputs.dim() == 2: 55 | inputs = torch.unsqueeze(inputs, 1) 56 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 57 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 58 | 59 | if self.feature_type == 'complex': 60 | return outputs 61 | else: 62 | dim = self.dim//2+1 63 | real = outputs[:, :dim, :] 64 | imag = outputs[:, dim:, :] 65 | mags = torch.sqrt(real**2+imag**2) 66 | phase = torch.atan2(imag, real) 67 | return mags, phase 68 | 69 | class ConviSTFT(nn.Module): 70 | 71 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'): 72 | super(ConviSTFT, self).__init__() 73 | if fft_len == None: 74 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 75 | else: 76 | self.fft_len = fft_len 77 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 78 | self.register_buffer('weight', kernel) 79 | self.feature_type = feature_type 80 | self.win_type = win_type 81 | self.win_len = win_len 82 | self.stride = win_inc 83 | self.stride = win_inc 84 | self.dim = self.fft_len 85 | self.register_buffer('window', window) 86 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 87 | 88 | def forward(self, inputs, phase=None): 89 | """ 90 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 91 | phase: [B, N//2+1, T] (if not none) 92 | """ 93 | 94 | if phase is not None: 95 | real = inputs*torch.cos(phase) 96 | imag = inputs*torch.sin(phase) 97 | inputs = torch.cat([real, imag], 1) 98 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 99 | 100 | # this is from torch-stft: https://github.com/pseeth/torch-stft 101 | t = self.window.repeat(1,1,inputs.size(-1))**2 102 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 103 | outputs = outputs/(coff+1e-8) 104 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 105 | outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)] 106 | 107 | return outputs -------------------------------------------------------------------------------- /memonger/__init__.py: -------------------------------------------------------------------------------- 1 | from .memonger import SublinearSequential -------------------------------------------------------------------------------- /memonger/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /memonger/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /memonger/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /memonger/__pycache__/checkpoint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/checkpoint.cpython-37.pyc -------------------------------------------------------------------------------- /memonger/__pycache__/memonger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/memonger.cpython-36.pyc -------------------------------------------------------------------------------- /memonger/__pycache__/memonger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhan03/dpccn/99e0e4011c23ad2e9bbf681761e8b2dc5aaecfdc/memonger/__pycache__/memonger.cpython-37.pyc -------------------------------------------------------------------------------- /memonger/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import torch 3 | import warnings 4 | 5 | 6 | def detach_variable(inputs): 7 | if isinstance(inputs, tuple): 8 | out = [] 9 | for inp in inputs: 10 | x = inp.detach() 11 | x.requires_grad = inp.requires_grad 12 | out.append(x) 13 | return tuple(out) 14 | else: 15 | raise RuntimeError( 16 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 17 | 18 | 19 | def check_backward_validity(inputs): 20 | if not any(inp.requires_grad for inp in inputs): 21 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 22 | 23 | 24 | # Global switch to toggle whether or not checkpointed passes stash and restore 25 | # the RNG state. If True, any checkpoints making use of RNG should achieve deterministic 26 | # output compared to non-checkpointed passes. 27 | preserve_rng_state = True 28 | 29 | 30 | class CheckpointFunction(torch.autograd.Function): 31 | 32 | @staticmethod 33 | def forward(ctx, run_function, *args): 34 | check_backward_validity(args) 35 | ctx.run_function = run_function 36 | if preserve_rng_state: 37 | # We can't know if the user will transfer some args from the host 38 | # to the device during their run_fn. Therefore, we stash both 39 | # the cpu and cuda rng states unconditionally. 40 | # 41 | # TODO: 42 | # We also can't know if the run_fn will internally move some args to a device 43 | # other than the current device, which would require logic to preserve 44 | # rng states for those devices as well. We could paranoically stash and restore 45 | # ALL the rng states for all visible devices, but that seems very wasteful for 46 | # most cases. 47 | ctx.fwd_cpu_rng_state = torch.get_rng_state() 48 | # Don't eagerly initialize the cuda context by accident. 49 | # (If the user intends that the context is initialized later, within their 50 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 51 | # we have no way to anticipate this will happen before we run the function.) 52 | ctx.had_cuda_in_fwd = False 53 | if torch.cuda._initialized: 54 | ctx.had_cuda_in_fwd = True 55 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() 56 | ctx.save_for_backward(*args) 57 | with torch.no_grad(): 58 | outputs = run_function(*args) 59 | return outputs 60 | 61 | @staticmethod 62 | def backward(ctx, *args): 63 | if not torch.autograd._is_checkpoint_valid(): 64 | raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") 65 | inputs = ctx.saved_tensors 66 | # Stash the surrounding rng state, and mimic the state that was 67 | # present at this time during forward. Restore the surrouding state 68 | # when we're done. 69 | rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else [] 70 | with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): 71 | if preserve_rng_state: 72 | torch.set_rng_state(ctx.fwd_cpu_rng_state) 73 | if ctx.had_cuda_in_fwd: 74 | torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state) 75 | detached_inputs = detach_variable(inputs) 76 | with torch.enable_grad(): 77 | outputs = ctx.run_function(*detached_inputs) 78 | 79 | if isinstance(outputs, torch.Tensor): 80 | outputs = (outputs,) 81 | torch.autograd.backward(outputs, args) 82 | return (None,) + tuple(inp.grad for inp in detached_inputs) 83 | 84 | 85 | def checkpoint(function, *args): 86 | r"""Checkpoint a model or part of the model 87 | 88 | Checkpointing works by trading compute for memory. Rather than storing all 89 | intermediate activations of the entire computation graph for computing 90 | backward, the checkpointed part does **not** save intermediate activations, 91 | and instead recomputes them in backward pass. It can be applied on any part 92 | of a model. 93 | 94 | Specifically, in the forward pass, :attr:`function` will run in 95 | :func:`torch.no_grad` manner, i.e., not storing the intermediate 96 | activations. Instead, the forward pass saves the inputs tuple and the 97 | :attr:`function` parameter. In the backwards pass, the saved inputs and 98 | :attr:`function` is retreived, and the forward pass is computed on 99 | :attr:`function` again, now tracking the intermediate activations, and then 100 | the gradients are calculated using these activation values. 101 | 102 | .. warning:: 103 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 104 | with :func:`torch.autograd.backward`. 105 | 106 | .. warning:: 107 | If :attr:`function` invocation during backward does anything different 108 | than the one during forward, e.g., due to some global variable, the 109 | checkpointed version won't be equivalent, and unfortunately it can't be 110 | detected. 111 | 112 | .. warning: 113 | At least one of the inputs needs to have :code:`requires_grad=True` if 114 | grads are needed for model inputs, otherwise the checkpointed part of the 115 | model won't have gradients. 116 | 117 | Args: 118 | function: describes what to run in the forward pass of the model or 119 | part of the model. It should also know how to handle the inputs 120 | passed as the tuple. For example, in LSTM, if user passes 121 | ``(activation, hidden)``, :attr:`function` should correctly use the 122 | first input as ``activation`` and the second input as ``hidden`` 123 | args: tuple containing inputs to the :attr:`function` 124 | 125 | Returns: 126 | Output of running :attr:`function` on :attr:`*args` 127 | """ 128 | return CheckpointFunction.apply(function, *args) 129 | 130 | 131 | def checkpoint_sequential(functions, segments, *inputs): 132 | r"""A helper function for checkpointing sequential models. 133 | 134 | Sequential models execute a list of modules/functions in order 135 | (sequentially). Therefore, we can divide such a model in various segments 136 | and checkpoint each segment. All segments except the last will run in 137 | :func:`torch.no_grad` manner, i.e., not storing the intermediate 138 | activations. The inputs of each checkpointed segment will be saved for 139 | re-running the segment in the backward pass. 140 | 141 | See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. 142 | 143 | .. warning:: 144 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 145 | with :func:`torch.autograd.backward`. 146 | 147 | .. warning: 148 | At least one of the inputs needs to have :code:`requires_grad=True` if 149 | grads are needed for model inputs, otherwise the checkpointed part of the 150 | model won't have gradients. 151 | 152 | Args: 153 | functions: A :class:`torch.nn.Sequential` or the list of modules or 154 | functions (comprising the model) to run sequentially. 155 | segments: Number of chunks to create in the model 156 | inputs: tuple of Tensors that are inputs to :attr:`functions` 157 | 158 | Returns: 159 | Output of running :attr:`functions` sequentially on :attr:`*inputs` 160 | 161 | Example: 162 | >>> model = nn.Sequential(...) 163 | >>> input_var = checkpoint_sequential(model, chunks, input_var) 164 | """ 165 | 166 | def run_function(start, end, functions): 167 | def forward(*inputs): 168 | for j in range(start, end + 1): 169 | if isinstance(inputs, tuple): 170 | inputs = functions[j](*inputs) 171 | else: 172 | inputs = functions[j](inputs) 173 | return inputs 174 | return forward 175 | 176 | if isinstance(functions, torch.nn.Sequential): 177 | functions = list(functions.children()) 178 | 179 | segment_size = len(functions) // segments 180 | # the last chunk has to be non-volatile 181 | end = -1 182 | for start in range(0, segment_size * (segments - 1), segment_size): 183 | end = start + segment_size - 1 184 | inputs = checkpoint(run_function(start, end, functions), *inputs) 185 | if not isinstance(inputs, tuple): 186 | inputs = (inputs,) 187 | return run_function(end + 1, len(functions) - 1, functions)(*inputs) 188 | -------------------------------------------------------------------------------- /memonger/memonger.py: -------------------------------------------------------------------------------- 1 | from math import sqrt, log 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | from .checkpoint import checkpoint 9 | 10 | 11 | def reforwad_momentum_fix(origin_momentum): 12 | return (1 - sqrt(1 - origin_momentum)) 13 | 14 | 15 | class SublinearSequential(nn.Sequential): 16 | def __init__(self, *args): 17 | super(SublinearSequential, self).__init__(*args) 18 | self.reforward = False 19 | self.momentum_dict = {} 20 | self.set_reforward(True) 21 | 22 | def set_reforward(self, enabled=True): 23 | if not self.reforward and enabled: 24 | print("Rescale BN Momemtum for re-forwarding purpose") 25 | for n, m in self.named_modules(): 26 | if isinstance(m, _BatchNorm): 27 | self.momentum_dict[n] = m.momentum 28 | m.momentum = reforwad_momentum_fix(self.momentum_dict[n]) 29 | if self.reforward and not enabled: 30 | print("Re-store BN Momemtum") 31 | for n, m in self.named_modules(): 32 | if isinstance(m, _BatchNorm): 33 | m.momentum = self.momentum_dict[n] 34 | self.reforward = enabled 35 | 36 | def forward(self, input): 37 | if self.reforward: 38 | return self.sublinear_forward(input) 39 | else: 40 | return self.normal_forward(input) 41 | 42 | def normal_forward(self, input): 43 | for module in self._modules.values(): 44 | input = module(input) 45 | return input 46 | 47 | def sublinear_forward(self, input): 48 | def run_function(start, end, functions): 49 | def forward(*inputs): 50 | input = inputs[0] 51 | for j in range(start, end + 1): 52 | input = functions[j](input) 53 | return input 54 | 55 | return forward 56 | 57 | functions = list(self.children()) 58 | segments = int(sqrt(len(functions))) 59 | segment_size = len(functions) // segments 60 | # the last chunk has to be non-volatile 61 | end = -1 62 | if not isinstance(input, tuple): 63 | inputs = (input,) 64 | for start in range(0, segment_size * (segments - 1), segment_size): 65 | end = start + segment_size - 1 66 | inputs = checkpoint(run_function(start, end, functions), *inputs) 67 | if not isinstance(inputs, tuple): 68 | inputs = (inputs,) 69 | # output = run_function(end + 1, len(functions) - 1, functions)(*inputs) 70 | output = checkpoint(run_function(end + 1, len(functions) - 1, functions), *inputs) 71 | return output 72 | -------------------------------------------------------------------------------- /memonger/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .memonger import SublinearSequential 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion*planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion*planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=100): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 64 72 | 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(64) 75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 78 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 79 | self.linear = nn.Linear(512*block.expansion, num_classes) 80 | 81 | def _make_layer(self, block, planes, num_blocks, stride): 82 | strides = [stride] + [1]*(num_blocks-1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(self.in_planes, planes, stride)) 86 | self.in_planes = planes * block.expansion 87 | return SublinearSequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = self.layer4(out) 95 | out = F.avg_pool2d(out, 4) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def ResNet18(): 102 | return ResNet(BasicBlock, [2,2,2,2]) 103 | 104 | def ResNet34(): 105 | return ResNet(BasicBlock, [3,4,6,3]) 106 | 107 | def ResNet50(): 108 | return ResNet(Bottleneck, [3,4,6,3]) 109 | 110 | def ResNet101(): 111 | return ResNet(Bottleneck, [3,4,23,3]) 112 | 113 | def ResNet152(): 114 | return ResNet(Bottleneck, [3,8,36,3]) 115 | 116 | 117 | def test(): 118 | net = ResNet18() 119 | y = net(torch.randn(1,3,32,32)) 120 | print(y.size()) 121 | 122 | # test() 123 | -------------------------------------------------------------------------------- /nnet_dpccn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun May 23 14:42:21 2021 4 | 5 | @author: Jyhan 6 | """ 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | from typing import Tuple, List 12 | from memonger import SublinearSequential 13 | 14 | from libs.conv_stft import ConvSTFT, ConviSTFT 15 | 16 | def param(nnet, Mb=True): 17 | """ 18 | Return number parameters(not bytes) in nnet 19 | """ 20 | neles = sum([param.nelement() for param in nnet.parameters()]) 21 | 22 | return neles / 10**6 if Mb else neles 23 | 24 | 25 | class Conv2dBlock(nn.Module): 26 | def __init__(self, 27 | in_dims: int = 16, 28 | out_dims: int = 32, 29 | kernel_size: Tuple[int] = (3, 3), 30 | stride: Tuple[int] = (1, 1), 31 | padding: Tuple[int] = (1, 1)) -> None: 32 | super(Conv2dBlock, self).__init__() 33 | self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride, padding) 34 | self.elu = nn.ELU() 35 | self.norm = nn.InstanceNorm2d(out_dims) 36 | 37 | def forward(self, x: th.Tensor) -> th.Tensor: 38 | x = self.conv2d(x) 39 | x = self.elu(x) 40 | 41 | return self.norm(x) 42 | 43 | 44 | class ConvTrans2dBlock(nn.Module): 45 | def __init__(self, 46 | in_dims: int = 32, 47 | out_dims: int = 16, 48 | kernel_size: Tuple[int] = (3, 3), 49 | stride: Tuple[int] = (1, 2), 50 | padding: Tuple[int] = (1, 0), 51 | output_padding: Tuple[int] = (0, 0)) -> None: 52 | super(ConvTrans2dBlock, self).__init__() 53 | self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size, stride, padding, output_padding) 54 | self.elu = nn.ELU() 55 | self.norm = nn.InstanceNorm2d(out_dims) 56 | 57 | def forward(self, x: th.Tensor) -> th.Tensor: 58 | x = self.convtrans2d(x) 59 | x = self.elu(x) 60 | 61 | return self.norm(x) 62 | 63 | 64 | class DenseBlock(nn.Module): 65 | def __init__(self, in_dims, out_dims, mode = "enc", **kargs): 66 | super(DenseBlock, self).__init__() 67 | if mode not in ["enc", "dec"]: 68 | raise RuntimeError("The mode option must be 'enc' or 'dec'!") 69 | 70 | n = 1 if mode == "enc" else 2 71 | self.conv1 = Conv2dBlock(in_dims=in_dims*n, out_dims=in_dims, **kargs) 72 | self.conv2 = Conv2dBlock(in_dims=in_dims*(n+1), out_dims=in_dims, **kargs) 73 | self.conv3 = Conv2dBlock(in_dims=in_dims*(n+2), out_dims=in_dims, **kargs) 74 | self.conv4 = Conv2dBlock(in_dims=in_dims*(n+3), out_dims=in_dims, **kargs) 75 | self.conv5 = Conv2dBlock(in_dims=in_dims*(n+4), out_dims=out_dims, **kargs) 76 | 77 | def forward(self, x: th.Tensor) -> th.Tensor: 78 | y1 = self.conv1(x) 79 | y2 = self.conv2(th.cat([x, y1], 1)) 80 | y3 = self.conv3(th.cat([x, y1, y2], 1)) 81 | y4 = self.conv4(th.cat([x, y1, y2, y3], 1)) 82 | y5 = self.conv5(th.cat([x, y1, y2, y3, y4], 1)) 83 | 84 | return y5 85 | 86 | 87 | class TCNBlock(nn.Module): 88 | """ 89 | TCN block: 90 | IN - ELU - Conv1D - IN - ELU - Conv1D 91 | """ 92 | 93 | def __init__(self, 94 | in_dims: int = 384, 95 | out_dims: int = 384, 96 | kernel_size: int = 3, 97 | stride: int = 1, 98 | paddings: int = 1, 99 | dilation: int = 1, 100 | causal: bool = False) -> None: 101 | super(TCNBlock, self).__init__() 102 | self.norm1 = nn.InstanceNorm1d(in_dims) 103 | self.elu1 = nn.ELU() 104 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 105 | dilation * (kernel_size - 1)) 106 | # dilated conv 107 | self.dconv1 = nn.Conv1d( 108 | in_dims, 109 | out_dims, 110 | kernel_size, 111 | padding=dconv_pad, 112 | dilation=dilation, 113 | groups=in_dims, 114 | bias=True) 115 | 116 | self.norm2 = nn.InstanceNorm1d(in_dims) 117 | self.elu2 = nn.ELU() 118 | self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True) 119 | 120 | # different padding way 121 | self.causal = causal 122 | self.dconv_pad = dconv_pad 123 | 124 | def forward(self, x: th.Tensor) -> th.Tensor: 125 | y = self.elu1(self.norm1(x)) 126 | y = self.dconv1(y) 127 | if self.causal: 128 | y = y[:, :, :-self.dconv_pad] 129 | y = self.elu2(self.norm2(y)) 130 | y = self.dconv2(y) 131 | x = x + y 132 | 133 | return x 134 | 135 | 136 | class DenseUNet(nn.Module): 137 | def __init__(self, 138 | win_len: int = 512, # 32 ms 139 | win_inc: int = 128, # 8 ms 140 | fft_len: int = 512, 141 | win_type: str = "sqrthann", 142 | kernel_size: Tuple[int] = (3, 3), 143 | stride1: Tuple[int] = (1, 1), 144 | stride2: Tuple[int] = (1, 2), 145 | paddings: Tuple[int] = (1, 0), 146 | output_padding: Tuple[int] = (0, 0), 147 | tcn_dims: int = 384, 148 | tcn_blocks: int = 10, 149 | tcn_layers: int = 2, 150 | causal: bool = False, 151 | pool_size: Tuple[int] = (4, 8, 16, 32), 152 | num_spks: int = 2) -> None: 153 | super(DenseUNet, self).__init__() 154 | 155 | self.fft_len = fft_len 156 | self.num_spks = num_spks 157 | 158 | self.stft = ConvSTFT(win_len, win_inc, fft_len, win_type, 'complex') 159 | self.conv2d = nn.Conv2d(2, 16, kernel_size, stride1, paddings) 160 | self.encoder = self._build_encoder( 161 | kernel_size=kernel_size, 162 | stride=stride2, 163 | padding=paddings 164 | ) 165 | self.tcn_layers = self._build_tcn_layers( 166 | tcn_layers, 167 | tcn_blocks, 168 | in_dims=tcn_dims, 169 | out_dims=tcn_dims, 170 | causal=causal 171 | ) 172 | self.decoder = self._build_decoder( 173 | kernel_size=kernel_size, 174 | stride=stride2, 175 | padding=paddings, 176 | output_padding=output_padding 177 | ) 178 | self.avg_pool = self._build_avg_pool(pool_size) 179 | self.avg_proj = nn.Conv2d(64, 32, 1, 1) 180 | 181 | self.deconv2d = nn.ConvTranspose2d(32, 2*num_spks, kernel_size, stride1, paddings) 182 | self.istft = ConviSTFT(win_len, win_inc, fft_len, win_type, 'complex') 183 | 184 | def _build_encoder(self, **enc_kargs): 185 | """ 186 | Build encoder layers 187 | """ 188 | encoder = nn.ModuleList() 189 | encoder.append(DenseBlock(16, 16, "enc")) 190 | for i in range(4): 191 | encoder.append( 192 | SublinearSequential( 193 | Conv2dBlock(in_dims=16 if i==0 else 32, 194 | out_dims=32, **enc_kargs), 195 | DenseBlock(32, 32, "enc") 196 | ) 197 | ) 198 | encoder.append(Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs)) 199 | encoder.append(Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs)) 200 | encoder.append(Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs)) 201 | 202 | return encoder 203 | 204 | def _build_decoder(self, **dec_kargs): 205 | """ 206 | Build decoder layers 207 | """ 208 | decoder = nn.ModuleList() 209 | decoder.append(ConvTrans2dBlock(in_dims=384*2, out_dims=128, **dec_kargs)) 210 | decoder.append(ConvTrans2dBlock(in_dims=128*2, out_dims=64, **dec_kargs)) 211 | decoder.append(ConvTrans2dBlock(in_dims=64*2, out_dims=32, **dec_kargs)) 212 | for i in range(4): 213 | decoder.append( 214 | SublinearSequential( 215 | DenseBlock(32, 64, "dec"), 216 | ConvTrans2dBlock(in_dims=64, 217 | out_dims=32 if i!=3 else 16, 218 | **dec_kargs) 219 | ) 220 | ) 221 | decoder.append(DenseBlock(16, 32, "dec")) 222 | 223 | return decoder 224 | 225 | def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs): 226 | """ 227 | Build TCN blocks in each repeat (layer) 228 | """ 229 | blocks = [ 230 | TCNBlock(**tcn_kargs, dilation=(2**b)) 231 | for b in range(tcn_blocks) 232 | ] 233 | 234 | return SublinearSequential(*blocks) 235 | 236 | def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs): 237 | """ 238 | Build TCN layers 239 | """ 240 | layers = [ 241 | self._build_tcn_blocks(tcn_blocks, **tcn_kargs) 242 | for _ in range(tcn_layers) 243 | ] 244 | 245 | return SublinearSequential(*layers) 246 | 247 | def _build_avg_pool(self, pool_size): 248 | """ 249 | Build avg pooling layers 250 | """ 251 | avg_pool = nn.ModuleList() 252 | for sz in pool_size: 253 | avg_pool.append( 254 | SublinearSequential( 255 | nn.AvgPool2d(sz), 256 | nn.Conv2d(32, 8, 1, 1) 257 | ) 258 | ) 259 | 260 | return avg_pool 261 | 262 | def wav2spec(self, x: th.Tensor, mags: bool = False) -> th.Tensor: 263 | """ 264 | convert waveform to spectrogram 265 | """ 266 | assert x.dim() == 2 267 | x = x / th.std(x, -1, keepdims=True) # variance normalization 268 | specs = self.stft(x) 269 | real = specs[:,:self.fft_len//2+1] 270 | imag = specs[:,self.fft_len//2+1:] 271 | spec = th.stack([real,imag], 1) 272 | spec = th.einsum("hijk->hikj", spec) # batchsize, 2, T, F 273 | if mags: 274 | return th.sqrt(real**2+imag**2+1e-8) 275 | else: 276 | return spec 277 | 278 | def sep(self, spec: th.Tensor) -> List[th.Tensor]: 279 | """ 280 | spec: (batchsize, 2*num_spks, T, F) 281 | return [real, imag] or waveform for each speaker 282 | """ 283 | spec = th.einsum("hijk->hikj", spec) # (batchsize, 2*num_spks, F, T) 284 | spec = th.chunk(spec, self.num_spks, 1) 285 | B, N, F, T = spec[0].shape 286 | est1 = th.chunk(spec[0], 2, 1) # [(B, 1, F, T), (B, 1, F, T)] 287 | est2 = th.chunk(spec[1], 2, 1) 288 | est1 = th.cat(est1, 2).reshape(B, -1, T) # B, 1, 2F, T 289 | est2 = th.cat(est2, 2).reshape(B, -1, T) 290 | return [th.squeeze(self.istft(est1)), th.squeeze(self.istft(est2))] 291 | 292 | def forward(self, x: th.Tensor) -> th.Tensor: 293 | if x.dim() == 1: 294 | x = th.unsqueeze(x, 0) 295 | 296 | spec = self.wav2spec(x) 297 | out = self.conv2d(spec) 298 | out_list = [] 299 | for _, enc in enumerate(self.encoder): 300 | out = enc(out) 301 | out_list.append(out) 302 | B, N, T, F = out.shape 303 | out = self.tcn_layers(out.reshape(B, N, T*F)) 304 | out = th.unsqueeze(out, -1) 305 | 306 | out_list = out_list[::-1] 307 | for idx, dec in enumerate(self.decoder): 308 | out = dec(th.cat([out_list[idx], out], 1)) 309 | 310 | # Pyramidal pooling 311 | B, N, T, F = out.shape 312 | upsample = nn.Upsample(size=(T, F), mode='bilinear') 313 | pool_list = [] 314 | for avg in self.avg_pool: 315 | pool_list.append(upsample(avg(out))) 316 | 317 | out = th.cat([out, *pool_list], 1) 318 | out = self.avg_proj(out) 319 | out = self.deconv2d(out) 320 | 321 | return self.sep(out) 322 | 323 | 324 | def test_covn2d_block(): 325 | x = th.randn(2, 16, 257, 200) 326 | conv = Conv2dBlock() 327 | y = conv(x) 328 | print(y.shape) 329 | convtrans = ConvTrans2dBlock() 330 | z = convtrans(y) 331 | print(z.shape) 332 | 333 | def test_dense_block(): 334 | x = th.randn(2, 16, 257, 200) 335 | dense = DenseBlock(16, 32, "enc") 336 | y = dense(x) 337 | print(y.shape) 338 | 339 | def test_tcn_block(): 340 | x = th.randn(2, 384, 1000) 341 | tcn = TCNBlock(dilation=128) 342 | print(tcn(x).shape) 343 | 344 | 345 | if __name__ == "__main__": 346 | nnet = DenseUNet() 347 | print(param(nnet)) 348 | x = th.randn(2, 32000) 349 | est1, est2 = nnet(x) 350 | --------------------------------------------------------------------------------